use crate::cli::CliOutput;
use crate::cli::helpers::{human_age, id_short};
use crate::config::AppConfig;
use crate::{color, daemon_runtime, db, embeddings, hnsw, reranker, validate};
use anyhow::Result;
use clap::Args;
use std::path::Path;
#[derive(Args)]
pub struct RecallArgs {
#[arg(allow_hyphen_values = true)]
pub context: String,
#[arg(long, short)]
pub namespace: Option<String>,
#[arg(long, default_value_t = 10)]
pub limit: usize,
#[arg(long)]
pub tags: Option<String>,
#[arg(long)]
pub since: Option<String>,
#[arg(long)]
pub until: Option<String>,
#[arg(long, short = 'T')]
pub tier: Option<String>,
#[arg(long)]
pub as_agent: Option<String>,
#[arg(long)]
pub budget_tokens: Option<usize>,
#[arg(long, value_delimiter = ',')]
pub context_tokens: Option<Vec<String>>,
}
#[allow(clippy::too_many_lines)]
pub fn run(
db_path: &Path,
args: &RecallArgs,
json_out: bool,
app_config: &AppConfig,
out: &mut CliOutput<'_>,
) -> Result<()> {
if let Some(ref a) = args.as_agent {
validate::validate_namespace(a)?;
}
let conn = db::open(db_path)?;
let _ = db::gc_if_needed(&conn, app_config.effective_archive_on_gc());
let feature_tier = app_config.effective_tier(args.tier.as_deref());
let tier_config = feature_tier.config();
let embedder = {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| {
handle.block_on(daemon_runtime::build_embedder(feature_tier, app_config))
})
} else {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?
.block_on(daemon_runtime::build_embedder(feature_tier, app_config))
}
};
if let Some(ref emb) = embedder {
writeln!(
out.stderr,
"ai-memory: embedder loaded ({})",
emb.model_description()
)?;
} else if tier_config.embedding_model.is_some() {
writeln!(
out.stderr,
"ai-memory: embedder failed to load, falling back to keyword"
)?;
}
if let Some(ref emb) = embedder
&& let Ok(unembedded) = db::get_unembedded_ids(&conn)
&& !unembedded.is_empty()
{
writeln!(
out.stderr,
"ai-memory: backfilling {} memories...",
unembedded.len()
)?;
let mut ok = 0usize;
for (id, title, content) in &unembedded {
let text = format!("{title} {content}");
if let Ok(embedding) = emb.embed(&text)
&& db::set_embedding(&conn, id, &embedding).is_ok()
{
ok += 1;
}
}
writeln!(
out.stderr,
"ai-memory: backfilled {}/{}",
ok,
unembedded.len()
)?;
}
let vector_index = if embedder.is_some() {
match db::get_all_embeddings(&conn) {
Ok(entries) if !entries.is_empty() => Some(hnsw::VectorIndex::build(entries)),
_ => Some(hnsw::VectorIndex::empty()),
}
} else {
None
};
let reranker = if tier_config.cross_encoder {
Some(reranker::CrossEncoder::new_neural())
} else {
None
};
let resolved_ttl = app_config.effective_ttl();
let resolved_scoring = app_config.effective_scoring();
let (results, outcome, mode) = if let Some(ref emb) = embedder {
match emb.embed(&args.context) {
Ok(primary_emb) => {
let query_emb = match args.context_tokens.as_deref() {
Some(tokens) if !tokens.is_empty() => {
let joined = tokens.join(" ");
match emb.embed(&joined) {
Ok(ctx_emb) => embeddings::Embedder::fuse(&primary_emb, &ctx_emb, 0.7),
Err(e) => {
writeln!(
out.stderr,
"ai-memory: context_tokens embed failed: {e}, using primary only"
)?;
primary_emb
}
}
}
_ => primary_emb,
};
let (results, outcome) = db::recall_hybrid(
&conn,
&args.context,
&query_emb,
args.namespace.as_deref(),
args.limit.min(50),
args.tags.as_deref(),
args.since.as_deref(),
args.until.as_deref(),
vector_index.as_ref(),
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
args.as_agent.as_deref(),
args.budget_tokens,
&resolved_scoring,
)?;
if let Some(ref ce) = reranker {
(ce.rerank(&args.context, results), outcome, "hybrid+rerank")
} else {
(results, outcome, "hybrid")
}
}
Err(e) => {
writeln!(
out.stderr,
"ai-memory: embedding query failed: {e}, falling back to keyword"
)?;
let (results, outcome) = db::recall(
&conn,
&args.context,
args.namespace.as_deref(),
args.limit,
args.tags.as_deref(),
args.since.as_deref(),
args.until.as_deref(),
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
args.as_agent.as_deref(),
args.budget_tokens,
)?;
(results, outcome, "keyword")
}
}
} else {
let (results, outcome) = db::recall(
&conn,
&args.context,
args.namespace.as_deref(),
args.limit,
args.tags.as_deref(),
args.since.as_deref(),
args.until.as_deref(),
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
args.as_agent.as_deref(),
args.budget_tokens,
)?;
(results, outcome, "keyword")
};
if json_out {
let scored: Vec<serde_json::Value> = results
.iter()
.map(|(m, s)| {
let mut v = serde_json::to_value(m).unwrap_or_default();
if let Some(obj) = v.as_object_mut() {
obj.insert(
"score".to_string(),
serde_json::json!((s * 1000.0).round() / 1000.0),
);
}
v
})
.collect();
let mut body = serde_json::json!({
"memories": scored,
"count": results.len(),
"mode": mode,
"tokens_used": outcome.tokens_used,
});
if let Some(b) = args.budget_tokens {
body["budget_tokens"] = serde_json::json!(b);
body["meta"] = serde_json::json!({
"budget_tokens_used": outcome.tokens_used,
"budget_tokens_remaining": outcome.tokens_remaining.unwrap_or(0),
"memories_dropped": outcome.memories_dropped,
"budget_overflow": outcome.budget_overflow,
});
}
writeln!(out.stdout, "{}", serde_json::to_string(&body)?)?;
return Ok(());
}
if results.is_empty() {
writeln!(out.stderr, "no memories found for: {}", args.context)?;
return Ok(());
}
for (mem, score) in &results {
let age = human_age(&mem.updated_at);
let config = if mem.confidence < 1.0 {
format!(" conf={:.0}%", mem.confidence * 100.0)
} else {
String::new()
};
writeln!(
out.stdout,
"[{}] {} {} score={:.2} (ns={}, {}x, {}{})",
color::tier_color(
mem.tier.as_str(),
&format!("{}/{}", mem.tier, id_short(&mem.id))
),
color::bold(&mem.title),
color::priority_bar(mem.priority),
score,
color::cyan(&mem.namespace),
mem.access_count,
color::dim(&age),
config
)?;
let preview: String = mem.content.chars().take(200).collect();
writeln!(out.stdout, " {}\n", color::dim(&preview))?;
}
writeln!(
out.stdout,
"{} memory(ies) recalled [{}]",
results.len(),
mode
)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cli::test_utils::{TestEnv, seed_memory};
use crate::config::FeatureTier;
fn default_args() -> RecallArgs {
RecallArgs {
context: "needle".to_string(),
namespace: None,
limit: 10,
tags: None,
since: None,
until: None,
tier: Some("keyword".to_string()),
as_agent: None,
budget_tokens: None,
context_tokens: None,
}
}
#[test]
fn test_recall_keyword_tier_no_embedder() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "haystack content");
let args = default_args();
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, false, &cfg, &mut out).unwrap();
}
let stdout = env.stdout_str();
assert!(stdout.contains("needle title"), "got: {stdout}");
assert!(stdout.contains("[keyword]"), "got: {stdout}");
}
#[test]
fn test_recall_keyword_empty_results() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
let args = default_args();
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, false, &cfg, &mut out).unwrap();
}
assert_eq!(env.stdout_str(), "");
assert!(
env.stderr_str().contains("no memories found for: needle"),
"got: {}",
env.stderr_str()
);
}
#[test]
fn test_recall_keyword_with_namespace_filter() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns-a", "needle in a", "content a");
seed_memory(&db, "ns-b", "needle in b", "content b");
let mut args = default_args();
args.namespace = Some("ns-a".to_string());
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
let mems = v["memories"].as_array().unwrap();
for m in mems {
assert_eq!(m["namespace"].as_str().unwrap(), "ns-a");
}
}
#[test]
fn test_recall_keyword_with_tags_filter() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "content");
let mut args = default_args();
args.tags = Some("nonexistent".to_string());
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["count"].as_u64().unwrap(), 0);
}
#[test]
fn test_recall_keyword_with_since_until_window() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "content");
let mut args = default_args();
args.since = Some("1970-01-01T00:00:00Z".to_string());
args.until = Some("1970-01-02T00:00:00Z".to_string());
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["count"].as_u64().unwrap(), 0);
}
#[test]
fn test_recall_with_as_agent_scope_filter() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "content");
let mut args = default_args();
args.as_agent = Some("test".to_string());
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert!(v["memories"].is_array());
}
#[test]
fn test_recall_with_budget_tokens_caps_results() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle one", "content one");
seed_memory(&db, "test", "needle two", "content two");
let mut args = default_args();
args.budget_tokens = Some(64);
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["budget_tokens"].as_u64().unwrap(), 64);
}
#[test]
fn test_recall_json_output_includes_score_mode_tokens() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "haystack content");
let args = default_args();
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["mode"].as_str().unwrap(), "keyword");
assert!(v["tokens_used"].is_number());
let mems = v["memories"].as_array().unwrap();
assert!(!mems.is_empty(), "expected at least one match");
for m in mems {
assert!(m["score"].is_number());
}
}
#[test]
fn test_recall_text_output_formats_correctly() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test-ns", "needle title", "haystack content");
let args = default_args();
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, false, &cfg, &mut out).unwrap();
}
let stdout = env.stdout_str();
assert!(stdout.contains("needle title"));
assert!(stdout.contains("ns="));
assert!(stdout.contains("score="));
assert!(stdout.contains("memory(ies) recalled"));
}
#[test]
fn test_recall_invalid_as_agent_namespace_validation_error() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
let mut args = default_args();
args.as_agent = Some(String::new());
let cfg = AppConfig::default();
let mut out = env.output();
let res = run(&db, &args, false, &cfg, &mut out);
assert!(res.is_err(), "expected validate_namespace to reject");
}
#[test]
fn test_recall_with_context_tokens_fusion() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "content");
let mut args = default_args();
args.context_tokens = Some(vec!["recent".to_string(), "talk".to_string()]);
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["mode"].as_str().unwrap(), "keyword");
}
#[test]
fn test_recall_embedder_failure_falls_back_to_keyword() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "test", "needle title", "content");
let args = default_args();
let cfg = AppConfig::default();
{
let mut out = env.output();
run(&db, &args, true, &cfg, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["mode"].as_str().unwrap(), "keyword");
let stderr = env.stderr_str();
assert!(
!stderr.contains("embedder loaded"),
"no embedder should be loaded on keyword tier"
);
}
#[tokio::test]
async fn test_shared_build_embedder_keyword_returns_none() {
let cfg = AppConfig::default();
let res = daemon_runtime::build_embedder(FeatureTier::Keyword, &cfg).await;
assert!(res.is_none(), "keyword tier must not build an embedder");
}
}