1use std::collections::HashSet;
2
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6use crate::error::Result;
7use crate::hash::compute_content_hash;
8use crate::model::event::{AgentEvent, EventType};
9use crate::model::memory::{MemoryRecord, MemoryType, Scope};
10use crate::query::MnemoEngine;
11use crate::storage::MemoryFilter;
12#[allow(unused_imports)]
13use base64::Engine as _;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct TemporalRange {
17 pub after: Option<String>,
18 pub before: Option<String>,
19}
20
21impl TemporalRange {
22 pub fn new() -> Self {
23 Self::default()
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RecallRequest {
29 pub query: String,
30 pub agent_id: Option<String>,
31 pub limit: Option<usize>,
32 pub memory_type: Option<MemoryType>,
33 pub memory_types: Option<Vec<MemoryType>>,
34 pub scope: Option<Scope>,
35 pub min_importance: Option<f32>,
36 pub tags: Option<Vec<String>>,
37 pub org_id: Option<String>,
38 pub strategy: Option<String>,
39 pub temporal_range: Option<TemporalRange>,
40 pub recency_half_life_hours: Option<f64>,
41 pub hybrid_weights: Option<Vec<f32>>,
42 pub rrf_k: Option<f32>,
43 pub as_of: Option<String>,
44 pub explain: Option<bool>,
48 pub with_provenance: Option<bool>,
55 #[serde(default, skip_serializing_if = "Option::is_none")]
61 pub mode: Option<crate::retrieval::RetrievalMode>,
62}
63
64impl RecallRequest {
65 pub fn new(query: String) -> Self {
66 Self {
67 query,
68 agent_id: None,
69 limit: None,
70 memory_type: None,
71 memory_types: None,
72 scope: None,
73 min_importance: None,
74 tags: None,
75 org_id: None,
76 strategy: None,
77 temporal_range: None,
78 recency_half_life_hours: None,
79 hybrid_weights: None,
80 rrf_k: None,
81 as_of: None,
82 explain: None,
83 with_provenance: None,
84 mode: None,
85 }
86 }
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
95pub struct ScoreBreakdown {
96 pub vector: f32,
97 pub bm25: f32,
98 pub graph: f32,
99 pub recency: f32,
100 pub rrf_rank: u32,
102}
103
104#[non_exhaustive]
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct RecallResponse {
107 pub memories: Vec<ScoredMemory>,
108 pub total: usize,
109 #[serde(skip_serializing_if = "Option::is_none", default)]
114 pub provenance: Option<crate::provenance::ReadProvenance>,
115}
116
117impl RecallResponse {
118 pub fn new(memories: Vec<ScoredMemory>, total: usize) -> Self {
119 Self {
120 memories,
121 total,
122 provenance: None,
123 }
124 }
125}
126
127#[non_exhaustive]
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ScoredMemory {
130 pub id: Uuid,
131 pub content: String,
132 pub agent_id: String,
133 pub memory_type: MemoryType,
134 pub scope: Scope,
135 pub importance: f32,
136 pub tags: Vec<String>,
137 pub metadata: serde_json::Value,
138 pub score: f32,
139 pub access_count: u64,
140 pub created_at: String,
141 pub updated_at: String,
142 #[serde(skip_serializing_if = "Option::is_none")]
143 pub score_breakdown: Option<ScoreBreakdown>,
144}
145
146impl From<(MemoryRecord, f32)> for ScoredMemory {
147 fn from((record, score): (MemoryRecord, f32)) -> Self {
148 Self {
149 id: record.id,
150 content: record.content,
151 agent_id: record.agent_id,
152 memory_type: record.memory_type,
153 scope: record.scope,
154 importance: record.importance,
155 tags: record.tags,
156 metadata: record.metadata,
157 score,
158 access_count: record.access_count,
159 created_at: record.created_at,
160 updated_at: record.updated_at,
161 score_breakdown: None,
162 }
163 }
164}
165
166async fn get_memory_cached(engine: &MnemoEngine, id: Uuid) -> Result<Option<MemoryRecord>> {
168 if let Some(ref cache) = engine.cache
169 && let Some(record) = cache.get(id)
170 {
171 return Ok(Some(record));
172 }
173 let result = engine.storage.get_memory(id).await?;
174 if let Some(ref record) = result
175 && let Some(ref cache) = engine.cache
176 {
177 cache.put(record.clone());
178 }
179 Ok(result)
180}
181
182pub async fn execute(engine: &MnemoEngine, request: RecallRequest) -> Result<RecallResponse> {
183 let limit = request.limit.unwrap_or(10).min(100);
184 let agent_id = request
185 .agent_id
186 .clone()
187 .unwrap_or_else(|| engine.default_agent_id.clone());
188 super::validate_agent_id(&agent_id)?;
189
190 let strategy = if let Some(ref mode) = request.mode {
195 mode.to_strategy_str()
196 } else {
197 request.strategy.as_deref().unwrap_or("auto")
198 };
199
200 let query_embedding = engine.embedding.embed(&request.query).await?;
202
203 let accessible_ids: HashSet<Uuid> = engine
205 .storage
206 .list_accessible_memory_ids(&agent_id, super::MAX_BATCH_QUERY_LIMIT)
207 .await?
208 .into_iter()
209 .collect();
210 let perm_filter = |id: Uuid| accessible_ids.contains(&id);
211
212 let mut scored_memories: Vec<(MemoryRecord, f32)> = Vec::new();
213 let mut breakdowns: std::collections::HashMap<Uuid, ScoreBreakdown> =
214 std::collections::HashMap::new();
215
216 match strategy {
217 "lexical" => {
218 if let Some(ref ft) = engine.full_text {
220 let bm25_results = ft.search(&request.query, limit * 3)?;
221 for (id, score) in bm25_results {
222 if let Some(record) = get_memory_cached(engine, id).await?
223 && passes_filters(&record, &request, &agent_id, engine).await
224 {
225 scored_memories.push((record, score));
226 }
227 }
228 }
229 }
230 "semantic" => {
231 let search_results =
233 engine
234 .index
235 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
236 for (id, distance) in search_results {
237 if let Some(record) = get_memory_cached(engine, id).await?
238 && passes_filters(&record, &request, &agent_id, engine).await
239 {
240 let score = 1.0 - distance;
241 scored_memories.push((record, score));
242 }
243 }
244 }
245 "graph" => {
246 let search_results =
248 engine
249 .index
250 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
251 let mut seeds: Vec<(Uuid, f32)> = Vec::new();
252 for (id, distance) in &search_results {
253 if let Some(record) = get_memory_cached(engine, *id).await?
254 && passes_filters(&record, &request, &agent_id, engine).await
255 {
256 seeds.push((*id, 1.0 - distance));
257 }
258 }
259
260 let max_hops = 2;
262 let mut seen: HashSet<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
263 let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
264
265 for &(id, _) in &seeds {
267 graph_ranked.push((id, 1.0));
268 }
269
270 let mut frontier: Vec<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
272 let mut decay = 0.5_f32;
273 for _hop in 0..max_hops {
274 let mut next_frontier: Vec<Uuid> = Vec::new();
275 for &id in &frontier {
276 let from_rels = engine.storage.get_relations_from(id).await?;
277 let to_rels = engine.storage.get_relations_to(id).await?;
278 for rel in from_rels.iter().chain(to_rels.iter()) {
279 let related_id = if rel.source_id == id {
280 rel.target_id
281 } else {
282 rel.source_id
283 };
284 if seen.insert(related_id)
285 && let Some(record) = get_memory_cached(engine, related_id).await?
286 && passes_filters(&record, &request, &agent_id, engine).await
287 {
288 graph_ranked.push((related_id, decay));
289 next_frontier.push(related_id);
290 }
291 }
292 }
293 frontier = next_frontier;
294 decay *= 0.5;
295 }
296
297 let mut v_sorted: Vec<(Uuid, f32)> = seeds.clone();
299 v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
300 graph_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
301
302 let ranked_lists = vec![v_sorted, graph_ranked];
303 let rrf_k = request.rrf_k.unwrap_or(60.0);
304 let fused = if let Some(ref weights) = request.hybrid_weights {
305 crate::query::retrieval::weighted_reciprocal_rank_fusion(
306 &ranked_lists,
307 rrf_k,
308 weights,
309 )
310 } else {
311 crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
312 };
313
314 for (id, score) in fused {
315 if let Some(record) = get_memory_cached(engine, id).await?
316 && passes_filters(&record, &request, &agent_id, engine).await
317 {
318 scored_memories.push((record, score));
319 }
320 }
321 }
322 "exact" => {
323 let filter = MemoryFilter {
326 agent_id: Some(agent_id.clone()),
327 memory_type: request.memory_type,
328 scope: request.scope,
329 tags: request.tags.clone(),
330 min_importance: request.min_importance,
331 org_id: request.org_id.clone(),
332 thread_id: None,
333 include_deleted: request.as_of.is_some(),
334 };
335 let memories = engine.storage.list_memories(&filter, limit, 0).await?;
336 for record in memories {
337 if passes_filters(&record, &request, &agent_id, engine).await {
338 scored_memories.push((record, 1.0));
339 }
340 }
341 }
342 _ => {
343 let vector_results =
345 engine
346 .index
347 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
348 let mut vector_ranked: Vec<(Uuid, f32)> = Vec::new();
349 for (id, distance) in vector_results {
350 vector_ranked.push((id, 1.0 - distance));
351 }
352
353 if let Some(ref ft) = engine.full_text {
354 let bm25_results = ft.search(&request.query, limit * 3)?;
356
357 let mut recency_ranked: Vec<(Uuid, f32)> = Vec::new();
359 for &(id, _) in &vector_ranked {
360 if let Some(record) = get_memory_cached(engine, id).await? {
361 let r_score = crate::query::retrieval::recency_score(
362 &record.created_at,
363 request.recency_half_life_hours.unwrap_or(168.0),
364 );
365 recency_ranked.push((id, r_score));
366 }
367 }
368 for &(id, _) in &bm25_results {
370 if !recency_ranked.iter().any(|(rid, _)| *rid == id)
371 && let Some(record) = get_memory_cached(engine, id).await?
372 {
373 let r_score = crate::query::retrieval::recency_score(
374 &record.created_at,
375 request.recency_half_life_hours.unwrap_or(168.0),
376 );
377 recency_ranked.push((id, r_score));
378 }
379 }
380
381 let mut v_sorted = vector_ranked.clone();
383 v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
384 let mut b_sorted = bm25_results;
385 b_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
386 recency_ranked
387 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
388
389 let max_hops = 2;
391 let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
392 let top_seeds: Vec<Uuid> =
393 vector_ranked.iter().take(10).map(|(id, _)| *id).collect();
394 let mut graph_seen: HashSet<Uuid> = top_seeds.iter().copied().collect();
395 for &seed_id in &top_seeds {
396 graph_ranked.push((seed_id, 1.0));
397 }
398 let mut frontier: Vec<Uuid> = top_seeds;
399 let mut decay = 0.5_f32;
400 for _hop in 0..max_hops {
401 let mut next_frontier: Vec<Uuid> = Vec::new();
402 for &fid in &frontier {
403 match engine.storage.get_relations_from(fid).await {
404 Ok(from_rels) => {
405 for rel in &from_rels {
406 if graph_seen.insert(rel.target_id) {
407 graph_ranked.push((rel.target_id, decay));
408 next_frontier.push(rel.target_id);
409 }
410 }
411 }
412 Err(e) => {
413 tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get outgoing relations");
414 }
415 }
416 match engine.storage.get_relations_to(fid).await {
417 Ok(to_rels) => {
418 for rel in &to_rels {
419 if graph_seen.insert(rel.source_id) {
420 graph_ranked.push((rel.source_id, decay));
421 next_frontier.push(rel.source_id);
422 }
423 }
424 }
425 Err(e) => {
426 tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get incoming relations");
427 }
428 }
429 }
430 frontier = next_frontier;
431 decay *= 0.5;
432 }
433 graph_ranked
434 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
435
436 let explain = request.explain.unwrap_or(false);
440 type SignalMap = std::collections::HashMap<Uuid, f32>;
441 let (vector_map, bm25_map, recency_map, graph_map): (
442 SignalMap,
443 SignalMap,
444 SignalMap,
445 SignalMap,
446 ) = if explain {
447 (
448 v_sorted.iter().copied().collect(),
449 b_sorted.iter().copied().collect(),
450 recency_ranked.iter().copied().collect(),
451 graph_ranked.iter().copied().collect(),
452 )
453 } else {
454 Default::default()
455 };
456
457 let ranked_lists = vec![v_sorted, b_sorted, recency_ranked, graph_ranked];
458 let rrf_k = request.rrf_k.unwrap_or(60.0);
459 let fused = if let Some(ref weights) = request.hybrid_weights {
460 crate::query::retrieval::weighted_reciprocal_rank_fusion(
461 &ranked_lists,
462 rrf_k,
463 weights,
464 )
465 } else {
466 crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
467 };
468
469 for (rank, (id, score)) in fused.into_iter().enumerate() {
470 if let Some(record) = get_memory_cached(engine, id).await?
471 && passes_filters(&record, &request, &agent_id, engine).await
472 {
473 scored_memories.push((record, score));
474 if explain {
475 breakdowns.insert(
476 id,
477 ScoreBreakdown {
478 vector: vector_map.get(&id).copied().unwrap_or(0.0),
479 bm25: bm25_map.get(&id).copied().unwrap_or(0.0),
480 graph: graph_map.get(&id).copied().unwrap_or(0.0),
481 recency: recency_map.get(&id).copied().unwrap_or(0.0),
482 rrf_rank: rank as u32,
483 },
484 );
485 }
486 }
487 }
488 } else {
489 for (id, score) in vector_ranked {
491 if let Some(record) = get_memory_cached(engine, id).await?
492 && passes_filters(&record, &request, &agent_id, engine).await
493 {
494 scored_memories.push((record, score));
495 }
496 }
497 }
498 }
499 }
500
501 scored_memories.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
503 scored_memories.truncate(limit);
504
505 let total = scored_memories.len();
506
507 for (record, _) in &scored_memories {
509 if let Err(e) = engine.storage.touch_memory(record.id).await {
510 tracing::warn!(memory_id = %record.id, error = %e, "failed to update access timestamp");
511 }
512 }
513
514 if let Some(ref enc) = engine.encryption {
516 for (record, _) in &mut scored_memories {
517 match base64::engine::general_purpose::STANDARD.decode(&record.content) {
518 Ok(encrypted_bytes) => match enc.decrypt(&encrypted_bytes) {
519 Ok(decrypted) => match String::from_utf8(decrypted) {
520 Ok(plaintext) => record.content = plaintext,
521 Err(e) => {
522 tracing::error!(memory_id = %record.id, error = %e, "decrypted content is not valid UTF-8");
523 record.content = "[content unavailable: decryption error]".to_string();
524 }
525 },
526 Err(e) => {
527 tracing::error!(memory_id = %record.id, error = %e, "failed to decrypt memory content");
528 record.content = "[content unavailable: decryption error]".to_string();
529 }
530 },
531 Err(e) => {
532 tracing::error!(memory_id = %record.id, error = %e, "failed to decode encrypted content");
533 record.content = "[content unavailable: decryption error]".to_string();
534 }
535 }
536 }
537 }
538
539 let provenance_records: Option<Vec<MemoryRecord>> =
544 if request.with_provenance == Some(true) && engine.provenance_signer.is_some() {
545 Some(scored_memories.iter().map(|(r, _)| r.clone()).collect())
546 } else {
547 None
548 };
549
550 let memories: Vec<ScoredMemory> = scored_memories
551 .into_iter()
552 .map(|(record, score)| {
553 let id = record.id;
554 let mut scored = ScoredMemory::from((record, score));
555 if let Some(breakdown) = breakdowns.remove(&id) {
556 scored.score_breakdown = Some(breakdown);
557 }
558 scored
559 })
560 .collect();
561
562 let now = chrono::Utc::now().to_rfc3339();
564 let event_content_hash = compute_content_hash(&request.query, &agent_id, &now);
565 let prev_event_hash = match engine.storage.get_latest_event_hash(&agent_id, None).await {
566 Ok(hash) => hash,
567 Err(e) => {
568 tracing::warn!(error = %e, "failed to get latest event hash, starting new chain segment");
569 None
570 }
571 };
572 let event_prev_hash = Some(crate::hash::compute_chain_hash(
573 &event_content_hash,
574 prev_event_hash.as_deref(),
575 ));
576 let mut event = AgentEvent {
577 id: Uuid::now_v7(),
578 agent_id: agent_id.clone(),
579 thread_id: None,
580 run_id: None,
581 parent_event_id: None,
582 event_type: EventType::MemoryRead,
583 payload: serde_json::json!({
584 "query": request.query,
585 "results": total,
586 "strategy": strategy,
587 }),
588 trace_id: None,
589 span_id: None,
590 model: None,
591 tokens_input: None,
592 tokens_output: None,
593 latency_ms: None,
594 cost_usd: None,
595 timestamp: now.clone(),
596 logical_clock: 0,
597 content_hash: event_content_hash,
598 prev_hash: event_prev_hash,
599 embedding: None,
600 };
601 if engine.embed_events
603 && let Ok(emb) = engine.embedding.embed(&event.payload.to_string()).await
604 {
605 event.embedding = Some(emb);
606 }
607 if let Err(e) = engine.storage.insert_event(&event).await {
608 tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
609 }
610
611 let provenance = if let (Some(records), Some(signer)) =
616 (provenance_records, engine.provenance_signer.as_ref())
617 {
618 match signer.sign(&agent_id, &request.query, &records) {
619 Ok(p) => Some(p),
620 Err(e) => {
621 tracing::warn!(error = %e, "failed to sign read provenance; degrading to no-provenance response");
622 None
623 }
624 }
625 } else {
626 None
627 };
628
629 Ok(RecallResponse {
630 memories,
631 total,
632 provenance,
633 })
634}
635
636async fn passes_filters(
637 record: &MemoryRecord,
638 request: &RecallRequest,
639 agent_id: &str,
640 engine: &MnemoEngine,
641) -> bool {
642 if request.as_of.is_none() && record.is_deleted() {
644 return false;
645 }
646
647 if let Some(ref expires_at) = record.expires_at
649 && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires_at)
650 && exp < chrono::Utc::now()
651 {
652 return false;
653 }
654
655 if record.quarantined {
657 return false;
658 }
659
660 if let Some(ref s) = request.scope
662 && record.scope != *s
663 {
664 return false;
665 }
666
667 if let Some(ref mts) = request.memory_types {
669 if !mts.contains(&record.memory_type) {
670 return false;
671 }
672 } else if let Some(ref mt) = request.memory_type
673 && record.memory_type != *mt
674 {
675 return false;
676 }
677
678 if let Some(min_imp) = request.min_importance
680 && record.importance < min_imp
681 {
682 return false;
683 }
684
685 if let Some(ref req_tags) = request.tags
687 && !req_tags.iter().any(|t| record.tags.contains(t))
688 {
689 return false;
690 }
691
692 if let Some(ref tr) = request.temporal_range {
694 if let Some(ref after) = tr.after
695 && let (Ok(after_dt), Ok(record_dt)) = (
696 chrono::DateTime::parse_from_rfc3339(after),
697 chrono::DateTime::parse_from_rfc3339(&record.created_at),
698 )
699 && record_dt < after_dt
700 {
701 return false;
702 }
703 if let Some(ref before) = tr.before
704 && let (Ok(before_dt), Ok(record_dt)) = (
705 chrono::DateTime::parse_from_rfc3339(before),
706 chrono::DateTime::parse_from_rfc3339(&record.created_at),
707 )
708 && record_dt > before_dt
709 {
710 return false;
711 }
712 }
713
714 if let Some(ref as_of) = request.as_of {
716 if let (Ok(as_of_dt), Ok(record_dt)) = (
717 chrono::DateTime::parse_from_rfc3339(as_of),
718 chrono::DateTime::parse_from_rfc3339(&record.created_at),
719 ) && record_dt > as_of_dt
720 {
721 return false;
723 }
724 if let Some(ref deleted_at) = record.deleted_at
726 && let (Ok(del_dt), Ok(as_of_dt)) = (
727 chrono::DateTime::parse_from_rfc3339(deleted_at),
728 chrono::DateTime::parse_from_rfc3339(as_of),
729 )
730 && del_dt <= as_of_dt
731 {
732 return false;
733 }
734 }
735
736 match record.scope {
738 Scope::Public | Scope::Global => true,
739 Scope::Shared => {
740 record.agent_id == agent_id
741 || engine
742 .storage
743 .check_permission(
744 record.id,
745 agent_id,
746 crate::model::acl::Permission::Read,
747 )
748 .await
749 .unwrap_or_else(|e| {
750 tracing::warn!(memory_id = %record.id, error = %e, "permission check failed, denying access");
751 false
752 })
753 }
754 Scope::Private => record.agent_id == agent_id,
755 }
756}