Skip to main content

contextdb_cli/
repl.rs

1use crate::formatter::format_query_result;
2use contextdb_core::{Error, TableMeta};
3use contextdb_engine::Database;
4use contextdb_engine::sync_types::{ConflictPolicy, SyncDirection};
5use contextdb_server::{SyncClient, SyncPlugin};
6use rustyline::DefaultEditor;
7use rustyline::error::ReadlineError;
8use std::collections::HashMap;
9use std::io::{BufRead, IsTerminal};
10use std::sync::Arc;
11
12/// Run the REPL loop. Returns `true` if all commands succeeded, `false` if any error occurred.
13pub fn run(
14    db: Arc<Database>,
15    sync_client: Option<&SyncClient>,
16    rt: Option<&tokio::runtime::Runtime>,
17    sync_plugin: Option<&SyncPlugin>,
18) -> bool {
19    let interactive = std::io::stdin().is_terminal();
20    if interactive {
21        eprintln!("ContextDB v{}", env!("CARGO_PKG_VERSION"));
22        eprintln!("Enter .help for usage hints.");
23        return run_interactive(&db, sync_client, rt, sync_plugin);
24    }
25
26    run_scripted(&db, sync_client, rt, sync_plugin)
27}
28
29fn run_interactive(
30    db: &Database,
31    sync_client: Option<&SyncClient>,
32    rt: Option<&tokio::runtime::Runtime>,
33    sync_plugin: Option<&SyncPlugin>,
34) -> bool {
35    let mut rl = DefaultEditor::new().expect("failed to initialize readline");
36    let mut had_error = false;
37
38    loop {
39        let readline = rl.readline("contextdb> ");
40        match readline {
41            Ok(line) => {
42                let line = line.trim();
43                if line.is_empty() {
44                    continue;
45                }
46                let _ = rl.add_history_entry(line);
47                if !process_input_line(db, sync_client, rt, line, true, sync_plugin, &mut had_error)
48                {
49                    break;
50                }
51            }
52            Err(ReadlineError::Interrupted | ReadlineError::Eof) => break,
53            Err(_) => break,
54        }
55    }
56
57    !had_error
58}
59
60fn run_scripted(
61    db: &Database,
62    sync_client: Option<&SyncClient>,
63    rt: Option<&tokio::runtime::Runtime>,
64    sync_plugin: Option<&SyncPlugin>,
65) -> bool {
66    let mut had_error = false;
67    let stdin = std::io::stdin();
68    for line in stdin.lock().lines() {
69        let line = match line {
70            Ok(line) => line,
71            Err(_) => break,
72        };
73        let line = line.trim();
74        if line.is_empty() {
75            continue;
76        }
77        if !process_input_line(
78            db,
79            sync_client,
80            rt,
81            line,
82            false,
83            sync_plugin,
84            &mut had_error,
85        ) {
86            break;
87        }
88    }
89    !had_error
90}
91
92fn process_input_line(
93    db: &Database,
94    sync_client: Option<&SyncClient>,
95    rt: Option<&tokio::runtime::Runtime>,
96    line: &str,
97    interactive: bool,
98    sync_plugin: Option<&SyncPlugin>,
99    had_error: &mut bool,
100) -> bool {
101    if line.starts_with(".sync") || line.starts_with("\\sync") {
102        let mut parts = line.splitn(2, ' ');
103        let _cmd = parts.next();
104        let rest = parts.next().unwrap_or("").trim();
105        let outcome = run_sync_command(sync_client, rt, rest, sync_plugin);
106        println!("{}", outcome.message);
107        if !outcome.ok {
108            *had_error = true;
109        }
110    } else if line.starts_with('.') || line.starts_with('\\') {
111        if !handle_meta_command(db, sync_client, rt, line, sync_plugin) {
112            return false;
113        }
114    } else {
115        let upper = line.trim_start().to_uppercase();
116        if !interactive && upper.starts_with("INSERT") {
117            println!("{line}");
118        }
119        if !execute_sql(db, line) {
120            *had_error = true;
121        }
122    }
123    true
124}
125
126struct SyncCommandOutcome {
127    message: String,
128    ok: bool,
129}
130
131pub(crate) fn handle_meta_command(
132    db: &Database,
133    sync_client: Option<&SyncClient>,
134    rt: Option<&tokio::runtime::Runtime>,
135    line: &str,
136    sync_plugin: Option<&SyncPlugin>,
137) -> bool {
138    let mut parts = line.splitn(2, ' ');
139    let cmd = parts.next().unwrap_or("");
140    let rest = parts.next().unwrap_or("").trim();
141
142    match cmd {
143        ".quit" | ".exit" | "\\q" => return false,
144        ".help" | "\\?" => {
145            println!(".help / \\?          Show this message");
146            println!(".quit/.exit / \\q    Exit REPL");
147            println!(".tables / \\dt       List tables");
148            println!(".schema / \\d <tbl>  Show table schema and constraints");
149            println!(".explain <sql>      Show execution plan");
150            println!(".sync status              Show sync connection info");
151            println!(".sync push                Push local changes to server");
152            println!(".sync pull                Pull remote changes from server");
153            println!(".sync reconnect           Reconnect to NATS");
154            println!(".sync direction <t> <d>   Set table sync direction (Push|Pull|Both|None)");
155            println!(
156                ".sync policy <t> <p>      Set table conflict policy (InsertIfNotExists|ServerWins|EdgeWins|LatestWins)"
157            );
158            println!(".sync policy default <p>  Set default conflict policy");
159            println!(".sync auto [on|off]       Toggle auto-sync after DML");
160        }
161        ".tables" | "\\dt" => {
162            for t in db.table_names() {
163                println!("{t}");
164            }
165        }
166        ".schema" | "\\d" => {
167            if rest.is_empty() {
168                eprintln!("Usage: .schema <table> or \\d <table>");
169            } else if let Some(meta) = db.table_meta(rest) {
170                print_table_meta(rest, &meta);
171            } else {
172                eprintln!("Table not found: {rest}");
173            }
174        }
175        ".explain" => {
176            if rest.is_empty() {
177                eprintln!("Usage: .explain <sql>");
178            } else {
179                match db.explain(rest) {
180                    Ok(plan) => println!("{}", plan),
181                    Err(e) => {
182                        if is_fatal_cli_error(&e) {
183                            eprintln!("Error: {}", e);
184                        } else {
185                            println!("Error: {}", e);
186                        }
187                    }
188                }
189            }
190        }
191        ".sync" | "\\sync" => {
192            println!(
193                "{}",
194                handle_sync_command(sync_client, rt, rest, sync_plugin)
195            );
196        }
197        _ => println!("Unknown command: {}. Type \\? for help.", cmd),
198    }
199
200    true
201}
202
203fn handle_sync_command(
204    sync_client: Option<&SyncClient>,
205    rt: Option<&tokio::runtime::Runtime>,
206    args: &str,
207    sync_plugin: Option<&SyncPlugin>,
208) -> String {
209    run_sync_command(sync_client, rt, args, sync_plugin).message
210}
211
212fn run_sync_command(
213    sync_client: Option<&SyncClient>,
214    rt: Option<&tokio::runtime::Runtime>,
215    args: &str,
216    sync_plugin: Option<&SyncPlugin>,
217) -> SyncCommandOutcome {
218    let (Some(client), Some(rt)) = (sync_client, rt) else {
219        return SyncCommandOutcome {
220            message: "Sync not configured. Start with --tenant-id to enable.".to_string(),
221            ok: true,
222        };
223    };
224
225    let parts: Vec<&str> = args.split_whitespace().collect();
226    let sub = parts.first().copied().unwrap_or("status");
227
228    match sub {
229        "status" => {
230            let connected = rt.block_on(client.ensure_connected()).is_ok();
231            let status = if connected {
232                "connected"
233            } else {
234                "unreachable"
235            };
236            let base = format!(
237                "Sync: tenant={}, url={}\nNATS: {status}\nDatabase LSN: {}\nPush watermark: LSN {}\nPull watermark: LSN {}",
238                client.tenant_id(),
239                client.nats_url(),
240                client.db().current_lsn(),
241                client.push_watermark(),
242                client.pull_watermark()
243            );
244            let render = contextdb_engine::cli_render::render_sync_status(client.db());
245            SyncCommandOutcome {
246                message: format!("{base}\n{render}"),
247                ok: true,
248            }
249        }
250        "push" => match rt.block_on(client.push()) {
251            Ok(result) => {
252                let mut msg = format!(
253                    "Pushed: {} applied, {} skipped, {} conflicts",
254                    result.applied_rows,
255                    result.skipped_rows,
256                    result.conflicts.len()
257                );
258                for conflict in &result.conflicts {
259                    if let Some(reason) = &conflict.reason {
260                        msg.push_str(&format!("\n  conflict: {}", reason));
261                    }
262                }
263                SyncCommandOutcome {
264                    message: msg,
265                    ok: true,
266                }
267            }
268            Err(e) => SyncCommandOutcome {
269                message: format!("Push failed: {e}"),
270                ok: false,
271            },
272        },
273        "pull" => match rt.block_on(client.pull_default()) {
274            Ok(result) => {
275                let mut msg = format!(
276                    "Pulled: {} applied, {} skipped, {} conflicts",
277                    result.applied_rows,
278                    result.skipped_rows,
279                    result.conflicts.len()
280                );
281                for conflict in &result.conflicts {
282                    if let Some(reason) = &conflict.reason {
283                        msg.push_str(&format!("\n  conflict: {}", reason));
284                    }
285                }
286                SyncCommandOutcome {
287                    message: msg,
288                    ok: true,
289                }
290            }
291            Err(e) => SyncCommandOutcome {
292                message: format!("Pull failed: {e}"),
293                ok: false,
294            },
295        },
296        "reconnect" => {
297            rt.block_on(client.reconnect());
298            let connected = rt.block_on(client.is_connected());
299            if connected {
300                SyncCommandOutcome {
301                    message: "Reconnected to NATS".to_string(),
302                    ok: true,
303                }
304            } else {
305                SyncCommandOutcome {
306                    message: "Reconnection failed — NATS unreachable".to_string(),
307                    ok: false,
308                }
309            }
310        }
311        "direction" => {
312            if parts.len() != 3 {
313                return SyncCommandOutcome {
314                    message: "Usage: .sync direction <table> <Push|Pull|Both|None>".to_string(),
315                    ok: true,
316                };
317            }
318            let table = parts[1];
319            let dir = match parts[2] {
320                "Push" | "push" => SyncDirection::Push,
321                "Pull" | "pull" => SyncDirection::Pull,
322                "Both" | "both" => SyncDirection::Both,
323                "None" | "none" => SyncDirection::None,
324                other => {
325                    return SyncCommandOutcome {
326                        message: format!("Unknown direction: {other}. Use: Push, Pull, Both, None"),
327                        ok: true,
328                    };
329                }
330            };
331            client.set_table_direction(table, dir);
332            SyncCommandOutcome {
333                message: format!("{table} -> {dir:?}"),
334                ok: true,
335            }
336        }
337        "policy" => {
338            if parts.len() != 3 {
339                return SyncCommandOutcome {
340                    message: "Usage: .sync policy <table> <InsertIfNotExists|ServerWins|EdgeWins|LatestWins>\n       .sync policy default <policy>".to_string(),
341                    ok: true,
342                };
343            }
344            let policy = match parts[2] {
345                "InsertIfNotExists" => ConflictPolicy::InsertIfNotExists,
346                "ServerWins" => ConflictPolicy::ServerWins,
347                "EdgeWins" => ConflictPolicy::EdgeWins,
348                "LatestWins" => ConflictPolicy::LatestWins,
349                other => {
350                    return SyncCommandOutcome {
351                        message: format!(
352                            "Unknown policy: {other}. Use: InsertIfNotExists, ServerWins, EdgeWins, LatestWins"
353                        ),
354                        ok: true,
355                    };
356                }
357            };
358            if parts[1] == "default" {
359                client.set_default_conflict_policy(policy);
360                SyncCommandOutcome {
361                    message: format!("Default conflict policy -> {policy:?}"),
362                    ok: true,
363                }
364            } else {
365                client.set_conflict_policy(parts[1], policy);
366                SyncCommandOutcome {
367                    message: format!("{} -> {policy:?}", parts[1]),
368                    ok: true,
369                }
370            }
371        }
372        "auto" => {
373            let Some(plugin) = sync_plugin else {
374                return SyncCommandOutcome {
375                    message: "Auto-sync not available (no sync plugin)".to_string(),
376                    ok: true,
377                };
378            };
379            let toggle = parts.get(1).copied().unwrap_or("");
380            match toggle {
381                "on" => {
382                    plugin.set_auto(true);
383                    SyncCommandOutcome {
384                        message: "Auto-sync enabled".to_string(),
385                        ok: true,
386                    }
387                }
388                "off" => {
389                    plugin.set_auto(false);
390                    SyncCommandOutcome {
391                        message: "Auto-sync disabled".to_string(),
392                        ok: true,
393                    }
394                }
395                "" => {
396                    let state = if plugin.is_auto() { "on" } else { "off" };
397                    SyncCommandOutcome {
398                        message: format!("Auto-sync: {state}"),
399                        ok: true,
400                    }
401                }
402                other => SyncCommandOutcome {
403                    message: format!("Unknown auto-sync option: {other}. Use: on, off"),
404                    ok: true,
405                },
406            }
407        }
408        _ => SyncCommandOutcome {
409            message: format!(
410                "Unknown sync command: {sub}. Try: status, push, pull, reconnect, direction, policy, auto"
411            ),
412            ok: true,
413        },
414    }
415}
416
417fn print_table_meta(table: &str, meta: &TableMeta) {
418    print!(
419        "{}",
420        contextdb_engine::cli_render::render_table_meta(table, meta)
421    );
422}
423
424/// Execute a SQL statement and print the result. Returns `true` on success, `false` on error.
425fn execute_sql(db: &Database, sql: &str) -> bool {
426    match db.execute(sql, &HashMap::new()) {
427        Ok(result) => {
428            if result.columns.is_empty() {
429                println!("ok (rows_affected={})", result.rows_affected);
430            } else {
431                println!("{}", format_query_result(&result));
432            }
433            true
434        }
435        Err(e) => {
436            if is_fatal_cli_error(&e) {
437                eprintln!("Error: {}", e);
438                false
439            } else {
440                println!("Error: {}", e);
441                true
442            }
443        }
444    }
445}
446
447pub fn is_fatal_cli_error_public(error: &Error) -> bool {
448    is_fatal_cli_error(error)
449}
450
451fn is_fatal_cli_error(error: &Error) -> bool {
452    matches!(
453        error,
454        Error::ParseError(_)
455            | Error::TableNotFound(_)
456            | Error::NotFound(_)
457            | Error::BfsDepthExceeded(_)
458            | Error::RecursiveCteNotSupported
459            | Error::WindowFunctionNotSupported
460            | Error::FullTextSearchNotSupported
461    )
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use contextdb_engine::sync_types::{ConflictPolicy, SyncDirection};
468    use contextdb_parser::{Statement, parse};
469
470    #[test]
471    fn test_backslash_dt() {
472        let db = Database::open_memory();
473        db.execute("CREATE TABLE t (id UUID PRIMARY KEY)", &HashMap::new())
474            .unwrap();
475        assert!(handle_meta_command(&db, None, None, "\\dt", None));
476    }
477
478    // B1: Existing \dt works with new handle_meta_command signature
479    #[test]
480    fn b1_existing_dt_works_with_new_signature() {
481        let db = Database::open_memory();
482        db.execute("CREATE TABLE t (id UUID PRIMARY KEY)", &HashMap::new())
483            .unwrap();
484        // Pass None for sync_client and rt — existing commands must work without sync
485        assert!(handle_meta_command(&db, None, None, "\\dt", None));
486        // Also verify .tables works
487        assert!(handle_meta_command(&db, None, None, ".tables", None));
488        // .quit returns false
489        assert!(!handle_meta_command(&db, None, None, ".quit", None));
490    }
491
492    // B2: .sync subcommands handle missing sync configuration
493    #[test]
494    fn b2_sync_not_configured_message() {
495        for subcmd in [
496            "status",
497            "push",
498            "pull",
499            "reconnect",
500            "direction t Push",
501            "policy t ServerWins",
502        ] {
503            let result = handle_sync_command(None, None, subcmd, None);
504            assert!(
505                result.contains("Sync not configured"),
506                "subcmd '{}' should return 'Sync not configured', got: {}",
507                subcmd,
508                result
509            );
510        }
511    }
512
513    // B3: .sync direction parses all four direction values
514    #[test]
515    fn b3_sync_direction_parsing() {
516        let db = Arc::new(Database::open_memory());
517        let rt = tokio::runtime::Runtime::new().unwrap();
518        let client =
519            rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b3-test") });
520
521        // Suppress unused import warning — SyncDirection is used to verify the API exists
522        let _directions = [
523            SyncDirection::Push,
524            SyncDirection::Pull,
525            SyncDirection::Both,
526            SyncDirection::None,
527        ];
528
529        for (table, dir) in [
530            ("observations", "Push"),
531            ("patterns", "pull"),
532            ("decisions", "Both"),
533            ("scratch", "None"),
534        ] {
535            let args = format!("direction {} {}", table, dir);
536            let result = handle_sync_command(Some(&client), Some(&rt), &args, None);
537            assert!(
538                result.contains(table),
539                "direction command for '{}' should contain table name, got: {}",
540                table,
541                result
542            );
543        }
544    }
545
546    // B4: .sync policy parses all four policies + default
547    #[test]
548    fn b4_sync_policy_parsing() {
549        let db = Arc::new(Database::open_memory());
550        let rt = tokio::runtime::Runtime::new().unwrap();
551        let client =
552            rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b4-test") });
553
554        // Suppress unused import warning — ConflictPolicy is used to verify the API exists
555        let _policies = [
556            ConflictPolicy::InsertIfNotExists,
557            ConflictPolicy::ServerWins,
558            ConflictPolicy::EdgeWins,
559            ConflictPolicy::LatestWins,
560        ];
561
562        let result = handle_sync_command(
563            Some(&client),
564            Some(&rt),
565            "policy obs InsertIfNotExists",
566            None,
567        );
568        assert!(
569            result.contains("InsertIfNotExists"),
570            "policy command should contain 'InsertIfNotExists', got: {}",
571            result
572        );
573
574        let result =
575            handle_sync_command(Some(&client), Some(&rt), "policy default ServerWins", None);
576        assert!(
577            result.contains("Default") || result.contains("default"),
578            "default policy command should reference 'default', got: {}",
579            result
580        );
581    }
582
583    // B5: .sync invalid arguments handled gracefully
584    #[test]
585    fn b5_sync_invalid_args() {
586        let db = Arc::new(Database::open_memory());
587        let rt = tokio::runtime::Runtime::new().unwrap();
588        let client =
589            rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b5-test") });
590
591        for bad_input in [
592            "bogus",
593            "direction",
594            "direction table_only",
595            "direction t InvalidDir",
596            "policy",
597            "policy table_only",
598            "policy t InvalidPolicy",
599        ] {
600            let result = handle_sync_command(Some(&client), Some(&rt), bad_input, None);
601            assert!(
602                !result.contains("not implemented"),
603                "bad input '{}' should not return 'not implemented', got: {}",
604                bad_input,
605                result
606            );
607        }
608    }
609
610    #[test]
611    fn rt2_repl_schema_display_round_trip_parse() {
612        let db = Database::open_memory();
613        db.execute(
614            "CREATE TABLE repl_rt_sm (id UUID PRIMARY KEY, status TEXT) STATE MACHINE (status: pending -> [done])",
615            &HashMap::new(),
616        )
617        .unwrap();
618
619        let meta = db.table_meta("repl_rt_sm").expect("table meta");
620        let rendered = contextdb_engine::cli_render::render_table_meta("repl_rt_sm", &meta);
621        assert!(matches!(parse(&rendered), Ok(Statement::CreateTable(_))));
622    }
623}