Skip to main content

contextdb_cli/
repl.rs

1use crate::formatter::{format_query_result, format_query_result_with_empty_headers};
2use contextdb_core::{Error, TableMeta};
3use contextdb_engine::Database;
4use contextdb_engine::sync_types::{ConflictPolicy, SyncDirection};
5use contextdb_server::{SyncClient, SyncPlugin};
6use rustyline::DefaultEditor;
7use rustyline::error::ReadlineError;
8use std::collections::HashMap;
9use std::io::{BufRead, IsTerminal};
10use std::sync::Arc;
11
12#[derive(Clone, Copy)]
13struct InputContext {
14    interactive: bool,
15    script_line: Option<usize>,
16}
17
18/// Run the REPL loop. Returns `true` if all commands succeeded, `false` if any error occurred.
19pub fn run(
20    db: Arc<Database>,
21    sync_client: Option<&SyncClient>,
22    rt: Option<&tokio::runtime::Runtime>,
23    sync_plugin: Option<&SyncPlugin>,
24) -> bool {
25    let interactive = std::io::stdin().is_terminal();
26    if interactive {
27        eprintln!("ContextDB v{}", env!("CARGO_PKG_VERSION"));
28        eprintln!("Enter .help for usage hints.");
29        return run_interactive(&db, sync_client, rt, sync_plugin);
30    }
31
32    run_scripted(&db, sync_client, rt, sync_plugin)
33}
34
35fn run_interactive(
36    db: &Database,
37    sync_client: Option<&SyncClient>,
38    rt: Option<&tokio::runtime::Runtime>,
39    sync_plugin: Option<&SyncPlugin>,
40) -> bool {
41    let mut rl = DefaultEditor::new().expect("failed to initialize readline");
42    let mut had_error = false;
43
44    loop {
45        let readline = rl.readline("contextdb> ");
46        match readline {
47            Ok(line) => {
48                let line = line.trim();
49                if line.is_empty() {
50                    continue;
51                }
52                let _ = rl.add_history_entry(line);
53                if !process_input_line(
54                    db,
55                    sync_client,
56                    rt,
57                    line,
58                    InputContext {
59                        interactive: true,
60                        script_line: None,
61                    },
62                    sync_plugin,
63                    &mut had_error,
64                ) {
65                    break;
66                }
67            }
68            Err(ReadlineError::Interrupted | ReadlineError::Eof) => break,
69            Err(_) => break,
70        }
71    }
72
73    !had_error
74}
75
76fn run_scripted(
77    db: &Database,
78    sync_client: Option<&SyncClient>,
79    rt: Option<&tokio::runtime::Runtime>,
80    sync_plugin: Option<&SyncPlugin>,
81) -> bool {
82    let mut had_error = false;
83    let stdin = std::io::stdin();
84    for (line_idx, line) in stdin.lock().lines().enumerate() {
85        let line = match line {
86            Ok(line) => line,
87            Err(_) => break,
88        };
89        let line = line.trim();
90        if line.is_empty() || line.starts_with("--") {
91            continue;
92        }
93        if !process_input_line(
94            db,
95            sync_client,
96            rt,
97            line,
98            InputContext {
99                interactive: false,
100                script_line: Some(line_idx + 1),
101            },
102            sync_plugin,
103            &mut had_error,
104        ) {
105            break;
106        }
107    }
108    !had_error
109}
110
111fn process_input_line(
112    db: &Database,
113    sync_client: Option<&SyncClient>,
114    rt: Option<&tokio::runtime::Runtime>,
115    line: &str,
116    input: InputContext,
117    sync_plugin: Option<&SyncPlugin>,
118    had_error: &mut bool,
119) -> bool {
120    if line.starts_with(".sync") || line.starts_with("\\sync") {
121        let mut parts = line.splitn(2, ' ');
122        let _cmd = parts.next();
123        let rest = parts.next().unwrap_or("").trim();
124        let outcome = run_sync_command(sync_client, rt, rest, sync_plugin);
125        println!("{}", outcome.message);
126        if !outcome.ok {
127            *had_error = true;
128        }
129    } else if line.starts_with('.') || line.starts_with('\\') {
130        if !handle_meta_command(db, sync_client, rt, line, sync_plugin) {
131            return false;
132        }
133    } else {
134        let upper = line.trim_start().to_uppercase();
135        if !input.interactive && upper.starts_with("INSERT") {
136            println!("{line}");
137        }
138        if !execute_sql(db, line, input.script_line) {
139            *had_error = true;
140        }
141    }
142    true
143}
144
145struct SyncCommandOutcome {
146    message: String,
147    ok: bool,
148}
149
150pub(crate) fn handle_meta_command(
151    db: &Database,
152    sync_client: Option<&SyncClient>,
153    rt: Option<&tokio::runtime::Runtime>,
154    line: &str,
155    sync_plugin: Option<&SyncPlugin>,
156) -> bool {
157    let mut parts = line.splitn(2, ' ');
158    let cmd = parts.next().unwrap_or("");
159    let rest = parts.next().unwrap_or("").trim();
160
161    match cmd {
162        ".quit" | ".exit" | "\\q" => return false,
163        ".help" | "\\?" => {
164            if rest.eq_ignore_ascii_case("vector") {
165                println!("VECTOR(N) WITH (quantization = 'F32'|'SQ8'|'SQ4')");
166                println!("SHOW VECTOR_INDEXES");
167                println!(
168                    "Vector errors include LegacyVectorStoreDetected, StoreCorrupted, VectorIndexDimensionMismatch, and UnknownVectorIndex"
169                );
170                return true;
171            }
172            println!(".help / \\?          Show this message");
173            println!(".help vector        Show vector index syntax and errors");
174            println!(".quit/.exit / \\q    Exit REPL");
175            println!(".tables / \\dt       List tables");
176            println!(".schema / \\d <tbl>  Show table schema and constraints");
177            println!(".explain <sql>      Show execution plan");
178            println!(".sync status              Show sync connection info");
179            println!(".sync push                Push local changes to server");
180            println!(".sync pull                Pull remote changes from server");
181            println!(".sync reconnect           Reconnect to NATS");
182            println!(".sync direction <t> <d>   Set table sync direction (Push|Pull|Both|None)");
183            println!(
184                ".sync policy <t> <p>      Set table conflict policy (InsertIfNotExists|ServerWins|EdgeWins|LatestWins)"
185            );
186            println!(".sync policy default <p>  Set default conflict policy");
187            println!(".sync auto [on|off]       Toggle auto-sync after DML");
188        }
189        ".tables" | "\\dt" => {
190            for t in db.table_names() {
191                println!("{t}");
192            }
193        }
194        ".schema" | "\\d" => {
195            if rest.is_empty() {
196                eprintln!("Usage: .schema <table> or \\d <table>");
197            } else if let Some(meta) = db.table_meta(rest) {
198                print_table_meta(rest, &meta);
199            } else {
200                eprintln!("Table not found: {rest}");
201            }
202        }
203        ".explain" => {
204            if rest.is_empty() {
205                eprintln!("Usage: .explain <sql>");
206            } else {
207                match db.explain(rest) {
208                    Ok(plan) => println!("{}", plan),
209                    Err(e) => {
210                        if is_fatal_cli_error(&e) {
211                            eprintln!("Error: {}", e);
212                        } else {
213                            println!("Error: {}", e);
214                        }
215                    }
216                }
217            }
218        }
219        ".sync" | "\\sync" => {
220            println!(
221                "{}",
222                handle_sync_command(sync_client, rt, rest, sync_plugin)
223            );
224        }
225        _ => println!("Unknown command: {}. Type \\? for help.", cmd),
226    }
227
228    true
229}
230
231fn handle_sync_command(
232    sync_client: Option<&SyncClient>,
233    rt: Option<&tokio::runtime::Runtime>,
234    args: &str,
235    sync_plugin: Option<&SyncPlugin>,
236) -> String {
237    run_sync_command(sync_client, rt, args, sync_plugin).message
238}
239
240fn run_sync_command(
241    sync_client: Option<&SyncClient>,
242    rt: Option<&tokio::runtime::Runtime>,
243    args: &str,
244    sync_plugin: Option<&SyncPlugin>,
245) -> SyncCommandOutcome {
246    let (Some(client), Some(rt)) = (sync_client, rt) else {
247        return SyncCommandOutcome {
248            message: "Sync not configured. Start with --tenant-id to enable.".to_string(),
249            ok: true,
250        };
251    };
252
253    let parts: Vec<&str> = args.split_whitespace().collect();
254    let sub = parts.first().copied().unwrap_or("status");
255
256    match sub {
257        "status" => {
258            let connected = rt.block_on(client.ensure_connected()).is_ok();
259            let status = if connected {
260                "connected"
261            } else {
262                "unreachable"
263            };
264            let base = format!(
265                "Sync: tenant={}, url={}\nNATS: {status}\nDatabase LSN: {}\nPush watermark: LSN {}\nPull watermark: LSN {}",
266                client.tenant_id(),
267                client.nats_url(),
268                client.db().current_lsn(),
269                client.push_watermark(),
270                client.pull_watermark()
271            );
272            let render = contextdb_engine::cli_render::render_sync_status(client.db());
273            SyncCommandOutcome {
274                message: format!("{base}\n{render}"),
275                ok: true,
276            }
277        }
278        "push" => match rt.block_on(client.push()) {
279            Ok(result) => {
280                let mut msg = format!(
281                    "Pushed: {} applied, {} skipped, {} conflicts",
282                    result.applied_rows,
283                    result.skipped_rows,
284                    result.conflicts.len()
285                );
286                for conflict in &result.conflicts {
287                    if let Some(reason) = &conflict.reason {
288                        msg.push_str(&format!("\n  conflict: {}", reason));
289                    }
290                }
291                SyncCommandOutcome {
292                    message: msg,
293                    ok: true,
294                }
295            }
296            Err(e) => SyncCommandOutcome {
297                message: format!("Push failed: {e}"),
298                ok: false,
299            },
300        },
301        "pull" => match rt.block_on(client.pull_default()) {
302            Ok(result) => {
303                let mut msg = format!(
304                    "Pulled: {} applied, {} skipped, {} conflicts",
305                    result.applied_rows,
306                    result.skipped_rows,
307                    result.conflicts.len()
308                );
309                for conflict in &result.conflicts {
310                    if let Some(reason) = &conflict.reason {
311                        msg.push_str(&format!("\n  conflict: {}", reason));
312                    }
313                }
314                SyncCommandOutcome {
315                    message: msg,
316                    ok: true,
317                }
318            }
319            Err(e) => SyncCommandOutcome {
320                message: format!("Pull failed: {e}"),
321                ok: false,
322            },
323        },
324        "reconnect" => {
325            rt.block_on(client.reconnect());
326            let connected = rt.block_on(client.is_connected());
327            if connected {
328                SyncCommandOutcome {
329                    message: "Reconnected to NATS".to_string(),
330                    ok: true,
331                }
332            } else {
333                SyncCommandOutcome {
334                    message: "Reconnection failed — NATS unreachable".to_string(),
335                    ok: false,
336                }
337            }
338        }
339        "direction" => {
340            if parts.len() != 3 {
341                return SyncCommandOutcome {
342                    message: "Usage: .sync direction <table> <Push|Pull|Both|None>".to_string(),
343                    ok: true,
344                };
345            }
346            let table = parts[1];
347            let dir = match parts[2] {
348                "Push" | "push" => SyncDirection::Push,
349                "Pull" | "pull" => SyncDirection::Pull,
350                "Both" | "both" => SyncDirection::Both,
351                "None" | "none" => SyncDirection::None,
352                other => {
353                    return SyncCommandOutcome {
354                        message: format!("Unknown direction: {other}. Use: Push, Pull, Both, None"),
355                        ok: true,
356                    };
357                }
358            };
359            client.set_table_direction(table, dir);
360            SyncCommandOutcome {
361                message: format!("{table} -> {dir:?}"),
362                ok: true,
363            }
364        }
365        "policy" => {
366            if parts.len() != 3 {
367                return SyncCommandOutcome {
368                    message: "Usage: .sync policy <table> <InsertIfNotExists|ServerWins|EdgeWins|LatestWins>\n       .sync policy default <policy>".to_string(),
369                    ok: true,
370                };
371            }
372            let policy = match parts[2] {
373                "InsertIfNotExists" => ConflictPolicy::InsertIfNotExists,
374                "ServerWins" => ConflictPolicy::ServerWins,
375                "EdgeWins" => ConflictPolicy::EdgeWins,
376                "LatestWins" => ConflictPolicy::LatestWins,
377                other => {
378                    return SyncCommandOutcome {
379                        message: format!(
380                            "Unknown policy: {other}. Use: InsertIfNotExists, ServerWins, EdgeWins, LatestWins"
381                        ),
382                        ok: true,
383                    };
384                }
385            };
386            if parts[1] == "default" {
387                client.set_default_conflict_policy(policy);
388                SyncCommandOutcome {
389                    message: format!("Default conflict policy -> {policy:?}"),
390                    ok: true,
391                }
392            } else {
393                client.set_conflict_policy(parts[1], policy);
394                SyncCommandOutcome {
395                    message: format!("{} -> {policy:?}", parts[1]),
396                    ok: true,
397                }
398            }
399        }
400        "auto" => {
401            let Some(plugin) = sync_plugin else {
402                return SyncCommandOutcome {
403                    message: "Auto-sync not available (no sync plugin)".to_string(),
404                    ok: true,
405                };
406            };
407            let toggle = parts.get(1).copied().unwrap_or("");
408            match toggle {
409                "on" => {
410                    plugin.set_auto(true);
411                    SyncCommandOutcome {
412                        message: "Auto-sync enabled".to_string(),
413                        ok: true,
414                    }
415                }
416                "off" => {
417                    plugin.set_auto(false);
418                    SyncCommandOutcome {
419                        message: "Auto-sync disabled".to_string(),
420                        ok: true,
421                    }
422                }
423                "" => {
424                    let state = if plugin.is_auto() { "on" } else { "off" };
425                    SyncCommandOutcome {
426                        message: format!("Auto-sync: {state}"),
427                        ok: true,
428                    }
429                }
430                other => SyncCommandOutcome {
431                    message: format!("Unknown auto-sync option: {other}. Use: on, off"),
432                    ok: true,
433                },
434            }
435        }
436        _ => SyncCommandOutcome {
437            message: format!(
438                "Unknown sync command: {sub}. Try: status, push, pull, reconnect, direction, policy, auto"
439            ),
440            ok: true,
441        },
442    }
443}
444
445fn print_table_meta(table: &str, meta: &TableMeta) {
446    print!(
447        "{}",
448        contextdb_engine::cli_render::render_table_meta(table, meta)
449    );
450}
451
452/// Execute a SQL statement and print the result. Returns `true` on success, `false` on error.
453fn execute_sql(db: &Database, sql: &str, script_line: Option<usize>) -> bool {
454    match db.execute(sql, &HashMap::new()) {
455        Ok(result) => {
456            if result.columns.is_empty() {
457                println!("ok (rows_affected={})", result.rows_affected);
458            } else {
459                let show_empty_headers = sql
460                    .trim()
461                    .trim_end_matches(';')
462                    .eq_ignore_ascii_case("SHOW VECTOR_INDEXES");
463                let formatted = if show_empty_headers {
464                    format_query_result_with_empty_headers(&result, true)
465                } else {
466                    format_query_result(&result)
467                };
468                println!("{formatted}");
469            }
470            true
471        }
472        Err(e) => {
473            let rendered = if matches!(e, Error::ParseError(_)) {
474                if let Some(line) = script_line {
475                    format!("line {line}: {e}")
476                } else {
477                    e.to_string()
478                }
479            } else {
480                e.to_string()
481            }
482            .replace('\n', " ");
483            if is_fatal_cli_error(&e) {
484                eprintln!("Error: {rendered}");
485                false
486            } else {
487                println!("Error: {rendered}");
488                true
489            }
490        }
491    }
492}
493
494pub fn is_fatal_cli_error_public(error: &Error) -> bool {
495    is_fatal_cli_error(error)
496}
497
498fn is_fatal_cli_error(error: &Error) -> bool {
499    match error {
500        Error::ParseError(_)
501        | Error::TableNotFound(_)
502        | Error::NotFound(_)
503        | Error::BfsDepthExceeded(_)
504        | Error::RecursiveCteNotSupported
505        | Error::WindowFunctionNotSupported
506        | Error::FullTextSearchNotSupported
507        | Error::VectorIndexDimensionMismatch { .. }
508        | Error::UnknownVectorIndex { .. }
509        | Error::StoreCorrupted { .. }
510        | Error::LegacyVectorStoreDetected { .. } => true,
511        Error::MemoryBudgetExceeded { operation, .. } => operation.contains('@'),
512        _ => false,
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use contextdb_engine::sync_types::{ConflictPolicy, SyncDirection};
520    use contextdb_parser::{Statement, parse};
521
522    #[test]
523    fn test_backslash_dt() {
524        let db = Database::open_memory();
525        db.execute("CREATE TABLE t (id UUID PRIMARY KEY)", &HashMap::new())
526            .unwrap();
527        assert!(handle_meta_command(&db, None, None, "\\dt", None));
528    }
529
530    // B1: Existing \dt works with new handle_meta_command signature
531    #[test]
532    fn b1_existing_dt_works_with_new_signature() {
533        let db = Database::open_memory();
534        db.execute("CREATE TABLE t (id UUID PRIMARY KEY)", &HashMap::new())
535            .unwrap();
536        // Pass None for sync_client and rt — existing commands must work without sync
537        assert!(handle_meta_command(&db, None, None, "\\dt", None));
538        // Also verify .tables works
539        assert!(handle_meta_command(&db, None, None, ".tables", None));
540        // .quit returns false
541        assert!(!handle_meta_command(&db, None, None, ".quit", None));
542    }
543
544    // B2: .sync subcommands handle missing sync configuration
545    #[test]
546    fn b2_sync_not_configured_message() {
547        for subcmd in [
548            "status",
549            "push",
550            "pull",
551            "reconnect",
552            "direction t Push",
553            "policy t ServerWins",
554        ] {
555            let result = handle_sync_command(None, None, subcmd, None);
556            assert!(
557                result.contains("Sync not configured"),
558                "subcmd '{}' should return 'Sync not configured', got: {}",
559                subcmd,
560                result
561            );
562        }
563    }
564
565    // B3: .sync direction parses all four direction values
566    #[test]
567    fn b3_sync_direction_parsing() {
568        let db = Arc::new(Database::open_memory());
569        let rt = tokio::runtime::Runtime::new().unwrap();
570        let client =
571            rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b3-test") });
572
573        // Suppress unused import warning — SyncDirection is used to verify the API exists
574        let _directions = [
575            SyncDirection::Push,
576            SyncDirection::Pull,
577            SyncDirection::Both,
578            SyncDirection::None,
579        ];
580
581        for (table, dir) in [
582            ("observations", "Push"),
583            ("patterns", "pull"),
584            ("decisions", "Both"),
585            ("scratch", "None"),
586        ] {
587            let args = format!("direction {} {}", table, dir);
588            let result = handle_sync_command(Some(&client), Some(&rt), &args, None);
589            assert!(
590                result.contains(table),
591                "direction command for '{}' should contain table name, got: {}",
592                table,
593                result
594            );
595        }
596    }
597
598    // B4: .sync policy parses all four policies + default
599    #[test]
600    fn b4_sync_policy_parsing() {
601        let db = Arc::new(Database::open_memory());
602        let rt = tokio::runtime::Runtime::new().unwrap();
603        let client =
604            rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b4-test") });
605
606        // Suppress unused import warning — ConflictPolicy is used to verify the API exists
607        let _policies = [
608            ConflictPolicy::InsertIfNotExists,
609            ConflictPolicy::ServerWins,
610            ConflictPolicy::EdgeWins,
611            ConflictPolicy::LatestWins,
612        ];
613
614        let result = handle_sync_command(
615            Some(&client),
616            Some(&rt),
617            "policy obs InsertIfNotExists",
618            None,
619        );
620        assert!(
621            result.contains("InsertIfNotExists"),
622            "policy command should contain 'InsertIfNotExists', got: {}",
623            result
624        );
625
626        let result =
627            handle_sync_command(Some(&client), Some(&rt), "policy default ServerWins", None);
628        assert!(
629            result.contains("Default") || result.contains("default"),
630            "default policy command should reference 'default', got: {}",
631            result
632        );
633    }
634
635    // B5: .sync invalid arguments handled gracefully
636    #[test]
637    fn b5_sync_invalid_args() {
638        let db = Arc::new(Database::open_memory());
639        let rt = tokio::runtime::Runtime::new().unwrap();
640        let client =
641            rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b5-test") });
642
643        for bad_input in [
644            "bogus",
645            "direction",
646            "direction table_only",
647            "direction t InvalidDir",
648            "policy",
649            "policy table_only",
650            "policy t InvalidPolicy",
651        ] {
652            let result = handle_sync_command(Some(&client), Some(&rt), bad_input, None);
653            assert!(
654                !result.contains("not implemented"),
655                "bad input '{}' should not return 'not implemented', got: {}",
656                bad_input,
657                result
658            );
659        }
660    }
661
662    #[test]
663    fn rt2_repl_schema_display_round_trip_parse() {
664        let db = Database::open_memory();
665        db.execute(
666            "CREATE TABLE repl_rt_sm (id UUID PRIMARY KEY, status TEXT) STATE MACHINE (status: pending -> [done])",
667            &HashMap::new(),
668        )
669        .unwrap();
670
671        let meta = db.table_meta("repl_rt_sm").expect("table meta");
672        let rendered = contextdb_engine::cli_render::render_table_meta("repl_rt_sm", &meta);
673        assert!(matches!(parse(&rendered), Ok(Statement::CreateTable(_))));
674    }
675}