Skip to main content

ai_memory/cli/
recall.rs

1// Copyright 2026 AlphaOne LLC
2// SPDX-License-Identifier: Apache-2.0
3
4//! `cmd_recall` migration. See `cli::store` for the design pattern.
5//!
6//! W6 (v0.6.3) — embedder construction was unified into
7//! [`crate::daemon_runtime::build_embedder`]. Both `serve()` and this
8//! handler now call the same builder, killing the per-call-site
9//! duplication that the original W5b note flagged. The TestHelper that
10//! used to live here (`build_embedder_for_recall`) is gone.
11
12use crate::cli::CliOutput;
13use crate::cli::helpers::{human_age, id_short};
14use crate::config::AppConfig;
15use crate::{color, daemon_runtime, db, embeddings, hnsw, reranker, validate};
16use anyhow::Result;
17use clap::Args;
18use std::path::Path;
19
20/// Clap-derived arg shape for the `recall` subcommand. Definition moved
21/// from `main.rs` verbatim in W5b — fields and attrs unchanged.
22#[derive(Args)]
23pub struct RecallArgs {
24    #[arg(allow_hyphen_values = true)]
25    pub context: String,
26    #[arg(long, short)]
27    pub namespace: Option<String>,
28    #[arg(long, default_value_t = 10)]
29    pub limit: usize,
30    #[arg(long)]
31    pub tags: Option<String>,
32    #[arg(long)]
33    pub since: Option<String>,
34    #[arg(long)]
35    pub until: Option<String>,
36    /// Feature tier for recall: keyword, semantic, smart, autonomous
37    #[arg(long, short = 'T')]
38    pub tier: Option<String>,
39    /// Task 1.5: querying agent's namespace position. Enables scope-based
40    /// visibility filtering (private/team/unit/org/collective).
41    #[arg(long)]
42    pub as_agent: Option<String>,
43    /// Task 1.11: context-budget-aware recall. Return the top-ranked
44    /// memories whose cumulative estimated tokens fit within N. Omit
45    /// for unlimited (limit-based only).
46    #[arg(long)]
47    pub budget_tokens: Option<usize>,
48    /// v0.6.0.0 contextual recall. Comma-separated list of recent
49    /// conversation tokens used to bias the query embedding at 70/30
50    /// (primary/context). Shifts the recall towards memories that
51    /// match both the explicit query and the conversation's nearby
52    /// topics.
53    #[arg(long, value_delimiter = ',')]
54    pub context_tokens: Option<Vec<String>>,
55}
56
57/// `recall` handler. Mirrors `cmd_recall` from the pre-W5b `main.rs`
58/// verbatim except every emit routes through `out.stdout` / `out.stderr`
59/// instead of `println!` / `eprintln!`. The embedder is built via the
60/// shared [`crate::daemon_runtime::build_embedder`] helper so the offline
61/// recall path and the HTTP daemon use identical construction logic.
62#[allow(clippy::too_many_lines)]
63pub fn run(
64    db_path: &Path,
65    args: &RecallArgs,
66    json_out: bool,
67    app_config: &AppConfig,
68    out: &mut CliOutput<'_>,
69) -> Result<()> {
70    // #151: validate --as-agent namespace
71    if let Some(ref a) = args.as_agent {
72        validate::validate_namespace(a)?;
73    }
74    let conn = db::open(db_path)?;
75    let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
76
77    // Resolve feature tier
78    let feature_tier = app_config.effective_tier(args.tier.as_deref());
79    let tier_config = feature_tier.config();
80
81    // Initialize embedder if tier supports it. Use the shared builder so
82    // recall and the HTTP daemon agree on tier→embedder semantics
83    // (embed_url, model selection, error fallback). The shared builder
84    // is async; we drive it on a small inline runtime to keep `run()`
85    // sync. Tier=Keyword short-circuits inside the builder before any
86    // tokio work happens, so the runtime's only cost is the keyword path.
87    let embedder = {
88        // Bridge sync→async: build a single-threaded runtime just for
89        // this call. Cheap on the Keyword path (no tasks spawned), and
90        // safe because `run()` is itself called from `main.rs` which is
91        // already inside `#[tokio::main]` only when invoked through
92        // `daemon_runtime::run` — the inner runtime is never nested
93        // because we use `Handle::try_current()` to detect that case.
94        if let Ok(handle) = tokio::runtime::Handle::try_current() {
95            tokio::task::block_in_place(|| {
96                handle.block_on(daemon_runtime::build_embedder(feature_tier, app_config))
97            })
98        } else {
99            tokio::runtime::Builder::new_current_thread()
100                .enable_all()
101                .build()?
102                .block_on(daemon_runtime::build_embedder(feature_tier, app_config))
103        }
104    };
105    if let Some(ref emb) = embedder {
106        writeln!(
107            out.stderr,
108            "ai-memory: embedder loaded ({})",
109            emb.model_description()
110        )?;
111    } else if tier_config.embedding_model.is_some() {
112        writeln!(
113            out.stderr,
114            "ai-memory: embedder failed to load, falling back to keyword"
115        )?;
116    }
117
118    // Backfill embeddings for memories that don't have them
119    if let Some(ref emb) = embedder
120        && let Ok(unembedded) = db::get_unembedded_ids(&conn)
121        && !unembedded.is_empty()
122    {
123        writeln!(
124            out.stderr,
125            "ai-memory: backfilling {} memories...",
126            unembedded.len()
127        )?;
128        let mut ok = 0usize;
129        for (id, title, content) in &unembedded {
130            let text = format!("{title} {content}");
131            if let Ok(embedding) = emb.embed(&text)
132                && db::set_embedding(&conn, id, &embedding).is_ok()
133            {
134                ok += 1;
135            }
136        }
137        writeln!(
138            out.stderr,
139            "ai-memory: backfilled {}/{}",
140            ok,
141            unembedded.len()
142        )?;
143    }
144
145    // Build HNSW vector index if embedder is available
146    let vector_index = if embedder.is_some() {
147        match db::get_all_embeddings(&conn) {
148            Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
149            _ => Some(hnsw::VectorIndex::empty()),
150        }
151    } else {
152        None
153    };
154
155    // Initialize cross-encoder reranker for autonomous tier
156    let reranker = if tier_config.cross_encoder {
157        Some(reranker::CrossEncoder::new_neural())
158    } else {
159        None
160    };
161
162    let resolved_ttl = app_config.effective_ttl();
163    let resolved_scoring = app_config.effective_scoring();
164
165    // Perform recall: hybrid if embedder available, keyword otherwise
166    let (results, tokens_used, mode) = if let Some(ref emb) = embedder {
167        match emb.embed(&args.context) {
168            Ok(primary_emb) => {
169                // v0.6.0.0 contextual recall. Fuse the primary query
170                // embedding with an embedding over recent conversation
171                // tokens (caller-supplied) at 70/30. Fusion is done
172                // caller-side so recall_hybrid stays unaware of the bias —
173                // the vector it receives is the final query direction.
174                let query_emb = match args.context_tokens.as_deref() {
175                    Some(tokens) if !tokens.is_empty() => {
176                        let joined = tokens.join(" ");
177                        match emb.embed(&joined) {
178                            Ok(ctx_emb) => embeddings::Embedder::fuse(&primary_emb, &ctx_emb, 0.7),
179                            Err(e) => {
180                                writeln!(
181                                    out.stderr,
182                                    "ai-memory: context_tokens embed failed: {e}, using primary only"
183                                )?;
184                                primary_emb
185                            }
186                        }
187                    }
188                    _ => primary_emb,
189                };
190                let (results, tokens_used) = db::recall_hybrid(
191                    &conn,
192                    &args.context,
193                    &query_emb,
194                    args.namespace.as_deref(),
195                    args.limit.min(50),
196                    args.tags.as_deref(),
197                    args.since.as_deref(),
198                    args.until.as_deref(),
199                    vector_index.as_ref(),
200                    resolved_ttl.short_extend_secs,
201                    resolved_ttl.mid_extend_secs,
202                    args.as_agent.as_deref(),
203                    args.budget_tokens,
204                    &resolved_scoring,
205                )?;
206                if let Some(ref ce) = reranker {
207                    (
208                        ce.rerank(&args.context, results),
209                        tokens_used,
210                        "hybrid+rerank",
211                    )
212                } else {
213                    (results, tokens_used, "hybrid")
214                }
215            }
216            Err(e) => {
217                writeln!(
218                    out.stderr,
219                    "ai-memory: embedding query failed: {e}, falling back to keyword"
220                )?;
221                let (results, tokens_used) = db::recall(
222                    &conn,
223                    &args.context,
224                    args.namespace.as_deref(),
225                    args.limit,
226                    args.tags.as_deref(),
227                    args.since.as_deref(),
228                    args.until.as_deref(),
229                    resolved_ttl.short_extend_secs,
230                    resolved_ttl.mid_extend_secs,
231                    args.as_agent.as_deref(),
232                    args.budget_tokens,
233                )?;
234                (results, tokens_used, "keyword")
235            }
236        }
237    } else {
238        let (results, tokens_used) = db::recall(
239            &conn,
240            &args.context,
241            args.namespace.as_deref(),
242            args.limit,
243            args.tags.as_deref(),
244            args.since.as_deref(),
245            args.until.as_deref(),
246            resolved_ttl.short_extend_secs,
247            resolved_ttl.mid_extend_secs,
248            args.as_agent.as_deref(),
249            args.budget_tokens,
250        )?;
251        (results, tokens_used, "keyword")
252    };
253
254    if json_out {
255        let scored: Vec<serde_json::Value> = results
256            .iter()
257            .map(|(m, s)| {
258                let mut v = serde_json::to_value(m).unwrap_or_default();
259                if let Some(obj) = v.as_object_mut() {
260                    obj.insert(
261                        "score".to_string(),
262                        serde_json::json!((s * 1000.0).round() / 1000.0),
263                    );
264                }
265                v
266            })
267            .collect();
268        let mut body = serde_json::json!({
269            "memories": scored,
270            "count": results.len(),
271            "mode": mode,
272            "tokens_used": tokens_used,
273        });
274        if let Some(b) = args.budget_tokens {
275            body["budget_tokens"] = serde_json::json!(b);
276        }
277        writeln!(out.stdout, "{}", serde_json::to_string(&body)?)?;
278        return Ok(());
279    }
280    if results.is_empty() {
281        writeln!(out.stderr, "no memories found for: {}", args.context)?;
282        return Ok(());
283    }
284    for (mem, score) in &results {
285        let age = human_age(&mem.updated_at);
286        let config = if mem.confidence < 1.0 {
287            format!(" conf={:.0}%", mem.confidence * 100.0)
288        } else {
289            String::new()
290        };
291        writeln!(
292            out.stdout,
293            "[{}] {} {} score={:.2} (ns={}, {}x, {}{})",
294            color::tier_color(
295                mem.tier.as_str(),
296                &format!("{}/{}", mem.tier, id_short(&mem.id))
297            ),
298            color::bold(&mem.title),
299            color::priority_bar(mem.priority),
300            score,
301            color::cyan(&mem.namespace),
302            mem.access_count,
303            color::dim(&age),
304            config
305        )?;
306        let preview: String = mem.content.chars().take(200).collect();
307        writeln!(out.stdout, "  {}\n", color::dim(&preview))?;
308    }
309    writeln!(
310        out.stdout,
311        "{} memory(ies) recalled [{}]",
312        results.len(),
313        mode
314    )?;
315    Ok(())
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::cli::test_utils::{TestEnv, seed_memory};
322    use crate::config::FeatureTier;
323
324    fn default_args() -> RecallArgs {
325        RecallArgs {
326            context: "needle".to_string(),
327            namespace: None,
328            limit: 10,
329            tags: None,
330            since: None,
331            until: None,
332            tier: Some("keyword".to_string()),
333            as_agent: None,
334            budget_tokens: None,
335            context_tokens: None,
336        }
337    }
338
339    #[test]
340    fn test_recall_keyword_tier_no_embedder() {
341        // Keyword tier => no embedder; the keyword branch must run
342        // happily and find the seeded title.
343        let mut env = TestEnv::fresh();
344        let db = env.db_path.clone();
345        seed_memory(&db, "test", "needle title", "haystack content");
346        let args = default_args();
347        let cfg = AppConfig::default();
348        {
349            let mut out = env.output();
350            run(&db, &args, false, &cfg, &mut out).unwrap();
351        }
352        let stdout = env.stdout_str();
353        assert!(stdout.contains("needle title"), "got: {stdout}");
354        assert!(stdout.contains("[keyword]"), "got: {stdout}");
355    }
356
357    #[test]
358    fn test_recall_keyword_empty_results() {
359        // No seeded rows => empty results => stderr emits "no memories
360        // found for: ..." and stdout stays empty (text mode).
361        let mut env = TestEnv::fresh();
362        let db = env.db_path.clone();
363        let args = default_args();
364        let cfg = AppConfig::default();
365        {
366            let mut out = env.output();
367            run(&db, &args, false, &cfg, &mut out).unwrap();
368        }
369        assert_eq!(env.stdout_str(), "");
370        assert!(
371            env.stderr_str().contains("no memories found for: needle"),
372            "got: {}",
373            env.stderr_str()
374        );
375    }
376
377    #[test]
378    fn test_recall_keyword_with_namespace_filter() {
379        let mut env = TestEnv::fresh();
380        let db = env.db_path.clone();
381        seed_memory(&db, "ns-a", "needle in a", "content a");
382        seed_memory(&db, "ns-b", "needle in b", "content b");
383        let mut args = default_args();
384        args.namespace = Some("ns-a".to_string());
385        let cfg = AppConfig::default();
386        {
387            let mut out = env.output();
388            run(&db, &args, true, &cfg, &mut out).unwrap();
389        }
390        // JSON mode — parse and verify only the ns-a row came back.
391        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
392        let mems = v["memories"].as_array().unwrap();
393        for m in mems {
394            assert_eq!(m["namespace"].as_str().unwrap(), "ns-a");
395        }
396    }
397
398    #[test]
399    fn test_recall_keyword_with_tags_filter() {
400        // tags filter takes a string; absence of tags on seeded rows
401        // means the filter excludes them. Just verify the call shape
402        // doesn't error when a tags filter is supplied.
403        let mut env = TestEnv::fresh();
404        let db = env.db_path.clone();
405        seed_memory(&db, "test", "needle title", "content");
406        let mut args = default_args();
407        args.tags = Some("nonexistent".to_string());
408        let cfg = AppConfig::default();
409        {
410            let mut out = env.output();
411            run(&db, &args, true, &cfg, &mut out).unwrap();
412        }
413        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
414        // No row has the "nonexistent" tag => 0 results.
415        assert_eq!(v["count"].as_u64().unwrap(), 0);
416    }
417
418    #[test]
419    fn test_recall_keyword_with_since_until_window() {
420        let mut env = TestEnv::fresh();
421        let db = env.db_path.clone();
422        seed_memory(&db, "test", "needle title", "content");
423        let mut args = default_args();
424        // A date range that excludes the just-now timestamp.
425        args.since = Some("1970-01-01T00:00:00Z".to_string());
426        args.until = Some("1970-01-02T00:00:00Z".to_string());
427        let cfg = AppConfig::default();
428        {
429            let mut out = env.output();
430            run(&db, &args, true, &cfg, &mut out).unwrap();
431        }
432        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
433        assert_eq!(v["count"].as_u64().unwrap(), 0);
434    }
435
436    #[test]
437    fn test_recall_with_as_agent_scope_filter() {
438        // --as-agent must validate as a namespace; passing a real
439        // namespace exercises the validation branch and succeeds.
440        let mut env = TestEnv::fresh();
441        let db = env.db_path.clone();
442        seed_memory(&db, "test", "needle title", "content");
443        let mut args = default_args();
444        args.as_agent = Some("test".to_string());
445        let cfg = AppConfig::default();
446        {
447            let mut out = env.output();
448            run(&db, &args, true, &cfg, &mut out).unwrap();
449        }
450        // No assertion error; JSON shape comes through.
451        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
452        assert!(v["memories"].is_array());
453    }
454
455    #[test]
456    fn test_recall_with_budget_tokens_caps_results() {
457        // budget_tokens flips through into recall(); JSON envelope
458        // includes the budget echo when set.
459        let mut env = TestEnv::fresh();
460        let db = env.db_path.clone();
461        seed_memory(&db, "test", "needle one", "content one");
462        seed_memory(&db, "test", "needle two", "content two");
463        let mut args = default_args();
464        args.budget_tokens = Some(64);
465        let cfg = AppConfig::default();
466        {
467            let mut out = env.output();
468            run(&db, &args, true, &cfg, &mut out).unwrap();
469        }
470        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
471        assert_eq!(v["budget_tokens"].as_u64().unwrap(), 64);
472    }
473
474    #[test]
475    fn test_recall_json_output_includes_score_mode_tokens() {
476        let mut env = TestEnv::fresh();
477        let db = env.db_path.clone();
478        seed_memory(&db, "test", "needle title", "haystack content");
479        let args = default_args();
480        let cfg = AppConfig::default();
481        {
482            let mut out = env.output();
483            run(&db, &args, true, &cfg, &mut out).unwrap();
484        }
485        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
486        assert_eq!(v["mode"].as_str().unwrap(), "keyword");
487        assert!(v["tokens_used"].is_number());
488        let mems = v["memories"].as_array().unwrap();
489        assert!(!mems.is_empty(), "expected at least one match");
490        for m in mems {
491            assert!(m["score"].is_number());
492        }
493    }
494
495    #[test]
496    fn test_recall_text_output_formats_correctly() {
497        let mut env = TestEnv::fresh();
498        let db = env.db_path.clone();
499        seed_memory(&db, "test-ns", "needle title", "haystack content");
500        let args = default_args();
501        let cfg = AppConfig::default();
502        {
503            let mut out = env.output();
504            run(&db, &args, false, &cfg, &mut out).unwrap();
505        }
506        let stdout = env.stdout_str();
507        // Header line: tier/short-id, title, score, namespace.
508        assert!(stdout.contains("needle title"));
509        assert!(stdout.contains("ns="));
510        assert!(stdout.contains("score="));
511        assert!(stdout.contains("memory(ies) recalled"));
512    }
513
514    #[test]
515    fn test_recall_invalid_as_agent_namespace_validation_error() {
516        let mut env = TestEnv::fresh();
517        let db = env.db_path.clone();
518        let mut args = default_args();
519        // Invalid namespace: empty after trimming, or contains illegal chars.
520        args.as_agent = Some(String::new());
521        let cfg = AppConfig::default();
522        let mut out = env.output();
523        let res = run(&db, &args, false, &cfg, &mut out);
524        assert!(res.is_err(), "expected validate_namespace to reject");
525    }
526
527    #[test]
528    fn test_recall_with_context_tokens_fusion() {
529        // With tier=keyword, no embedder is built, so the fusion path
530        // is skipped entirely and the call falls through the keyword
531        // branch. This proves the fall-through path exists when an
532        // embedder is absent. The actual fusion path requires a real
533        // embedder and is exercised under feature = "test-with-models".
534        let mut env = TestEnv::fresh();
535        let db = env.db_path.clone();
536        seed_memory(&db, "test", "needle title", "content");
537        let mut args = default_args();
538        args.context_tokens = Some(vec!["recent".to_string(), "talk".to_string()]);
539        let cfg = AppConfig::default();
540        {
541            let mut out = env.output();
542            run(&db, &args, true, &cfg, &mut out).unwrap();
543        }
544        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
545        assert_eq!(v["mode"].as_str().unwrap(), "keyword");
546    }
547
548    #[test]
549    fn test_recall_embedder_failure_falls_back_to_keyword() {
550        // Same shape as the no-embedder test, but routed through the
551        // build_embedder_for_recall path. Keyword tier => Ok(None) and
552        // no stderr emission about embedder failure.
553        let mut env = TestEnv::fresh();
554        let db = env.db_path.clone();
555        seed_memory(&db, "test", "needle title", "content");
556        let args = default_args();
557        let cfg = AppConfig::default();
558        {
559            let mut out = env.output();
560            run(&db, &args, true, &cfg, &mut out).unwrap();
561        }
562        let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
563        assert_eq!(v["mode"].as_str().unwrap(), "keyword");
564        // No embedder messages on stderr in the keyword branch.
565        let stderr = env.stderr_str();
566        assert!(
567            !stderr.contains("embedder loaded"),
568            "no embedder should be loaded on keyword tier"
569        );
570    }
571
572    #[tokio::test]
573    async fn test_shared_build_embedder_keyword_returns_none() {
574        // W6 — recall now delegates embedder construction to
575        // `daemon_runtime::build_embedder`. Smoke-test that the keyword
576        // tier short-circuit still yields `None` (no model load attempt,
577        // no panic).
578        let cfg = AppConfig::default();
579        let res = daemon_runtime::build_embedder(FeatureTier::Keyword, &cfg).await;
580        assert!(res.is_none(), "keyword tier must not build an embedder");
581    }
582}