use crate::embeddings::Embed;
use crate::hnsw::VectorIndex;
use crate::mcp::param_names;
use crate::mcp::registry::McpTool;
use crate::models::{
AttestLevel, CandidateCounts, ConfidenceTier, Memory, MemoryKind, RecallMeta, RecallTelemetry,
};
use crate::observations;
use crate::reranker::BatchedReranker;
use crate::{db, validate};
use serde_json::{Value, json};
#[allow(unused_imports)]
pub use crate::models::recall_request::{KindsFilter, RecallRequest};
#[allow(dead_code)]
pub struct RecallTool;
impl McpTool for RecallTool {
fn name() -> &'static str {
crate::mcp::registry::tool_names::MEMORY_RECALL
}
fn description() -> &'static str {
"Recall memories relevant to a context (ranked)."
}
fn docs() -> &'static str {
"Fuzzy OR recall ranked by relevance + priority + access + tier. Optional: budget_tokens (cl100k cap), context_tokens (query-embed bias), session_id (+0.05 recency boost per #518), session_default (splice [agents.defaults.recall_scope]), include_archived, kinds filter. Default format toon_compact (~79% smaller)."
}
fn input_schema() -> Value {
crate::mcp::registry::input_schema_for::<RecallRequest>()
}
fn family() -> &'static str {
crate::profile::Family::Core.name()
}
}
#[cfg(test)]
mod d1_3_984_tests {
use super::*;
fn legacy_props(tool_name: &str) -> serde_json::Map<String, Value> {
let defs = crate::mcp::registry::tool_definitions();
let tools = defs
.get("tools")
.and_then(Value::as_array)
.expect("tool_definitions emits `tools` array");
let entry = tools
.iter()
.find(|t| t.get("name").and_then(Value::as_str) == Some(tool_name))
.unwrap_or_else(|| panic!("{tool_name} must be in legacy catalog"));
entry
.pointer("/inputSchema/properties")
.and_then(Value::as_object)
.unwrap_or_else(|| panic!("{tool_name}.inputSchema.properties must be object"))
.clone()
}
fn derived_props_for<T: schemars::JsonSchema>() -> serde_json::Map<String, Value> {
let schema = schemars::schema_for!(T);
let v = serde_json::to_value(schema).expect("schema → value");
v.get("properties")
.and_then(Value::as_object)
.or_else(|| {
v.pointer(&format!(
"/definitions/{}/properties",
std::any::type_name::<T>().rsplit("::").next().unwrap_or("")
))
.and_then(Value::as_object)
})
.cloned()
.expect("schemars schema must have properties at a known path")
}
fn assert_property_set_parity(tool_name: &str, derived: &serde_json::Map<String, Value>) {
let legacy = legacy_props(tool_name);
let legacy_keys: std::collections::BTreeSet<&str> =
legacy.keys().map(String::as_str).collect();
let derived_keys: std::collections::BTreeSet<&str> =
derived.keys().map(String::as_str).collect();
assert_eq!(
legacy_keys,
derived_keys,
"{tool_name}: property set drift; diff = {:?}",
legacy_keys
.symmetric_difference(&derived_keys)
.collect::<Vec<_>>()
);
}
fn assert_descriptions_match(tool_name: &str, derived: &serde_json::Map<String, Value>) {
let legacy = legacy_props(tool_name);
for (name, legacy_prop) in &legacy {
if let Some(want) = legacy_prop.get("description").and_then(Value::as_str) {
let got = derived
.get(name)
.and_then(|p| p.get("description"))
.and_then(Value::as_str);
assert_eq!(
got,
Some(want),
"{tool_name}.{name}: description must match legacy byte-for-byte"
);
}
}
}
#[test]
fn recall_parity_984() {
let derived = derived_props_for::<RecallRequest>();
assert_property_set_parity("memory_recall", &derived);
assert_descriptions_match("memory_recall", &derived);
}
#[test]
fn recall_tool_metadata_984() {
assert_eq!(RecallTool::name(), "memory_recall");
assert_eq!(RecallTool::family(), "core");
}
}
fn apply_kinds_filter(
results: Vec<(Memory, f64)>,
kinds: Option<&[MemoryKind]>,
) -> Vec<(Memory, f64)> {
match kinds {
None => results,
Some(allowed) => results
.into_iter()
.filter(|(m, _)| allowed.contains(&m.memory_kind))
.collect(),
}
}
pub async fn handle_recall_with_pre_recall_hook(
conn: &rusqlite::Connection,
params: &Value,
embedder: Option<&dyn Embed>,
vector_index: Option<&VectorIndex>,
reranker: Option<&BatchedReranker>,
archive_on_gc: bool,
resolved_ttl: &crate::config::ResolvedTtl,
resolved_scoring: &crate::config::ResolvedScoring,
chain: &crate::hooks::HookChain,
registry: &mut crate::hooks::ExecutorRegistry,
recall_scope: Option<&crate::config::RecallScope>,
caller: Option<&str>,
) -> Result<Value, String> {
let context = params["context"]
.as_str()
.ok_or(crate::errors::msg::CONTEXT_REQUIRED)?;
let namespace = params["namespace"].as_str().unwrap_or("");
let k = u32::try_from(params["limit"].as_u64().unwrap_or(10)).unwrap_or(u32::MAX);
let outcome =
crate::hooks::apply_pre_recall_expand(context, namespace, k, chain, registry).await;
if let crate::hooks::PreRecallOutcome::Denied { reason, code } = &outcome {
let mut resp = json!({
"memories": [],
"count": 0,
"mode": "denied_by_hook",
});
let meta = resp
.as_object_mut()
.expect("recall response is always a JSON object")
.entry("meta".to_string())
.or_insert_with(|| json!({}));
meta["diagnostic"] = json!({
"pre_recall_denied": {
"reason": reason,
"code": code,
}
});
return Ok(resp);
}
let mut effective = params.clone();
if let crate::hooks::PreRecallOutcome::Modified {
query: q,
namespace: ns,
k: nk,
} = outcome
{
if let Some(obj) = effective.as_object_mut() {
obj.insert("context".to_string(), json!(q));
if !ns.is_empty() {
obj.insert("namespace".to_string(), json!(ns));
}
obj.insert("limit".to_string(), json!(u64::from(nk)));
}
}
handle_recall_caller(
conn,
&effective,
embedder,
vector_index,
reranker,
archive_on_gc,
resolved_ttl,
resolved_scoring,
recall_scope,
caller,
)
}
pub(crate) fn decorate_memory(
mem: &Memory,
score: f64,
verbose_provenance: bool,
conn: &rusqlite::Connection,
) -> Value {
let mut val = serde_json::to_value(mem).unwrap_or_default();
let Some(obj) = val.as_object_mut() else {
return val;
};
obj.insert(
"score".to_string(),
json!(
(score * crate::SCORE_DISPLAY_ROUND_FACTOR).round() / crate::SCORE_DISPLAY_ROUND_FACTOR
),
);
if !verbose_provenance {
return val;
}
obj.insert(
"confidence_tier".to_string(),
json!(mem.confidence_tier().as_str()),
);
obj.insert("freshness_state".to_string(), json!(freshness_state(mem)));
let latest_attest = latest_link_attest_level(conn, &mem.id);
if let Some(level) = latest_attest {
obj.insert("latest_link_attest_level".to_string(), json!(level));
}
val
}
pub(crate) fn freshness_state(mem: &Memory) -> &'static str {
let now = chrono::Utc::now();
if let Some(exp) = mem.expires_at.as_deref()
&& let Ok(dt) = chrono::DateTime::parse_from_rfc3339(exp)
&& dt < now
{
return "expired";
}
let last = mem.last_accessed_at.as_deref().unwrap_or(&mem.created_at);
let Ok(last_dt) = chrono::DateTime::parse_from_rfc3339(last) else {
return "warm";
};
let age_days = (now - last_dt.with_timezone(&chrono::Utc)).num_days();
if age_days > 30 {
"stale"
} else if age_days < 1 && mem.access_count == 0 {
"fresh"
} else {
"warm"
}
}
pub(crate) fn latest_link_attest_level(
conn: &rusqlite::Connection,
memory_id: &str,
) -> Option<String> {
let links = db::get_links(conn, memory_id).ok()?;
let mut best: Option<AttestLevel> = None;
for link in &links {
let Some(level_str) = link.attest_level.as_deref() else {
continue;
};
let Some(level) = AttestLevel::from_str(level_str) else {
continue;
};
let candidate_rank = attest_rank(level);
match best {
None => best = Some(level),
Some(curr) if candidate_rank > attest_rank(curr) => best = Some(level),
_ => {}
}
}
best.map(|l| l.as_str().to_string())
}
const fn attest_rank(level: AttestLevel) -> u8 {
match level {
AttestLevel::Unsigned => 0,
AttestLevel::SelfSigned | AttestLevel::DaemonSigned => 1,
AttestLevel::PeerAttested | AttestLevel::SignedByPeer => 2,
}
}
pub(crate) fn latest_link_attest_level_many(
conn: &rusqlite::Connection,
ids: &[&str],
) -> std::collections::HashMap<String, String> {
let mut out: std::collections::HashMap<String, String> = std::collections::HashMap::new();
if ids.is_empty() {
return out;
}
const CHUNK: usize = 250;
let mut best_by_id: std::collections::HashMap<String, AttestLevel> =
std::collections::HashMap::new();
for chunk in ids.chunks(CHUNK) {
let placeholders = std::iter::repeat("?")
.take(chunk.len())
.collect::<Vec<_>>()
.join(",");
let sql = format!(
"SELECT source_id, target_id, attest_level \
FROM memory_links \
WHERE source_id IN ({placeholders}) OR target_id IN ({placeholders})"
);
let mut params: Vec<&str> = Vec::with_capacity(chunk.len() * 2);
params.extend_from_slice(chunk);
params.extend_from_slice(chunk);
let Ok(mut stmt) = conn.prepare(&sql) else {
return out;
};
let Ok(rows) = stmt.query_map(rusqlite::params_from_iter(params.iter()), |row| {
let source_id: String = row.get(0)?;
let target_id: String = row.get(1)?;
let level: Option<String> = row.get(2)?;
Ok((source_id, target_id, level))
}) else {
return out;
};
let in_batch: std::collections::HashSet<&str> = chunk.iter().copied().collect();
for r in rows {
let Ok((source_id, target_id, level_opt)) = r else {
continue;
};
let Some(level_str) = level_opt else { continue };
let Some(level) = AttestLevel::from_str(&level_str) else {
continue;
};
let rank = attest_rank(level);
for endpoint in [&source_id, &target_id] {
if !in_batch.contains(endpoint.as_str()) {
continue;
}
match best_by_id.get(endpoint) {
None => {
best_by_id.insert(endpoint.clone(), level);
}
Some(curr) if rank > attest_rank(*curr) => {
best_by_id.insert(endpoint.clone(), level);
}
_ => {}
}
}
}
}
for (id, level) in best_by_id {
out.insert(id, level.as_str().to_string());
}
out
}
pub fn decorate_memory_many(
rows: &[(Memory, f64)],
verbose_provenance: bool,
conn: &rusqlite::Connection,
) -> Vec<Value> {
if !verbose_provenance {
return rows
.iter()
.map(|(mem, score)| {
let mut val = serde_json::to_value(mem).unwrap_or_default();
if let Some(obj) = val.as_object_mut() {
obj.insert(
"score".to_string(),
json!(
(score * crate::SCORE_DISPLAY_ROUND_FACTOR).round()
/ crate::SCORE_DISPLAY_ROUND_FACTOR
),
);
}
val
})
.collect();
}
let ids: Vec<&str> = rows.iter().map(|(m, _)| m.id.as_str()).collect();
let attest_map = latest_link_attest_level_many(conn, &ids);
rows.iter()
.map(|(mem, score)| {
let mut val = serde_json::to_value(mem).unwrap_or_default();
let Some(obj) = val.as_object_mut() else {
return val;
};
obj.insert(
"score".to_string(),
json!(
(score * crate::SCORE_DISPLAY_ROUND_FACTOR).round()
/ crate::SCORE_DISPLAY_ROUND_FACTOR
),
);
obj.insert(
"confidence_tier".to_string(),
json!(mem.confidence_tier().as_str()),
);
obj.insert("freshness_state".to_string(), json!(freshness_state(mem)));
if let Some(level) = attest_map.get(&mem.id) {
obj.insert("latest_link_attest_level".to_string(), json!(level));
}
val
})
.collect()
}
fn record_recall_observations(
conn: &rusqlite::Connection,
recall_id: &str,
memories_json: &[Value],
retriever: &str,
) {
if !observations::table_exists(conn) {
return;
}
let mut candidates: Vec<observations::Candidate<'_>> = Vec::with_capacity(memories_json.len());
let mut id_holders: Vec<&str> = Vec::with_capacity(memories_json.len());
for (idx, m) in memories_json.iter().enumerate() {
if let Some(id) = m.get(param_names::ID).and_then(Value::as_str) {
id_holders.push(id);
let score = m.get("score").and_then(Value::as_f64).unwrap_or(0.0);
#[allow(clippy::cast_possible_wrap)]
let rank = (idx + 1) as i64;
candidates.push(observations::Candidate {
memory_id: id_holders
.last()
.copied()
.expect("just pushed id_holders above"),
retriever,
rank,
score,
});
}
}
if let Err(e) = observations::record_recall(conn, recall_id, &candidates) {
tracing::warn!(
target: "observations",
recall_id = %recall_id,
"record_recall failed (non-fatal): {e}"
);
}
}
#[allow(clippy::too_many_arguments)]
pub fn handle_recall(
conn: &rusqlite::Connection,
params: &Value,
embedder: Option<&dyn Embed>,
vector_index: Option<&VectorIndex>,
reranker: Option<&BatchedReranker>,
archive_on_gc: bool,
resolved_ttl: &crate::config::ResolvedTtl,
resolved_scoring: &crate::config::ResolvedScoring,
recall_scope: Option<&crate::config::RecallScope>,
) -> Result<Value, String> {
handle_recall_caller(
conn,
params,
embedder,
vector_index,
reranker,
archive_on_gc,
resolved_ttl,
resolved_scoring,
recall_scope,
None,
)
}
#[allow(clippy::too_many_arguments)]
pub fn handle_recall_caller(
conn: &rusqlite::Connection,
params: &Value,
embedder: Option<&dyn Embed>,
vector_index: Option<&VectorIndex>,
reranker: Option<&BatchedReranker>,
archive_on_gc: bool,
resolved_ttl: &crate::config::ResolvedTtl,
resolved_scoring: &crate::config::ResolvedScoring,
recall_scope: Option<&crate::config::RecallScope>,
caller: Option<&str>,
) -> Result<Value, String> {
let req = RecallRequest::from_mcp_params(params)?;
handle_recall_dto(
conn,
&req,
embedder,
vector_index,
reranker,
archive_on_gc,
resolved_ttl,
resolved_scoring,
recall_scope,
caller,
)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::too_many_lines)]
pub fn handle_recall_dto(
conn: &rusqlite::Connection,
req: &RecallRequest,
embedder: Option<&dyn Embed>,
vector_index: Option<&VectorIndex>,
reranker: Option<&BatchedReranker>,
archive_on_gc: bool,
resolved_ttl: &crate::config::ResolvedTtl,
resolved_scoring: &crate::config::ResolvedScoring,
recall_scope: Option<&crate::config::RecallScope>,
caller: Option<&str>,
) -> Result<Value, String> {
let verbose_provenance = req.verbose_provenance.unwrap_or(true);
let recall_id = uuid::Uuid::new_v4().to_string();
let confidence_tier_filter: Option<ConfidenceTier> = req
.confidence_tier
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty())
.and_then(ConfidenceTier::parse);
let scored_memories =
|results: Vec<(Memory, f64)>, conn: &rusqlite::Connection| -> Vec<Value> {
results
.into_iter()
.map(|(mem, score)| decorate_memory(&mem, score, verbose_provenance, conn))
.collect()
};
let apply_confidence_tier_filter = |results: Vec<(Memory, f64)>| -> Vec<(Memory, f64)> {
match confidence_tier_filter {
None => results,
Some(target) => results
.into_iter()
.filter(|(m, _)| m.confidence_tier() == target)
.collect(),
}
};
let apply_visibility_filter = |results: Vec<(Memory, f64)>| -> Vec<(Memory, f64)> {
match caller {
None => results,
Some(c) => results
.into_iter()
.filter(|(m, _)| crate::visibility::is_visible_to_caller(m, c))
.collect(),
}
};
let _ = db::gc_if_needed(conn, archive_on_gc);
let context = req.context.as_str();
if context.is_empty() {
return Err(crate::errors::msg::CONTEXT_REQUIRED.to_string());
}
let session_default = req.session_default.unwrap_or(false);
let scope = if session_default { recall_scope } else { None };
let scope_namespace: Option<String> = scope
.and_then(|s| s.namespaces.as_ref())
.and_then(|v| v.first())
.cloned();
let scope_since: Option<String> = scope.and_then(|s| {
s.since.as_deref().and_then(|d| {
crate::config::parse_duration_string(d).map(|dur| {
let cutoff = chrono::Utc::now() - dur;
cutoff.to_rfc3339()
})
})
});
let explicit_namespace = req.namespace.as_deref();
let explicit_since = req.since.as_deref();
let namespace: Option<&str> = explicit_namespace.or(scope_namespace.as_deref());
let limit = if let Some(v) = req.limit
&& v > 0
{
usize::try_from(v).unwrap_or(usize::MAX)
} else if let Some(v) = scope.and_then(|s| s.limit) {
usize::try_from(v).unwrap_or(usize::MAX)
} else {
10
};
let tags = req.tags.as_deref();
let since: Option<&str> = explicit_since.or(scope_since.as_deref());
let until = req.until.as_deref();
let as_agent = req.as_agent.as_deref();
if let Some(a) = as_agent {
validate::validate_namespace(a).map_err(|e| e.to_string())?;
}
let budget_tokens = req.resolved_budget_tokens();
let kinds_filter = req.kinds.as_ref().and_then(KindsFilter::parse);
let include_archived = req.include_archived.unwrap_or(false);
let has_citations_filter = req.has_citations.unwrap_or(false);
let source_uri_prefix: Option<String> = req.source_uri_prefix.clone();
let session_id: Option<String> = req
.session_id
.as_deref()
.map(str::trim)
.filter(|s| !s.is_empty())
.map(std::string::ToString::to_string);
let session_tracker = crate::reranker::global_session_recall_tracker();
let context_tokens: Vec<String> = req
.context_tokens
.as_ref()
.map(|arr| arr.iter().filter(|s| !s.is_empty()).cloned().collect())
.unwrap_or_default();
let decorate_budget = |resp: &mut Value, outcome: &db::BudgetOutcome| {
resp["tokens_used"] = json!(outcome.tokens_used);
if let Some(b) = budget_tokens {
resp["budget_tokens"] = json!(b);
let meta = resp
.as_object_mut()
.expect("recall response is always a JSON object")
.entry("meta".to_string())
.or_insert_with(|| json!({}));
meta["budget_tokens_used"] = json!(outcome.tokens_used);
meta["budget_tokens_remaining"] = json!(outcome.tokens_remaining.unwrap_or(0));
meta["memories_dropped"] = json!(outcome.memories_dropped);
meta["budget_overflow"] = json!(outcome.budget_overflow);
}
};
let reranker_used = match reranker {
Some(ce) if ce.is_neural() => "neural",
Some(ce) if ce.is_degraded_lexical() => "degraded_lexical",
Some(_) => "lexical",
None => "none",
};
let attach_meta = |resp: &mut Value, recall_mode: &str, telemetry: &RecallTelemetry| {
let blend_weight = (telemetry.blend_weight_avg * crate::SCORE_DISPLAY_ROUND_FACTOR).round()
/ crate::SCORE_DISPLAY_ROUND_FACTOR;
let meta = RecallMeta {
recall_mode: recall_mode.to_string(),
reranker_used: reranker_used.to_string(),
candidate_counts: CandidateCounts {
fts: telemetry.fts_candidates,
hnsw: telemetry.hnsw_candidates,
},
blend_weight,
};
if let Ok(Value::Object(p3_fields)) = serde_json::to_value(&meta) {
let meta_obj = resp
.as_object_mut()
.expect("recall response is always a JSON object")
.entry("meta".to_string())
.or_insert_with(|| json!({}));
if let Some(existing) = meta_obj.as_object_mut() {
for (k, v) in p3_fields {
existing.insert(k, v);
}
}
}
};
if let Some(emb) = embedder {
match emb.embed_query(context) {
Ok(primary_emb) => {
let query_emb = if context_tokens.is_empty() {
primary_emb
} else {
let joined = context_tokens.join(" ");
match emb.embed_query(&joined) {
Ok(ctx_emb) => crate::embeddings::Embedder::fuse(
&primary_emb,
&ctx_emb,
crate::RECALL_PRIMARY_CTX_BLEND,
),
Err(e) => {
tracing::warn!("context_tokens embed failed, using primary only: {e}");
primary_emb
}
}
};
let (results, outcome, telemetry) = db::recall_hybrid_with_telemetry(
conn,
context,
&query_emb,
namespace,
limit.min(50),
tags,
since,
until,
vector_index,
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
as_agent,
budget_tokens,
resolved_scoring,
include_archived,
source_uri_prefix.as_deref(),
)
.map_err(|e| e.to_string())?;
let results = crate::cli::recall::apply_form4_recall_filters(
results,
has_citations_filter,
source_uri_prefix.as_deref(),
);
if let Some(ce) = reranker {
let ce_reranked = ce.rerank(context, results);
let ce_reranked = apply_kinds_filter(ce_reranked, kinds_filter.as_deref());
let ce_reranked = apply_confidence_tier_filter(ce_reranked);
let ce_reranked = apply_visibility_filter(ce_reranked);
let ce_reranked = crate::reranker::apply_session_recency_boost(
ce_reranked,
session_id.as_deref(),
session_tracker,
);
let memories = scored_memories(ce_reranked, conn);
record_recall_observations(
conn,
&recall_id,
&memories,
crate::models::RECALL_MODE_HYBRID_RERANK,
);
let mut resp = json!({
"recall_id": recall_id,
"memories": memories,
"count": memories.len(),
"mode": crate::models::RECALL_MODE_HYBRID_RERANK,
});
decorate_budget(&mut resp, &outcome);
attach_meta(&mut resp, "hybrid", &telemetry);
super::inject_namespace_standard(conn, namespace, &mut resp);
return Ok(resp);
}
let results = apply_kinds_filter(results, kinds_filter.as_deref());
let results = apply_confidence_tier_filter(results);
let results = apply_visibility_filter(results);
let results = crate::reranker::apply_session_recency_boost(
results,
session_id.as_deref(),
session_tracker,
);
let memories = scored_memories(results, conn);
record_recall_observations(conn, &recall_id, &memories, "hybrid");
let mut resp = json!({
"recall_id": recall_id,
"memories": memories,
"count": memories.len(),
"mode": "hybrid",
});
decorate_budget(&mut resp, &outcome);
attach_meta(&mut resp, "hybrid", &telemetry);
super::inject_namespace_standard(conn, namespace, &mut resp);
return Ok(resp);
}
Err(e) => {
tracing::warn!("embedding failed, falling back to FTS: {}", e);
}
}
}
let (results, outcome, telemetry) = db::recall_with_telemetry(
conn,
context,
namespace,
limit.min(50),
tags,
since,
until,
resolved_ttl.short_extend_secs,
resolved_ttl.mid_extend_secs,
as_agent,
budget_tokens,
include_archived,
source_uri_prefix.as_deref(),
)
.map_err(|e| e.to_string())?;
let results = crate::cli::recall::apply_form4_recall_filters(
results,
has_citations_filter,
source_uri_prefix.as_deref(),
);
let results = apply_kinds_filter(results, kinds_filter.as_deref());
let results = apply_confidence_tier_filter(results);
let results = apply_visibility_filter(results);
let results = crate::reranker::apply_session_recency_boost(
results,
session_id.as_deref(),
session_tracker,
);
let memories = scored_memories(results, conn);
record_recall_observations(conn, &recall_id, &memories, "keyword");
let mut resp = json!({
"recall_id": recall_id,
"memories": memories,
"count": memories.len(),
"mode": "keyword",
});
decorate_budget(&mut resp, &outcome);
attach_meta(&mut resp, "keyword_only", &telemetry);
super::inject_namespace_standard(conn, namespace, &mut resp);
Ok(resp)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{RecallScope, ResolvedScoring, ResolvedTtl};
use crate::embeddings::test_support::MockEmbedder;
use crate::hnsw::VectorIndex;
use crate::models::{Memory, Tier};
use crate::reranker::{BatchedReranker, CrossEncoder};
use crate::storage as db;
fn fresh_conn() -> rusqlite::Connection {
db::open(std::path::Path::new(":memory:")).expect("open in-memory db")
}
fn make_mem(title: &str, content: &str, ns: &str) -> Memory {
let now = chrono::Utc::now().to_rfc3339();
Memory {
id: uuid::Uuid::new_v4().to_string(),
tier: Tier::Long,
namespace: ns.to_string(),
title: title.to_string(),
content: content.to_string(),
tags: vec![],
priority: 5,
confidence: 1.0,
source: "test".to_string(),
access_count: 0,
created_at: now.clone(),
updated_at: now,
last_accessed_at: None,
expires_at: None,
metadata: json!({"agent_id": "ai:test"}),
reflection_depth: 0,
memory_kind: crate::models::MemoryKind::Observation,
entity_id: None,
persona_version: None,
citations: Vec::new(),
source_uri: None,
source_span: None,
confidence_source: crate::models::ConfidenceSource::CallerProvided,
confidence_signals: None,
confidence_decayed_at: None,
version: 1,
}
}
fn seed(conn: &rusqlite::Connection) {
db::insert(
conn,
&make_mem(
"Rust ownership",
"Rust ownership rules prevent data races",
"test",
),
)
.unwrap();
db::insert(
conn,
&make_mem(
"Python typing",
"Python typing is dynamic with hints",
"test",
),
)
.unwrap();
db::insert(conn, &make_mem("Other topic", "Unrelated content", "other")).unwrap();
}
#[test]
fn missing_context_errors() {
let conn = fresh_conn();
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let err = handle_recall(
&conn,
&json!({}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.unwrap_err();
assert!(err.contains("context"));
}
#[test]
fn keyword_only_path() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test"}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("keyword"));
assert_eq!(resp["meta"]["recall_mode"].as_str(), Some("keyword_only"));
}
fn owned_mem(title: &str, agent: &str, scope: Option<&str>) -> Memory {
let mut m = make_mem(title, "shared ownership keyword content", "vis");
m.metadata = match scope {
Some(s) => json!({crate::META_KEY_AGENT_ID: agent, crate::META_KEY_SCOPE: s}),
None => json!({crate::META_KEY_AGENT_ID: agent}),
};
m
}
fn seed_vis(conn: &rusqlite::Connection) {
use crate::models::namespace::MemoryScope;
db::insert(conn, &owned_mem("priv", "ai:alice", None)).expect("ins");
db::insert(
conn,
&owned_mem("shared", "ai:bob", Some(MemoryScope::Collective.as_str())),
)
.expect("ins");
}
fn recall_titles(resp: &Value) -> Vec<String> {
resp["memories"]
.as_array()
.map(|a| {
a.iter()
.filter_map(|m| m["title"].as_str().map(str::to_string))
.collect()
})
.unwrap_or_default()
}
#[test]
fn recall_caller_none_returns_all() {
let conn = fresh_conn();
seed_vis(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall_caller(
&conn,
&json!({"context": "ownership", "namespace": "vis"}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
None,
)
.expect("ok");
assert_eq!(resp["count"].as_u64(), Some(2));
}
#[test]
fn recall_non_owner_excludes_cross_agent_private() {
let conn = fresh_conn();
seed_vis(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall_caller(
&conn,
&json!({"context": "ownership", "namespace": "vis"}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
Some("ai:carol"),
)
.expect("ok");
assert_eq!(resp["count"].as_u64(), Some(1));
assert_eq!(recall_titles(&resp), vec!["shared".to_string()]);
}
#[test]
fn recall_owner_sees_own_private_and_shared() {
let conn = fresh_conn();
seed_vis(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall_caller(
&conn,
&json!({"context": "ownership", "namespace": "vis"}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
Some("ai:alice"),
)
.expect("ok");
assert_eq!(resp["count"].as_u64(), Some(2));
}
#[test]
fn hybrid_path_with_embedder() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let mock = MockEmbedder::new_local().expect("mock");
let resp = handle_recall(
&conn,
&json!({"context": "ownership rules", "namespace": "test"}),
Some(&mock as &dyn crate::embeddings::Embed),
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("hybrid"));
assert_eq!(resp["meta"]["recall_mode"].as_str(), Some("hybrid"));
}
#[test]
fn hybrid_with_reranker_path() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let mock = MockEmbedder::new_local().expect("mock");
let lex = CrossEncoder::new();
let batched = BatchedReranker::new(lex);
let resp = handle_recall(
&conn,
&json!({"context": "ownership rules", "namespace": "test"}),
Some(&mock as &dyn crate::embeddings::Embed),
None,
Some(&batched),
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("hybrid+rerank"));
assert_eq!(resp["meta"]["reranker_used"].as_str(), Some("lexical"));
}
#[test]
fn hybrid_with_vector_index() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let mock = MockEmbedder::new_local().expect("mock");
let idx = VectorIndex::empty();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test"}),
Some(&mock as &dyn crate::embeddings::Embed),
Some(&idx),
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("hybrid"));
}
#[test]
fn budget_tokens_meta_emitted() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test", "budget_tokens": 100u64}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert!(resp["meta"]["budget_tokens_used"].is_number());
assert_eq!(resp["budget_tokens"].as_u64(), Some(100));
}
#[test]
fn budget_tokens_zero_returns_empty() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test", "budget_tokens": 0u64}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert!(resp["meta"]["budget_overflow"].is_boolean());
}
#[test]
fn session_default_recall_scope_splices_defaults() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let scope = RecallScope {
namespaces: Some(vec!["test".to_string()]),
since: Some("24h".to_string()),
tier: None,
limit: Some(2),
};
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "session_default": true}),
None,
None,
None,
false,
&ttl,
&scoring,
Some(&scope),
)
.expect("ok");
assert!(resp["count"].as_u64().unwrap() <= 2);
}
#[test]
fn context_tokens_fusion_path() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let mock = MockEmbedder::new_local().expect("mock");
let resp = handle_recall(
&conn,
&json!({
"context": "ownership",
"namespace": "test",
"context_tokens": ["rust", "memory"]
}),
Some(&mock as &dyn crate::embeddings::Embed),
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("hybrid"));
}
#[test]
fn as_agent_validated() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test", "as_agent": "ai:viewer"}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert!(resp["count"].is_number());
}
#[test]
fn as_agent_invalid_errors() {
let conn = fresh_conn();
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let err = handle_recall(
&conn,
&json!({"context": "ownership", "as_agent": "has space"}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.unwrap_err();
assert!(!err.is_empty());
}
#[test]
fn archive_on_gc_true_runs_gc() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test"}),
None,
None,
None,
true,
&ttl,
&scoring,
None,
)
.expect("ok");
assert!(resp["memories"].is_array());
}
#[test]
fn since_until_filters_applied() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({
"context": "ownership",
"namespace": "test",
"since": "2000-01-01T00:00:00Z",
"until": "2100-01-01T00:00:00Z",
"tags": "rust",
}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert!(resp["memories"].is_array());
}
#[test]
fn limit_overflow_saturates() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test", "limit": u64::MAX}),
None,
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert!(resp["memories"].is_array());
}
struct FailEmbedder {
fail_first: bool,
fail_second: bool,
calls: std::sync::atomic::AtomicUsize,
}
impl FailEmbedder {
fn primary_fail() -> Self {
Self {
fail_first: true,
fail_second: false,
calls: std::sync::atomic::AtomicUsize::new(0),
}
}
fn secondary_fail() -> Self {
Self {
fail_first: false,
fail_second: true,
calls: std::sync::atomic::AtomicUsize::new(0),
}
}
}
impl crate::embeddings::Embed for FailEmbedder {
fn embed(&self, _: &str) -> anyhow::Result<Vec<f32>> {
let n = self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if (n == 0 && self.fail_first) || (n >= 1 && self.fail_second) {
anyhow::bail!("FailEmbedder: synthetic failure on call {n}");
}
Ok(vec![0.1_f32; 384])
}
}
#[test]
fn primary_embedder_error_falls_back_to_keyword() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let fe = FailEmbedder::primary_fail();
let resp = handle_recall(
&conn,
&json!({"context": "ownership", "namespace": "test"}),
Some(&fe as &dyn crate::embeddings::Embed),
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("keyword"));
assert_eq!(resp["meta"]["recall_mode"].as_str(), Some("keyword_only"));
}
#[test]
fn context_tokens_embedder_error_uses_primary_only() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let fe = FailEmbedder::secondary_fail();
let resp = handle_recall(
&conn,
&json!({
"context": "ownership",
"namespace": "test",
"context_tokens": ["rust", "memory"]
}),
Some(&fe as &dyn crate::embeddings::Embed),
None,
None,
false,
&ttl,
&scoring,
None,
)
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("hybrid"));
}
#[tokio::test]
async fn pre_recall_hook_empty_chain_passes_through() {
let conn = fresh_conn();
seed(&conn);
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let chain = crate::hooks::HookChain::new(vec![]);
let mut registry = crate::hooks::ExecutorRegistry::default();
let resp = handle_recall_with_pre_recall_hook(
&conn,
&json!({"context": "ownership", "namespace": "test"}),
None,
None,
None,
false,
&ttl,
&scoring,
&chain,
&mut registry,
None,
None,
)
.await
.expect("ok");
assert_eq!(resp["mode"].as_str(), Some("keyword"));
}
#[tokio::test]
async fn pre_recall_hook_missing_context_errors() {
let conn = fresh_conn();
let ttl = ResolvedTtl::default();
let scoring = ResolvedScoring::default();
let chain = crate::hooks::HookChain::new(vec![]);
let mut registry = crate::hooks::ExecutorRegistry::default();
let err = handle_recall_with_pre_recall_hook(
&conn,
&json!({}),
None,
None,
None,
false,
&ttl,
&scoring,
&chain,
&mut registry,
None,
None,
)
.await
.unwrap_err();
assert!(err.contains("context"));
}
}