use clap::Parser;
use comfy_table::{Attribute, Cell, Color, ContentArrangement, Table, presets::UTF8_FULL_CONDENSED};
use dbcli::SqlResult;
use dbcli::column_info::ColumnBaseInfo;
use rustyline::DefaultEditor;
use rustyline::error::ReadlineError;
use std::io::IsTerminal;
#[derive(Parser)]
#[command(
name = "dbcli",
version,
about = "Execute SQL and output results as JSON or formatted table",
long_about = "Execute one or more SQL statements against MySQL, PostgreSQL, or SQLite.\n\
Output is auto-detected: table when stdout is a TTY, JSON when piped.\n\
Use --json or --table to override.",
after_help = "EXAMPLES:\n \
dbcli -u postgres://user:pass@localhost/mydb -s \"SELECT * FROM users\"\n \
dbcli --connect (test database connection only)\n \
dbcli (enter interactive REPL mode)\n \
echo \"SELECT 1\" | dbcli -u sqlite://test.db --json"
)]
struct Cli {
#[arg(short, long, env = "DATABASE_URL")]
url: String,
#[arg(short, long)]
sql: Option<String>,
#[arg(long, visible_alias = "conn", visible_alias = "con")]
connect: bool,
#[arg(long)]
json: bool,
#[arg(long)]
table: bool,
#[arg(short, long, default_value_t = 1000)]
limit: usize,
#[arg(long, visible_alias = "mcw", default_value_t = 0)]
max_col_width: usize,
}
enum OutputMode {
Table,
Json,
}
fn detect_output_mode(cli: &Cli) -> OutputMode {
if cli.json {
OutputMode::Json
} else if cli.table {
OutputMode::Table
} else if std::io::stdout().is_terminal() {
OutputMode::Table
} else {
OutputMode::Json
}
}
fn is_chinese_locale() -> bool {
for var in &["LC_ALL", "LANG", "LANGUAGE"] {
if let Ok(val) = std::env::var(var) {
let lower = val.to_lowercase();
if lower.starts_with("zh") {
return true;
}
}
}
false
}
struct Messages {
chinese: bool,
}
impl Messages {
fn new() -> Self {
Self { chinese: is_chinese_locale() }
}
fn connection_successful(&self) -> &str {
if self.chinese { "连接成功。" } else { "Connection successful." }
}
fn backend_label(&self) -> &str {
if self.chinese { "数据库后端" } else { "Backend" }
}
fn repl_welcome(&self) -> &str {
if self.chinese {
"dbcli 交互模式。输入 SQL 语句以 `;` 结尾执行。"
} else {
"dbcli interactive mode. Type SQL followed by `;` to execute."
}
}
fn repl_quit_hint(&self) -> &str {
if self.chinese {
"输入 exit、quit 或 \\q 退出。"
} else {
"Commands: exit, quit, \\q to quit."
}
}
fn supported_backends(&self) -> &str {
if self.chinese { "支持的后端" } else { "Supported backends" }
}
fn no_backend_warning(&self) -> &str {
if self.chinese {
"警告:未启用任何数据库后端。请在编译时启用 mysql、postgres、sqlite 或 odbc feature。"
} else {
"Warning: No database backend enabled. Enable mysql, postgres, sqlite, or odbc feature."
}
}
fn affected_rows(&self, n: u64) -> String {
if self.chinese {
format!("影响 {} 行", n)
} else {
format!("Affected {} rows", n)
}
}
fn rows_count(&self, n: usize) -> String {
if self.chinese {
format!("({} 行)", n)
} else {
format!("({} rows)", n)
}
}
fn error_prefix(&self) -> &str {
if self.chinese { "错误" } else { "Error" }
}
fn connecting(&self) -> &str {
if self.chinese { "正在连接数据库..." } else { "Connecting to database..." }
}
fn connection_failed(&self) -> &str {
if self.chinese { "连接失败" } else { "Connection failed" }
}
fn rows_truncated(&self, shown: usize, total: usize) -> String {
if self.chinese {
format!("... 已省略 {} 行(共 {} 行,使用 --limit 调整或 --json 获取全部)", total - shown, total)
} else {
format!("... {} more rows omitted ({} total, use --limit to adjust or --json for all)", total - shown, total)
}
}
}
fn normalize_args() -> Vec<String> {
let mut args: Vec<String> = std::env::args().collect();
let mut i = 1; while i < args.len() {
if args[i].starts_with('-') {
if let Some(eq_pos) = args[i].find('=') {
let (flag_part, value_part) = args[i].split_at(eq_pos);
args[i] = format!("{}{}", flag_part.to_lowercase(), value_part);
} else {
args[i] = args[i].to_lowercase();
}
if args[i] == "-v" {
args[i] = "-V".to_string();
}
let lower = args[i].as_str();
if matches!(lower, "-u" | "--url" | "-s" | "--sql" | "-l" | "--limit" | "--max-col-width" | "--mcw") {
i += 1; }
}
i += 1;
}
args
}
fn max_cell_width(max_col_width: usize) -> usize {
if max_col_width > 0 {
return max_col_width;
}
let term_width = crossterm::terminal::size()
.map(|(w, _)| w as usize)
.unwrap_or(120);
(term_width / 2).max(30)
}
fn format_cell(value: Option<&serde_json::Value>) -> String {
match value {
None => "NULL".to_string(),
Some(serde_json::Value::Null) => "NULL".to_string(),
Some(serde_json::Value::String(s)) => s.clone(),
Some(serde_json::Value::Number(n)) => n.to_string(),
Some(serde_json::Value::Bool(b)) => b.to_string(),
Some(v) => v.to_string(),
}
}
fn truncate_cell(s: &str, max_width: usize) -> String {
if max_width == 0 || s.chars().count() <= max_width {
s.to_string()
} else {
let truncated: String = s.chars().take(max_width.saturating_sub(3)).collect();
format!("{}...", truncated)
}
}
fn print_table(data: &[serde_json::Value], columns: &[ColumnBaseInfo], msg: &Messages, limit: usize, max_col_width: usize) {
if columns.is_empty() {
println!("{}", msg.rows_count(0));
return;
}
let display_data = if limit > 0 && data.len() > limit {
&data[..limit]
} else {
data
};
let term_width = crossterm::terminal::size()
.map(|(w, _)| w as usize)
.unwrap_or(120);
let cell_max = max_cell_width(max_col_width);
let mut table = Table::new();
table
.load_preset(UTF8_FULL_CONDENSED)
.set_content_arrangement(ContentArrangement::Dynamic)
.set_width(term_width as u16);
let mut header_cells = vec![
Cell::new("#")
.fg(Color::DarkGrey)
.add_attribute(Attribute::Bold)
];
for col in columns {
header_cells.push(
Cell::new(&col.name)
.fg(Color::Cyan)
.add_attribute(Attribute::Bold)
);
}
table.set_header(header_cells);
for (idx, row) in display_data.iter().enumerate() {
if let Some(obj) = row.as_object() {
let mut cells = vec![
Cell::new(idx + 1).fg(Color::DarkGrey)
];
for col in columns.iter() {
let val = format_cell(obj.get(&col.name));
let is_null = val == "NULL";
let val = truncate_cell(&val, cell_max);
let cell = if is_null {
Cell::new(val).fg(Color::DarkGrey)
} else {
Cell::new(val)
};
cells.push(cell);
}
table.add_row(cells);
}
}
println!("{table}");
if limit > 0 && data.len() > limit {
eprintln!("{}", msg.rows_truncated(limit, data.len()));
}
println!("{}", msg.rows_count(data.len()));
}
fn print_json(results: &[SqlResult]) {
println!("{}", serde_json::to_string_pretty(results).unwrap_or_else(|e| {
format!("{{\"error\": \"{}\"}}", e)
}));
}
fn print_results(results: &[SqlResult], mode: &OutputMode, msg: &Messages, limit: usize, max_col_width: usize) {
match mode {
OutputMode::Json => {
print_json(results);
}
OutputMode::Table => {
for (i, result) in results.iter().enumerate() {
if i > 0 {
println!();
}
match result {
SqlResult::Query { data, columns } => {
print_table(data, columns, msg, limit, max_col_width);
}
SqlResult::Execute { rows_affected } => {
println!("{}", msg.affected_rows(*rows_affected));
}
}
}
}
}
}
#[cfg(feature = "odbc")]
fn detect_odbc_connection_string(url: &str) -> Option<String> {
let url_lower = url.to_lowercase();
if url_lower.contains("driver=") {
return Some(url.to_owned());
}
if url_lower.starts_with("odbc://") {
let remainder = &url["odbc://".len()..];
if remainder.to_lowercase().contains("driver=") {
return Some(remainder.to_owned());
}
let conn_str = odbc_connection_string_from_path(remainder);
return Some(conn_str);
}
None
}
#[cfg(feature = "odbc")]
fn odbc_connection_string_from_path(path: &str) -> String {
let path_lower = path.to_lowercase();
if path_lower.ends_with(".mdb") || path_lower.ends_with(".accdb") {
format!(
"Driver={{Microsoft Access Driver (*.mdb, *.accdb)}};DBQ={};",
path
)
} else if path_lower.ends_with(".xls")
|| path_lower.ends_with(".xlsx")
|| path_lower.ends_with(".xlsb")
|| path_lower.ends_with(".xlsm")
{
format!(
"Driver={{Microsoft Excel Driver (*.xls, *.xlsx, *.xlsm, *.xlsb)}};DBQ={};ReadOnly=0;",
path
)
} else {
format!(
"Driver={{Microsoft Access Driver (*.mdb, *.accdb)}};DBQ={};",
path
)
}
}
#[cfg(feature = "odbc")]
async fn run_repl_odbc(
connection_string: &str,
mode: &OutputMode,
msg: &Messages,
limit: usize,
max_col_width: usize,
) -> anyhow::Result<()> {
let mut rl = DefaultEditor::new()?;
let history_file = history_file_path();
let _ = rl.load_history(&history_file);
eprintln!("{}", msg.repl_welcome());
eprintln!("{}", msg.repl_quit_hint());
let mut backends = Vec::new();
backends.push("ODBC");
eprintln!("{}: {}", msg.supported_backends(), backends.join(", "));
eprintln!();
let mut sql_buffer = String::new();
loop {
let prompt = if sql_buffer.is_empty() {
"dbcli> "
} else {
" -> "
};
match rl.readline(prompt) {
Ok(line) => {
let trimmed = line.trim();
if sql_buffer.is_empty() {
match trimmed.to_lowercase().as_str() {
"exit" | "quit" | "\\q" | "exit;" | "quit;" => break,
"" => continue,
_ => {}
}
}
sql_buffer.push_str(trimmed);
sql_buffer.push('\n');
if trimmed.ends_with(';') {
let sql = sql_buffer.trim().to_string();
sql_buffer.clear();
let _ = rl.add_history_entry(&sql);
match dbcli::execute::odbc::execute_raw_sql(connection_string, &sql).await {
Ok(results) => {
print_results(&results, mode, msg, limit, max_col_width);
eprintln!();
}
Err(e) => {
eprintln!("{}: {}", msg.error_prefix(), e);
eprintln!();
}
}
}
}
Err(ReadlineError::Interrupted) => {
if !sql_buffer.is_empty() {
sql_buffer.clear();
eprintln!();
} else {
break;
}
}
Err(ReadlineError::Eof) => {
eprintln!();
break;
}
Err(err) => {
eprintln!("{}: {}", msg.error_prefix(), err);
break;
}
}
}
let _ = rl.save_history(&history_file);
Ok(())
}
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
enum DbPool {
#[cfg(feature = "postgres")]
Postgres(sqlx::postgres::PgPool),
#[cfg(feature = "mysql")]
Mysql(sqlx::mysql::MySqlPool),
#[cfg(feature = "sqlite")]
Sqlite(sqlx::sqlite::SqlitePool),
}
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
async fn connect(url: &str) -> anyhow::Result<DbPool> {
#[cfg(feature = "postgres")]
if url.starts_with("postgres://") || url.starts_with("postgresql://") {
let pool = sqlx::postgres::PgPoolOptions::new()
.max_connections(1)
.connect(url)
.await?;
return Ok(DbPool::Postgres(pool));
}
#[cfg(feature = "mysql")]
if url.starts_with("mysql://") {
let pool = sqlx::mysql::MySqlPoolOptions::new()
.max_connections(1)
.connect(url)
.await?;
return Ok(DbPool::Mysql(pool));
}
#[cfg(feature = "sqlite")]
{
let sqlite_url = if url.starts_with("sqlite://") {
url.to_string()
} else if !url.contains("://") {
format!("sqlite://{}", url)
} else {
anyhow::bail!("Unsupported database URL scheme: {}", url);
};
let pool = sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect(&sqlite_url)
.await?;
return Ok(DbPool::Sqlite(pool));
}
#[allow(unreachable_code)]
{
anyhow::bail!(
"Unsupported or not compiled database URL: {}. \
Enable the corresponding feature (mysql, postgres, sqlite) when building.",
url
);
}
}
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
async fn execute_sql(pool: &DbPool, sql: &str) -> anyhow::Result<Vec<SqlResult>> {
match pool {
#[cfg(feature = "postgres")]
DbPool::Postgres(p) => dbcli::execute::postgres::execute_raw_sql(p, sql).await,
#[cfg(feature = "mysql")]
DbPool::Mysql(p) => dbcli::execute::mysql::execute_raw_sql(p, sql).await,
#[cfg(feature = "sqlite")]
DbPool::Sqlite(p) => dbcli::execute::sqlite::execute_raw_sql(p, sql).await,
}
}
fn history_file_path() -> std::path::PathBuf {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE")) .unwrap_or_else(|_| ".".to_string());
std::path::PathBuf::from(home).join(".db-json_history")
}
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
async fn run_repl(pool: &DbPool, mode: &OutputMode, msg: &Messages, limit: usize, max_col_width: usize) -> anyhow::Result<()> {
let mut rl = DefaultEditor::new()?;
let history_file = history_file_path();
let _ = rl.load_history(&history_file);
eprintln!("{}", msg.repl_welcome());
eprintln!("{}", msg.repl_quit_hint());
let mut backends = Vec::new();
#[cfg(feature = "postgres")]
backends.push("PostgreSQL");
#[cfg(feature = "mysql")]
backends.push("MySQL");
#[cfg(feature = "sqlite")]
backends.push("SQLite");
if backends.is_empty() {
eprintln!("{}", msg.no_backend_warning());
} else {
eprintln!("{}: {}", msg.supported_backends(), backends.join(", "));
}
eprintln!();
let mut sql_buffer = String::new();
loop {
let prompt = if sql_buffer.is_empty() {
"dbcli> "
} else {
" -> "
};
match rl.readline(prompt) {
Ok(line) => {
let trimmed = line.trim();
if sql_buffer.is_empty() {
match trimmed.to_lowercase().as_str() {
"exit" | "quit" | "\\q" | "exit;" | "quit;" => break,
"" => continue,
_ => {}
}
}
sql_buffer.push_str(trimmed);
sql_buffer.push('\n');
if trimmed.ends_with(';') {
let sql = sql_buffer.trim().to_string();
sql_buffer.clear();
let _ = rl.add_history_entry(&sql);
match execute_sql(pool, &sql).await {
Ok(results) => {
print_results(&results, mode, msg, limit, max_col_width);
eprintln!();
}
Err(e) => {
eprintln!("{}: {}", msg.error_prefix(), e);
eprintln!();
}
}
}
}
Err(ReadlineError::Interrupted) => {
if !sql_buffer.is_empty() {
sql_buffer.clear();
eprintln!();
} else {
break;
}
}
Err(ReadlineError::Eof) => {
eprintln!();
break;
}
Err(err) => {
eprintln!("{}: {}", msg.error_prefix(), err);
break;
}
}
}
let _ = rl.save_history(&history_file);
Ok(())
}
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
async fn run_sqlx(cli: &Cli, mode: &OutputMode, msg: &Messages) -> anyhow::Result<()> {
eprintln!("{}", msg.connecting());
let pool = match connect(&cli.url).await {
Ok(pool) => {
eprintln!("{}", msg.connection_successful());
let backend = match &pool {
#[cfg(feature = "postgres")]
DbPool::Postgres(_) => "PostgreSQL",
#[cfg(feature = "mysql")]
DbPool::Mysql(_) => "MySQL",
#[cfg(feature = "sqlite")]
DbPool::Sqlite(_) => "SQLite",
};
eprintln!("{}: {}", msg.backend_label(), backend);
pool
}
Err(e) => {
eprintln!("{}: {}", msg.connection_failed(), e);
std::process::exit(1);
}
};
if cli.connect {
return Ok(());
}
if let Some(sql) = &cli.sql {
let results = execute_sql(&pool, sql).await?;
print_results(&results, mode, msg, cli.limit, cli.max_col_width);
} else {
run_repl(&pool, mode, msg, cli.limit, cli.max_col_width).await?;
}
Ok(())
}
async fn run(msg: &Messages) -> anyhow::Result<()> {
let args = normalize_args();
let cli = Cli::parse_from(args);
let mode = detect_output_mode(&cli);
#[cfg(feature = "odbc")]
if let Some(odbc_conn_str) = detect_odbc_connection_string(&cli.url) {
eprintln!("{}", msg.connecting());
eprintln!("{}", msg.connection_successful());
eprintln!("{}: ODBC", msg.backend_label());
if cli.connect {
let conn_str = odbc_conn_str.clone();
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let env = odbc_api::Environment::new().map_err(|e| anyhow::anyhow!("{}", e))?;
let _conn = env
.connect_with_connection_string(&conn_str, odbc_api::ConnectionOptions::default())
.map_err(|e| anyhow::anyhow!("{}", e))?;
Ok(())
})
.await??;
return Ok(());
}
if let Some(sql) = &cli.sql {
let results = dbcli::execute::odbc::execute_raw_sql(&odbc_conn_str, sql).await?;
print_results(&results, &mode, msg, cli.limit, cli.max_col_width);
} else {
run_repl_odbc(&odbc_conn_str, &mode, msg, cli.limit, cli.max_col_width).await?;
}
return Ok(());
}
#[cfg(any(feature = "postgres", feature = "mysql", feature = "sqlite"))]
{
return run_sqlx(&cli, &mode, msg).await;
}
#[allow(unreachable_code)]
{
anyhow::bail!(
"Unsupported or not compiled database URL: {}. \
Enable the corresponding feature (mysql, postgres, sqlite, odbc) when building.",
cli.url
);
}
}
#[tokio::main]
async fn main() {
let msg = Messages::new();
if let Err(e) = run(&msg).await {
eprintln!("{}: {}", msg.error_prefix(), e);
std::process::exit(1);
}
}