Skip to main content

ai_memory/cli/
shell.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! `cmd_shell` REPL migration. The line-handling logic is extracted into
5//! `handle_command(parts, conn, out)` so unit tests can drive command
6//! parsing/dispatch without spawning a subprocess. The outer stdin loop
7//! is intentionally minimal and is **not** covered by unit tests — its
8//! `read_line` blocking call would deadlock a buffer-driven test fixture.
9
10use crate::cli::CliOutput;
11use crate::cli::helpers::human_age;
12use crate::{color, db, models, validate};
13use anyhow::Result;
14use rusqlite::Connection;
15use std::path::Path;
16
17/// Returned by `handle_command` to signal whether the REPL should keep
18/// reading more lines.
19#[derive(Debug, PartialEq, Eq)]
20pub enum ShellAction {
21    /// Continue reading the next prompt line.
22    Continue,
23    /// Exit the REPL cleanly.
24    Quit,
25}
26
27/// REPL command dispatcher. Splits its input into a command + tail and
28/// emits all output through `out`. Returns `Quit` on `quit/exit/q`,
29/// `Continue` otherwise.
30#[allow(clippy::too_many_lines)]
31pub fn handle_command(parts: &[&str], conn: &Connection, out: &mut CliOutput<'_>) -> ShellAction {
32    if parts.is_empty() {
33        return ShellAction::Continue;
34    }
35    match parts[0] {
36        "quit" | "exit" | "q" => return ShellAction::Quit,
37        "help" | "h" => {
38            let _ = writeln!(out.stdout, "  recall <context>    — fuzzy recall");
39            let _ = writeln!(out.stdout, "  search <query>      — keyword search");
40            let _ = writeln!(out.stdout, "  list [namespace]    — list memories");
41            let _ = writeln!(out.stdout, "  get <id>            — show memory details");
42            let _ = writeln!(out.stdout, "  stats               — show statistics");
43            let _ = writeln!(out.stdout, "  namespaces          — list namespaces");
44            let _ = writeln!(out.stdout, "  delete <id>         — delete a memory");
45            let _ = writeln!(out.stdout, "  quit                — exit shell");
46        }
47        "recall" | "r" => {
48            let ctx = parts[1..].join(" ");
49            if ctx.is_empty() {
50                let _ = writeln!(out.stderr, "usage: recall <context>");
51                return ShellAction::Continue;
52            }
53            match db::recall(
54                conn,
55                &ctx,
56                None,
57                10,
58                None,
59                None,
60                None,
61                models::SHORT_TTL_EXTEND_SECS,
62                models::MID_TTL_EXTEND_SECS,
63                None,
64                None,
65            ) {
66                Ok((results, _outcome)) => {
67                    for (mem, score) in &results {
68                        let _ = writeln!(
69                            out.stdout,
70                            "  [{}] {} {} score={:.2}",
71                            color::tier_color(mem.tier.as_str(), mem.tier.as_str()),
72                            color::bold(&mem.title),
73                            color::priority_bar(mem.priority),
74                            score
75                        );
76                        let preview: String = mem.content.chars().take(100).collect();
77                        let _ = writeln!(out.stdout, "    {}", color::dim(&preview));
78                    }
79                    let _ = writeln!(out.stdout, "  {} result(s)", results.len());
80                }
81                Err(e) => {
82                    let _ = writeln!(out.stderr, "error: {e}");
83                }
84            }
85        }
86        "search" | "s" => {
87            let q = parts[1..].join(" ");
88            if q.is_empty() {
89                let _ = writeln!(out.stderr, "usage: search <query>");
90                return ShellAction::Continue;
91            }
92            match db::search(conn, &q, None, None, 20, None, None, None, None, None, None) {
93                Ok(results) => {
94                    for mem in &results {
95                        let _ = writeln!(
96                            out.stdout,
97                            "  [{}] {} (p={})",
98                            color::tier_color(mem.tier.as_str(), mem.tier.as_str()),
99                            mem.title,
100                            mem.priority
101                        );
102                    }
103                    let _ = writeln!(out.stdout, "  {} result(s)", results.len());
104                }
105                Err(e) => {
106                    let _ = writeln!(out.stderr, "error: {e}");
107                }
108            }
109        }
110        "list" | "ls" => {
111            let ns = parts.get(1).copied();
112            match db::list(conn, ns, None, 20, 0, None, None, None, None, None) {
113                Ok(results) => {
114                    for mem in &results {
115                        let age = human_age(&mem.updated_at);
116                        let _ = writeln!(
117                            out.stdout,
118                            "  [{}] {} (ns={}, {})",
119                            color::tier_color(mem.tier.as_str(), mem.tier.as_str()),
120                            mem.title,
121                            mem.namespace,
122                            color::dim(&age)
123                        );
124                    }
125                    let _ = writeln!(out.stdout, "  {} memory(ies)", results.len());
126                }
127                Err(e) => {
128                    let _ = writeln!(out.stderr, "error: {e}");
129                }
130            }
131        }
132        "get" => {
133            let id = parts.get(1).copied().unwrap_or("");
134            if id.is_empty() {
135                let _ = writeln!(out.stderr, "usage: get <id>");
136                return ShellAction::Continue;
137            }
138            if let Err(e) = validate::validate_id(id) {
139                let _ = writeln!(out.stderr, "invalid id: {e}");
140                return ShellAction::Continue;
141            }
142            match db::get(conn, id) {
143                Ok(Some(mem)) => {
144                    let _ = writeln!(
145                        out.stdout,
146                        "{}",
147                        serde_json::to_string_pretty(&mem).unwrap_or_default()
148                    );
149                }
150                Ok(None) => {
151                    let _ = writeln!(out.stderr, "not found");
152                }
153                Err(e) => {
154                    let _ = writeln!(out.stderr, "error: {e}");
155                }
156            }
157        }
158        "stats" => match db::stats(conn, Path::new(":memory:")) {
159            Ok(s) => {
160                let _ = writeln!(out.stdout, "  total: {}, links: {}", s.total, s.links_count);
161                for t in &s.by_tier {
162                    let _ = writeln!(
163                        out.stdout,
164                        "    {}: {}",
165                        color::tier_color(&t.tier, &t.tier),
166                        t.count
167                    );
168                }
169            }
170            Err(e) => {
171                let _ = writeln!(out.stderr, "error: {e}");
172            }
173        },
174        "namespaces" | "ns" => match db::list_namespaces(conn) {
175            Ok(ns) => {
176                for n in &ns {
177                    let _ = writeln!(out.stdout, "  {}: {}", color::cyan(&n.namespace), n.count);
178                }
179            }
180            Err(e) => {
181                let _ = writeln!(out.stderr, "error: {e}");
182            }
183        },
184        "delete" | "del" | "rm" => {
185            let id = parts.get(1).copied().unwrap_or("");
186            if id.is_empty() {
187                let _ = writeln!(out.stderr, "usage: delete <id>");
188                return ShellAction::Continue;
189            }
190            if let Err(e) = validate::validate_id(id) {
191                let _ = writeln!(out.stderr, "invalid id: {e}");
192                return ShellAction::Continue;
193            }
194            match db::delete(conn, id) {
195                Ok(true) => {
196                    let _ = writeln!(out.stdout, "  deleted");
197                }
198                Ok(false) => {
199                    let _ = writeln!(out.stderr, "  not found");
200                }
201                Err(e) => {
202                    let _ = writeln!(out.stderr, "error: {e}");
203                }
204            }
205        }
206        unknown => {
207            let _ = writeln!(
208                out.stderr,
209                "unknown command: {unknown}. Type 'help' for commands."
210            );
211        }
212    }
213    ShellAction::Continue
214}
215
216/// `shell` handler. Outer stdin loop. Not unit-tested — the blocking
217/// `read_line` call would deadlock a `Vec<u8>` test fixture; the line
218/// handler logic lives in `handle_command`, which is exhaustively tested.
219pub fn run(db_path: &Path) -> Result<()> {
220    let conn = db::open(db_path)?;
221    println!(
222        "{}",
223        color::bold("ai-memory shell — type 'help' for commands, 'quit' to exit")
224    );
225    let stdin = std::io::stdin();
226    let stdout_handle = std::io::stdout();
227    let stderr_handle = std::io::stderr();
228    loop {
229        eprint!("{} ", color::cyan("memory>"));
230        let mut line = String::new();
231        if stdin.read_line(&mut line)? == 0 {
232            break;
233        }
234        let parts: Vec<&str> = line.split_whitespace().collect();
235        let mut so = stdout_handle.lock();
236        let mut se = stderr_handle.lock();
237        let mut out = CliOutput::from_std(&mut so, &mut se);
238        let action = handle_command(&parts, &conn, &mut out);
239        drop(out);
240        if action == ShellAction::Quit {
241            break;
242        }
243    }
244    println!("goodbye");
245    Ok(())
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use crate::cli::test_utils::{TestEnv, seed_memory};
252
253    fn fresh_conn(env: &TestEnv) -> Connection {
254        // Seed at least once so the schema is materialised, then reopen.
255        seed_memory(&env.db_path, "shell-ns", "seed", "seed-content");
256        db::open(&env.db_path).unwrap()
257    }
258
259    #[test]
260    fn test_shell_quit_command_returns_quit() {
261        let env = TestEnv::fresh();
262        let conn = fresh_conn(&env);
263        let mut stdout = Vec::new();
264        let mut stderr = Vec::new();
265        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
266        let action = handle_command(&["quit"], &conn, &mut out);
267        assert_eq!(action, ShellAction::Quit);
268        let action = handle_command(&["exit"], &conn, &mut out);
269        assert_eq!(action, ShellAction::Quit);
270        let action = handle_command(&["q"], &conn, &mut out);
271        assert_eq!(action, ShellAction::Quit);
272    }
273
274    #[test]
275    fn test_shell_recall_runs_recall() {
276        let env = TestEnv::fresh();
277        let conn = fresh_conn(&env);
278        let mut stdout = Vec::new();
279        let mut stderr = Vec::new();
280        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
281        let action = handle_command(&["recall", "seed"], &conn, &mut out);
282        assert_eq!(action, ShellAction::Continue);
283        let stdout_str = String::from_utf8(stdout).unwrap();
284        assert!(stdout_str.contains("result(s)"));
285    }
286
287    #[test]
288    fn test_shell_recall_empty_args_writes_usage() {
289        let env = TestEnv::fresh();
290        let conn = fresh_conn(&env);
291        let mut stdout = Vec::new();
292        let mut stderr = Vec::new();
293        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
294        handle_command(&["recall"], &conn, &mut out);
295        let stderr_str = String::from_utf8(stderr).unwrap();
296        assert!(stderr_str.contains("usage: recall"));
297    }
298
299    #[test]
300    fn test_shell_search_runs_search() {
301        let env = TestEnv::fresh();
302        let conn = fresh_conn(&env);
303        let mut stdout = Vec::new();
304        let mut stderr = Vec::new();
305        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
306        let action = handle_command(&["search", "seed"], &conn, &mut out);
307        assert_eq!(action, ShellAction::Continue);
308        let stdout_str = String::from_utf8(stdout).unwrap();
309        assert!(stdout_str.contains("result(s)"));
310    }
311
312    #[test]
313    fn test_shell_help_writes_help_text() {
314        let env = TestEnv::fresh();
315        let conn = fresh_conn(&env);
316        let mut stdout = Vec::new();
317        let mut stderr = Vec::new();
318        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
319        handle_command(&["help"], &conn, &mut out);
320        let stdout_str = String::from_utf8(stdout).unwrap();
321        assert!(stdout_str.contains("recall"));
322        assert!(stdout_str.contains("search"));
323        assert!(stdout_str.contains("quit"));
324    }
325
326    #[test]
327    fn test_shell_unknown_command_writes_error() {
328        let env = TestEnv::fresh();
329        let conn = fresh_conn(&env);
330        let mut stdout = Vec::new();
331        let mut stderr = Vec::new();
332        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
333        let action = handle_command(&["frobnicate"], &conn, &mut out);
334        assert_eq!(action, ShellAction::Continue);
335        let stderr_str = String::from_utf8(stderr).unwrap();
336        assert!(stderr_str.contains("unknown command"));
337    }
338
339    #[test]
340    fn test_shell_empty_parts_continues() {
341        let env = TestEnv::fresh();
342        let conn = fresh_conn(&env);
343        let mut stdout = Vec::new();
344        let mut stderr = Vec::new();
345        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
346        let action = handle_command(&[], &conn, &mut out);
347        assert_eq!(action, ShellAction::Continue);
348    }
349
350    #[test]
351    fn test_shell_list_runs_list() {
352        let env = TestEnv::fresh();
353        let conn = fresh_conn(&env);
354        let mut stdout = Vec::new();
355        let mut stderr = Vec::new();
356        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
357        let action = handle_command(&["list"], &conn, &mut out);
358        assert_eq!(action, ShellAction::Continue);
359        let stdout_str = String::from_utf8(stdout).unwrap();
360        assert!(stdout_str.contains("memory(ies)"));
361    }
362
363    #[test]
364    fn test_shell_namespaces_runs() {
365        let env = TestEnv::fresh();
366        let conn = fresh_conn(&env);
367        let mut stdout = Vec::new();
368        let mut stderr = Vec::new();
369        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
370        let action = handle_command(&["namespaces"], &conn, &mut out);
371        assert_eq!(action, ShellAction::Continue);
372        let stdout_str = String::from_utf8(stdout).unwrap();
373        assert!(stdout_str.contains("shell-ns"));
374    }
375
376    #[test]
377    fn test_shell_get_invalid_id_writes_error() {
378        let env = TestEnv::fresh();
379        let conn = fresh_conn(&env);
380        let mut stdout = Vec::new();
381        let mut stderr = Vec::new();
382        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
383        // Trigger "id contains invalid characters" via a control character.
384        handle_command(&["get", "bad\x07id"], &conn, &mut out);
385        let stderr_str = String::from_utf8(stderr).unwrap();
386        assert!(stderr_str.contains("invalid id"), "stderr: {stderr_str}");
387    }
388
389    #[test]
390    fn test_shell_get_missing_arg_writes_usage() {
391        let env = TestEnv::fresh();
392        let conn = fresh_conn(&env);
393        let mut stdout = Vec::new();
394        let mut stderr = Vec::new();
395        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
396        handle_command(&["get"], &conn, &mut out);
397        let stderr_str = String::from_utf8(stderr).unwrap();
398        assert!(stderr_str.contains("usage: get"));
399    }
400
401    #[test]
402    fn test_shell_delete_missing_arg() {
403        let env = TestEnv::fresh();
404        let conn = fresh_conn(&env);
405        let mut stdout = Vec::new();
406        let mut stderr = Vec::new();
407        let mut out = CliOutput::from_std(&mut stdout, &mut stderr);
408        handle_command(&["delete"], &conn, &mut out);
409        let stderr_str = String::from_utf8(stderr).unwrap();
410        assert!(stderr_str.contains("usage: delete"));
411    }
412}