mod highlight;
use std::collections::HashMap;
use std::io::{self, IsTerminal, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use arrow::util::pretty::print_batches;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::physical_plan::{collect, ExecutionPlan};
use datafusion::prelude::{SessionConfig, SessionContext};
use highlight::SqlHelper;
use rustyline::error::ReadlineError;
use rustyline::Editor;
use tracing_subscriber::EnvFilter;
use zarr_datafusion::datasource::factory::ZarrTableFactory;
use zarr_datafusion::optimizer::{CountStatisticsRule, MinMaxStatisticsRule};
use zarr_datafusion::physical_plan::zarr_exec::ZarrExec;
use zarr_datafusion::reader::stats::{format_bytes, SharedIoStats};
const HISTORY_FILE: &str = ".zarr_cli_history";
fn get_history_path() -> PathBuf {
std::env::var("HOME")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("."))
.join(HISTORY_FILE)
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.with_target(true)
.with_line_number(true)
.init();
let config = SessionConfig::new().with_information_schema(true);
let state = SessionStateBuilder::new()
.with_default_features()
.with_config(config)
.with_table_factories(HashMap::from([(
"ZARR".to_string(),
Arc::new(ZarrTableFactory) as _,
)]))
.with_optimizer_rule(Arc::new(CountStatisticsRule::new()))
.with_optimizer_rule(Arc::new(MinMaxStatisticsRule::new()))
.build();
let ctx = SessionContext::new_with_state(state);
println!("Zarr-DataFusion CLI");
println!("\nType SQL queries or 'help' for commands.\n");
let helper = SqlHelper::new();
let mut rl = Editor::new()?;
rl.set_helper(Some(helper));
let history_path = get_history_path();
let _ = rl.load_history(&history_path);
loop {
match rl.readline("zarr> ") {
Ok(line) => {
let line = line.trim();
if line.is_empty() {
continue;
}
let _ = rl.add_history_entry(line);
if line.eq_ignore_ascii_case("quit") || line.eq_ignore_ascii_case("exit") {
break;
}
if line.eq_ignore_ascii_case("help") {
print_help();
continue;
}
if line.starts_with("\\d") || line.eq_ignore_ascii_case("show tables") {
match ctx.sql("SHOW TABLES").await {
Ok(df) => {
if let Err(e) = df.show().await {
eprintln!("Error: {e}");
}
}
Err(e) => eprintln!("Error: {e}"),
}
continue;
}
let start = Instant::now();
match ctx.sql(line).await {
Ok(df) => {
let line_upper = line.to_uppercase();
let is_ddl = line_upper.starts_with("CREATE ")
|| line_upper.starts_with("DROP ")
|| line_upper.starts_with("ALTER ");
if is_ddl {
if let Err(e) = df.collect().await {
eprintln!("Error: {e}");
} else {
let elapsed = start.elapsed();
println!("OK ({:.3}s)", elapsed.as_secs_f64());
}
} else {
match df.create_physical_plan().await {
Ok(plan) => {
let io_stats = find_zarr_exec_stats(&plan);
let stop_flag = Arc::new(AtomicBool::new(false));
let is_tty = io::stdout().is_terminal();
let live_task = if is_tty {
io_stats.as_ref().map(|stats| {
spawn_live_stats(stats.clone(), stop_flag.clone())
})
} else {
None
};
let task_ctx = ctx.task_ctx();
let result = collect(plan, task_ctx).await;
stop_flag.store(true, Ordering::Relaxed);
if let Some(task) = live_task {
let _ = task.await;
}
match result {
Ok(batches) => {
let elapsed = start.elapsed();
let row_count: usize =
batches.iter().map(|b| b.num_rows()).sum();
if is_tty && io_stats.is_some() {
print!("\r\x1b[K");
let _ = io::stdout().flush();
}
if let Err(e) = print_batches(&batches) {
eprintln!("Error displaying results: {e}");
} else {
print_stats_line(
row_count,
elapsed.as_secs_f64(),
io_stats.as_ref(),
);
}
}
Err(e) => eprintln!("Error executing query: {e}"),
}
}
Err(e) => eprintln!("Error creating plan: {e}"),
}
}
}
Err(e) => eprintln!("SQL Error: {e}"),
}
}
Err(ReadlineError::Interrupted) => {
println!("^C");
continue;
}
Err(ReadlineError::Eof) => {
break;
}
Err(err) => {
eprintln!("Error: {err}");
break;
}
}
}
if let Err(e) = rl.save_history(&history_path) {
eprintln!("Warning: Could not save history: {e}");
}
println!("Goodbye!");
Ok(())
}
fn print_help() {
println!(
r#"
Zarr-DataFusion CLI Commands:
<SQL> Execute a SQL query
show tables List registered tables
\d List registered tables
help Show this help
quit/exit Exit the CLI
Loading data:
CREATE EXTERNAL TABLE <name> STORED AS ZARR LOCATION '<path>';
DROP TABLE <name>;
Example:
CREATE EXTERNAL TABLE weather STORED AS ZARR LOCATION 'data/synthetic.zarr';
SELECT * FROM weather LIMIT 10;
SELECT AVG(temperature) FROM weather GROUP BY lat, lon;
DROP TABLE weather;
"#
);
}
fn find_zarr_exec_stats(plan: &Arc<dyn ExecutionPlan>) -> Option<SharedIoStats> {
if let Some(zarr_exec) = plan.as_any().downcast_ref::<ZarrExec>() {
return Some(zarr_exec.io_stats());
}
for child in plan.children() {
if let Some(stats) = find_zarr_exec_stats(child) {
return Some(stats);
}
}
None
}
fn print_stats_line(row_count: usize, elapsed_secs: f64, io_stats: Option<&SharedIoStats>) {
let mut parts = vec![format!(
"{} row{}",
row_count,
if row_count == 1 { "" } else { "s" }
)];
if let Some(stats) = io_stats {
let total_arrays =
stats.coord_arrays.load(Ordering::Relaxed) + stats.data_arrays.load(Ordering::Relaxed);
let disk_bytes = stats.total_disk_bytes();
let mem_bytes = stats.total_bytes();
parts.push(format!(
"{} array{}",
total_arrays,
if total_arrays == 1 { "" } else { "s" }
));
parts.push(format!("{} disk", format_bytes(disk_bytes)));
parts.push(format!("{} mem", format_bytes(mem_bytes)));
}
parts.push(format!("{:.3}s", elapsed_secs));
println!("\n{}", parts.join(" · "));
}
fn spawn_live_stats(stats: SharedIoStats, stop: Arc<AtomicBool>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
while !stop.load(Ordering::Relaxed) {
let arrays = stats.coord_arrays.load(Ordering::Relaxed)
+ stats.data_arrays.load(Ordering::Relaxed);
let disk_bytes = stats.total_disk_bytes();
print!(
"\r{} array{} · {} disk...\x1b[K",
arrays,
if arrays == 1 { "" } else { "s" },
format_bytes(disk_bytes)
);
let _ = io::stdout().flush();
tokio::time::sleep(Duration::from_millis(50)).await;
}
})
}