#![feature(div_duration)]
use std::fs::File;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use clap::Parser;
use humantime::format_duration;
use itertools::Itertools;
use minitrace::prelude::*;
use risinglight::array::{datachunk_to_sqllogictest_string, Chunk};
use risinglight::storage::SecondaryStorageOptions;
use risinglight::utils::time::RoundingDuration;
use risinglight::Database;
use rustyline::error::ReadlineError;
use rustyline::Editor;
use tokio::{select, signal};
use tracing::{info, warn, Level};
use tracing_subscriber::prelude::*;
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long)]
file: Option<String>,
#[clap(long)]
memory: bool,
#[clap(long)]
output_format: Option<String>,
#[clap(long)]
enable_tracing: bool,
#[clap(long)]
tokio_console: bool,
}
fn print_chunk(chunk: &Chunk, output_format: &Option<String>) {
let output_format = output_format.as_ref().map(|x| x.as_str());
match output_format {
Some("human") | None => match chunk.header() {
Some(header) => match header[0].as_str() {
"$insert.row_counts" => {
println!(
"{} rows inserted",
chunk.get_first_data_chunk().array_at(0).get_to_string(0)
)
}
"$delete.row_counts" => {
println!(
"{} rows deleted",
chunk.get_first_data_chunk().array_at(0).get_to_string(0)
)
}
"$create" => println!("created"),
"$drop" => println!("dropped"),
"$explain" => println!(
"{}",
chunk.get_first_data_chunk().array_at(0).get_to_string(0)
),
_ => println!("{}", chunk),
},
None => println!("{}", chunk),
},
Some("text") => println!(
"{}",
datachunk_to_sqllogictest_string(chunk)
.iter()
.format_with("\n", |row, f| f(&row.iter().format(","))),
),
Some(format) => panic!("unsupported output format: {}", format),
}
}
fn print_execution_time(start_time: Instant) {
let duration = start_time.elapsed();
let duration_in_seconds = duration.div_duration_f64(Duration::new(1, 0));
if duration_in_seconds > 1.0 {
println!(
"in {:.3}s ({})",
duration_in_seconds,
format_duration(duration.round_to_seconds())
);
} else {
println!("in {:.3}s", duration_in_seconds);
}
}
async fn run_query_in_background(
db: Arc<Database>,
sql: String,
output_format: Option<String>,
enable_tracing: bool,
) {
let start_time = Instant::now();
let task = async move {
if enable_tracing {
let (root, collector) = Span::root("root");
let result = db.run(&sql).in_span(root).await;
let records: Vec<SpanRecord> = collector.collect().await;
println!("{records:#?}");
result
} else {
db.run(&sql).await
}
};
select! {
_ = signal::ctrl_c() => {
println!("Interrupted");
}
ret = task => {
match ret {
Ok(chunks) => {
for chunk in chunks {
print_chunk(&chunk, &output_format);
}
print_execution_time(start_time);
}
Err(err) => println!("{}", err),
}
}
}
}
fn read_sql(rl: &mut Editor<()>) -> Result<String, ReadlineError> {
let mut sql = String::new();
loop {
let prompt = if sql.is_empty() { "> " } else { "? " };
let line = rl.readline(prompt)?;
if line.is_empty() {
continue;
}
if line.starts_with('\\') && sql.is_empty() {
return Ok(line);
}
sql.push_str(line.as_str());
if line.ends_with(';') {
return Ok(sql);
} else {
sql.push('\n');
}
}
}
async fn interactive(
db: Database,
output_format: Option<String>,
enable_tracing: bool,
) -> Result<()> {
let mut rl = Editor::<()>::new()?;
let history_path = dirs::cache_dir().map(|p| {
let cache_dir = p.join("risinglight");
std::fs::create_dir_all(cache_dir.as_path()).ok();
let history_path = cache_dir.join("history.txt");
if !history_path.as_path().exists() {
File::create(history_path.as_path()).ok();
}
history_path.into_boxed_path()
});
if let Some(ref history_path) = history_path {
if let Err(err) = rl.load_history(&history_path) {
println!("No previous history. {err}");
}
}
let db = Arc::new(db);
loop {
let read_sql = read_sql(&mut rl);
match read_sql {
Ok(sql) => {
if !sql.trim().is_empty() {
rl.add_history_entry(sql.as_str());
run_query_in_background(db.clone(), sql, output_format.clone(), enable_tracing)
.await;
}
}
Err(ReadlineError::Interrupted) => {
println!("Interrupted");
}
Err(ReadlineError::Eof) => {
println!("Exited");
break;
}
Err(err) => {
println!("Error: {:?}", err);
break;
}
}
}
if let Some(ref history_path) = history_path {
if let Err(err) = rl.save_history(&history_path) {
println!("Save history failed, {err}");
}
}
Ok(())
}
async fn run_sql(
db: Database,
path: &str,
output_format: Option<String>,
enable_tracing: bool,
) -> Result<()> {
let lines = std::fs::read_to_string(path)?;
info!("{}", lines);
let chunks = if enable_tracing {
let (root, collector) = Span::root("root");
let chunk = db.run(&lines).in_span(root).await?;
let records: Vec<SpanRecord> = collector.collect().await;
println!("{records:#?}");
chunk
} else {
db.run(&lines).await?
};
for chunk in chunks {
print_chunk(&chunk, &output_format);
}
Ok(())
}
struct DatabaseWrapper {
db: Database,
output_format: Option<String>,
enable_tracing: bool,
}
#[async_trait]
impl sqllogictest::AsyncDB for DatabaseWrapper {
type Error = risinglight::Error;
async fn run(&mut self, sql: &str) -> Result<sqllogictest::DBOutput, Self::Error> {
use sqllogictest::{ColumnType, DBOutput};
let is_query_sql = {
let lower_sql = sql.trim_start().to_ascii_lowercase();
lower_sql.starts_with("select")
|| lower_sql.starts_with("values")
|| lower_sql.starts_with("show")
|| lower_sql.starts_with("with")
|| lower_sql.starts_with("describe")
};
info!("{}", sql);
let chunks = if self.enable_tracing {
let (root, collector) = Span::root("root");
let chunk = self.db.run(sql).in_span(root).await?;
let records: Vec<SpanRecord> = collector.collect().await;
println!("{records:#?}");
chunk
} else {
self.db.run(sql).await?
};
for chunk in &chunks {
print_chunk(chunk, &self.output_format);
}
if chunks.is_empty() || chunks.iter().all(|c| c.data_chunks().is_empty()) {
if is_query_sql {
return Ok(DBOutput::Rows {
types: vec![],
rows: vec![],
});
} else {
return Ok(DBOutput::StatementComplete(0));
}
}
let types = vec![ColumnType::Any; chunks[0].get_first_data_chunk().column_count()];
let rows = chunks
.iter()
.flat_map(datachunk_to_sqllogictest_string)
.collect();
Ok(DBOutput::Rows { types, rows })
}
}
async fn run_sqllogictest(
db: Database,
path: &str,
output_format: Option<String>,
enable_tracing: bool,
) -> Result<()> {
let mut tester = sqllogictest::Runner::new(DatabaseWrapper {
db,
output_format,
enable_tracing,
});
let path = path.to_string();
tester
.run_file_async(path)
.await
.map_err(|err| anyhow!("{:?}", err))?;
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
if args.tokio_console {
console_subscriber::init();
} else {
let fmt_layer = tracing_subscriber::fmt::layer().compact();
let filter_layer = tracing_subscriber::EnvFilter::from_default_env()
.add_directive(Level::INFO.into())
.add_directive("egg=warn".parse()?);
tracing_subscriber::registry()
.with(filter_layer)
.with(fmt_layer)
.init();
}
info!("using query engine v2. type '\\v1' to use the legacy engine");
let db = if args.memory {
info!("using memory engine");
Database::new_in_memory()
} else {
info!("using Secondary engine");
Database::new_on_disk(SecondaryStorageOptions::default_for_cli()).await
};
if let Some(file) = args.file {
if file.ends_with(".sql") {
run_sql(db, &file, args.output_format, args.enable_tracing).await?;
} else if file.ends_with(".slt") {
run_sqllogictest(db, &file, args.output_format, args.enable_tracing).await?;
} else {
warn!("No suffix detected, assume sql file");
run_sql(db, &file, args.output_format, args.enable_tracing).await?;
}
} else {
interactive(db, args.output_format, args.enable_tracing).await?;
}
Ok(())
}