Skip to main content

mnemo_core/query/
recall.rs

1use std::collections::HashSet;
2
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6use crate::error::Result;
7use crate::hash::compute_content_hash;
8use crate::model::event::{AgentEvent, EventType};
9use crate::model::memory::{MemoryRecord, MemoryType, Scope};
10use crate::query::MnemoEngine;
11use crate::storage::MemoryFilter;
12#[allow(unused_imports)]
13use base64::Engine as _;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct TemporalRange {
17    pub after: Option<String>,
18    pub before: Option<String>,
19}
20
21impl TemporalRange {
22    pub fn new() -> Self {
23        Self::default()
24    }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RecallRequest {
29    pub query: String,
30    pub agent_id: Option<String>,
31    pub limit: Option<usize>,
32    pub memory_type: Option<MemoryType>,
33    pub memory_types: Option<Vec<MemoryType>>,
34    pub scope: Option<Scope>,
35    pub min_importance: Option<f32>,
36    pub tags: Option<Vec<String>>,
37    pub org_id: Option<String>,
38    pub strategy: Option<String>,
39    pub temporal_range: Option<TemporalRange>,
40    pub recency_half_life_hours: Option<f64>,
41    pub hybrid_weights: Option<Vec<f32>>,
42    pub rrf_k: Option<f32>,
43    pub as_of: Option<String>,
44    /// When set, each `ScoredMemory` is augmented with a `score_breakdown`
45    /// that reports the per-signal score contributions (vector, bm25, graph,
46    /// recency) and final RRF rank.
47    pub explain: Option<bool>,
48}
49
50impl RecallRequest {
51    pub fn new(query: String) -> Self {
52        Self {
53            query,
54            agent_id: None,
55            limit: None,
56            memory_type: None,
57            memory_types: None,
58            scope: None,
59            min_importance: None,
60            tags: None,
61            org_id: None,
62            strategy: None,
63            temporal_range: None,
64            recency_half_life_hours: None,
65            hybrid_weights: None,
66            rrf_k: None,
67            as_of: None,
68            explain: None,
69        }
70    }
71}
72
73/// Per-signal score contributions for a single recall hit.
74///
75/// Emitted when `RecallRequest.explain = Some(true)`. Each field is the
76/// raw signal score used as input to reciprocal-rank fusion (0 when the
77/// memory didn't appear in that list).
78#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct ScoreBreakdown {
80    pub vector: f32,
81    pub bm25: f32,
82    pub graph: f32,
83    pub recency: f32,
84    /// 0-based position of the memory in the fused ranking.
85    pub rrf_rank: u32,
86}
87
88#[non_exhaustive]
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RecallResponse {
91    pub memories: Vec<ScoredMemory>,
92    pub total: usize,
93}
94
95impl RecallResponse {
96    pub fn new(memories: Vec<ScoredMemory>, total: usize) -> Self {
97        Self { memories, total }
98    }
99}
100
101#[non_exhaustive]
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ScoredMemory {
104    pub id: Uuid,
105    pub content: String,
106    pub agent_id: String,
107    pub memory_type: MemoryType,
108    pub scope: Scope,
109    pub importance: f32,
110    pub tags: Vec<String>,
111    pub metadata: serde_json::Value,
112    pub score: f32,
113    pub access_count: u64,
114    pub created_at: String,
115    pub updated_at: String,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub score_breakdown: Option<ScoreBreakdown>,
118}
119
120impl From<(MemoryRecord, f32)> for ScoredMemory {
121    fn from((record, score): (MemoryRecord, f32)) -> Self {
122        Self {
123            id: record.id,
124            content: record.content,
125            agent_id: record.agent_id,
126            memory_type: record.memory_type,
127            scope: record.scope,
128            importance: record.importance,
129            tags: record.tags,
130            metadata: record.metadata,
131            score,
132            access_count: record.access_count,
133            created_at: record.created_at,
134            updated_at: record.updated_at,
135            score_breakdown: None,
136        }
137    }
138}
139
140/// Get a memory by ID, checking cache first then falling back to storage.
141async fn get_memory_cached(engine: &MnemoEngine, id: Uuid) -> Result<Option<MemoryRecord>> {
142    if let Some(ref cache) = engine.cache
143        && let Some(record) = cache.get(id)
144    {
145        return Ok(Some(record));
146    }
147    let result = engine.storage.get_memory(id).await?;
148    if let Some(ref record) = result
149        && let Some(ref cache) = engine.cache
150    {
151        cache.put(record.clone());
152    }
153    Ok(result)
154}
155
156pub async fn execute(engine: &MnemoEngine, request: RecallRequest) -> Result<RecallResponse> {
157    let limit = request.limit.unwrap_or(10).min(100);
158    let agent_id = request
159        .agent_id
160        .clone()
161        .unwrap_or_else(|| engine.default_agent_id.clone());
162    super::validate_agent_id(&agent_id)?;
163
164    // Determine strategy
165    let strategy = request.strategy.as_deref().unwrap_or("auto");
166
167    // Compute query embedding (needed for semantic/hybrid/auto)
168    let query_embedding = engine.embedding.embed(&request.query).await?;
169
170    // Pre-compute accessible memory IDs for permission-safe ANN pre-filtering
171    let accessible_ids: HashSet<Uuid> = engine
172        .storage
173        .list_accessible_memory_ids(&agent_id, super::MAX_BATCH_QUERY_LIMIT)
174        .await?
175        .into_iter()
176        .collect();
177    let perm_filter = |id: Uuid| accessible_ids.contains(&id);
178
179    let mut scored_memories: Vec<(MemoryRecord, f32)> = Vec::new();
180    let mut breakdowns: std::collections::HashMap<Uuid, ScoreBreakdown> =
181        std::collections::HashMap::new();
182
183    match strategy {
184        "lexical" => {
185            // BM25-only path
186            if let Some(ref ft) = engine.full_text {
187                let bm25_results = ft.search(&request.query, limit * 3)?;
188                for (id, score) in bm25_results {
189                    if let Some(record) = get_memory_cached(engine, id).await?
190                        && passes_filters(&record, &request, &agent_id, engine).await
191                    {
192                        scored_memories.push((record, score));
193                    }
194                }
195            }
196        }
197        "semantic" => {
198            // Vector-only path with permission pre-filtering
199            let search_results =
200                engine
201                    .index
202                    .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
203            for (id, distance) in search_results {
204                if let Some(record) = get_memory_cached(engine, id).await?
205                    && passes_filters(&record, &request, &agent_id, engine).await
206                {
207                    let score = 1.0 - distance;
208                    scored_memories.push((record, score));
209                }
210            }
211        }
212        "graph" => {
213            // Seed from vector results with permission pre-filtering, then expand via graph relations
214            let search_results =
215                engine
216                    .index
217                    .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
218            let mut seeds: Vec<(Uuid, f32)> = Vec::new();
219            for (id, distance) in &search_results {
220                if let Some(record) = get_memory_cached(engine, *id).await?
221                    && passes_filters(&record, &request, &agent_id, engine).await
222                {
223                    seeds.push((*id, 1.0 - distance));
224                }
225            }
226
227            // Collect graph-expanded results with configurable multi-hop traversal
228            let max_hops = 2;
229            let mut seen: HashSet<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
230            let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
231
232            // Seeds get score 1.0
233            for &(id, _) in &seeds {
234                graph_ranked.push((id, 1.0));
235            }
236
237            // Multi-hop expansion with exponential decay
238            let mut frontier: Vec<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
239            let mut decay = 0.5_f32;
240            for _hop in 0..max_hops {
241                let mut next_frontier: Vec<Uuid> = Vec::new();
242                for &id in &frontier {
243                    let from_rels = engine.storage.get_relations_from(id).await?;
244                    let to_rels = engine.storage.get_relations_to(id).await?;
245                    for rel in from_rels.iter().chain(to_rels.iter()) {
246                        let related_id = if rel.source_id == id {
247                            rel.target_id
248                        } else {
249                            rel.source_id
250                        };
251                        if seen.insert(related_id)
252                            && let Some(record) = get_memory_cached(engine, related_id).await?
253                            && passes_filters(&record, &request, &agent_id, engine).await
254                        {
255                            graph_ranked.push((related_id, decay));
256                            next_frontier.push(related_id);
257                        }
258                    }
259                }
260                frontier = next_frontier;
261                decay *= 0.5;
262            }
263
264            // Use RRF fusion with vector + graph lists
265            let mut v_sorted: Vec<(Uuid, f32)> = seeds.clone();
266            v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
267            graph_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
268
269            let ranked_lists = vec![v_sorted, graph_ranked];
270            let rrf_k = request.rrf_k.unwrap_or(60.0);
271            let fused = if let Some(ref weights) = request.hybrid_weights {
272                crate::query::retrieval::weighted_reciprocal_rank_fusion(
273                    &ranked_lists,
274                    rrf_k,
275                    weights,
276                )
277            } else {
278                crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
279            };
280
281            for (id, score) in fused {
282                if let Some(record) = get_memory_cached(engine, id).await?
283                    && passes_filters(&record, &request, &agent_id, engine).await
284                {
285                    scored_memories.push((record, score));
286                }
287            }
288        }
289        "exact" => {
290            // Filter-based exact matching, no embedding needed
291            // When as_of is set, include deleted records so the as_of filter can evaluate them
292            let filter = MemoryFilter {
293                agent_id: Some(agent_id.clone()),
294                memory_type: request.memory_type,
295                scope: request.scope,
296                tags: request.tags.clone(),
297                min_importance: request.min_importance,
298                org_id: request.org_id.clone(),
299                thread_id: None,
300                include_deleted: request.as_of.is_some(),
301            };
302            let memories = engine.storage.list_memories(&filter, limit, 0).await?;
303            for record in memories {
304                if passes_filters(&record, &request, &agent_id, engine).await {
305                    scored_memories.push((record, 1.0));
306                }
307            }
308        }
309        _ => {
310            // "auto" or "hybrid" — use hybrid if full_text available, else semantic
311            let vector_results =
312                engine
313                    .index
314                    .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
315            let mut vector_ranked: Vec<(Uuid, f32)> = Vec::new();
316            for (id, distance) in vector_results {
317                vector_ranked.push((id, 1.0 - distance));
318            }
319
320            if let Some(ref ft) = engine.full_text {
321                // Hybrid: RRF fusion of vector + BM25 + recency
322                let bm25_results = ft.search(&request.query, limit * 3)?;
323
324                // Build recency-scored list from vector candidates
325                let mut recency_ranked: Vec<(Uuid, f32)> = Vec::new();
326                for &(id, _) in &vector_ranked {
327                    if let Some(record) = get_memory_cached(engine, id).await? {
328                        let r_score = crate::query::retrieval::recency_score(
329                            &record.created_at,
330                            request.recency_half_life_hours.unwrap_or(168.0),
331                        );
332                        recency_ranked.push((id, r_score));
333                    }
334                }
335                // Also add BM25 candidates to recency
336                for &(id, _) in &bm25_results {
337                    if !recency_ranked.iter().any(|(rid, _)| *rid == id)
338                        && let Some(record) = get_memory_cached(engine, id).await?
339                    {
340                        let r_score = crate::query::retrieval::recency_score(
341                            &record.created_at,
342                            request.recency_half_life_hours.unwrap_or(168.0),
343                        );
344                        recency_ranked.push((id, r_score));
345                    }
346                }
347
348                // Sort each list by score descending
349                let mut v_sorted = vector_ranked.clone();
350                v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
351                let mut b_sorted = bm25_results;
352                b_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
353                recency_ranked
354                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
355
356                // Graph expansion signal: from top-10 vector results, multi-hop expansion
357                let max_hops = 2;
358                let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
359                let top_seeds: Vec<Uuid> =
360                    vector_ranked.iter().take(10).map(|(id, _)| *id).collect();
361                let mut graph_seen: HashSet<Uuid> = top_seeds.iter().copied().collect();
362                for &seed_id in &top_seeds {
363                    graph_ranked.push((seed_id, 1.0));
364                }
365                let mut frontier: Vec<Uuid> = top_seeds;
366                let mut decay = 0.5_f32;
367                for _hop in 0..max_hops {
368                    let mut next_frontier: Vec<Uuid> = Vec::new();
369                    for &fid in &frontier {
370                        match engine.storage.get_relations_from(fid).await {
371                            Ok(from_rels) => {
372                                for rel in &from_rels {
373                                    if graph_seen.insert(rel.target_id) {
374                                        graph_ranked.push((rel.target_id, decay));
375                                        next_frontier.push(rel.target_id);
376                                    }
377                                }
378                            }
379                            Err(e) => {
380                                tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get outgoing relations");
381                            }
382                        }
383                        match engine.storage.get_relations_to(fid).await {
384                            Ok(to_rels) => {
385                                for rel in &to_rels {
386                                    if graph_seen.insert(rel.source_id) {
387                                        graph_ranked.push((rel.source_id, decay));
388                                        next_frontier.push(rel.source_id);
389                                    }
390                                }
391                            }
392                            Err(e) => {
393                                tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get incoming relations");
394                            }
395                        }
396                    }
397                    frontier = next_frontier;
398                    decay *= 0.5;
399                }
400                graph_ranked
401                    .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
402
403                // Capture per-signal score maps before moving the ranked lists
404                // into the fusion call, so `explain=true` can surface each
405                // signal's contribution in the response.
406                let explain = request.explain.unwrap_or(false);
407                type SignalMap = std::collections::HashMap<Uuid, f32>;
408                let (vector_map, bm25_map, recency_map, graph_map): (
409                    SignalMap,
410                    SignalMap,
411                    SignalMap,
412                    SignalMap,
413                ) = if explain {
414                    (
415                        v_sorted.iter().copied().collect(),
416                        b_sorted.iter().copied().collect(),
417                        recency_ranked.iter().copied().collect(),
418                        graph_ranked.iter().copied().collect(),
419                    )
420                } else {
421                    Default::default()
422                };
423
424                let ranked_lists = vec![v_sorted, b_sorted, recency_ranked, graph_ranked];
425                let rrf_k = request.rrf_k.unwrap_or(60.0);
426                let fused = if let Some(ref weights) = request.hybrid_weights {
427                    crate::query::retrieval::weighted_reciprocal_rank_fusion(
428                        &ranked_lists,
429                        rrf_k,
430                        weights,
431                    )
432                } else {
433                    crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
434                };
435
436                for (rank, (id, score)) in fused.into_iter().enumerate() {
437                    if let Some(record) = get_memory_cached(engine, id).await?
438                        && passes_filters(&record, &request, &agent_id, engine).await
439                    {
440                        scored_memories.push((record, score));
441                        if explain {
442                            breakdowns.insert(
443                                id,
444                                ScoreBreakdown {
445                                    vector: vector_map.get(&id).copied().unwrap_or(0.0),
446                                    bm25: bm25_map.get(&id).copied().unwrap_or(0.0),
447                                    graph: graph_map.get(&id).copied().unwrap_or(0.0),
448                                    recency: recency_map.get(&id).copied().unwrap_or(0.0),
449                                    rrf_rank: rank as u32,
450                                },
451                            );
452                        }
453                    }
454                }
455            } else {
456                // Fallback to semantic-only
457                for (id, score) in vector_ranked {
458                    if let Some(record) = get_memory_cached(engine, id).await?
459                        && passes_filters(&record, &request, &agent_id, engine).await
460                    {
461                        scored_memories.push((record, score));
462                    }
463                }
464            }
465        }
466    }
467
468    // Sort by score descending
469    scored_memories.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
470    scored_memories.truncate(limit);
471
472    let total = scored_memories.len();
473
474    // Touch accessed memories
475    for (record, _) in &scored_memories {
476        if let Err(e) = engine.storage.touch_memory(record.id).await {
477            tracing::warn!(memory_id = %record.id, error = %e, "failed to update access timestamp");
478        }
479    }
480
481    // Decrypt content if encryption is configured
482    if let Some(ref enc) = engine.encryption {
483        for (record, _) in &mut scored_memories {
484            match base64::engine::general_purpose::STANDARD.decode(&record.content) {
485                Ok(encrypted_bytes) => match enc.decrypt(&encrypted_bytes) {
486                    Ok(decrypted) => match String::from_utf8(decrypted) {
487                        Ok(plaintext) => record.content = plaintext,
488                        Err(e) => {
489                            tracing::error!(memory_id = %record.id, error = %e, "decrypted content is not valid UTF-8");
490                            record.content = "[content unavailable: decryption error]".to_string();
491                        }
492                    },
493                    Err(e) => {
494                        tracing::error!(memory_id = %record.id, error = %e, "failed to decrypt memory content");
495                        record.content = "[content unavailable: decryption error]".to_string();
496                    }
497                },
498                Err(e) => {
499                    tracing::error!(memory_id = %record.id, error = %e, "failed to decode encrypted content");
500                    record.content = "[content unavailable: decryption error]".to_string();
501                }
502            }
503        }
504    }
505
506    let memories: Vec<ScoredMemory> = scored_memories
507        .into_iter()
508        .map(|(record, score)| {
509            let id = record.id;
510            let mut scored = ScoredMemory::from((record, score));
511            if let Some(breakdown) = breakdowns.remove(&id) {
512                scored.score_breakdown = Some(breakdown);
513            }
514            scored
515        })
516        .collect();
517
518    // Emit MemoryRead event with hash chain linking (fire-and-forget)
519    let now = chrono::Utc::now().to_rfc3339();
520    let event_content_hash = compute_content_hash(&request.query, &agent_id, &now);
521    let prev_event_hash = match engine.storage.get_latest_event_hash(&agent_id, None).await {
522        Ok(hash) => hash,
523        Err(e) => {
524            tracing::warn!(error = %e, "failed to get latest event hash, starting new chain segment");
525            None
526        }
527    };
528    let event_prev_hash = Some(crate::hash::compute_chain_hash(
529        &event_content_hash,
530        prev_event_hash.as_deref(),
531    ));
532    let mut event = AgentEvent {
533        id: Uuid::now_v7(),
534        agent_id: agent_id.clone(),
535        thread_id: None,
536        run_id: None,
537        parent_event_id: None,
538        event_type: EventType::MemoryRead,
539        payload: serde_json::json!({
540            "query": request.query,
541            "results": total,
542            "strategy": strategy,
543        }),
544        trace_id: None,
545        span_id: None,
546        model: None,
547        tokens_input: None,
548        tokens_output: None,
549        latency_ms: None,
550        cost_usd: None,
551        timestamp: now.clone(),
552        logical_clock: 0,
553        content_hash: event_content_hash,
554        prev_hash: event_prev_hash,
555        embedding: None,
556    };
557    // Optionally embed the event payload
558    if engine.embed_events
559        && let Ok(emb) = engine.embedding.embed(&event.payload.to_string()).await
560    {
561        event.embedding = Some(emb);
562    }
563    if let Err(e) = engine.storage.insert_event(&event).await {
564        tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
565    }
566
567    Ok(RecallResponse { memories, total })
568}
569
570async fn passes_filters(
571    record: &MemoryRecord,
572    request: &RecallRequest,
573    agent_id: &str,
574    engine: &MnemoEngine,
575) -> bool {
576    // Skip deleted (unless as_of is set — the as_of filter handles deleted records)
577    if request.as_of.is_none() && record.is_deleted() {
578        return false;
579    }
580
581    // Skip expired
582    if let Some(ref expires_at) = record.expires_at
583        && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires_at)
584        && exp < chrono::Utc::now()
585    {
586        return false;
587    }
588
589    // Skip quarantined
590    if record.quarantined {
591        return false;
592    }
593
594    // Scope filter (explicit request scope filter, separate from visibility below)
595    if let Some(ref s) = request.scope
596        && record.scope != *s
597    {
598        return false;
599    }
600
601    // Type filter: memory_types (multi) takes precedence over memory_type (single)
602    if let Some(ref mts) = request.memory_types {
603        if !mts.contains(&record.memory_type) {
604            return false;
605        }
606    } else if let Some(ref mt) = request.memory_type
607        && record.memory_type != *mt
608    {
609        return false;
610    }
611
612    // Importance filter
613    if let Some(min_imp) = request.min_importance
614        && record.importance < min_imp
615    {
616        return false;
617    }
618
619    // Tags filter
620    if let Some(ref req_tags) = request.tags
621        && !req_tags.iter().any(|t| record.tags.contains(t))
622    {
623        return false;
624    }
625
626    // Temporal range filter (parse to DateTime for correct comparison)
627    if let Some(ref tr) = request.temporal_range {
628        if let Some(ref after) = tr.after
629            && let (Ok(after_dt), Ok(record_dt)) = (
630                chrono::DateTime::parse_from_rfc3339(after),
631                chrono::DateTime::parse_from_rfc3339(&record.created_at),
632            )
633            && record_dt < after_dt
634        {
635            return false;
636        }
637        if let Some(ref before) = tr.before
638            && let (Ok(before_dt), Ok(record_dt)) = (
639                chrono::DateTime::parse_from_rfc3339(before),
640                chrono::DateTime::parse_from_rfc3339(&record.created_at),
641            )
642            && record_dt > before_dt
643        {
644            return false;
645        }
646    }
647
648    // Point-in-time as_of filter: show memory state at time T
649    if let Some(ref as_of) = request.as_of {
650        if let (Ok(as_of_dt), Ok(record_dt)) = (
651            chrono::DateTime::parse_from_rfc3339(as_of),
652            chrono::DateTime::parse_from_rfc3339(&record.created_at),
653        ) && record_dt > as_of_dt
654        {
655            // Exclude memories created after as_of
656            return false;
657        }
658        // Exclude memories already deleted at as_of
659        if let Some(ref deleted_at) = record.deleted_at
660            && let (Ok(del_dt), Ok(as_of_dt)) = (
661                chrono::DateTime::parse_from_rfc3339(deleted_at),
662                chrono::DateTime::parse_from_rfc3339(as_of),
663            )
664            && del_dt <= as_of_dt
665        {
666            return false;
667        }
668    }
669
670    // Scope-based visibility
671    match record.scope {
672        Scope::Public | Scope::Global => true,
673        Scope::Shared => {
674            record.agent_id == agent_id
675                || engine
676                    .storage
677                    .check_permission(
678                        record.id,
679                        agent_id,
680                        crate::model::acl::Permission::Read,
681                    )
682                    .await
683                    .unwrap_or_else(|e| {
684                        tracing::warn!(memory_id = %record.id, error = %e, "permission check failed, denying access");
685                        false
686                    })
687        }
688        Scope::Private => record.agent_id == agent_id,
689    }
690}