1use std::collections::HashSet;
2
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6use crate::error::Result;
7use crate::hash::compute_content_hash;
8use crate::model::event::{AgentEvent, EventType};
9use crate::model::memory::{MemoryRecord, MemoryType, Scope};
10use crate::query::MnemoEngine;
11use crate::storage::MemoryFilter;
12#[allow(unused_imports)]
13use base64::Engine as _;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct TemporalRange {
17 pub after: Option<String>,
18 pub before: Option<String>,
19}
20
21impl TemporalRange {
22 pub fn new() -> Self {
23 Self::default()
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RecallRequest {
29 pub query: String,
30 pub agent_id: Option<String>,
31 pub limit: Option<usize>,
32 pub memory_type: Option<MemoryType>,
33 pub memory_types: Option<Vec<MemoryType>>,
34 pub scope: Option<Scope>,
35 pub min_importance: Option<f32>,
36 pub tags: Option<Vec<String>>,
37 pub org_id: Option<String>,
38 pub strategy: Option<String>,
39 pub temporal_range: Option<TemporalRange>,
40 pub recency_half_life_hours: Option<f64>,
41 pub hybrid_weights: Option<Vec<f32>>,
42 pub rrf_k: Option<f32>,
43 pub as_of: Option<String>,
44 pub explain: Option<bool>,
48 pub with_provenance: Option<bool>,
55}
56
57impl RecallRequest {
58 pub fn new(query: String) -> Self {
59 Self {
60 query,
61 agent_id: None,
62 limit: None,
63 memory_type: None,
64 memory_types: None,
65 scope: None,
66 min_importance: None,
67 tags: None,
68 org_id: None,
69 strategy: None,
70 temporal_range: None,
71 recency_half_life_hours: None,
72 hybrid_weights: None,
73 rrf_k: None,
74 as_of: None,
75 explain: None,
76 with_provenance: None,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
87pub struct ScoreBreakdown {
88 pub vector: f32,
89 pub bm25: f32,
90 pub graph: f32,
91 pub recency: f32,
92 pub rrf_rank: u32,
94}
95
96#[non_exhaustive]
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct RecallResponse {
99 pub memories: Vec<ScoredMemory>,
100 pub total: usize,
101 #[serde(skip_serializing_if = "Option::is_none", default)]
106 pub provenance: Option<crate::provenance::ReadProvenance>,
107}
108
109impl RecallResponse {
110 pub fn new(memories: Vec<ScoredMemory>, total: usize) -> Self {
111 Self {
112 memories,
113 total,
114 provenance: None,
115 }
116 }
117}
118
119#[non_exhaustive]
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct ScoredMemory {
122 pub id: Uuid,
123 pub content: String,
124 pub agent_id: String,
125 pub memory_type: MemoryType,
126 pub scope: Scope,
127 pub importance: f32,
128 pub tags: Vec<String>,
129 pub metadata: serde_json::Value,
130 pub score: f32,
131 pub access_count: u64,
132 pub created_at: String,
133 pub updated_at: String,
134 #[serde(skip_serializing_if = "Option::is_none")]
135 pub score_breakdown: Option<ScoreBreakdown>,
136}
137
138impl From<(MemoryRecord, f32)> for ScoredMemory {
139 fn from((record, score): (MemoryRecord, f32)) -> Self {
140 Self {
141 id: record.id,
142 content: record.content,
143 agent_id: record.agent_id,
144 memory_type: record.memory_type,
145 scope: record.scope,
146 importance: record.importance,
147 tags: record.tags,
148 metadata: record.metadata,
149 score,
150 access_count: record.access_count,
151 created_at: record.created_at,
152 updated_at: record.updated_at,
153 score_breakdown: None,
154 }
155 }
156}
157
158async fn get_memory_cached(engine: &MnemoEngine, id: Uuid) -> Result<Option<MemoryRecord>> {
160 if let Some(ref cache) = engine.cache
161 && let Some(record) = cache.get(id)
162 {
163 return Ok(Some(record));
164 }
165 let result = engine.storage.get_memory(id).await?;
166 if let Some(ref record) = result
167 && let Some(ref cache) = engine.cache
168 {
169 cache.put(record.clone());
170 }
171 Ok(result)
172}
173
174pub async fn execute(engine: &MnemoEngine, request: RecallRequest) -> Result<RecallResponse> {
175 let limit = request.limit.unwrap_or(10).min(100);
176 let agent_id = request
177 .agent_id
178 .clone()
179 .unwrap_or_else(|| engine.default_agent_id.clone());
180 super::validate_agent_id(&agent_id)?;
181
182 let strategy = request.strategy.as_deref().unwrap_or("auto");
184
185 let query_embedding = engine.embedding.embed(&request.query).await?;
187
188 let accessible_ids: HashSet<Uuid> = engine
190 .storage
191 .list_accessible_memory_ids(&agent_id, super::MAX_BATCH_QUERY_LIMIT)
192 .await?
193 .into_iter()
194 .collect();
195 let perm_filter = |id: Uuid| accessible_ids.contains(&id);
196
197 let mut scored_memories: Vec<(MemoryRecord, f32)> = Vec::new();
198 let mut breakdowns: std::collections::HashMap<Uuid, ScoreBreakdown> =
199 std::collections::HashMap::new();
200
201 match strategy {
202 "lexical" => {
203 if let Some(ref ft) = engine.full_text {
205 let bm25_results = ft.search(&request.query, limit * 3)?;
206 for (id, score) in bm25_results {
207 if let Some(record) = get_memory_cached(engine, id).await?
208 && passes_filters(&record, &request, &agent_id, engine).await
209 {
210 scored_memories.push((record, score));
211 }
212 }
213 }
214 }
215 "semantic" => {
216 let search_results =
218 engine
219 .index
220 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
221 for (id, distance) in search_results {
222 if let Some(record) = get_memory_cached(engine, id).await?
223 && passes_filters(&record, &request, &agent_id, engine).await
224 {
225 let score = 1.0 - distance;
226 scored_memories.push((record, score));
227 }
228 }
229 }
230 "graph" => {
231 let search_results =
233 engine
234 .index
235 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
236 let mut seeds: Vec<(Uuid, f32)> = Vec::new();
237 for (id, distance) in &search_results {
238 if let Some(record) = get_memory_cached(engine, *id).await?
239 && passes_filters(&record, &request, &agent_id, engine).await
240 {
241 seeds.push((*id, 1.0 - distance));
242 }
243 }
244
245 let max_hops = 2;
247 let mut seen: HashSet<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
248 let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
249
250 for &(id, _) in &seeds {
252 graph_ranked.push((id, 1.0));
253 }
254
255 let mut frontier: Vec<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
257 let mut decay = 0.5_f32;
258 for _hop in 0..max_hops {
259 let mut next_frontier: Vec<Uuid> = Vec::new();
260 for &id in &frontier {
261 let from_rels = engine.storage.get_relations_from(id).await?;
262 let to_rels = engine.storage.get_relations_to(id).await?;
263 for rel in from_rels.iter().chain(to_rels.iter()) {
264 let related_id = if rel.source_id == id {
265 rel.target_id
266 } else {
267 rel.source_id
268 };
269 if seen.insert(related_id)
270 && let Some(record) = get_memory_cached(engine, related_id).await?
271 && passes_filters(&record, &request, &agent_id, engine).await
272 {
273 graph_ranked.push((related_id, decay));
274 next_frontier.push(related_id);
275 }
276 }
277 }
278 frontier = next_frontier;
279 decay *= 0.5;
280 }
281
282 let mut v_sorted: Vec<(Uuid, f32)> = seeds.clone();
284 v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
285 graph_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
286
287 let ranked_lists = vec![v_sorted, graph_ranked];
288 let rrf_k = request.rrf_k.unwrap_or(60.0);
289 let fused = if let Some(ref weights) = request.hybrid_weights {
290 crate::query::retrieval::weighted_reciprocal_rank_fusion(
291 &ranked_lists,
292 rrf_k,
293 weights,
294 )
295 } else {
296 crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
297 };
298
299 for (id, score) in fused {
300 if let Some(record) = get_memory_cached(engine, id).await?
301 && passes_filters(&record, &request, &agent_id, engine).await
302 {
303 scored_memories.push((record, score));
304 }
305 }
306 }
307 "exact" => {
308 let filter = MemoryFilter {
311 agent_id: Some(agent_id.clone()),
312 memory_type: request.memory_type,
313 scope: request.scope,
314 tags: request.tags.clone(),
315 min_importance: request.min_importance,
316 org_id: request.org_id.clone(),
317 thread_id: None,
318 include_deleted: request.as_of.is_some(),
319 };
320 let memories = engine.storage.list_memories(&filter, limit, 0).await?;
321 for record in memories {
322 if passes_filters(&record, &request, &agent_id, engine).await {
323 scored_memories.push((record, 1.0));
324 }
325 }
326 }
327 _ => {
328 let vector_results =
330 engine
331 .index
332 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
333 let mut vector_ranked: Vec<(Uuid, f32)> = Vec::new();
334 for (id, distance) in vector_results {
335 vector_ranked.push((id, 1.0 - distance));
336 }
337
338 if let Some(ref ft) = engine.full_text {
339 let bm25_results = ft.search(&request.query, limit * 3)?;
341
342 let mut recency_ranked: Vec<(Uuid, f32)> = Vec::new();
344 for &(id, _) in &vector_ranked {
345 if let Some(record) = get_memory_cached(engine, id).await? {
346 let r_score = crate::query::retrieval::recency_score(
347 &record.created_at,
348 request.recency_half_life_hours.unwrap_or(168.0),
349 );
350 recency_ranked.push((id, r_score));
351 }
352 }
353 for &(id, _) in &bm25_results {
355 if !recency_ranked.iter().any(|(rid, _)| *rid == id)
356 && let Some(record) = get_memory_cached(engine, id).await?
357 {
358 let r_score = crate::query::retrieval::recency_score(
359 &record.created_at,
360 request.recency_half_life_hours.unwrap_or(168.0),
361 );
362 recency_ranked.push((id, r_score));
363 }
364 }
365
366 let mut v_sorted = vector_ranked.clone();
368 v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
369 let mut b_sorted = bm25_results;
370 b_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
371 recency_ranked
372 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
373
374 let max_hops = 2;
376 let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
377 let top_seeds: Vec<Uuid> =
378 vector_ranked.iter().take(10).map(|(id, _)| *id).collect();
379 let mut graph_seen: HashSet<Uuid> = top_seeds.iter().copied().collect();
380 for &seed_id in &top_seeds {
381 graph_ranked.push((seed_id, 1.0));
382 }
383 let mut frontier: Vec<Uuid> = top_seeds;
384 let mut decay = 0.5_f32;
385 for _hop in 0..max_hops {
386 let mut next_frontier: Vec<Uuid> = Vec::new();
387 for &fid in &frontier {
388 match engine.storage.get_relations_from(fid).await {
389 Ok(from_rels) => {
390 for rel in &from_rels {
391 if graph_seen.insert(rel.target_id) {
392 graph_ranked.push((rel.target_id, decay));
393 next_frontier.push(rel.target_id);
394 }
395 }
396 }
397 Err(e) => {
398 tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get outgoing relations");
399 }
400 }
401 match engine.storage.get_relations_to(fid).await {
402 Ok(to_rels) => {
403 for rel in &to_rels {
404 if graph_seen.insert(rel.source_id) {
405 graph_ranked.push((rel.source_id, decay));
406 next_frontier.push(rel.source_id);
407 }
408 }
409 }
410 Err(e) => {
411 tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get incoming relations");
412 }
413 }
414 }
415 frontier = next_frontier;
416 decay *= 0.5;
417 }
418 graph_ranked
419 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
420
421 let explain = request.explain.unwrap_or(false);
425 type SignalMap = std::collections::HashMap<Uuid, f32>;
426 let (vector_map, bm25_map, recency_map, graph_map): (
427 SignalMap,
428 SignalMap,
429 SignalMap,
430 SignalMap,
431 ) = if explain {
432 (
433 v_sorted.iter().copied().collect(),
434 b_sorted.iter().copied().collect(),
435 recency_ranked.iter().copied().collect(),
436 graph_ranked.iter().copied().collect(),
437 )
438 } else {
439 Default::default()
440 };
441
442 let ranked_lists = vec![v_sorted, b_sorted, recency_ranked, graph_ranked];
443 let rrf_k = request.rrf_k.unwrap_or(60.0);
444 let fused = if let Some(ref weights) = request.hybrid_weights {
445 crate::query::retrieval::weighted_reciprocal_rank_fusion(
446 &ranked_lists,
447 rrf_k,
448 weights,
449 )
450 } else {
451 crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
452 };
453
454 for (rank, (id, score)) in fused.into_iter().enumerate() {
455 if let Some(record) = get_memory_cached(engine, id).await?
456 && passes_filters(&record, &request, &agent_id, engine).await
457 {
458 scored_memories.push((record, score));
459 if explain {
460 breakdowns.insert(
461 id,
462 ScoreBreakdown {
463 vector: vector_map.get(&id).copied().unwrap_or(0.0),
464 bm25: bm25_map.get(&id).copied().unwrap_or(0.0),
465 graph: graph_map.get(&id).copied().unwrap_or(0.0),
466 recency: recency_map.get(&id).copied().unwrap_or(0.0),
467 rrf_rank: rank as u32,
468 },
469 );
470 }
471 }
472 }
473 } else {
474 for (id, score) in vector_ranked {
476 if let Some(record) = get_memory_cached(engine, id).await?
477 && passes_filters(&record, &request, &agent_id, engine).await
478 {
479 scored_memories.push((record, score));
480 }
481 }
482 }
483 }
484 }
485
486 scored_memories.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
488 scored_memories.truncate(limit);
489
490 let total = scored_memories.len();
491
492 for (record, _) in &scored_memories {
494 if let Err(e) = engine.storage.touch_memory(record.id).await {
495 tracing::warn!(memory_id = %record.id, error = %e, "failed to update access timestamp");
496 }
497 }
498
499 if let Some(ref enc) = engine.encryption {
501 for (record, _) in &mut scored_memories {
502 match base64::engine::general_purpose::STANDARD.decode(&record.content) {
503 Ok(encrypted_bytes) => match enc.decrypt(&encrypted_bytes) {
504 Ok(decrypted) => match String::from_utf8(decrypted) {
505 Ok(plaintext) => record.content = plaintext,
506 Err(e) => {
507 tracing::error!(memory_id = %record.id, error = %e, "decrypted content is not valid UTF-8");
508 record.content = "[content unavailable: decryption error]".to_string();
509 }
510 },
511 Err(e) => {
512 tracing::error!(memory_id = %record.id, error = %e, "failed to decrypt memory content");
513 record.content = "[content unavailable: decryption error]".to_string();
514 }
515 },
516 Err(e) => {
517 tracing::error!(memory_id = %record.id, error = %e, "failed to decode encrypted content");
518 record.content = "[content unavailable: decryption error]".to_string();
519 }
520 }
521 }
522 }
523
524 let provenance_records: Option<Vec<MemoryRecord>> =
529 if request.with_provenance == Some(true) && engine.provenance_signer.is_some() {
530 Some(scored_memories.iter().map(|(r, _)| r.clone()).collect())
531 } else {
532 None
533 };
534
535 let memories: Vec<ScoredMemory> = scored_memories
536 .into_iter()
537 .map(|(record, score)| {
538 let id = record.id;
539 let mut scored = ScoredMemory::from((record, score));
540 if let Some(breakdown) = breakdowns.remove(&id) {
541 scored.score_breakdown = Some(breakdown);
542 }
543 scored
544 })
545 .collect();
546
547 let now = chrono::Utc::now().to_rfc3339();
549 let event_content_hash = compute_content_hash(&request.query, &agent_id, &now);
550 let prev_event_hash = match engine.storage.get_latest_event_hash(&agent_id, None).await {
551 Ok(hash) => hash,
552 Err(e) => {
553 tracing::warn!(error = %e, "failed to get latest event hash, starting new chain segment");
554 None
555 }
556 };
557 let event_prev_hash = Some(crate::hash::compute_chain_hash(
558 &event_content_hash,
559 prev_event_hash.as_deref(),
560 ));
561 let mut event = AgentEvent {
562 id: Uuid::now_v7(),
563 agent_id: agent_id.clone(),
564 thread_id: None,
565 run_id: None,
566 parent_event_id: None,
567 event_type: EventType::MemoryRead,
568 payload: serde_json::json!({
569 "query": request.query,
570 "results": total,
571 "strategy": strategy,
572 }),
573 trace_id: None,
574 span_id: None,
575 model: None,
576 tokens_input: None,
577 tokens_output: None,
578 latency_ms: None,
579 cost_usd: None,
580 timestamp: now.clone(),
581 logical_clock: 0,
582 content_hash: event_content_hash,
583 prev_hash: event_prev_hash,
584 embedding: None,
585 };
586 if engine.embed_events
588 && let Ok(emb) = engine.embedding.embed(&event.payload.to_string()).await
589 {
590 event.embedding = Some(emb);
591 }
592 if let Err(e) = engine.storage.insert_event(&event).await {
593 tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
594 }
595
596 let provenance = if let (Some(records), Some(signer)) =
601 (provenance_records, engine.provenance_signer.as_ref())
602 {
603 match signer.sign(&agent_id, &request.query, &records) {
604 Ok(p) => Some(p),
605 Err(e) => {
606 tracing::warn!(error = %e, "failed to sign read provenance; degrading to no-provenance response");
607 None
608 }
609 }
610 } else {
611 None
612 };
613
614 Ok(RecallResponse {
615 memories,
616 total,
617 provenance,
618 })
619}
620
621async fn passes_filters(
622 record: &MemoryRecord,
623 request: &RecallRequest,
624 agent_id: &str,
625 engine: &MnemoEngine,
626) -> bool {
627 if request.as_of.is_none() && record.is_deleted() {
629 return false;
630 }
631
632 if let Some(ref expires_at) = record.expires_at
634 && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires_at)
635 && exp < chrono::Utc::now()
636 {
637 return false;
638 }
639
640 if record.quarantined {
642 return false;
643 }
644
645 if let Some(ref s) = request.scope
647 && record.scope != *s
648 {
649 return false;
650 }
651
652 if let Some(ref mts) = request.memory_types {
654 if !mts.contains(&record.memory_type) {
655 return false;
656 }
657 } else if let Some(ref mt) = request.memory_type
658 && record.memory_type != *mt
659 {
660 return false;
661 }
662
663 if let Some(min_imp) = request.min_importance
665 && record.importance < min_imp
666 {
667 return false;
668 }
669
670 if let Some(ref req_tags) = request.tags
672 && !req_tags.iter().any(|t| record.tags.contains(t))
673 {
674 return false;
675 }
676
677 if let Some(ref tr) = request.temporal_range {
679 if let Some(ref after) = tr.after
680 && let (Ok(after_dt), Ok(record_dt)) = (
681 chrono::DateTime::parse_from_rfc3339(after),
682 chrono::DateTime::parse_from_rfc3339(&record.created_at),
683 )
684 && record_dt < after_dt
685 {
686 return false;
687 }
688 if let Some(ref before) = tr.before
689 && let (Ok(before_dt), Ok(record_dt)) = (
690 chrono::DateTime::parse_from_rfc3339(before),
691 chrono::DateTime::parse_from_rfc3339(&record.created_at),
692 )
693 && record_dt > before_dt
694 {
695 return false;
696 }
697 }
698
699 if let Some(ref as_of) = request.as_of {
701 if let (Ok(as_of_dt), Ok(record_dt)) = (
702 chrono::DateTime::parse_from_rfc3339(as_of),
703 chrono::DateTime::parse_from_rfc3339(&record.created_at),
704 ) && record_dt > as_of_dt
705 {
706 return false;
708 }
709 if let Some(ref deleted_at) = record.deleted_at
711 && let (Ok(del_dt), Ok(as_of_dt)) = (
712 chrono::DateTime::parse_from_rfc3339(deleted_at),
713 chrono::DateTime::parse_from_rfc3339(as_of),
714 )
715 && del_dt <= as_of_dt
716 {
717 return false;
718 }
719 }
720
721 match record.scope {
723 Scope::Public | Scope::Global => true,
724 Scope::Shared => {
725 record.agent_id == agent_id
726 || engine
727 .storage
728 .check_permission(
729 record.id,
730 agent_id,
731 crate::model::acl::Permission::Read,
732 )
733 .await
734 .unwrap_or_else(|e| {
735 tracing::warn!(memory_id = %record.id, error = %e, "permission check failed, denying access");
736 false
737 })
738 }
739 Scope::Private => record.agent_id == agent_id,
740 }
741}