use std::collections::HashSet;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::Result;
use crate::hash::compute_content_hash;
use crate::model::event::{AgentEvent, EventType};
use crate::model::memory::{MemoryRecord, MemoryType, Scope};
use crate::query::MnemoEngine;
use crate::storage::MemoryFilter;
#[allow(unused_imports)]
use base64::Engine as _;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TemporalRange {
pub after: Option<String>,
pub before: Option<String>,
}
impl TemporalRange {
pub fn new() -> Self {
Self::default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallRequest {
pub query: String,
pub agent_id: Option<String>,
pub limit: Option<usize>,
pub memory_type: Option<MemoryType>,
pub memory_types: Option<Vec<MemoryType>>,
pub scope: Option<Scope>,
pub min_importance: Option<f32>,
pub tags: Option<Vec<String>>,
pub org_id: Option<String>,
pub strategy: Option<String>,
pub temporal_range: Option<TemporalRange>,
pub recency_half_life_hours: Option<f64>,
pub hybrid_weights: Option<Vec<f32>>,
pub rrf_k: Option<f32>,
pub as_of: Option<String>,
pub explain: Option<bool>,
pub with_provenance: Option<bool>,
}
impl RecallRequest {
pub fn new(query: String) -> Self {
Self {
query,
agent_id: None,
limit: None,
memory_type: None,
memory_types: None,
scope: None,
min_importance: None,
tags: None,
org_id: None,
strategy: None,
temporal_range: None,
recency_half_life_hours: None,
hybrid_weights: None,
rrf_k: None,
as_of: None,
explain: None,
with_provenance: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScoreBreakdown {
pub vector: f32,
pub bm25: f32,
pub graph: f32,
pub recency: f32,
pub rrf_rank: u32,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecallResponse {
pub memories: Vec<ScoredMemory>,
pub total: usize,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub provenance: Option<crate::provenance::ReadProvenance>,
}
impl RecallResponse {
pub fn new(memories: Vec<ScoredMemory>, total: usize) -> Self {
Self {
memories,
total,
provenance: None,
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredMemory {
pub id: Uuid,
pub content: String,
pub agent_id: String,
pub memory_type: MemoryType,
pub scope: Scope,
pub importance: f32,
pub tags: Vec<String>,
pub metadata: serde_json::Value,
pub score: f32,
pub access_count: u64,
pub created_at: String,
pub updated_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub score_breakdown: Option<ScoreBreakdown>,
}
impl From<(MemoryRecord, f32)> for ScoredMemory {
fn from((record, score): (MemoryRecord, f32)) -> Self {
Self {
id: record.id,
content: record.content,
agent_id: record.agent_id,
memory_type: record.memory_type,
scope: record.scope,
importance: record.importance,
tags: record.tags,
metadata: record.metadata,
score,
access_count: record.access_count,
created_at: record.created_at,
updated_at: record.updated_at,
score_breakdown: None,
}
}
}
async fn get_memory_cached(engine: &MnemoEngine, id: Uuid) -> Result<Option<MemoryRecord>> {
if let Some(ref cache) = engine.cache
&& let Some(record) = cache.get(id)
{
return Ok(Some(record));
}
let result = engine.storage.get_memory(id).await?;
if let Some(ref record) = result
&& let Some(ref cache) = engine.cache
{
cache.put(record.clone());
}
Ok(result)
}
pub async fn execute(engine: &MnemoEngine, request: RecallRequest) -> Result<RecallResponse> {
let limit = request.limit.unwrap_or(10).min(100);
let agent_id = request
.agent_id
.clone()
.unwrap_or_else(|| engine.default_agent_id.clone());
super::validate_agent_id(&agent_id)?;
let strategy = request.strategy.as_deref().unwrap_or("auto");
let query_embedding = engine.embedding.embed(&request.query).await?;
let accessible_ids: HashSet<Uuid> = engine
.storage
.list_accessible_memory_ids(&agent_id, super::MAX_BATCH_QUERY_LIMIT)
.await?
.into_iter()
.collect();
let perm_filter = |id: Uuid| accessible_ids.contains(&id);
let mut scored_memories: Vec<(MemoryRecord, f32)> = Vec::new();
let mut breakdowns: std::collections::HashMap<Uuid, ScoreBreakdown> =
std::collections::HashMap::new();
match strategy {
"lexical" => {
if let Some(ref ft) = engine.full_text {
let bm25_results = ft.search(&request.query, limit * 3)?;
for (id, score) in bm25_results {
if let Some(record) = get_memory_cached(engine, id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
scored_memories.push((record, score));
}
}
}
}
"semantic" => {
let search_results =
engine
.index
.filtered_search(&query_embedding, limit * 3, &perm_filter)?;
for (id, distance) in search_results {
if let Some(record) = get_memory_cached(engine, id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
let score = 1.0 - distance;
scored_memories.push((record, score));
}
}
}
"graph" => {
let search_results =
engine
.index
.filtered_search(&query_embedding, limit * 3, &perm_filter)?;
let mut seeds: Vec<(Uuid, f32)> = Vec::new();
for (id, distance) in &search_results {
if let Some(record) = get_memory_cached(engine, *id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
seeds.push((*id, 1.0 - distance));
}
}
let max_hops = 2;
let mut seen: HashSet<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
for &(id, _) in &seeds {
graph_ranked.push((id, 1.0));
}
let mut frontier: Vec<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
let mut decay = 0.5_f32;
for _hop in 0..max_hops {
let mut next_frontier: Vec<Uuid> = Vec::new();
for &id in &frontier {
let from_rels = engine.storage.get_relations_from(id).await?;
let to_rels = engine.storage.get_relations_to(id).await?;
for rel in from_rels.iter().chain(to_rels.iter()) {
let related_id = if rel.source_id == id {
rel.target_id
} else {
rel.source_id
};
if seen.insert(related_id)
&& let Some(record) = get_memory_cached(engine, related_id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
graph_ranked.push((related_id, decay));
next_frontier.push(related_id);
}
}
}
frontier = next_frontier;
decay *= 0.5;
}
let mut v_sorted: Vec<(Uuid, f32)> = seeds.clone();
v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
graph_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let ranked_lists = vec![v_sorted, graph_ranked];
let rrf_k = request.rrf_k.unwrap_or(60.0);
let fused = if let Some(ref weights) = request.hybrid_weights {
crate::query::retrieval::weighted_reciprocal_rank_fusion(
&ranked_lists,
rrf_k,
weights,
)
} else {
crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
};
for (id, score) in fused {
if let Some(record) = get_memory_cached(engine, id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
scored_memories.push((record, score));
}
}
}
"exact" => {
let filter = MemoryFilter {
agent_id: Some(agent_id.clone()),
memory_type: request.memory_type,
scope: request.scope,
tags: request.tags.clone(),
min_importance: request.min_importance,
org_id: request.org_id.clone(),
thread_id: None,
include_deleted: request.as_of.is_some(),
};
let memories = engine.storage.list_memories(&filter, limit, 0).await?;
for record in memories {
if passes_filters(&record, &request, &agent_id, engine).await {
scored_memories.push((record, 1.0));
}
}
}
_ => {
let vector_results =
engine
.index
.filtered_search(&query_embedding, limit * 3, &perm_filter)?;
let mut vector_ranked: Vec<(Uuid, f32)> = Vec::new();
for (id, distance) in vector_results {
vector_ranked.push((id, 1.0 - distance));
}
if let Some(ref ft) = engine.full_text {
let bm25_results = ft.search(&request.query, limit * 3)?;
let mut recency_ranked: Vec<(Uuid, f32)> = Vec::new();
for &(id, _) in &vector_ranked {
if let Some(record) = get_memory_cached(engine, id).await? {
let r_score = crate::query::retrieval::recency_score(
&record.created_at,
request.recency_half_life_hours.unwrap_or(168.0),
);
recency_ranked.push((id, r_score));
}
}
for &(id, _) in &bm25_results {
if !recency_ranked.iter().any(|(rid, _)| *rid == id)
&& let Some(record) = get_memory_cached(engine, id).await?
{
let r_score = crate::query::retrieval::recency_score(
&record.created_at,
request.recency_half_life_hours.unwrap_or(168.0),
);
recency_ranked.push((id, r_score));
}
}
let mut v_sorted = vector_ranked.clone();
v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut b_sorted = bm25_results;
b_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
recency_ranked
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let max_hops = 2;
let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
let top_seeds: Vec<Uuid> =
vector_ranked.iter().take(10).map(|(id, _)| *id).collect();
let mut graph_seen: HashSet<Uuid> = top_seeds.iter().copied().collect();
for &seed_id in &top_seeds {
graph_ranked.push((seed_id, 1.0));
}
let mut frontier: Vec<Uuid> = top_seeds;
let mut decay = 0.5_f32;
for _hop in 0..max_hops {
let mut next_frontier: Vec<Uuid> = Vec::new();
for &fid in &frontier {
match engine.storage.get_relations_from(fid).await {
Ok(from_rels) => {
for rel in &from_rels {
if graph_seen.insert(rel.target_id) {
graph_ranked.push((rel.target_id, decay));
next_frontier.push(rel.target_id);
}
}
}
Err(e) => {
tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get outgoing relations");
}
}
match engine.storage.get_relations_to(fid).await {
Ok(to_rels) => {
for rel in &to_rels {
if graph_seen.insert(rel.source_id) {
graph_ranked.push((rel.source_id, decay));
next_frontier.push(rel.source_id);
}
}
}
Err(e) => {
tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get incoming relations");
}
}
}
frontier = next_frontier;
decay *= 0.5;
}
graph_ranked
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let explain = request.explain.unwrap_or(false);
type SignalMap = std::collections::HashMap<Uuid, f32>;
let (vector_map, bm25_map, recency_map, graph_map): (
SignalMap,
SignalMap,
SignalMap,
SignalMap,
) = if explain {
(
v_sorted.iter().copied().collect(),
b_sorted.iter().copied().collect(),
recency_ranked.iter().copied().collect(),
graph_ranked.iter().copied().collect(),
)
} else {
Default::default()
};
let ranked_lists = vec![v_sorted, b_sorted, recency_ranked, graph_ranked];
let rrf_k = request.rrf_k.unwrap_or(60.0);
let fused = if let Some(ref weights) = request.hybrid_weights {
crate::query::retrieval::weighted_reciprocal_rank_fusion(
&ranked_lists,
rrf_k,
weights,
)
} else {
crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
};
for (rank, (id, score)) in fused.into_iter().enumerate() {
if let Some(record) = get_memory_cached(engine, id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
scored_memories.push((record, score));
if explain {
breakdowns.insert(
id,
ScoreBreakdown {
vector: vector_map.get(&id).copied().unwrap_or(0.0),
bm25: bm25_map.get(&id).copied().unwrap_or(0.0),
graph: graph_map.get(&id).copied().unwrap_or(0.0),
recency: recency_map.get(&id).copied().unwrap_or(0.0),
rrf_rank: rank as u32,
},
);
}
}
}
} else {
for (id, score) in vector_ranked {
if let Some(record) = get_memory_cached(engine, id).await?
&& passes_filters(&record, &request, &agent_id, engine).await
{
scored_memories.push((record, score));
}
}
}
}
}
scored_memories.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored_memories.truncate(limit);
let total = scored_memories.len();
for (record, _) in &scored_memories {
if let Err(e) = engine.storage.touch_memory(record.id).await {
tracing::warn!(memory_id = %record.id, error = %e, "failed to update access timestamp");
}
}
if let Some(ref enc) = engine.encryption {
for (record, _) in &mut scored_memories {
match base64::engine::general_purpose::STANDARD.decode(&record.content) {
Ok(encrypted_bytes) => match enc.decrypt(&encrypted_bytes) {
Ok(decrypted) => match String::from_utf8(decrypted) {
Ok(plaintext) => record.content = plaintext,
Err(e) => {
tracing::error!(memory_id = %record.id, error = %e, "decrypted content is not valid UTF-8");
record.content = "[content unavailable: decryption error]".to_string();
}
},
Err(e) => {
tracing::error!(memory_id = %record.id, error = %e, "failed to decrypt memory content");
record.content = "[content unavailable: decryption error]".to_string();
}
},
Err(e) => {
tracing::error!(memory_id = %record.id, error = %e, "failed to decode encrypted content");
record.content = "[content unavailable: decryption error]".to_string();
}
}
}
}
let provenance_records: Option<Vec<MemoryRecord>> =
if request.with_provenance == Some(true) && engine.provenance_signer.is_some() {
Some(scored_memories.iter().map(|(r, _)| r.clone()).collect())
} else {
None
};
let memories: Vec<ScoredMemory> = scored_memories
.into_iter()
.map(|(record, score)| {
let id = record.id;
let mut scored = ScoredMemory::from((record, score));
if let Some(breakdown) = breakdowns.remove(&id) {
scored.score_breakdown = Some(breakdown);
}
scored
})
.collect();
let now = chrono::Utc::now().to_rfc3339();
let event_content_hash = compute_content_hash(&request.query, &agent_id, &now);
let prev_event_hash = match engine.storage.get_latest_event_hash(&agent_id, None).await {
Ok(hash) => hash,
Err(e) => {
tracing::warn!(error = %e, "failed to get latest event hash, starting new chain segment");
None
}
};
let event_prev_hash = Some(crate::hash::compute_chain_hash(
&event_content_hash,
prev_event_hash.as_deref(),
));
let mut event = AgentEvent {
id: Uuid::now_v7(),
agent_id: agent_id.clone(),
thread_id: None,
run_id: None,
parent_event_id: None,
event_type: EventType::MemoryRead,
payload: serde_json::json!({
"query": request.query,
"results": total,
"strategy": strategy,
}),
trace_id: None,
span_id: None,
model: None,
tokens_input: None,
tokens_output: None,
latency_ms: None,
cost_usd: None,
timestamp: now.clone(),
logical_clock: 0,
content_hash: event_content_hash,
prev_hash: event_prev_hash,
embedding: None,
};
if engine.embed_events
&& let Ok(emb) = engine.embedding.embed(&event.payload.to_string()).await
{
event.embedding = Some(emb);
}
if let Err(e) = engine.storage.insert_event(&event).await {
tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
}
let provenance = if let (Some(records), Some(signer)) =
(provenance_records, engine.provenance_signer.as_ref())
{
match signer.sign(&agent_id, &request.query, &records) {
Ok(p) => Some(p),
Err(e) => {
tracing::warn!(error = %e, "failed to sign read provenance; degrading to no-provenance response");
None
}
}
} else {
None
};
Ok(RecallResponse {
memories,
total,
provenance,
})
}
async fn passes_filters(
record: &MemoryRecord,
request: &RecallRequest,
agent_id: &str,
engine: &MnemoEngine,
) -> bool {
if request.as_of.is_none() && record.is_deleted() {
return false;
}
if let Some(ref expires_at) = record.expires_at
&& let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires_at)
&& exp < chrono::Utc::now()
{
return false;
}
if record.quarantined {
return false;
}
if let Some(ref s) = request.scope
&& record.scope != *s
{
return false;
}
if let Some(ref mts) = request.memory_types {
if !mts.contains(&record.memory_type) {
return false;
}
} else if let Some(ref mt) = request.memory_type
&& record.memory_type != *mt
{
return false;
}
if let Some(min_imp) = request.min_importance
&& record.importance < min_imp
{
return false;
}
if let Some(ref req_tags) = request.tags
&& !req_tags.iter().any(|t| record.tags.contains(t))
{
return false;
}
if let Some(ref tr) = request.temporal_range {
if let Some(ref after) = tr.after
&& let (Ok(after_dt), Ok(record_dt)) = (
chrono::DateTime::parse_from_rfc3339(after),
chrono::DateTime::parse_from_rfc3339(&record.created_at),
)
&& record_dt < after_dt
{
return false;
}
if let Some(ref before) = tr.before
&& let (Ok(before_dt), Ok(record_dt)) = (
chrono::DateTime::parse_from_rfc3339(before),
chrono::DateTime::parse_from_rfc3339(&record.created_at),
)
&& record_dt > before_dt
{
return false;
}
}
if let Some(ref as_of) = request.as_of {
if let (Ok(as_of_dt), Ok(record_dt)) = (
chrono::DateTime::parse_from_rfc3339(as_of),
chrono::DateTime::parse_from_rfc3339(&record.created_at),
) && record_dt > as_of_dt
{
return false;
}
if let Some(ref deleted_at) = record.deleted_at
&& let (Ok(del_dt), Ok(as_of_dt)) = (
chrono::DateTime::parse_from_rfc3339(deleted_at),
chrono::DateTime::parse_from_rfc3339(as_of),
)
&& del_dt <= as_of_dt
{
return false;
}
}
match record.scope {
Scope::Public | Scope::Global => true,
Scope::Shared => {
record.agent_id == agent_id
|| engine
.storage
.check_permission(
record.id,
agent_id,
crate::model::acl::Permission::Read,
)
.await
.unwrap_or_else(|e| {
tracing::warn!(memory_id = %record.id, error = %e, "permission check failed, denying access");
false
})
}
Scope::Private => record.agent_id == agent_id,
}
}