use std::collections::HashMap;
use chrono::Utc;
use tracing::info;
use uuid::Uuid;
use crate::activation;
use crate::config::RetrievalConfig;
use crate::decay;
use crate::embedding::OnnxEmbedder;
use crate::graph::MemoryGraph;
use crate::reranker::Reranker;
use crate::rrf;
use crate::storage::StorageTrait;
use crate::types::Memory;
use crate::vector::VectorIndex;
type CandidateMaps = (HashMap<Uuid, Memory>, HashMap<Uuid, f32>);
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QueryIntent {
Question,
Action,
Recall,
Code,
Visual,
General,
}
const RECALL_KEYWORDS: &[&str] = &[
"remember",
"recall",
"told me",
"said that",
"mentioned",
"last time",
"previously",
"earlier",
"before",
"history",
"past ",
"talked about",
"discussed",
"you said",
"i said",
"we discussed",
];
const ACTION_KEYWORDS: &[&str] = &[
"how do i",
"how to",
"steps to",
"run ",
"execute",
"deploy",
"install",
"build ",
"create ",
"fix ",
"solve",
"implement",
"configure",
"setup",
"set up",
"start ",
"stop ",
"restart",
"update ",
"upgrade",
"debug",
"troubleshoot",
];
const QUESTION_KEYWORDS: &[&str] = &[
"what ",
"what's",
"who ",
"who's",
"where ",
"where's",
"when ",
"when's",
"why ",
"which ",
"is it",
"are there",
"does ",
"do ",
"can ",
"could ",
"should ",
"would ",
"will ",
"?",
];
const CODE_KEYWORDS: &[&str] = &[
"code",
"function",
"class",
"import",
"def ",
"fn ",
"struct ",
"implement",
"syntax",
"compile",
"runtime",
"error in",
"stack trace",
"exception",
"variable",
"method",
"API",
"endpoint",
"schema",
"migration",
"query",
"SQL",
];
const VISUAL_KEYWORDS: &[&str] = &[
"image",
"picture",
"photo",
"screenshot",
"diagram",
"chart",
"graph",
"visual",
"looks like",
"shown in",
"display",
"UI",
"interface",
"design",
"layout",
];
fn matches_any(text: &str, keywords: &[&str]) -> bool {
keywords.iter().any(|kw| text.contains(kw))
}
pub fn classify_intent(query: &str) -> QueryIntent {
let lower = query.to_lowercase();
let checks: &[(&[&str], QueryIntent)] = &[
(RECALL_KEYWORDS, QueryIntent::Recall),
(CODE_KEYWORDS, QueryIntent::Code),
(VISUAL_KEYWORDS, QueryIntent::Visual),
(ACTION_KEYWORDS, QueryIntent::Action),
(QUESTION_KEYWORDS, QueryIntent::Question),
];
for (keywords, intent) in checks {
if matches_any(&lower, keywords) {
return intent.clone();
}
}
QueryIntent::General
}
pub fn intent_score_for_type(intent: &QueryIntent, memory_type: &str) -> f32 {
match intent {
QueryIntent::Question => match memory_type {
"episodic" => 0.8,
"semantic" => 0.6,
"procedural" => 0.2,
_ => 0.5,
},
QueryIntent::Action => match memory_type {
"procedural" => 0.9,
"semantic" => 0.3,
"episodic" => 0.1,
_ => 0.5,
},
QueryIntent::Recall => match memory_type {
"semantic" => 0.8,
"episodic" => 0.6,
"procedural" => 0.3,
_ => 0.5,
},
QueryIntent::Code => match memory_type {
"procedural" => 0.8,
"semantic" => 0.6,
"episodic" => 0.3,
_ => 0.5,
},
QueryIntent::Visual => match memory_type {
"episodic" => 0.8,
"procedural" => 0.2,
_ => 0.5,
},
QueryIntent::General => 0.5,
}
}
#[derive(Debug, thiserror::Error)]
pub enum RecallError {
#[error("Embedding error: {0}")]
Embedding(#[from] crate::embedding::EmbeddingError),
#[error("Storage error: {0}")]
Storage(#[from] crate::storage::StorageError),
#[error("Vector error: {0}")]
Vector(#[from] crate::vector::VectorError),
#[error("Reranker error: {0}")]
Reranker(#[from] crate::reranker::RerankerError),
#[error("RRF error: {0}")]
Rrf(#[from] crate::rrf::RrfError),
#[error("Recall timed out after {0} seconds")]
Timeout(u64),
}
#[derive(Debug, Clone)]
pub struct ScoredCandidate {
pub memory_id: Uuid,
pub memory: Memory,
pub vector_score: f32,
pub bm25_score: f32,
pub graph_score: f32,
pub intent_score: f32,
pub recency_score: f32,
pub access_score: f32,
pub confidence_score: f32,
pub entity_score: f32,
pub type_boost: f32,
pub final_score: f32,
}
#[derive(Debug)]
pub struct RecallResult {
pub memories: Vec<ScoredCandidate>,
}
pub struct RecallEngine<'a> {
storage: &'a dyn StorageTrait,
embedder: &'a OnnxEmbedder,
vector_index: &'a VectorIndex,
config: &'a RetrievalConfig,
graph: Option<&'a MemoryGraph>,
reranker: Option<&'a Reranker>,
}
const RERANK_TOP_N: usize = 20;
impl<'a> RecallEngine<'a> {
pub fn new(
storage: &'a dyn StorageTrait,
embedder: &'a OnnxEmbedder,
vector_index: &'a VectorIndex,
config: &'a RetrievalConfig,
) -> Self {
Self {
storage,
embedder,
vector_index,
config,
graph: None,
reranker: None,
}
}
#[must_use]
pub fn with_graph(mut self, graph: &'a MemoryGraph) -> Self {
self.graph = Some(graph);
self
}
#[must_use]
pub fn with_reranker(mut self, reranker: &'a Reranker) -> Self {
self.reranker = Some(reranker);
self
}
pub fn recall(
&self,
query: &str,
namespace_id: Uuid,
limit: usize,
) -> Result<RecallResult, RecallError> {
self.recall_with_entity(query, namespace_id, limit, None)
}
pub fn recall_with_embedding(
&self,
query: &str,
query_embedding: Option<&[f32]>,
namespace_id: Uuid,
limit: usize,
target_entity: Option<Uuid>,
) -> Result<RecallResult, RecallError> {
self.recall_inner(query, query_embedding, namespace_id, limit, target_entity)
}
pub fn recall_grouped(
&self,
query: &str,
namespace_id: Uuid,
config: &crate::recall_grouped::RecallGroupedConfig,
) -> Result<Vec<crate::recall_grouped::SessionGroup>, RecallError> {
let result = self.recall(query, namespace_id, config.limit)?;
let groups = crate::recall_grouped::group_by_session(
result.memories,
config.order,
config.max_groups,
);
Ok(crate::recall_grouped::attach_observations_to_groups(
self.storage,
groups,
))
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all, fields(query, namespace_id = %namespace_id, limit))]
pub fn recall_with_entity(
&self,
query: &str,
namespace_id: Uuid,
limit: usize,
target_entity: Option<Uuid>,
) -> Result<RecallResult, RecallError> {
self.recall_inner(query, None, namespace_id, limit, target_entity)
}
#[allow(clippy::too_many_lines)]
fn recall_inner(
&self,
query: &str,
pre_embedding: Option<&[f32]>,
namespace_id: Uuid,
limit: usize,
target_entity: Option<Uuid>,
) -> Result<RecallResult, RecallError> {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(self.config.recall_timeout_secs);
let max_candidates = self.config.max_candidates;
let (candidates, vector_map) = if let Some(emb) = pre_embedding {
if target_entity.is_some() {
self.gather_candidates_dual_path(
emb,
query,
namespace_id,
max_candidates,
target_entity,
)?
} else {
self.gather_candidates_with_embedding(emb, query, namespace_id, max_candidates)?
}
} else {
if target_entity.is_some() {
let query_embedding = self.embedder.embed(query)?;
self.gather_candidates_dual_path(
&query_embedding,
query,
namespace_id,
max_candidates,
target_entity,
)?
} else {
self.gather_candidates(query, namespace_id, max_candidates)?
}
};
if candidates.is_empty() {
return Ok(RecallResult { memories: vec![] });
}
if start.elapsed() > timeout {
return Err(RecallError::Timeout(self.config.recall_timeout_secs));
}
let bm25_map = self.build_bm25_map(query, namespace_id, max_candidates)?;
let intent = classify_intent(query);
let candidates_found = candidates.len();
let now = Utc::now();
let mut ranking_vec: Vec<(Uuid, f32)> =
vector_map.iter().map(|(&id, &score)| (id, score)).collect();
ranking_vec.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranking_bm25: Vec<(Uuid, f32)> =
bm25_map.iter().map(|(&id, &score)| (id, score)).collect();
ranking_bm25.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranking_activation: Vec<(Uuid, f32)> = candidates
.iter()
.map(|(&id, mem)| {
let b = match mem {
Memory::Episodic(e) => {
let count = e.access_count.max(1);
let last = e.last_accessed.unwrap_or(e.timestamp).timestamp() as f64;
let times: Vec<f64> = (0..count.min(20))
.map(|i| last - (f64::from(i) * 3600.0))
.collect();
activation::base_level_activation(×, now.timestamp() as f64, 0.5)
}
Memory::Semantic(_) | Memory::Procedural(_) | Memory::Observation(_) => 0.0,
};
(id, b)
})
.collect();
ranking_activation
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let ranking_spread: Vec<(Uuid, f32)> = match (self.graph, target_entity) {
(Some(g), Some(entity_id)) => {
let intent_str = match &intent {
QueryIntent::Question => "question",
QueryIntent::Action => "action",
QueryIntent::Recall => "recall",
QueryIntent::Code => "code",
QueryIntent::Visual => "visual",
QueryIntent::General => "general",
};
g.beam_search(
entity_id,
intent_str,
self.config.beam_width,
self.config.max_depth,
)
}
_ => Vec::new(),
};
let mut ranking_intent: Vec<(Uuid, f32)> = candidates
.iter()
.map(|(&id, mem)| {
let score = intent_score_for_type(&intent, mem.type_name());
(id, score)
})
.collect();
ranking_intent.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranking_confidence: Vec<(Uuid, f32)> = candidates
.iter()
.map(|(&id, mem)| {
let conf = match mem {
Memory::Episodic(_) => 1.0,
Memory::Semantic(s) => s.confidence,
Memory::Procedural(p) => p.reliability,
Memory::Observation(o) => o.confidence,
};
(id, conf)
})
.collect();
ranking_confidence
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut ranking_entity: Vec<(Uuid, f32)> = if let Some(entity_id) = target_entity {
candidates
.iter()
.map(|(&id, mem)| {
let affinity = match mem {
Memory::Semantic(s) if s.subject == entity_id => 1.0,
Memory::Episodic(e) if e.about_entity == entity_id => 1.0,
Memory::Episodic(e) if e.source_entity == entity_id => 0.8,
_ => 0.0,
};
(id, affinity)
})
.collect()
} else {
Vec::new()
};
ranking_entity.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let all_rankings = vec![
(ranking_vec, self.config.rrf_weights[0]),
(ranking_bm25, self.config.rrf_weights[1]),
(ranking_activation, self.config.rrf_weights[2]),
(ranking_spread, self.config.rrf_weights[3]),
(ranking_intent, self.config.rrf_weights[4]),
(ranking_confidence, self.config.rrf_weights[5]),
(ranking_entity, self.config.rrf_weights[6]),
];
let (rankings, rrf_weights): (Vec<_>, Vec<_>) = all_rankings
.into_iter()
.filter(|(ranking, _)| has_discriminative_signal(ranking))
.unzip();
let effective_k = rrf::adaptive_k(candidates.len(), self.config.rrf_k);
let rrf_results = rrf::reciprocal_rank_fusion(&rankings, &rrf_weights, effective_k)?;
if start.elapsed() > timeout {
return Err(RecallError::Timeout(self.config.recall_timeout_secs));
}
let max_access = candidates
.values()
.map(|m| match m {
Memory::Episodic(e) => e.access_count,
Memory::Semantic(_) | Memory::Procedural(_) | Memory::Observation(_) => 0,
})
.max()
.unwrap_or(0);
let mut scored: Vec<ScoredCandidate> = rrf_results
.iter()
.filter_map(|&(id, rrf_score)| {
candidates.get(&id).map(|mem| {
let vector_score = vector_map.get(&id).copied().unwrap_or(0.0).clamp(0.0, 1.0);
let bm25_score = bm25_map.get(&id).copied().unwrap_or(0.0);
let recency_score = match mem {
Memory::Episodic(e) => decay::retrievability(
e.stability,
decay::elapsed_days(e.timestamp, now),
),
Memory::Semantic(s) => {
decay::retrievability(s.stability, decay::elapsed_days(s.valid_at, now))
}
Memory::Procedural(p) => decay::retrievability(
p.reliability,
decay::elapsed_days(p.created_at, now),
),
Memory::Observation(o) => decay::retrievability(
o.stability,
decay::elapsed_days(o.created_at, now),
),
};
let confidence_score = match mem {
Memory::Episodic(_) => 1.0_f32,
Memory::Semantic(s) => s.confidence,
Memory::Procedural(p) => p.reliability,
Memory::Observation(o) => o.confidence,
};
let intent_score = intent_score_for_type(&intent, mem.type_name());
let access_count = match mem {
Memory::Episodic(e) => e.access_count,
Memory::Semantic(_) | Memory::Procedural(_) | Memory::Observation(_) => 0,
};
let access_score = if max_access == 0 {
0.0_f32
} else {
((access_count + 1) as f32).ln() / ((max_access + 1) as f32).ln()
};
let entity_score = if let Some(entity_id) = target_entity {
match mem {
Memory::Semantic(s) if s.subject == entity_id => 1.0,
Memory::Episodic(e) if e.about_entity == entity_id => 1.0,
Memory::Episodic(e) if e.source_entity == entity_id => 0.8,
_ => 0.0,
}
} else {
0.0
};
ScoredCandidate {
memory_id: id,
memory: mem.clone(),
vector_score,
bm25_score,
graph_score: 0.0, intent_score,
recency_score,
access_score,
confidence_score,
entity_score,
type_boost: 1.0,
final_score: rrf_score,
}
})
})
.collect();
if let Some(reranker) = self.reranker {
scored = apply_reranking(scored, reranker, query)?;
}
scored.truncate(limit);
self.apply_reinforcement(&scored);
info!(
event = "recall_decision",
query = %query,
intent = ?intent,
candidates_found = candidates_found,
results_returned = scored.len(),
duration_ms = start.elapsed().as_millis() as u64,
"recall completed"
);
Ok(RecallResult { memories: scored })
}
fn gather_candidates(
&self,
query: &str,
namespace_id: Uuid,
max_candidates: usize,
) -> Result<CandidateMaps, RecallError> {
let query_embedding = self.embedder.embed(query)?;
self.gather_candidates_with_embedding(&query_embedding, query, namespace_id, max_candidates)
}
fn gather_candidates_with_embedding(
&self,
query_embedding: &[f32],
query: &str,
namespace_id: Uuid,
max_candidates: usize,
) -> Result<CandidateMaps, RecallError> {
let vector_hits = self.vector_index.search(query_embedding, max_candidates)?;
let vector_map: HashMap<Uuid, f32> = vector_hits.iter().copied().collect();
let fts_memories = self
.storage
.search_fts(query, namespace_id, max_candidates)?;
let mut candidates: HashMap<Uuid, Memory> = HashMap::new();
for mem in fts_memories {
candidates.entry(mem.id()).or_insert(mem);
}
for (id, _) in &vector_hits {
if !candidates.contains_key(id) {
if let Ok(Some(m)) = self.storage.get_episodic(*id) {
candidates.insert(*id, Memory::Episodic(m));
} else if let Ok(Some(m)) = self.storage.get_semantic(*id) {
candidates.insert(*id, Memory::Semantic(m));
} else if let Ok(Some(m)) = self.storage.get_procedural(*id) {
candidates.insert(*id, Memory::Procedural(m));
}
}
}
Ok((candidates, vector_map))
}
fn gather_candidates_dual_path(
&self,
query_embedding: &[f32],
query: &str,
namespace_id: Uuid,
max_candidates: usize,
target_entity: Option<Uuid>,
) -> Result<CandidateMaps, RecallError> {
let mut candidates: HashMap<Uuid, Memory> = HashMap::new();
let mut vector_map: HashMap<Uuid, f32> = HashMap::new();
if let Some(entity_id) = target_entity {
let entity_map_ref = &self.vector_index;
let entity_hits =
entity_map_ref.filtered_search(query_embedding, max_candidates, |id| {
entity_map_ref.entity_for(id) == Some(entity_id)
})?;
for &(id, score) in &entity_hits {
vector_map.insert(id, score);
}
let scoped_fts =
self.storage
.search_fts_scoped(query, namespace_id, entity_id, max_candidates)?;
for mem in scoped_fts {
candidates.entry(mem.id()).or_insert(mem);
}
for (id, _) in &entity_hits {
if !candidates.contains_key(id) {
if let Ok(Some(m)) = self.storage.get_episodic(*id) {
candidates.insert(*id, Memory::Episodic(m));
} else if let Ok(Some(m)) = self.storage.get_semantic(*id) {
candidates.insert(*id, Memory::Semantic(m));
} else if let Ok(Some(m)) = self.storage.get_procedural(*id) {
candidates.insert(*id, Memory::Procedural(m));
}
}
}
}
let broad_limit = max_candidates / 4;
let broad_vector_hits = self.vector_index.search(query_embedding, broad_limit)?;
for &(id, score) in &broad_vector_hits {
vector_map.entry(id).or_insert(score);
}
let broad_fts = self.storage.search_fts(query, namespace_id, broad_limit)?;
for mem in broad_fts {
candidates.entry(mem.id()).or_insert(mem);
}
for (id, _) in &broad_vector_hits {
if !candidates.contains_key(id) {
if let Ok(Some(m)) = self.storage.get_episodic(*id) {
candidates.insert(*id, Memory::Episodic(m));
} else if let Ok(Some(m)) = self.storage.get_semantic(*id) {
candidates.insert(*id, Memory::Semantic(m));
} else if let Ok(Some(m)) = self.storage.get_procedural(*id) {
candidates.insert(*id, Memory::Procedural(m));
}
}
}
Ok((candidates, vector_map))
}
fn build_bm25_map(
&self,
query: &str,
namespace_id: Uuid,
max_candidates: usize,
) -> Result<HashMap<Uuid, f32>, RecallError> {
let ordered = self
.storage
.search_fts(query, namespace_id, max_candidates)?;
let fts_count = ordered.len();
let map = ordered
.iter()
.enumerate()
.map(|(pos, m)| {
let score = if fts_count == 1 {
1.0_f32
} else {
(fts_count - pos) as f32 / fts_count as f32
};
(m.id(), score)
})
.collect();
Ok(map)
}
fn apply_reinforcement(&self, scored: &[ScoredCandidate]) {
for candidate in scored {
if let Memory::Episodic(e) = &candidate.memory {
let new_stability = decay::reinforce(e.stability, candidate.recency_score, 5);
let new_retrievability = decay::retrievability(new_stability, 0.0);
let _ = self.storage.update_episodic_access(
candidate.memory_id,
new_stability,
new_retrievability,
);
}
}
}
}
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
fn has_discriminative_signal(ranking: &[(Uuid, f32)]) -> bool {
if ranking.len() < 2 {
return !ranking.is_empty();
}
let first = ranking[0].1;
ranking
.iter()
.any(|(_, score)| (score - first).abs() > 1e-6)
}
#[allow(dead_code, clippy::too_many_arguments)]
fn score_candidate(
id: Uuid,
memory: Memory,
vector_map: &HashMap<Uuid, f32>,
bm25_map: &HashMap<Uuid, f32>,
graph_map: &HashMap<Uuid, f32>,
intent: &QueryIntent,
max_access: u32,
now: chrono::DateTime<Utc>,
weights: &[f32; 8],
) -> ScoredCandidate {
let vector_score = vector_map.get(&id).copied().unwrap_or(0.0).clamp(0.0, 1.0);
let bm25_score = bm25_map.get(&id).copied().unwrap_or(0.0);
let recency_score = match &memory {
Memory::Episodic(e) => {
decay::retrievability(e.stability, decay::elapsed_days(e.timestamp, now))
}
Memory::Semantic(s) => {
decay::retrievability(s.stability, decay::elapsed_days(s.valid_at, now))
}
Memory::Procedural(p) => {
decay::retrievability(p.reliability, decay::elapsed_days(p.created_at, now))
}
Memory::Observation(o) => {
decay::retrievability(o.stability, decay::elapsed_days(o.created_at, now))
}
};
let access_count = match &memory {
Memory::Episodic(e) => e.access_count,
Memory::Semantic(_) | Memory::Procedural(_) | Memory::Observation(_) => 0,
};
let access_score = if max_access == 0 {
0.0_f32
} else {
((access_count + 1) as f32).ln() / ((max_access + 1) as f32).ln()
};
let confidence_score = match &memory {
Memory::Episodic(_) => 1.0_f32,
Memory::Semantic(s) => s.confidence,
Memory::Procedural(p) => p.reliability,
Memory::Observation(o) => o.confidence,
};
let direct = graph_map.get(&id).copied().unwrap_or(0.0);
let entity_linked = match &memory {
Memory::Episodic(e) => graph_map.get(&e.about_entity).copied().unwrap_or(0.0),
Memory::Semantic(s) => graph_map.get(&s.subject).copied().unwrap_or(0.0),
Memory::Procedural(_) | Memory::Observation(_) => 0.0,
};
let graph_score = direct.max(entity_linked);
let intent_score = intent_score_for_type(intent, memory.type_name());
let type_boost = 1.0_f32;
let final_score = weights[0] * vector_score
+ weights[1] * bm25_score
+ weights[2] * graph_score
+ weights[3] * intent_score
+ weights[4] * recency_score
+ weights[5] * access_score
+ weights[6] * confidence_score
+ weights[7] * type_boost;
ScoredCandidate {
memory_id: id,
memory,
vector_score,
bm25_score,
graph_score,
intent_score,
recency_score,
access_score,
confidence_score,
entity_score: 0.0,
type_boost,
final_score,
}
}
fn apply_reranking(
mut scored: Vec<ScoredCandidate>,
reranker: &crate::reranker::Reranker,
query: &str,
) -> Result<Vec<ScoredCandidate>, crate::reranker::RerankerError> {
let rerank_count = scored.len().min(RERANK_TOP_N);
let tail = scored.split_off(rerank_count);
let texts: Vec<String> = scored
.iter()
.map(|c| match &c.memory {
Memory::Episodic(e) => e.content.clone(),
Memory::Semantic(s) => format!("{} {} {}", s.subject, s.predicate, s.object),
Memory::Procedural(p) => format!("trigger: {} action: {}", p.trigger, p.action),
Memory::Observation(o) => o.content.clone(),
})
.collect();
let text_refs: Vec<&str> = texts.iter().map(String::as_str).collect();
let rerank_results = reranker.rerank(query, &text_refs, rerank_count)?;
let mut sorted_by_reranker: Vec<ScoredCandidate> = rerank_results
.into_iter()
.map(|r| {
let mut cand = scored[r.index].clone();
cand.final_score = r.score;
cand
})
.collect();
sorted_by_reranker.extend(tail);
Ok(sorted_by_reranker)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::RetrievalConfig;
use crate::embedding::OnnxEmbedder;
use crate::storage::sqlite::SqliteBackend;
use crate::types::{Entity, EntityKind, Episode, EpisodicMemory, Namespace};
use crate::vector::VectorIndex;
const TEST_WEIGHTS: [f32; 8] = [0.25, 0.10, 0.15, 0.05, 0.20, 0.10, 0.10, 0.05];
fn test_config() -> RetrievalConfig {
RetrievalConfig {
default_limit: 5,
max_candidates: 50,
weights: TEST_WEIGHTS,
recall_timeout_secs: 5,
rrf_k: 60,
rrf_weights: [1.0, 0.8, 1.0, 0.8, 0.5, 0.5, 1.2],
beam_width: 10,
max_depth: 4,
}
}
fn setup_episodic(
storage: &SqliteBackend,
embedder: &OnnxEmbedder,
ns: &Namespace,
content: &str,
) -> EpisodicMemory {
let mut entity = Entity::new("agent", EntityKind::Agent);
entity.namespace_id = ns.id;
storage.save_entity(&entity).unwrap();
let episode = Episode::new(ns.id, vec![entity.id]);
storage.save_episode(&episode).unwrap();
let mut mem = EpisodicMemory::new(ns.id, episode.id, entity.id, entity.id, content);
mem.embedding = embedder.embed(content).unwrap();
storage.save_episodic(&mem).unwrap();
mem
}
#[test]
fn test_fusion_scoring_ranks_relevant_higher() {
let dummy_id_a = Uuid::new_v4();
let dummy_id_b = Uuid::new_v4();
let make_mem = |ns_id: Uuid| -> Memory {
let ep_id = Uuid::new_v4();
let ent = Uuid::new_v4();
Memory::Episodic(EpisodicMemory::new(ns_id, ep_id, ent, ent, "dummy"))
};
let ns_id = Uuid::new_v4();
let weights = TEST_WEIGHTS;
let a_vector = 0.95f32;
let a_bm25 = 0.90f32;
let a_recency = 0.80f32;
let a_confidence = 1.0f32;
let a_type_boost = 1.0f32;
let score_a = weights[0] * a_vector
+ weights[1] * a_bm25
+ weights[4] * a_recency
+ weights[6] * a_confidence
+ weights[7] * a_type_boost;
let b_vector = 0.10f32;
let b_bm25 = 0.05f32;
let b_recency = 0.50f32;
let b_confidence = 1.0f32;
let b_type_boost = 1.0f32;
let score_b = weights[0] * b_vector
+ weights[1] * b_bm25
+ weights[4] * b_recency
+ weights[6] * b_confidence
+ weights[7] * b_type_boost;
assert!(
score_a > score_b,
"High-signal candidate A ({score_a}) should outrank B ({score_b})"
);
let _ = (dummy_id_a, dummy_id_b, ns_id, make_mem(Uuid::new_v4()));
}
#[test]
fn test_recall_end_to_end() {
let dir = tempfile::tempdir().unwrap();
let storage = SqliteBackend::open(dir.path()).unwrap();
let embedder = OnnxEmbedder::new_mock(64);
let mut vector_index = VectorIndex::new(64, 16);
let config = test_config();
let ns = Namespace::new("test-ns");
storage.save_namespace(&ns).unwrap();
let mem = setup_episodic(&storage, &embedder, &ns, "Rust memory engine test content");
vector_index.add(mem.id, &mem.embedding).unwrap();
let engine = RecallEngine::new(&storage, &embedder, &vector_index, &config);
let result = engine.recall("Rust memory engine", ns.id, 5).unwrap();
assert!(!result.memories.is_empty(), "Expected at least one result");
let found = result.memories.iter().any(|c| c.memory_id == mem.id);
assert!(found, "Inserted memory should appear in recall results");
}
#[test]
fn test_recall_with_multiple_memories() {
let dir = tempfile::tempdir().unwrap();
let storage = SqliteBackend::open(dir.path()).unwrap();
let embedder = OnnxEmbedder::new_mock(64);
let mut vector_index = VectorIndex::new(64, 16);
let config = test_config();
let ns = Namespace::new("multi-ns");
storage.save_namespace(&ns).unwrap();
let mem_a = setup_episodic(
&storage,
&embedder,
&ns,
"quantum physics relativity theory",
);
let mem_b = setup_episodic(
&storage,
&embedder,
&ns,
"cooking pasta recipe Italian food",
);
let mem_c = setup_episodic(
&storage,
&embedder,
&ns,
"quantum entanglement superposition",
);
vector_index.add(mem_a.id, &mem_a.embedding).unwrap();
vector_index.add(mem_b.id, &mem_b.embedding).unwrap();
vector_index.add(mem_c.id, &mem_c.embedding).unwrap();
let engine = RecallEngine::new(&storage, &embedder, &vector_index, &config);
let result = engine.recall("quantum physics", ns.id, 3).unwrap();
assert!(!result.memories.is_empty());
if result.memories.len() >= 2 {
let top_id = result.memories[0].memory_id;
assert_ne!(
top_id, mem_b.id,
"Cooking memory should not be top result for quantum physics query"
);
}
}
#[test]
fn test_recall_empty_index() {
let dir = tempfile::tempdir().unwrap();
let storage = SqliteBackend::open(dir.path()).unwrap();
let embedder = OnnxEmbedder::new_mock(64);
let vector_index = VectorIndex::new(64, 16);
let config = test_config();
let ns = Namespace::new("empty-ns");
storage.save_namespace(&ns).unwrap();
let engine = RecallEngine::new(&storage, &embedder, &vector_index, &config);
let result = engine.recall("anything", ns.id, 5).unwrap();
assert!(
result.memories.is_empty(),
"Empty index should return no results"
);
}
#[test]
fn test_retrieval_reinforcement() {
let dir = tempfile::tempdir().unwrap();
let storage = SqliteBackend::open(dir.path()).unwrap();
let embedder = OnnxEmbedder::new_mock(64);
let mut vector_index = VectorIndex::new(64, 16);
let config = test_config();
let ns = Namespace::new("reinforce-ns");
storage.save_namespace(&ns).unwrap();
let mem = setup_episodic(
&storage,
&embedder,
&ns,
"reinforcement learning access count",
);
vector_index.add(mem.id, &mem.embedding).unwrap();
let initial_access = mem.access_count;
let engine = RecallEngine::new(&storage, &embedder, &vector_index, &config);
let result = engine.recall("reinforcement learning", ns.id, 5).unwrap();
assert!(!result.memories.is_empty());
let updated = storage.get_episodic(mem.id).unwrap();
let updated_access = updated.map(|m| m.access_count).unwrap_or(0);
assert!(
updated_access > initial_access,
"access_count should increase after retrieval (was {initial_access}, now {updated_access})"
);
}
#[test]
fn test_classify_intent_question() {
assert_eq!(classify_intent("What is Rust?"), QueryIntent::Question);
assert_eq!(
classify_intent("Who wrote this library?"),
QueryIntent::Question
);
assert_eq!(
classify_intent("Where is the config file?"),
QueryIntent::Question
);
}
#[test]
fn test_classify_intent_action() {
assert_eq!(
classify_intent("How to build the project"),
QueryIntent::Action
);
assert_eq!(
classify_intent("Deploy the application to prod"),
QueryIntent::Action
);
assert_eq!(classify_intent("Fix the broken test"), QueryIntent::Action);
}
#[test]
fn test_classify_intent_recall() {
assert_eq!(
classify_intent("Do you remember our talk?"),
QueryIntent::Recall
);
assert_eq!(
classify_intent("What did we discuss last time?"),
QueryIntent::Recall
);
assert_eq!(
classify_intent("You mentioned something previously"),
QueryIntent::Recall
);
}
#[test]
fn test_classify_intent_general() {
assert_eq!(classify_intent("Rust"), QueryIntent::General);
assert_eq!(classify_intent("hello world"), QueryIntent::General);
assert_eq!(classify_intent("pensyve core"), QueryIntent::General);
}
#[test]
fn test_intent_score_question_favors_episodic() {
let q_episodic = intent_score_for_type(&QueryIntent::Question, "episodic");
let q_semantic = intent_score_for_type(&QueryIntent::Question, "semantic");
let q_procedural = intent_score_for_type(&QueryIntent::Question, "procedural");
assert!(
q_episodic > q_semantic,
"Question should favor episodic over semantic"
);
assert!(
q_semantic > q_procedural,
"Question should favor semantic over procedural"
);
}
#[test]
fn test_intent_score_action_favors_procedural() {
let a_procedural = intent_score_for_type(&QueryIntent::Action, "procedural");
let a_semantic = intent_score_for_type(&QueryIntent::Action, "semantic");
let a_episodic = intent_score_for_type(&QueryIntent::Action, "episodic");
assert!(
a_procedural > a_semantic,
"Action should favor procedural over semantic"
);
assert!(
a_semantic > a_episodic,
"Action should favor semantic over episodic"
);
assert!(
(a_procedural - 0.9).abs() < f32::EPSILON,
"Action+procedural should be 0.9"
);
}
#[test]
fn test_classify_intent_code() {
assert_eq!(
classify_intent("Show me the function definition"),
QueryIntent::Code
);
assert_eq!(
classify_intent("What's the API endpoint for users?"),
QueryIntent::Code
);
}
#[test]
fn test_classify_intent_visual() {
assert_eq!(
classify_intent("What does the image show?"),
QueryIntent::Visual
);
assert_eq!(
classify_intent("Describe the screenshot"),
QueryIntent::Visual
);
}
#[test]
fn test_intent_score_code_favors_procedural() {
let c_procedural = intent_score_for_type(&QueryIntent::Code, "procedural");
let c_semantic = intent_score_for_type(&QueryIntent::Code, "semantic");
assert!(c_procedural > c_semantic);
}
#[test]
fn test_intent_score_visual_favors_episodic() {
let v_episodic = intent_score_for_type(&QueryIntent::Visual, "episodic");
let v_semantic = intent_score_for_type(&QueryIntent::Visual, "semantic");
assert!(v_episodic > v_semantic);
}
#[test]
fn test_recall_with_mock_reranker() {
let dir = tempfile::tempdir().unwrap();
let storage = SqliteBackend::open(dir.path()).unwrap();
let embedder = OnnxEmbedder::new_mock(64);
let mut vector_index = VectorIndex::new(64, 16);
let config = test_config();
let reranker = crate::reranker::Reranker::new_mock();
let ns = Namespace::new("reranker-ns");
storage.save_namespace(&ns).unwrap();
let mem_a = setup_episodic(
&storage,
&embedder,
&ns,
"Rust programming language systems",
);
let mem_b = setup_episodic(
&storage,
&embedder,
&ns,
"cooking delicious pasta with garlic",
);
vector_index.add(mem_a.id, &mem_a.embedding).unwrap();
vector_index.add(mem_b.id, &mem_b.embedding).unwrap();
let engine =
RecallEngine::new(&storage, &embedder, &vector_index, &config).with_reranker(&reranker);
let result = engine.recall("Rust systems programming", ns.id, 5).unwrap();
assert!(
!result.memories.is_empty(),
"Expected results with reranker attached"
);
for cand in &result.memories {
assert!(
cand.final_score > 0.0 && cand.final_score <= 1.0,
"Mock reranker score out of range: {}",
cand.final_score
);
}
}
}