1use crate::formatter::{format_query_result, format_query_result_with_empty_headers};
2use contextdb_core::{Error, TableMeta};
3use contextdb_engine::Database;
4use contextdb_engine::sync_types::{ConflictPolicy, SyncDirection};
5use contextdb_server::{SyncClient, SyncPlugin};
6use rustyline::DefaultEditor;
7use rustyline::error::ReadlineError;
8use std::collections::HashMap;
9use std::io::{BufRead, IsTerminal};
10use std::sync::Arc;
11
12#[derive(Clone, Copy)]
13struct InputContext {
14 interactive: bool,
15 script_line: Option<usize>,
16}
17
18pub fn run(
20 db: Arc<Database>,
21 sync_client: Option<&SyncClient>,
22 rt: Option<&tokio::runtime::Runtime>,
23 sync_plugin: Option<&SyncPlugin>,
24) -> bool {
25 let interactive = std::io::stdin().is_terminal();
26 if interactive {
27 eprintln!("ContextDB v{}", env!("CARGO_PKG_VERSION"));
28 eprintln!("Enter .help for usage hints.");
29 return run_interactive(&db, sync_client, rt, sync_plugin);
30 }
31
32 run_scripted(&db, sync_client, rt, sync_plugin)
33}
34
35fn run_interactive(
36 db: &Database,
37 sync_client: Option<&SyncClient>,
38 rt: Option<&tokio::runtime::Runtime>,
39 sync_plugin: Option<&SyncPlugin>,
40) -> bool {
41 let mut rl = DefaultEditor::new().expect("failed to initialize readline");
42 let mut had_error = false;
43
44 loop {
45 let readline = rl.readline("contextdb> ");
46 match readline {
47 Ok(line) => {
48 let line = line.trim();
49 if line.is_empty() {
50 continue;
51 }
52 let _ = rl.add_history_entry(line);
53 if !process_input_line(
54 db,
55 sync_client,
56 rt,
57 line,
58 InputContext {
59 interactive: true,
60 script_line: None,
61 },
62 sync_plugin,
63 &mut had_error,
64 ) {
65 break;
66 }
67 }
68 Err(ReadlineError::Interrupted | ReadlineError::Eof) => break,
69 Err(_) => break,
70 }
71 }
72
73 !had_error
74}
75
76fn run_scripted(
77 db: &Database,
78 sync_client: Option<&SyncClient>,
79 rt: Option<&tokio::runtime::Runtime>,
80 sync_plugin: Option<&SyncPlugin>,
81) -> bool {
82 let mut had_error = false;
83 let stdin = std::io::stdin();
84 for (line_idx, line) in stdin.lock().lines().enumerate() {
85 let line = match line {
86 Ok(line) => line,
87 Err(_) => break,
88 };
89 let line = line.trim();
90 if line.is_empty() || line.starts_with("--") {
91 continue;
92 }
93 if !process_input_line(
94 db,
95 sync_client,
96 rt,
97 line,
98 InputContext {
99 interactive: false,
100 script_line: Some(line_idx + 1),
101 },
102 sync_plugin,
103 &mut had_error,
104 ) {
105 break;
106 }
107 }
108 !had_error
109}
110
111fn process_input_line(
112 db: &Database,
113 sync_client: Option<&SyncClient>,
114 rt: Option<&tokio::runtime::Runtime>,
115 line: &str,
116 input: InputContext,
117 sync_plugin: Option<&SyncPlugin>,
118 had_error: &mut bool,
119) -> bool {
120 if line.starts_with(".sync") || line.starts_with("\\sync") {
121 let mut parts = line.splitn(2, ' ');
122 let _cmd = parts.next();
123 let rest = parts.next().unwrap_or("").trim();
124 let outcome = run_sync_command(sync_client, rt, rest, sync_plugin);
125 println!("{}", outcome.message);
126 if !outcome.ok {
127 *had_error = true;
128 }
129 } else if line.starts_with('.') || line.starts_with('\\') {
130 if !handle_meta_command(db, sync_client, rt, line, sync_plugin) {
131 return false;
132 }
133 } else {
134 let upper = line.trim_start().to_uppercase();
135 if !input.interactive && upper.starts_with("INSERT") {
136 println!("{line}");
137 }
138 if !execute_sql(db, line, input.script_line) {
139 *had_error = true;
140 }
141 }
142 true
143}
144
145struct SyncCommandOutcome {
146 message: String,
147 ok: bool,
148}
149
150pub(crate) fn handle_meta_command(
151 db: &Database,
152 sync_client: Option<&SyncClient>,
153 rt: Option<&tokio::runtime::Runtime>,
154 line: &str,
155 sync_plugin: Option<&SyncPlugin>,
156) -> bool {
157 let mut parts = line.splitn(2, ' ');
158 let cmd = parts.next().unwrap_or("");
159 let rest = parts.next().unwrap_or("").trim();
160
161 match cmd {
162 ".quit" | ".exit" | "\\q" => return false,
163 ".help" | "\\?" => {
164 if rest.eq_ignore_ascii_case("vector") {
165 println!("VECTOR(N) WITH (quantization = 'F32'|'SQ8'|'SQ4')");
166 println!("SHOW VECTOR_INDEXES");
167 println!(
168 "Vector errors include LegacyVectorStoreDetected, StoreCorrupted, VectorIndexDimensionMismatch, and UnknownVectorIndex"
169 );
170 return true;
171 }
172 println!(".help / \\? Show this message");
173 println!(".help vector Show vector index syntax and errors");
174 println!(".quit/.exit / \\q Exit REPL");
175 println!(".tables / \\dt List tables");
176 println!(".schema / \\d <tbl> Show table schema and constraints");
177 println!(".explain <sql> Show execution plan");
178 println!(".sync status Show sync connection info");
179 println!(".sync push Push local changes to server");
180 println!(".sync pull Pull remote changes from server");
181 println!(".sync reconnect Reconnect to NATS");
182 println!(".sync direction <t> <d> Set table sync direction (Push|Pull|Both|None)");
183 println!(
184 ".sync policy <t> <p> Set table conflict policy (InsertIfNotExists|ServerWins|EdgeWins|LatestWins)"
185 );
186 println!(".sync policy default <p> Set default conflict policy");
187 println!(".sync auto [on|off] Toggle auto-sync after DML");
188 }
189 ".tables" | "\\dt" => {
190 for t in db.table_names() {
191 println!("{t}");
192 }
193 }
194 ".schema" | "\\d" => {
195 if rest.is_empty() {
196 eprintln!("Usage: .schema <table> or \\d <table>");
197 } else if let Some(meta) = db.table_meta(rest) {
198 print_table_meta(rest, &meta);
199 } else {
200 eprintln!("Table not found: {rest}");
201 }
202 }
203 ".explain" => {
204 if rest.is_empty() {
205 eprintln!("Usage: .explain <sql>");
206 } else {
207 match db.explain(rest) {
208 Ok(plan) => println!("{}", plan),
209 Err(e) => {
210 if is_fatal_cli_error(&e) {
211 eprintln!("Error: {}", e);
212 } else {
213 println!("Error: {}", e);
214 }
215 }
216 }
217 }
218 }
219 ".sync" | "\\sync" => {
220 println!(
221 "{}",
222 handle_sync_command(sync_client, rt, rest, sync_plugin)
223 );
224 }
225 _ => println!("Unknown command: {}. Type \\? for help.", cmd),
226 }
227
228 true
229}
230
231fn handle_sync_command(
232 sync_client: Option<&SyncClient>,
233 rt: Option<&tokio::runtime::Runtime>,
234 args: &str,
235 sync_plugin: Option<&SyncPlugin>,
236) -> String {
237 run_sync_command(sync_client, rt, args, sync_plugin).message
238}
239
240fn run_sync_command(
241 sync_client: Option<&SyncClient>,
242 rt: Option<&tokio::runtime::Runtime>,
243 args: &str,
244 sync_plugin: Option<&SyncPlugin>,
245) -> SyncCommandOutcome {
246 let (Some(client), Some(rt)) = (sync_client, rt) else {
247 return SyncCommandOutcome {
248 message: "Sync not configured. Start with --tenant-id to enable.".to_string(),
249 ok: true,
250 };
251 };
252
253 let parts: Vec<&str> = args.split_whitespace().collect();
254 let sub = parts.first().copied().unwrap_or("status");
255
256 match sub {
257 "status" => {
258 let connected = rt.block_on(client.ensure_connected()).is_ok();
259 let status = if connected {
260 "connected"
261 } else {
262 "unreachable"
263 };
264 let base = format!(
265 "Sync: tenant={}, url={}\nNATS: {status}\nDatabase LSN: {}\nPush watermark: LSN {}\nPull watermark: LSN {}",
266 client.tenant_id(),
267 client.nats_url(),
268 client.db().current_lsn(),
269 client.push_watermark(),
270 client.pull_watermark()
271 );
272 let render = contextdb_engine::cli_render::render_sync_status(client.db());
273 SyncCommandOutcome {
274 message: format!("{base}\n{render}"),
275 ok: true,
276 }
277 }
278 "push" => match rt.block_on(client.push()) {
279 Ok(result) => {
280 let mut msg = format!(
281 "Pushed: {} applied, {} skipped, {} conflicts",
282 result.applied_rows,
283 result.skipped_rows,
284 result.conflicts.len()
285 );
286 for conflict in &result.conflicts {
287 if let Some(reason) = &conflict.reason {
288 msg.push_str(&format!("\n conflict: {}", reason));
289 }
290 }
291 SyncCommandOutcome {
292 message: msg,
293 ok: true,
294 }
295 }
296 Err(e) => SyncCommandOutcome {
297 message: format!("Push failed: {e}"),
298 ok: false,
299 },
300 },
301 "pull" => match rt.block_on(client.pull_default()) {
302 Ok(result) => {
303 let mut msg = format!(
304 "Pulled: {} applied, {} skipped, {} conflicts",
305 result.applied_rows,
306 result.skipped_rows,
307 result.conflicts.len()
308 );
309 for conflict in &result.conflicts {
310 if let Some(reason) = &conflict.reason {
311 msg.push_str(&format!("\n conflict: {}", reason));
312 }
313 }
314 SyncCommandOutcome {
315 message: msg,
316 ok: true,
317 }
318 }
319 Err(e) => SyncCommandOutcome {
320 message: format!("Pull failed: {e}"),
321 ok: false,
322 },
323 },
324 "reconnect" => {
325 rt.block_on(client.reconnect());
326 let connected = rt.block_on(client.is_connected());
327 if connected {
328 SyncCommandOutcome {
329 message: "Reconnected to NATS".to_string(),
330 ok: true,
331 }
332 } else {
333 SyncCommandOutcome {
334 message: "Reconnection failed — NATS unreachable".to_string(),
335 ok: false,
336 }
337 }
338 }
339 "direction" => {
340 if parts.len() != 3 {
341 return SyncCommandOutcome {
342 message: "Usage: .sync direction <table> <Push|Pull|Both|None>".to_string(),
343 ok: true,
344 };
345 }
346 let table = parts[1];
347 let dir = match parts[2] {
348 "Push" | "push" => SyncDirection::Push,
349 "Pull" | "pull" => SyncDirection::Pull,
350 "Both" | "both" => SyncDirection::Both,
351 "None" | "none" => SyncDirection::None,
352 other => {
353 return SyncCommandOutcome {
354 message: format!("Unknown direction: {other}. Use: Push, Pull, Both, None"),
355 ok: true,
356 };
357 }
358 };
359 client.set_table_direction(table, dir);
360 SyncCommandOutcome {
361 message: format!("{table} -> {dir:?}"),
362 ok: true,
363 }
364 }
365 "policy" => {
366 if parts.len() != 3 {
367 return SyncCommandOutcome {
368 message: "Usage: .sync policy <table> <InsertIfNotExists|ServerWins|EdgeWins|LatestWins>\n .sync policy default <policy>".to_string(),
369 ok: true,
370 };
371 }
372 let policy = match parts[2] {
373 "InsertIfNotExists" => ConflictPolicy::InsertIfNotExists,
374 "ServerWins" => ConflictPolicy::ServerWins,
375 "EdgeWins" => ConflictPolicy::EdgeWins,
376 "LatestWins" => ConflictPolicy::LatestWins,
377 other => {
378 return SyncCommandOutcome {
379 message: format!(
380 "Unknown policy: {other}. Use: InsertIfNotExists, ServerWins, EdgeWins, LatestWins"
381 ),
382 ok: true,
383 };
384 }
385 };
386 if parts[1] == "default" {
387 client.set_default_conflict_policy(policy);
388 SyncCommandOutcome {
389 message: format!("Default conflict policy -> {policy:?}"),
390 ok: true,
391 }
392 } else {
393 client.set_conflict_policy(parts[1], policy);
394 SyncCommandOutcome {
395 message: format!("{} -> {policy:?}", parts[1]),
396 ok: true,
397 }
398 }
399 }
400 "auto" => {
401 let Some(plugin) = sync_plugin else {
402 return SyncCommandOutcome {
403 message: "Auto-sync not available (no sync plugin)".to_string(),
404 ok: true,
405 };
406 };
407 let toggle = parts.get(1).copied().unwrap_or("");
408 match toggle {
409 "on" => {
410 plugin.set_auto(true);
411 SyncCommandOutcome {
412 message: "Auto-sync enabled".to_string(),
413 ok: true,
414 }
415 }
416 "off" => {
417 plugin.set_auto(false);
418 SyncCommandOutcome {
419 message: "Auto-sync disabled".to_string(),
420 ok: true,
421 }
422 }
423 "" => {
424 let state = if plugin.is_auto() { "on" } else { "off" };
425 SyncCommandOutcome {
426 message: format!("Auto-sync: {state}"),
427 ok: true,
428 }
429 }
430 other => SyncCommandOutcome {
431 message: format!("Unknown auto-sync option: {other}. Use: on, off"),
432 ok: true,
433 },
434 }
435 }
436 _ => SyncCommandOutcome {
437 message: format!(
438 "Unknown sync command: {sub}. Try: status, push, pull, reconnect, direction, policy, auto"
439 ),
440 ok: true,
441 },
442 }
443}
444
445fn print_table_meta(table: &str, meta: &TableMeta) {
446 print!(
447 "{}",
448 contextdb_engine::cli_render::render_table_meta(table, meta)
449 );
450}
451
452fn execute_sql(db: &Database, sql: &str, script_line: Option<usize>) -> bool {
454 match db.execute(sql, &HashMap::new()) {
455 Ok(result) => {
456 if result.columns.is_empty() {
457 println!("ok (rows_affected={})", result.rows_affected);
458 } else {
459 let show_empty_headers = sql
460 .trim()
461 .trim_end_matches(';')
462 .eq_ignore_ascii_case("SHOW VECTOR_INDEXES");
463 let formatted = if show_empty_headers {
464 format_query_result_with_empty_headers(&result, true)
465 } else {
466 format_query_result(&result)
467 };
468 println!("{formatted}");
469 }
470 true
471 }
472 Err(e) => {
473 let rendered = if matches!(e, Error::ParseError(_)) {
474 if let Some(line) = script_line {
475 format!("line {line}: {e}")
476 } else {
477 e.to_string()
478 }
479 } else {
480 e.to_string()
481 }
482 .replace('\n', " ");
483 if is_fatal_cli_error(&e) {
484 eprintln!("Error: {rendered}");
485 false
486 } else {
487 println!("Error: {rendered}");
488 true
489 }
490 }
491 }
492}
493
494pub fn is_fatal_cli_error_public(error: &Error) -> bool {
495 is_fatal_cli_error(error)
496}
497
498fn is_fatal_cli_error(error: &Error) -> bool {
499 match error {
500 Error::ParseError(_)
501 | Error::TableNotFound(_)
502 | Error::NotFound(_)
503 | Error::BfsDepthExceeded(_)
504 | Error::RecursiveCteNotSupported
505 | Error::WindowFunctionNotSupported
506 | Error::FullTextSearchNotSupported
507 | Error::VectorIndexDimensionMismatch { .. }
508 | Error::UnknownVectorIndex { .. }
509 | Error::StoreCorrupted { .. }
510 | Error::LegacyVectorStoreDetected { .. } => true,
511 Error::MemoryBudgetExceeded { operation, .. } => operation.contains('@'),
512 _ => false,
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use contextdb_engine::sync_types::{ConflictPolicy, SyncDirection};
520 use contextdb_parser::{Statement, parse};
521
522 #[test]
523 fn test_backslash_dt() {
524 let db = Database::open_memory();
525 db.execute("CREATE TABLE t (id UUID PRIMARY KEY)", &HashMap::new())
526 .unwrap();
527 assert!(handle_meta_command(&db, None, None, "\\dt", None));
528 }
529
530 #[test]
532 fn b1_existing_dt_works_with_new_signature() {
533 let db = Database::open_memory();
534 db.execute("CREATE TABLE t (id UUID PRIMARY KEY)", &HashMap::new())
535 .unwrap();
536 assert!(handle_meta_command(&db, None, None, "\\dt", None));
538 assert!(handle_meta_command(&db, None, None, ".tables", None));
540 assert!(!handle_meta_command(&db, None, None, ".quit", None));
542 }
543
544 #[test]
546 fn b2_sync_not_configured_message() {
547 for subcmd in [
548 "status",
549 "push",
550 "pull",
551 "reconnect",
552 "direction t Push",
553 "policy t ServerWins",
554 ] {
555 let result = handle_sync_command(None, None, subcmd, None);
556 assert!(
557 result.contains("Sync not configured"),
558 "subcmd '{}' should return 'Sync not configured', got: {}",
559 subcmd,
560 result
561 );
562 }
563 }
564
565 #[test]
567 fn b3_sync_direction_parsing() {
568 let db = Arc::new(Database::open_memory());
569 let rt = tokio::runtime::Runtime::new().unwrap();
570 let client =
571 rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b3-test") });
572
573 let _directions = [
575 SyncDirection::Push,
576 SyncDirection::Pull,
577 SyncDirection::Both,
578 SyncDirection::None,
579 ];
580
581 for (table, dir) in [
582 ("observations", "Push"),
583 ("patterns", "pull"),
584 ("decisions", "Both"),
585 ("scratch", "None"),
586 ] {
587 let args = format!("direction {} {}", table, dir);
588 let result = handle_sync_command(Some(&client), Some(&rt), &args, None);
589 assert!(
590 result.contains(table),
591 "direction command for '{}' should contain table name, got: {}",
592 table,
593 result
594 );
595 }
596 }
597
598 #[test]
600 fn b4_sync_policy_parsing() {
601 let db = Arc::new(Database::open_memory());
602 let rt = tokio::runtime::Runtime::new().unwrap();
603 let client =
604 rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b4-test") });
605
606 let _policies = [
608 ConflictPolicy::InsertIfNotExists,
609 ConflictPolicy::ServerWins,
610 ConflictPolicy::EdgeWins,
611 ConflictPolicy::LatestWins,
612 ];
613
614 let result = handle_sync_command(
615 Some(&client),
616 Some(&rt),
617 "policy obs InsertIfNotExists",
618 None,
619 );
620 assert!(
621 result.contains("InsertIfNotExists"),
622 "policy command should contain 'InsertIfNotExists', got: {}",
623 result
624 );
625
626 let result =
627 handle_sync_command(Some(&client), Some(&rt), "policy default ServerWins", None);
628 assert!(
629 result.contains("Default") || result.contains("default"),
630 "default policy command should reference 'default', got: {}",
631 result
632 );
633 }
634
635 #[test]
637 fn b5_sync_invalid_args() {
638 let db = Arc::new(Database::open_memory());
639 let rt = tokio::runtime::Runtime::new().unwrap();
640 let client =
641 rt.block_on(async { SyncClient::new(db, "nats://localhost:19999", "b5-test") });
642
643 for bad_input in [
644 "bogus",
645 "direction",
646 "direction table_only",
647 "direction t InvalidDir",
648 "policy",
649 "policy table_only",
650 "policy t InvalidPolicy",
651 ] {
652 let result = handle_sync_command(Some(&client), Some(&rt), bad_input, None);
653 assert!(
654 !result.contains("not implemented"),
655 "bad input '{}' should not return 'not implemented', got: {}",
656 bad_input,
657 result
658 );
659 }
660 }
661
662 #[test]
663 fn rt2_repl_schema_display_round_trip_parse() {
664 let db = Database::open_memory();
665 db.execute(
666 "CREATE TABLE repl_rt_sm (id UUID PRIMARY KEY, status TEXT) STATE MACHINE (status: pending -> [done])",
667 &HashMap::new(),
668 )
669 .unwrap();
670
671 let meta = db.table_meta("repl_rt_sm").expect("table meta");
672 let rendered = contextdb_engine::cli_render::render_table_meta("repl_rt_sm", &meta);
673 assert!(matches!(parse(&rendered), Ok(Statement::CreateTable(_))));
674 }
675}