Skip to main content

ai_memory/cli/
recall.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! `cmd_recall` migration. See `cli::store` for the design pattern.
5//!
6//! W6 (v0.6.3) — embedder construction was unified into
7//! [`crate::daemon_runtime::build_embedder`]. Both `serve()` and this
8//! handler now call the same builder, killing the per-call-site
9//! duplication that the original W5b note flagged. The TestHelper that
10//! used to live here (`build_embedder_for_recall`) is gone.
11
12use crate::cli::CliOutput;
13use crate::cli::helpers::{human_age, id_short};
14use crate::config::AppConfig;
15use crate::{color, daemon_runtime, db, embeddings, hnsw, reranker, validate};
16use anyhow::Result;
17use clap::Args;
18use std::path::Path;
19
20/// Clap-derived arg shape for the `recall` subcommand. Definition moved
21/// from `main.rs` verbatim in W5b — fields and attrs unchanged.
22#[derive(Args)]
23pub struct RecallArgs {
24    #[arg(allow_hyphen_values = true)]
25    pub context: String,
26    #[arg(long, short)]
27    pub namespace: Option<String>,
28    #[arg(long, default_value_t = 10)]
29    pub limit: usize,
30    #[arg(long)]
31    pub tags: Option<String>,
32    #[arg(long)]
33    pub since: Option<String>,
34    #[arg(long)]
35    pub until: Option<String>,
36    /// Feature tier for recall: keyword, semantic, smart, autonomous
37    #[arg(long, short = 'T')]
38    pub tier: Option<String>,
39    /// Task 1.5: querying agent's namespace position. Enables scope-based
40    /// visibility filtering (private/team/unit/org/collective).
41    #[arg(long)]
42    pub as_agent: Option<String>,
43    /// Task 1.11: context-budget-aware recall. Return the top-ranked
44    /// memories whose cumulative estimated tokens fit within N. Omit
45    /// for unlimited (limit-based only).
46    #[arg(long)]
47    pub budget_tokens: Option<usize>,
48    /// v0.6.0.0 contextual recall. Comma-separated list of recent
49    /// conversation tokens used to bias the query embedding at 70/30
50    /// (primary/context). Shifts the recall towards memories that
51    /// match both the explicit query and the conversation's nearby
52    /// topics.
53    #[arg(long, value_delimiter = ',')]
54    pub context_tokens: Option<Vec<String>>,
55}
56
57/// `recall` handler. Mirrors `cmd_recall` from the pre-W5b `main.rs`
58/// verbatim except every emit routes through `out.stdout` / `out.stderr`
59/// instead of `println!` / `eprintln!`. The embedder is built via the
60/// shared [`crate::daemon_runtime::build_embedder`] helper so the offline
61/// recall path and the HTTP daemon use identical construction logic.
62#[allow(clippy::too_many_lines)]
63pub fn run(
64    db_path: &Path,
65    args: &RecallArgs,
66    json_out: bool,
67    app_config: &AppConfig,
68    out: &mut CliOutput<'_>,
69) -> Result<()> {
70    // #151: validate --as-agent namespace
71    if let Some(ref a) = args.as_agent {
72        validate::validate_namespace(a)?;
73    }
74    let conn = db::open(db_path)?;
75    let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
76
77    // Resolve feature tier
78    let feature_tier = app_config.effective_tier(args.tier.as_deref());
79    let tier_config = feature_tier.config();
80
81    // Initialize embedder if tier supports it. Use the shared builder so
82    // recall and the HTTP daemon agree on tier→embedder semantics
83    // (embed_url, model selection, error fallback). The shared builder
84    // is async; we drive it on a small inline runtime to keep `run()`
85    // sync. Tier=Keyword short-circuits inside the builder before any
86    // tokio work happens, so the runtime's only cost is the keyword path.
87    let embedder = {
88        // Bridge sync→async: build a single-threaded runtime just for
89        // this call. Cheap on the Keyword path (no tasks spawned), and
90        // safe because `run()` is itself called from `main.rs` which is
91        // already inside `#[tokio::main]` only when invoked through
92        // `daemon_runtime::run` — the inner runtime is never nested
93        // because we use `Handle::try_current()` to detect that case.
94        if let Ok(handle) = tokio::runtime::Handle::try_current() {
95            tokio::task::block_in_place(|| {
96                handle.block_on(daemon_runtime::build_embedder(feature_tier, app_config))
97            })
98        } else {
99            tokio::runtime::Builder::new_current_thread()
100                .enable_all()
101                .build()?
102                .block_on(daemon_runtime::build_embedder(feature_tier, app_config))
103        }
104    };
105    if let Some(ref emb) = embedder {
106        writeln!(
107            out.stderr,
108            "ai-memory: embedder loaded ({})",
109            emb.model_description()
110        )?;
111    } else if tier_config.embedding_model.is_some() {
112        writeln!(
113            out.stderr,
114            "ai-memory: embedder failed to load, falling back to keyword"
115        )?;
116    }
117
118    // Backfill embeddings for memories that don't have them
119    if let Some(ref emb) = embedder
120        && let Ok(unembedded) = db::get_unembedded_ids(&conn)
121        && !unembedded.is_empty()
122    {
123        writeln!(
124            out.stderr,
125            "ai-memory: backfilling {} memories...",
126            unembedded.len()
127        )?;
128        let mut ok = 0usize;
129        for (id, title, content) in &unembedded {
130            let text = format!("{title} {content}");
131            if let Ok(embedding) = emb.embed(&text)
132                && db::set_embedding(&conn, id, &embedding).is_ok()
133            {
134                ok += 1;
135            }
136        }
137        writeln!(
138            out.stderr,
139            "ai-memory: backfilled {}/{}",
140            ok,
141            unembedded.len()
142        )?;
143    }
144
145    // Build HNSW vector index if embedder is available
146    let vector_index = if embedder.is_some() {
147        match db::get_all_embeddings(&conn) {
148            Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
149            _ => Some(hnsw::VectorIndex::empty()),
150        }
151    } else {
152        None
153    };
154
155    // Initialize cross-encoder reranker for autonomous tier
156    let reranker = if tier_config.cross_encoder {
157        Some(reranker::CrossEncoder::new_neural())
158    } else {
159        None
160    };
161
162    let resolved_ttl = app_config.effective_ttl();
163    let resolved_scoring = app_config.effective_scoring();
164
165    // Perform recall: hybrid if embedder available, keyword otherwise
166    let (results, outcome, mode) = if let Some(ref emb) = embedder {
167        match emb.embed(&args.context) {
168            Ok(primary_emb) => {
169                // v0.6.0.0 contextual recall. Fuse the primary query
170                // embedding with an embedding over recent conversation
171                // tokens (caller-supplied) at 70/30. Fusion is done
172                // caller-side so recall_hybrid stays unaware of the bias —
173                // the vector it receives is the final query direction.
174                let query_emb = match args.context_tokens.as_deref() {
175                    Some(tokens) if !tokens.is_empty() => {
176                        let joined = tokens.join(" ");
177                        match emb.embed(&joined) {
178                            Ok(ctx_emb) => embeddings::Embedder::fuse(&primary_emb, &ctx_emb, 0.7),
179                            Err(e) => {
180                                writeln!(
181                                    out.stderr,
182                                    "ai-memory: context_tokens embed failed: {e}, using primary only"
183                                )?;
184                                primary_emb
185                            }
186                        }
187                    }
188                    _ => primary_emb,
189                };
190                let (results, outcome) = db::recall_hybrid(
191                    &conn,
192                    &args.context,
193                    &query_emb,
194                    args.namespace.as_deref(),
195                    args.limit.min(50),
196                    args.tags.as_deref(),
197                    args.since.as_deref(),
198                    args.until.as_deref(),
199                    vector_index.as_ref(),
200                    resolved_ttl.short_extend_secs,
201                    resolved_ttl.mid_extend_secs,
202                    args.as_agent.as_deref(),
203                    args.budget_tokens,
204                    &resolved_scoring,
205                )?;
206                if let Some(ref ce) = reranker {
207                    (ce.rerank(&args.context, results), outcome, "hybrid+rerank")
208                } else {
209                    (results, outcome, "hybrid")
210                }
211            }
212            Err(e) => {
213                writeln!(
214                    out.stderr,
215                    "ai-memory: embedding query failed: {e}, falling back to keyword"
216                )?;
217                let (results, outcome) = db::recall(
218                    &conn,
219                    &args.context,
220                    args.namespace.as_deref(),
221                    args.limit,
222                    args.tags.as_deref(),
223                    args.since.as_deref(),
224                    args.until.as_deref(),
225                    resolved_ttl.short_extend_secs,
226                    resolved_ttl.mid_extend_secs,
227                    args.as_agent.as_deref(),
228                    args.budget_tokens,
229                )?;
230                (results, outcome, "keyword")
231            }
232        }
233    } else {
234        let (results, outcome) = db::recall(
235            &conn,
236            &args.context,
237            args.namespace.as_deref(),
238            args.limit,
239            args.tags.as_deref(),
240            args.since.as_deref(),
241            args.until.as_deref(),
242            resolved_ttl.short_extend_secs,
243            resolved_ttl.mid_extend_secs,
244            args.as_agent.as_deref(),
245            args.budget_tokens,
246        )?;
247        (results, outcome, "keyword")
248    };
249
250    if json_out {
251        let scored: Vec<serde_json::Value> = results
252            .iter()
253            .map(|(m, s)| {
254                let mut v = serde_json::to_value(m).unwrap_or_default();
255                if let Some(obj) = v.as_object_mut() {
256                    obj.insert(
257                        "score".to_string(),
258                        serde_json::json!((s * 1000.0).round() / 1000.0),
259                    );
260                }
261                v
262            })
263            .collect();
264        let mut body = serde_json::json!({
265            "memories": scored,
266            "count": results.len(),
267            "mode": mode,
268            "tokens_used": outcome.tokens_used,
269        });
270        if let Some(b) = args.budget_tokens {
271            body["budget_tokens"] = serde_json::json!(b);
272            // Phase P6 (R1) meta block — same shape as MCP / HTTP paths.
273            body["meta"] = serde_json::json!({
274                "budget_tokens_used": outcome.tokens_used,
275                "budget_tokens_remaining": outcome.tokens_remaining.unwrap_or(0),
276                "memories_dropped": outcome.memories_dropped,
277                "budget_overflow": outcome.budget_overflow,
278            });
279        }
280        writeln!(out.stdout, "{}", serde_json::to_string(&body)?)?;
281        return Ok(());
282    }
283    if results.is_empty() {
284        writeln!(out.stderr, "no memories found for: {}", args.context)?;
285        return Ok(());
286    }
287    for (mem, score) in &results {
288        let age = human_age(&mem.updated_at);
289        let config = if mem.confidence < 1.0 {
290            format!(" conf={:.0}%", mem.confidence * 100.0)
291        } else {
292            String::new()
293        };
294        writeln!(
295            out.stdout,
296            "[{}] {} {} score={:.2} (ns={}, {}x, {}{})",
297            color::tier_color(
298                mem.tier.as_str(),
299                &format!("{}/{}", mem.tier, id_short(&mem.id))
300            ),
301            color::bold(&mem.title),
302            color::priority_bar(mem.priority),
303            score,
304            color::cyan(&mem.namespace),
305            mem.access_count,
306            color::dim(&age),
307            config
308        )?;
309        let preview: String = mem.content.chars().take(200).collect();
310        writeln!(out.stdout, "  {}\n", color::dim(&preview))?;
311    }
312    writeln!(
313        out.stdout,
314        "{} memory(ies) recalled [{}]",
315        results.len(),
316        mode
317    )?;
318    Ok(())
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::cli::test_utils::{TestEnv, seed_memory};
325    use crate::config::FeatureTier;
326
327    fn default_args() -> RecallArgs {
328        RecallArgs {
329            context: "needle".to_string(),
330            namespace: None,
331            limit: 10,
332            tags: None,
333            since: None,
334            until: None,
335            tier: Some("keyword".to_string()),
336            as_agent: None,
337            budget_tokens: None,
338            context_tokens: None,
339        }
340    }
341
342    #[test]
343    fn test_recall_keyword_tier_no_embedder() {
344        // Keyword tier => no embedder; the keyword branch must run
345        // happily and find the seeded title.
346        let mut env = TestEnv::fresh();
347        let db = env.db_path.clone();
348        seed_memory(&db, "test", "needle title", "haystack content");
349        let args = default_args();
350        let cfg = AppConfig::default();
351        {
352            let mut out = env.output();
353            run(&db, &args, false, &cfg, &mut out).unwrap();
354        }
355        let stdout = env.stdout_str();
356        assert!(stdout.contains("needle title"), "got: {stdout}");
357        assert!(stdout.contains("[keyword]"), "got: {stdout}");
358    }
359
360    #[test]
361    fn test_recall_keyword_empty_results() {
362        // No seeded rows => empty results => stderr emits "no memories
363        // found for: ..." and stdout stays empty (text mode).
364        let mut env = TestEnv::fresh();
365        let db = env.db_path.clone();
366        let args = default_args();
367        let cfg = AppConfig::default();
368        {
369            let mut out = env.output();
370            run(&db, &args, false, &cfg, &mut out).unwrap();
371        }
372        assert_eq!(env.stdout_str(), "");
373        assert!(
374            env.stderr_str().contains("no memories found for: needle"),
375            "got: {}",
376            env.stderr_str()
377        );
378    }
379
380    #[test]
381    fn test_recall_keyword_with_namespace_filter() {
382        let mut env = TestEnv::fresh();
383        let db = env.db_path.clone();
384        seed_memory(&db, "ns-a", "needle in a", "content a");
385        seed_memory(&db, "ns-b", "needle in b", "content b");
386        let mut args = default_args();
387        args.namespace = Some("ns-a".to_string());
388        let cfg = AppConfig::default();
389        {
390            let mut out = env.output();
391            run(&db, &args, true, &cfg, &mut out).unwrap();
392        }
393        // JSON mode — parse and verify only the ns-a row came back.
394        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
395        let mems = v["memories"].as_array().unwrap();
396        for m in mems {
397            assert_eq!(m["namespace"].as_str().unwrap(), "ns-a");
398        }
399    }
400
401    #[test]
402    fn test_recall_keyword_with_tags_filter() {
403        // tags filter takes a string; absence of tags on seeded rows
404        // means the filter excludes them. Just verify the call shape
405        // doesn't error when a tags filter is supplied.
406        let mut env = TestEnv::fresh();
407        let db = env.db_path.clone();
408        seed_memory(&db, "test", "needle title", "content");
409        let mut args = default_args();
410        args.tags = Some("nonexistent".to_string());
411        let cfg = AppConfig::default();
412        {
413            let mut out = env.output();
414            run(&db, &args, true, &cfg, &mut out).unwrap();
415        }
416        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
417        // No row has the "nonexistent" tag => 0 results.
418        assert_eq!(v["count"].as_u64().unwrap(), 0);
419    }
420
421    #[test]
422    fn test_recall_keyword_with_since_until_window() {
423        let mut env = TestEnv::fresh();
424        let db = env.db_path.clone();
425        seed_memory(&db, "test", "needle title", "content");
426        let mut args = default_args();
427        // A date range that excludes the just-now timestamp.
428        args.since = Some("1970-01-01T00:00:00Z".to_string());
429        args.until = Some("1970-01-02T00:00:00Z".to_string());
430        let cfg = AppConfig::default();
431        {
432            let mut out = env.output();
433            run(&db, &args, true, &cfg, &mut out).unwrap();
434        }
435        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
436        assert_eq!(v["count"].as_u64().unwrap(), 0);
437    }
438
439    #[test]
440    fn test_recall_with_as_agent_scope_filter() {
441        // --as-agent must validate as a namespace; passing a real
442        // namespace exercises the validation branch and succeeds.
443        let mut env = TestEnv::fresh();
444        let db = env.db_path.clone();
445        seed_memory(&db, "test", "needle title", "content");
446        let mut args = default_args();
447        args.as_agent = Some("test".to_string());
448        let cfg = AppConfig::default();
449        {
450            let mut out = env.output();
451            run(&db, &args, true, &cfg, &mut out).unwrap();
452        }
453        // No assertion error; JSON shape comes through.
454        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
455        assert!(v["memories"].is_array());
456    }
457
458    #[test]
459    fn test_recall_with_budget_tokens_caps_results() {
460        // budget_tokens flips through into recall(); JSON envelope
461        // includes the budget echo when set.
462        let mut env = TestEnv::fresh();
463        let db = env.db_path.clone();
464        seed_memory(&db, "test", "needle one", "content one");
465        seed_memory(&db, "test", "needle two", "content two");
466        let mut args = default_args();
467        args.budget_tokens = Some(64);
468        let cfg = AppConfig::default();
469        {
470            let mut out = env.output();
471            run(&db, &args, true, &cfg, &mut out).unwrap();
472        }
473        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
474        assert_eq!(v["budget_tokens"].as_u64().unwrap(), 64);
475    }
476
477    #[test]
478    fn test_recall_json_output_includes_score_mode_tokens() {
479        let mut env = TestEnv::fresh();
480        let db = env.db_path.clone();
481        seed_memory(&db, "test", "needle title", "haystack content");
482        let args = default_args();
483        let cfg = AppConfig::default();
484        {
485            let mut out = env.output();
486            run(&db, &args, true, &cfg, &mut out).unwrap();
487        }
488        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
489        assert_eq!(v["mode"].as_str().unwrap(), "keyword");
490        assert!(v["tokens_used"].is_number());
491        let mems = v["memories"].as_array().unwrap();
492        assert!(!mems.is_empty(), "expected at least one match");
493        for m in mems {
494            assert!(m["score"].is_number());
495        }
496    }
497
498    #[test]
499    fn test_recall_text_output_formats_correctly() {
500        let mut env = TestEnv::fresh();
501        let db = env.db_path.clone();
502        seed_memory(&db, "test-ns", "needle title", "haystack content");
503        let args = default_args();
504        let cfg = AppConfig::default();
505        {
506            let mut out = env.output();
507            run(&db, &args, false, &cfg, &mut out).unwrap();
508        }
509        let stdout = env.stdout_str();
510        // Header line: tier/short-id, title, score, namespace.
511        assert!(stdout.contains("needle title"));
512        assert!(stdout.contains("ns="));
513        assert!(stdout.contains("score="));
514        assert!(stdout.contains("memory(ies) recalled"));
515    }
516
517    #[test]
518    fn test_recall_invalid_as_agent_namespace_validation_error() {
519        let mut env = TestEnv::fresh();
520        let db = env.db_path.clone();
521        let mut args = default_args();
522        // Invalid namespace: empty after trimming, or contains illegal chars.
523        args.as_agent = Some(String::new());
524        let cfg = AppConfig::default();
525        let mut out = env.output();
526        let res = run(&db, &args, false, &cfg, &mut out);
527        assert!(res.is_err(), "expected validate_namespace to reject");
528    }
529
530    #[test]
531    fn test_recall_with_context_tokens_fusion() {
532        // With tier=keyword, no embedder is built, so the fusion path
533        // is skipped entirely and the call falls through the keyword
534        // branch. This proves the fall-through path exists when an
535        // embedder is absent. The actual fusion path requires a real
536        // embedder and is exercised under feature = "test-with-models".
537        let mut env = TestEnv::fresh();
538        let db = env.db_path.clone();
539        seed_memory(&db, "test", "needle title", "content");
540        let mut args = default_args();
541        args.context_tokens = Some(vec!["recent".to_string(), "talk".to_string()]);
542        let cfg = AppConfig::default();
543        {
544            let mut out = env.output();
545            run(&db, &args, true, &cfg, &mut out).unwrap();
546        }
547        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
548        assert_eq!(v["mode"].as_str().unwrap(), "keyword");
549    }
550
551    #[test]
552    fn test_recall_embedder_failure_falls_back_to_keyword() {
553        // Same shape as the no-embedder test, but routed through the
554        // build_embedder_for_recall path. Keyword tier => Ok(None) and
555        // no stderr emission about embedder failure.
556        let mut env = TestEnv::fresh();
557        let db = env.db_path.clone();
558        seed_memory(&db, "test", "needle title", "content");
559        let args = default_args();
560        let cfg = AppConfig::default();
561        {
562            let mut out = env.output();
563            run(&db, &args, true, &cfg, &mut out).unwrap();
564        }
565        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
566        assert_eq!(v["mode"].as_str().unwrap(), "keyword");
567        // No embedder messages on stderr in the keyword branch.
568        let stderr = env.stderr_str();
569        assert!(
570            !stderr.contains("embedder loaded"),
571            "no embedder should be loaded on keyword tier"
572        );
573    }
574
575    #[tokio::test]
576    async fn test_shared_build_embedder_keyword_returns_none() {
577        // W6 — recall now delegates embedder construction to
578        // `daemon_runtime::build_embedder`. Smoke-test that the keyword
579        // tier short-circuit still yields `None` (no model load attempt,
580        // no panic).
581        let cfg = AppConfig::default();
582        let res = daemon_runtime::build_embedder(FeatureTier::Keyword, &cfg).await;
583        assert!(res.is_none(), "keyword tier must not build an embedder");
584    }
585}