1use crate::cli::CliOutput;
13use crate::cli::helpers::{human_age, id_short};
14use crate::config::AppConfig;
15use crate::{color, daemon_runtime, db, embeddings, hnsw, reranker, validate};
16use anyhow::Result;
17use clap::Args;
18use std::path::Path;
19
20#[derive(Args)]
23pub struct RecallArgs {
24 #[arg(allow_hyphen_values = true)]
25 pub context: String,
26 #[arg(long, short)]
27 pub namespace: Option<String>,
28 #[arg(long, default_value_t = 10)]
29 pub limit: usize,
30 #[arg(long)]
31 pub tags: Option<String>,
32 #[arg(long)]
33 pub since: Option<String>,
34 #[arg(long)]
35 pub until: Option<String>,
36 #[arg(long, short = 'T')]
38 pub tier: Option<String>,
39 #[arg(long)]
42 pub as_agent: Option<String>,
43 #[arg(long)]
47 pub budget_tokens: Option<usize>,
48 #[arg(long, value_delimiter = ',')]
54 pub context_tokens: Option<Vec<String>>,
55}
56
57#[allow(clippy::too_many_lines)]
63pub fn run(
64 db_path: &Path,
65 args: &RecallArgs,
66 json_out: bool,
67 app_config: &AppConfig,
68 out: &mut CliOutput<'_>,
69) -> Result<()> {
70 if let Some(ref a) = args.as_agent {
72 validate::validate_namespace(a)?;
73 }
74 let conn = db::open(db_path)?;
75 let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
76
77 let feature_tier = app_config.effective_tier(args.tier.as_deref());
79 let tier_config = feature_tier.config();
80
81 let embedder = {
88 if let Ok(handle) = tokio::runtime::Handle::try_current() {
95 tokio::task::block_in_place(|| {
96 handle.block_on(daemon_runtime::build_embedder(feature_tier, app_config))
97 })
98 } else {
99 tokio::runtime::Builder::new_current_thread()
100 .enable_all()
101 .build()?
102 .block_on(daemon_runtime::build_embedder(feature_tier, app_config))
103 }
104 };
105 if let Some(ref emb) = embedder {
106 writeln!(
107 out.stderr,
108 "ai-memory: embedder loaded ({})",
109 emb.model_description()
110 )?;
111 } else if tier_config.embedding_model.is_some() {
112 writeln!(
113 out.stderr,
114 "ai-memory: embedder failed to load, falling back to keyword"
115 )?;
116 }
117
118 if let Some(ref emb) = embedder
120 && let Ok(unembedded) = db::get_unembedded_ids(&conn)
121 && !unembedded.is_empty()
122 {
123 writeln!(
124 out.stderr,
125 "ai-memory: backfilling {} memories...",
126 unembedded.len()
127 )?;
128 let mut ok = 0usize;
129 for (id, title, content) in &unembedded {
130 let text = format!("{title} {content}");
131 if let Ok(embedding) = emb.embed(&text)
132 && db::set_embedding(&conn, id, &embedding).is_ok()
133 {
134 ok += 1;
135 }
136 }
137 writeln!(
138 out.stderr,
139 "ai-memory: backfilled {}/{}",
140 ok,
141 unembedded.len()
142 )?;
143 }
144
145 let vector_index = if embedder.is_some() {
147 match db::get_all_embeddings(&conn) {
148 Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
149 _ => Some(hnsw::VectorIndex::empty()),
150 }
151 } else {
152 None
153 };
154
155 let reranker = if tier_config.cross_encoder {
157 Some(reranker::CrossEncoder::new_neural())
158 } else {
159 None
160 };
161
162 let resolved_ttl = app_config.effective_ttl();
163 let resolved_scoring = app_config.effective_scoring();
164
165 let (results, outcome, mode) = if let Some(ref emb) = embedder {
167 match emb.embed(&args.context) {
168 Ok(primary_emb) => {
169 let query_emb = match args.context_tokens.as_deref() {
175 Some(tokens) if !tokens.is_empty() => {
176 let joined = tokens.join(" ");
177 match emb.embed(&joined) {
178 Ok(ctx_emb) => embeddings::Embedder::fuse(&primary_emb, &ctx_emb, 0.7),
179 Err(e) => {
180 writeln!(
181 out.stderr,
182 "ai-memory: context_tokens embed failed: {e}, using primary only"
183 )?;
184 primary_emb
185 }
186 }
187 }
188 _ => primary_emb,
189 };
190 let (results, outcome) = db::recall_hybrid(
191 &conn,
192 &args.context,
193 &query_emb,
194 args.namespace.as_deref(),
195 args.limit.min(50),
196 args.tags.as_deref(),
197 args.since.as_deref(),
198 args.until.as_deref(),
199 vector_index.as_ref(),
200 resolved_ttl.short_extend_secs,
201 resolved_ttl.mid_extend_secs,
202 args.as_agent.as_deref(),
203 args.budget_tokens,
204 &resolved_scoring,
205 )?;
206 if let Some(ref ce) = reranker {
207 (ce.rerank(&args.context, results), outcome, "hybrid+rerank")
208 } else {
209 (results, outcome, "hybrid")
210 }
211 }
212 Err(e) => {
213 writeln!(
214 out.stderr,
215 "ai-memory: embedding query failed: {e}, falling back to keyword"
216 )?;
217 let (results, outcome) = db::recall(
218 &conn,
219 &args.context,
220 args.namespace.as_deref(),
221 args.limit,
222 args.tags.as_deref(),
223 args.since.as_deref(),
224 args.until.as_deref(),
225 resolved_ttl.short_extend_secs,
226 resolved_ttl.mid_extend_secs,
227 args.as_agent.as_deref(),
228 args.budget_tokens,
229 )?;
230 (results, outcome, "keyword")
231 }
232 }
233 } else {
234 let (results, outcome) = db::recall(
235 &conn,
236 &args.context,
237 args.namespace.as_deref(),
238 args.limit,
239 args.tags.as_deref(),
240 args.since.as_deref(),
241 args.until.as_deref(),
242 resolved_ttl.short_extend_secs,
243 resolved_ttl.mid_extend_secs,
244 args.as_agent.as_deref(),
245 args.budget_tokens,
246 )?;
247 (results, outcome, "keyword")
248 };
249
250 if json_out {
251 let scored: Vec<serde_json::Value> = results
252 .iter()
253 .map(|(m, s)| {
254 let mut v = serde_json::to_value(m).unwrap_or_default();
255 if let Some(obj) = v.as_object_mut() {
256 obj.insert(
257 "score".to_string(),
258 serde_json::json!((s * 1000.0).round() / 1000.0),
259 );
260 }
261 v
262 })
263 .collect();
264 let mut body = serde_json::json!({
265 "memories": scored,
266 "count": results.len(),
267 "mode": mode,
268 "tokens_used": outcome.tokens_used,
269 });
270 if let Some(b) = args.budget_tokens {
271 body["budget_tokens"] = serde_json::json!(b);
272 body["meta"] = serde_json::json!({
274 "budget_tokens_used": outcome.tokens_used,
275 "budget_tokens_remaining": outcome.tokens_remaining.unwrap_or(0),
276 "memories_dropped": outcome.memories_dropped,
277 "budget_overflow": outcome.budget_overflow,
278 });
279 }
280 writeln!(out.stdout, "{}", serde_json::to_string(&body)?)?;
281 return Ok(());
282 }
283 if results.is_empty() {
284 writeln!(out.stderr, "no memories found for: {}", args.context)?;
285 return Ok(());
286 }
287 for (mem, score) in &results {
288 let age = human_age(&mem.updated_at);
289 let config = if mem.confidence < 1.0 {
290 format!(" conf={:.0}%", mem.confidence * 100.0)
291 } else {
292 String::new()
293 };
294 writeln!(
295 out.stdout,
296 "[{}] {} {} score={:.2} (ns={}, {}x, {}{})",
297 color::tier_color(
298 mem.tier.as_str(),
299 &format!("{}/{}", mem.tier, id_short(&mem.id))
300 ),
301 color::bold(&mem.title),
302 color::priority_bar(mem.priority),
303 score,
304 color::cyan(&mem.namespace),
305 mem.access_count,
306 color::dim(&age),
307 config
308 )?;
309 let preview: String = mem.content.chars().take(200).collect();
310 writeln!(out.stdout, " {}\n", color::dim(&preview))?;
311 }
312 writeln!(
313 out.stdout,
314 "{} memory(ies) recalled [{}]",
315 results.len(),
316 mode
317 )?;
318 Ok(())
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::cli::test_utils::{TestEnv, seed_memory};
325 use crate::config::FeatureTier;
326
327 fn default_args() -> RecallArgs {
328 RecallArgs {
329 context: "needle".to_string(),
330 namespace: None,
331 limit: 10,
332 tags: None,
333 since: None,
334 until: None,
335 tier: Some("keyword".to_string()),
336 as_agent: None,
337 budget_tokens: None,
338 context_tokens: None,
339 }
340 }
341
342 #[test]
343 fn test_recall_keyword_tier_no_embedder() {
344 let mut env = TestEnv::fresh();
347 let db = env.db_path.clone();
348 seed_memory(&db, "test", "needle title", "haystack content");
349 let args = default_args();
350 let cfg = AppConfig::default();
351 {
352 let mut out = env.output();
353 run(&db, &args, false, &cfg, &mut out).unwrap();
354 }
355 let stdout = env.stdout_str();
356 assert!(stdout.contains("needle title"), "got: {stdout}");
357 assert!(stdout.contains("[keyword]"), "got: {stdout}");
358 }
359
360 #[test]
361 fn test_recall_keyword_empty_results() {
362 let mut env = TestEnv::fresh();
365 let db = env.db_path.clone();
366 let args = default_args();
367 let cfg = AppConfig::default();
368 {
369 let mut out = env.output();
370 run(&db, &args, false, &cfg, &mut out).unwrap();
371 }
372 assert_eq!(env.stdout_str(), "");
373 assert!(
374 env.stderr_str().contains("no memories found for: needle"),
375 "got: {}",
376 env.stderr_str()
377 );
378 }
379
380 #[test]
381 fn test_recall_keyword_with_namespace_filter() {
382 let mut env = TestEnv::fresh();
383 let db = env.db_path.clone();
384 seed_memory(&db, "ns-a", "needle in a", "content a");
385 seed_memory(&db, "ns-b", "needle in b", "content b");
386 let mut args = default_args();
387 args.namespace = Some("ns-a".to_string());
388 let cfg = AppConfig::default();
389 {
390 let mut out = env.output();
391 run(&db, &args, true, &cfg, &mut out).unwrap();
392 }
393 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
395 let mems = v["memories"].as_array().unwrap();
396 for m in mems {
397 assert_eq!(m["namespace"].as_str().unwrap(), "ns-a");
398 }
399 }
400
401 #[test]
402 fn test_recall_keyword_with_tags_filter() {
403 let mut env = TestEnv::fresh();
407 let db = env.db_path.clone();
408 seed_memory(&db, "test", "needle title", "content");
409 let mut args = default_args();
410 args.tags = Some("nonexistent".to_string());
411 let cfg = AppConfig::default();
412 {
413 let mut out = env.output();
414 run(&db, &args, true, &cfg, &mut out).unwrap();
415 }
416 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
417 assert_eq!(v["count"].as_u64().unwrap(), 0);
419 }
420
421 #[test]
422 fn test_recall_keyword_with_since_until_window() {
423 let mut env = TestEnv::fresh();
424 let db = env.db_path.clone();
425 seed_memory(&db, "test", "needle title", "content");
426 let mut args = default_args();
427 args.since = Some("1970-01-01T00:00:00Z".to_string());
429 args.until = Some("1970-01-02T00:00:00Z".to_string());
430 let cfg = AppConfig::default();
431 {
432 let mut out = env.output();
433 run(&db, &args, true, &cfg, &mut out).unwrap();
434 }
435 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
436 assert_eq!(v["count"].as_u64().unwrap(), 0);
437 }
438
439 #[test]
440 fn test_recall_with_as_agent_scope_filter() {
441 let mut env = TestEnv::fresh();
444 let db = env.db_path.clone();
445 seed_memory(&db, "test", "needle title", "content");
446 let mut args = default_args();
447 args.as_agent = Some("test".to_string());
448 let cfg = AppConfig::default();
449 {
450 let mut out = env.output();
451 run(&db, &args, true, &cfg, &mut out).unwrap();
452 }
453 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
455 assert!(v["memories"].is_array());
456 }
457
458 #[test]
459 fn test_recall_with_budget_tokens_caps_results() {
460 let mut env = TestEnv::fresh();
463 let db = env.db_path.clone();
464 seed_memory(&db, "test", "needle one", "content one");
465 seed_memory(&db, "test", "needle two", "content two");
466 let mut args = default_args();
467 args.budget_tokens = Some(64);
468 let cfg = AppConfig::default();
469 {
470 let mut out = env.output();
471 run(&db, &args, true, &cfg, &mut out).unwrap();
472 }
473 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
474 assert_eq!(v["budget_tokens"].as_u64().unwrap(), 64);
475 }
476
477 #[test]
478 fn test_recall_json_output_includes_score_mode_tokens() {
479 let mut env = TestEnv::fresh();
480 let db = env.db_path.clone();
481 seed_memory(&db, "test", "needle title", "haystack content");
482 let args = default_args();
483 let cfg = AppConfig::default();
484 {
485 let mut out = env.output();
486 run(&db, &args, true, &cfg, &mut out).unwrap();
487 }
488 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
489 assert_eq!(v["mode"].as_str().unwrap(), "keyword");
490 assert!(v["tokens_used"].is_number());
491 let mems = v["memories"].as_array().unwrap();
492 assert!(!mems.is_empty(), "expected at least one match");
493 for m in mems {
494 assert!(m["score"].is_number());
495 }
496 }
497
498 #[test]
499 fn test_recall_text_output_formats_correctly() {
500 let mut env = TestEnv::fresh();
501 let db = env.db_path.clone();
502 seed_memory(&db, "test-ns", "needle title", "haystack content");
503 let args = default_args();
504 let cfg = AppConfig::default();
505 {
506 let mut out = env.output();
507 run(&db, &args, false, &cfg, &mut out).unwrap();
508 }
509 let stdout = env.stdout_str();
510 assert!(stdout.contains("needle title"));
512 assert!(stdout.contains("ns="));
513 assert!(stdout.contains("score="));
514 assert!(stdout.contains("memory(ies) recalled"));
515 }
516
517 #[test]
518 fn test_recall_invalid_as_agent_namespace_validation_error() {
519 let mut env = TestEnv::fresh();
520 let db = env.db_path.clone();
521 let mut args = default_args();
522 args.as_agent = Some(String::new());
524 let cfg = AppConfig::default();
525 let mut out = env.output();
526 let res = run(&db, &args, false, &cfg, &mut out);
527 assert!(res.is_err(), "expected validate_namespace to reject");
528 }
529
530 #[test]
531 fn test_recall_with_context_tokens_fusion() {
532 let mut env = TestEnv::fresh();
538 let db = env.db_path.clone();
539 seed_memory(&db, "test", "needle title", "content");
540 let mut args = default_args();
541 args.context_tokens = Some(vec!["recent".to_string(), "talk".to_string()]);
542 let cfg = AppConfig::default();
543 {
544 let mut out = env.output();
545 run(&db, &args, true, &cfg, &mut out).unwrap();
546 }
547 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
548 assert_eq!(v["mode"].as_str().unwrap(), "keyword");
549 }
550
551 #[test]
552 fn test_recall_embedder_failure_falls_back_to_keyword() {
553 let mut env = TestEnv::fresh();
557 let db = env.db_path.clone();
558 seed_memory(&db, "test", "needle title", "content");
559 let args = default_args();
560 let cfg = AppConfig::default();
561 {
562 let mut out = env.output();
563 run(&db, &args, true, &cfg, &mut out).unwrap();
564 }
565 let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
566 assert_eq!(v["mode"].as_str().unwrap(), "keyword");
567 let stderr = env.stderr_str();
569 assert!(
570 !stderr.contains("embedder loaded"),
571 "no embedder should be loaded on keyword tier"
572 );
573 }
574
575 #[tokio::test]
576 async fn test_shared_build_embedder_keyword_returns_none() {
577 let cfg = AppConfig::default();
582 let res = daemon_runtime::build_embedder(FeatureTier::Keyword, &cfg).await;
583 assert!(res.is_none(), "keyword tier must not build an embedder");
584 }
585}