Skip to main content

citadel_cli/
lib.rs

1mod commands;
2mod formatter;
3mod helper;
4mod repl;
5
6use std::io::IsTerminal;
7use std::path::PathBuf;
8
9use clap::Parser;
10
11use crate::formatter::OutputMode;
12
13#[derive(Parser)]
14#[command(
15    name = "citadel",
16    about = "Interactive SQL shell for Citadel encrypted database"
17)]
18#[command(version)]
19struct Cli {
20    /// Path to database file
21    database: Option<PathBuf>,
22
23    /// SQL to execute (non-interactive mode)
24    sql: Option<String>,
25
26    /// Create a new database
27    #[arg(long)]
28    create: bool,
29
30    /// Passphrase (prompted if omitted)
31    #[arg(long)]
32    passphrase: Option<String>,
33
34    /// Output mode: box, table, csv, json, line
35    #[arg(long, default_value = "box")]
36    mode: String,
37
38    /// Show column headers
39    #[arg(long, default_value = "on")]
40    header: String,
41
42    /// NULL display string
43    #[arg(long, default_value = "NULL")]
44    null_value: String,
45
46    /// Disable colors
47    #[arg(long)]
48    no_color: bool,
49
50    /// Read/execute commands from FILE on startup
51    #[arg(long)]
52    init: Option<PathBuf>,
53
54    /// Execute TEXT before interactive input
55    #[arg(long)]
56    cmd: Option<String>,
57}
58
59/// Run the CLI with the given argv (argv[0] is the program name); returns exit code.
60pub fn run(args: Vec<String>) -> i32 {
61    let cli = match Cli::try_parse_from(args) {
62        Ok(cli) => cli,
63        Err(e) => {
64            let _ = e.print();
65            // clap sends help/version to stdout (exit 0), errors to stderr (exit 2).
66            return if e.use_stderr() { 2 } else { 0 };
67        }
68    };
69
70    let db_path = match &cli.database {
71        Some(p) => p.clone(),
72        None => {
73            eprintln!("Error: database path is required");
74            eprintln!("Usage: citadel [OPTIONS] <DATABASE> [SQL]");
75            return 1;
76        }
77    };
78
79    let passphrase = match &cli.passphrase {
80        Some(p) => p.clone(),
81        None => {
82            if !std::io::stdin().is_terminal() {
83                eprintln!("Error: passphrase required (use --passphrase in non-interactive mode)");
84                return 1;
85            }
86            match rpassword::prompt_password("Enter passphrase: ") {
87                Ok(p) => p,
88                Err(e) => {
89                    eprintln!("Error reading passphrase: {e}");
90                    return 1;
91                }
92            }
93        }
94    };
95
96    let db = if cli.create {
97        match citadel::DatabaseBuilder::new(&db_path)
98            .passphrase(passphrase.as_bytes())
99            .create()
100        {
101            Ok(db) => db,
102            Err(e) => {
103                eprintln!("Error creating database: {e}");
104                return 1;
105            }
106        }
107    } else {
108        match citadel::DatabaseBuilder::new(&db_path)
109            .passphrase(passphrase.as_bytes())
110            .open()
111        {
112            Ok(db) => db,
113            Err(e) => {
114                eprintln!("Error opening database: {e}");
115                return 1;
116            }
117        }
118    };
119
120    let output_mode = match cli.mode.as_str() {
121        "box" => OutputMode::Box,
122        "table" => OutputMode::Table,
123        "csv" => OutputMode::Csv,
124        "json" => OutputMode::Json,
125        "line" => OutputMode::Line,
126        other => {
127            eprintln!("Error: unknown output mode '{other}'. Use: box, table, csv, json, line");
128            return 1;
129        }
130    };
131
132    let is_interactive = cli.sql.is_none() && std::io::stdin().is_terminal();
133    let use_color = is_interactive && !cli.no_color;
134
135    let mut settings = repl::Settings {
136        mode: output_mode,
137        show_headers: cli.header != "off",
138        null_display: cli.null_value.clone(),
139        timer: false,
140        show_changes: false,
141        use_color,
142        column_widths: Vec::new(),
143        output_file: None,
144    };
145
146    if let Some(ref sql) = cli.sql {
147        return run_batch(&db, sql, &mut settings);
148    }
149
150    if !is_interactive {
151        return run_piped(&db, &mut settings);
152    }
153
154    repl::run_interactive(db, db_path, passphrase, settings, cli.init, cli.cmd);
155    0
156}
157
158fn run_batch(db: &citadel::Database, sql: &str, settings: &mut repl::Settings) -> i32 {
159    use std::time::Instant;
160
161    let conn = match citadel_sql::Connection::open(db) {
162        Ok(c) => c,
163        Err(e) => {
164            eprintln!("Error: {e}");
165            return 1;
166        }
167    };
168
169    let start = Instant::now();
170    match conn.execute(sql) {
171        Ok(result) => {
172            let output = formatter::format_result(&result, settings);
173            if !output.is_empty() {
174                settings.write_output(&output);
175            }
176            if settings.timer {
177                settings.write_output(&format!("Run Time: {:.3}s", start.elapsed().as_secs_f64()));
178            }
179            0
180        }
181        Err(e) => {
182            eprintln!("Error: {e}");
183            1
184        }
185    }
186}
187
188fn run_piped(db: &citadel::Database, settings: &mut repl::Settings) -> i32 {
189    use std::io::{self, BufRead};
190
191    let conn = match citadel_sql::Connection::open(db) {
192        Ok(c) => c,
193        Err(e) => {
194            eprintln!("Error: {e}");
195            return 1;
196        }
197    };
198
199    let mut buf = String::new();
200    let stdin = io::stdin();
201
202    for line in stdin.lock().lines() {
203        let line = match line {
204            Ok(l) => l,
205            Err(e) => {
206                eprintln!("Error reading stdin: {e}");
207                return 1;
208            }
209        };
210
211        let trimmed = line.trim();
212        if trimmed.is_empty() {
213            continue;
214        }
215
216        if trimmed.starts_with('.') {
217            commands::execute_dot_command_mut(trimmed, db, &conn, settings, &mut io::stdout());
218            continue;
219        }
220
221        buf.push_str(&line);
222        buf.push(' ');
223
224        if has_complete_statement(&buf) {
225            let sql = buf.trim();
226            if !sql.is_empty() {
227                execute_and_display(&conn, sql, &mut *settings);
228            }
229            buf.clear();
230        }
231    }
232
233    if !buf.trim().is_empty() {
234        execute_and_display(&conn, buf.trim(), settings);
235    }
236
237    0
238}
239
240fn execute_and_display(
241    conn: &citadel_sql::Connection<'_>,
242    sql: &str,
243    settings: &mut repl::Settings,
244) {
245    use std::time::Instant;
246
247    let start = Instant::now();
248    match conn.execute(sql) {
249        Ok(result) => {
250            let output = formatter::format_result(&result, settings);
251            if !output.is_empty() {
252                settings.write_output(&output);
253            }
254            if settings.timer {
255                settings.write_output(&format!("Run Time: {:.3}s", start.elapsed().as_secs_f64()));
256            }
257        }
258        Err(e) => {
259            eprintln!("Error: {e}");
260        }
261    }
262}
263
264/// True when quotes are balanced (honoring backslash escapes) and it ends with `;`.
265pub(crate) fn has_complete_statement(s: &str) -> bool {
266    let trimmed = s.trim();
267    if trimmed.is_empty() {
268        return false;
269    }
270
271    let mut in_single_quote = false;
272    let mut in_double_quote = false;
273    let mut last_char = '\0';
274
275    for ch in trimmed.chars() {
276        match ch {
277            '\'' if !in_double_quote && last_char != '\\' => in_single_quote = !in_single_quote,
278            '"' if !in_single_quote && last_char != '\\' => in_double_quote = !in_double_quote,
279            _ => {}
280        }
281        last_char = ch;
282    }
283
284    !in_single_quote && !in_double_quote && trimmed.ends_with(';')
285}