1use std::collections::HashSet;
2
3use serde::{Deserialize, Serialize};
4use uuid::Uuid;
5
6use crate::error::Result;
7use crate::hash::compute_content_hash;
8use crate::model::event::{AgentEvent, EventType};
9use crate::model::memory::{MemoryRecord, MemoryType, Scope};
10use crate::query::MnemoEngine;
11use crate::storage::MemoryFilter;
12#[allow(unused_imports)]
13use base64::Engine as _;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
16pub struct TemporalRange {
17 pub after: Option<String>,
18 pub before: Option<String>,
19}
20
21impl TemporalRange {
22 pub fn new() -> Self {
23 Self::default()
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RecallRequest {
29 pub query: String,
30 pub agent_id: Option<String>,
31 pub limit: Option<usize>,
32 pub memory_type: Option<MemoryType>,
33 pub memory_types: Option<Vec<MemoryType>>,
34 pub scope: Option<Scope>,
35 pub min_importance: Option<f32>,
36 pub tags: Option<Vec<String>>,
37 pub org_id: Option<String>,
38 pub strategy: Option<String>,
39 pub temporal_range: Option<TemporalRange>,
40 pub recency_half_life_hours: Option<f64>,
41 pub hybrid_weights: Option<Vec<f32>>,
42 pub rrf_k: Option<f32>,
43 pub as_of: Option<String>,
44 pub explain: Option<bool>,
48}
49
50impl RecallRequest {
51 pub fn new(query: String) -> Self {
52 Self {
53 query,
54 agent_id: None,
55 limit: None,
56 memory_type: None,
57 memory_types: None,
58 scope: None,
59 min_importance: None,
60 tags: None,
61 org_id: None,
62 strategy: None,
63 temporal_range: None,
64 recency_half_life_hours: None,
65 hybrid_weights: None,
66 rrf_k: None,
67 as_of: None,
68 explain: None,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Default, Serialize, Deserialize)]
79pub struct ScoreBreakdown {
80 pub vector: f32,
81 pub bm25: f32,
82 pub graph: f32,
83 pub recency: f32,
84 pub rrf_rank: u32,
86}
87
88#[non_exhaustive]
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct RecallResponse {
91 pub memories: Vec<ScoredMemory>,
92 pub total: usize,
93}
94
95impl RecallResponse {
96 pub fn new(memories: Vec<ScoredMemory>, total: usize) -> Self {
97 Self { memories, total }
98 }
99}
100
101#[non_exhaustive]
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct ScoredMemory {
104 pub id: Uuid,
105 pub content: String,
106 pub agent_id: String,
107 pub memory_type: MemoryType,
108 pub scope: Scope,
109 pub importance: f32,
110 pub tags: Vec<String>,
111 pub metadata: serde_json::Value,
112 pub score: f32,
113 pub access_count: u64,
114 pub created_at: String,
115 pub updated_at: String,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 pub score_breakdown: Option<ScoreBreakdown>,
118}
119
120impl From<(MemoryRecord, f32)> for ScoredMemory {
121 fn from((record, score): (MemoryRecord, f32)) -> Self {
122 Self {
123 id: record.id,
124 content: record.content,
125 agent_id: record.agent_id,
126 memory_type: record.memory_type,
127 scope: record.scope,
128 importance: record.importance,
129 tags: record.tags,
130 metadata: record.metadata,
131 score,
132 access_count: record.access_count,
133 created_at: record.created_at,
134 updated_at: record.updated_at,
135 score_breakdown: None,
136 }
137 }
138}
139
140async fn get_memory_cached(engine: &MnemoEngine, id: Uuid) -> Result<Option<MemoryRecord>> {
142 if let Some(ref cache) = engine.cache
143 && let Some(record) = cache.get(id)
144 {
145 return Ok(Some(record));
146 }
147 let result = engine.storage.get_memory(id).await?;
148 if let Some(ref record) = result
149 && let Some(ref cache) = engine.cache
150 {
151 cache.put(record.clone());
152 }
153 Ok(result)
154}
155
156pub async fn execute(engine: &MnemoEngine, request: RecallRequest) -> Result<RecallResponse> {
157 let limit = request.limit.unwrap_or(10).min(100);
158 let agent_id = request
159 .agent_id
160 .clone()
161 .unwrap_or_else(|| engine.default_agent_id.clone());
162 super::validate_agent_id(&agent_id)?;
163
164 let strategy = request.strategy.as_deref().unwrap_or("auto");
166
167 let query_embedding = engine.embedding.embed(&request.query).await?;
169
170 let accessible_ids: HashSet<Uuid> = engine
172 .storage
173 .list_accessible_memory_ids(&agent_id, super::MAX_BATCH_QUERY_LIMIT)
174 .await?
175 .into_iter()
176 .collect();
177 let perm_filter = |id: Uuid| accessible_ids.contains(&id);
178
179 let mut scored_memories: Vec<(MemoryRecord, f32)> = Vec::new();
180 let mut breakdowns: std::collections::HashMap<Uuid, ScoreBreakdown> =
181 std::collections::HashMap::new();
182
183 match strategy {
184 "lexical" => {
185 if let Some(ref ft) = engine.full_text {
187 let bm25_results = ft.search(&request.query, limit * 3)?;
188 for (id, score) in bm25_results {
189 if let Some(record) = get_memory_cached(engine, id).await?
190 && passes_filters(&record, &request, &agent_id, engine).await
191 {
192 scored_memories.push((record, score));
193 }
194 }
195 }
196 }
197 "semantic" => {
198 let search_results =
200 engine
201 .index
202 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
203 for (id, distance) in search_results {
204 if let Some(record) = get_memory_cached(engine, id).await?
205 && passes_filters(&record, &request, &agent_id, engine).await
206 {
207 let score = 1.0 - distance;
208 scored_memories.push((record, score));
209 }
210 }
211 }
212 "graph" => {
213 let search_results =
215 engine
216 .index
217 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
218 let mut seeds: Vec<(Uuid, f32)> = Vec::new();
219 for (id, distance) in &search_results {
220 if let Some(record) = get_memory_cached(engine, *id).await?
221 && passes_filters(&record, &request, &agent_id, engine).await
222 {
223 seeds.push((*id, 1.0 - distance));
224 }
225 }
226
227 let max_hops = 2;
229 let mut seen: HashSet<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
230 let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
231
232 for &(id, _) in &seeds {
234 graph_ranked.push((id, 1.0));
235 }
236
237 let mut frontier: Vec<Uuid> = seeds.iter().map(|(id, _)| *id).collect();
239 let mut decay = 0.5_f32;
240 for _hop in 0..max_hops {
241 let mut next_frontier: Vec<Uuid> = Vec::new();
242 for &id in &frontier {
243 let from_rels = engine.storage.get_relations_from(id).await?;
244 let to_rels = engine.storage.get_relations_to(id).await?;
245 for rel in from_rels.iter().chain(to_rels.iter()) {
246 let related_id = if rel.source_id == id {
247 rel.target_id
248 } else {
249 rel.source_id
250 };
251 if seen.insert(related_id)
252 && let Some(record) = get_memory_cached(engine, related_id).await?
253 && passes_filters(&record, &request, &agent_id, engine).await
254 {
255 graph_ranked.push((related_id, decay));
256 next_frontier.push(related_id);
257 }
258 }
259 }
260 frontier = next_frontier;
261 decay *= 0.5;
262 }
263
264 let mut v_sorted: Vec<(Uuid, f32)> = seeds.clone();
266 v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
267 graph_ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
268
269 let ranked_lists = vec![v_sorted, graph_ranked];
270 let rrf_k = request.rrf_k.unwrap_or(60.0);
271 let fused = if let Some(ref weights) = request.hybrid_weights {
272 crate::query::retrieval::weighted_reciprocal_rank_fusion(
273 &ranked_lists,
274 rrf_k,
275 weights,
276 )
277 } else {
278 crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
279 };
280
281 for (id, score) in fused {
282 if let Some(record) = get_memory_cached(engine, id).await?
283 && passes_filters(&record, &request, &agent_id, engine).await
284 {
285 scored_memories.push((record, score));
286 }
287 }
288 }
289 "exact" => {
290 let filter = MemoryFilter {
293 agent_id: Some(agent_id.clone()),
294 memory_type: request.memory_type,
295 scope: request.scope,
296 tags: request.tags.clone(),
297 min_importance: request.min_importance,
298 org_id: request.org_id.clone(),
299 thread_id: None,
300 include_deleted: request.as_of.is_some(),
301 };
302 let memories = engine.storage.list_memories(&filter, limit, 0).await?;
303 for record in memories {
304 if passes_filters(&record, &request, &agent_id, engine).await {
305 scored_memories.push((record, 1.0));
306 }
307 }
308 }
309 _ => {
310 let vector_results =
312 engine
313 .index
314 .filtered_search(&query_embedding, limit * 3, &perm_filter)?;
315 let mut vector_ranked: Vec<(Uuid, f32)> = Vec::new();
316 for (id, distance) in vector_results {
317 vector_ranked.push((id, 1.0 - distance));
318 }
319
320 if let Some(ref ft) = engine.full_text {
321 let bm25_results = ft.search(&request.query, limit * 3)?;
323
324 let mut recency_ranked: Vec<(Uuid, f32)> = Vec::new();
326 for &(id, _) in &vector_ranked {
327 if let Some(record) = get_memory_cached(engine, id).await? {
328 let r_score = crate::query::retrieval::recency_score(
329 &record.created_at,
330 request.recency_half_life_hours.unwrap_or(168.0),
331 );
332 recency_ranked.push((id, r_score));
333 }
334 }
335 for &(id, _) in &bm25_results {
337 if !recency_ranked.iter().any(|(rid, _)| *rid == id)
338 && let Some(record) = get_memory_cached(engine, id).await?
339 {
340 let r_score = crate::query::retrieval::recency_score(
341 &record.created_at,
342 request.recency_half_life_hours.unwrap_or(168.0),
343 );
344 recency_ranked.push((id, r_score));
345 }
346 }
347
348 let mut v_sorted = vector_ranked.clone();
350 v_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
351 let mut b_sorted = bm25_results;
352 b_sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
353 recency_ranked
354 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
355
356 let max_hops = 2;
358 let mut graph_ranked: Vec<(Uuid, f32)> = Vec::new();
359 let top_seeds: Vec<Uuid> =
360 vector_ranked.iter().take(10).map(|(id, _)| *id).collect();
361 let mut graph_seen: HashSet<Uuid> = top_seeds.iter().copied().collect();
362 for &seed_id in &top_seeds {
363 graph_ranked.push((seed_id, 1.0));
364 }
365 let mut frontier: Vec<Uuid> = top_seeds;
366 let mut decay = 0.5_f32;
367 for _hop in 0..max_hops {
368 let mut next_frontier: Vec<Uuid> = Vec::new();
369 for &fid in &frontier {
370 match engine.storage.get_relations_from(fid).await {
371 Ok(from_rels) => {
372 for rel in &from_rels {
373 if graph_seen.insert(rel.target_id) {
374 graph_ranked.push((rel.target_id, decay));
375 next_frontier.push(rel.target_id);
376 }
377 }
378 }
379 Err(e) => {
380 tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get outgoing relations");
381 }
382 }
383 match engine.storage.get_relations_to(fid).await {
384 Ok(to_rels) => {
385 for rel in &to_rels {
386 if graph_seen.insert(rel.source_id) {
387 graph_ranked.push((rel.source_id, decay));
388 next_frontier.push(rel.source_id);
389 }
390 }
391 }
392 Err(e) => {
393 tracing::warn!(memory_id = %fid, error = %e, "graph expansion: failed to get incoming relations");
394 }
395 }
396 }
397 frontier = next_frontier;
398 decay *= 0.5;
399 }
400 graph_ranked
401 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
402
403 let explain = request.explain.unwrap_or(false);
407 type SignalMap = std::collections::HashMap<Uuid, f32>;
408 let (vector_map, bm25_map, recency_map, graph_map): (
409 SignalMap,
410 SignalMap,
411 SignalMap,
412 SignalMap,
413 ) = if explain {
414 (
415 v_sorted.iter().copied().collect(),
416 b_sorted.iter().copied().collect(),
417 recency_ranked.iter().copied().collect(),
418 graph_ranked.iter().copied().collect(),
419 )
420 } else {
421 Default::default()
422 };
423
424 let ranked_lists = vec![v_sorted, b_sorted, recency_ranked, graph_ranked];
425 let rrf_k = request.rrf_k.unwrap_or(60.0);
426 let fused = if let Some(ref weights) = request.hybrid_weights {
427 crate::query::retrieval::weighted_reciprocal_rank_fusion(
428 &ranked_lists,
429 rrf_k,
430 weights,
431 )
432 } else {
433 crate::query::retrieval::reciprocal_rank_fusion(&ranked_lists, rrf_k)
434 };
435
436 for (rank, (id, score)) in fused.into_iter().enumerate() {
437 if let Some(record) = get_memory_cached(engine, id).await?
438 && passes_filters(&record, &request, &agent_id, engine).await
439 {
440 scored_memories.push((record, score));
441 if explain {
442 breakdowns.insert(
443 id,
444 ScoreBreakdown {
445 vector: vector_map.get(&id).copied().unwrap_or(0.0),
446 bm25: bm25_map.get(&id).copied().unwrap_or(0.0),
447 graph: graph_map.get(&id).copied().unwrap_or(0.0),
448 recency: recency_map.get(&id).copied().unwrap_or(0.0),
449 rrf_rank: rank as u32,
450 },
451 );
452 }
453 }
454 }
455 } else {
456 for (id, score) in vector_ranked {
458 if let Some(record) = get_memory_cached(engine, id).await?
459 && passes_filters(&record, &request, &agent_id, engine).await
460 {
461 scored_memories.push((record, score));
462 }
463 }
464 }
465 }
466 }
467
468 scored_memories.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
470 scored_memories.truncate(limit);
471
472 let total = scored_memories.len();
473
474 for (record, _) in &scored_memories {
476 if let Err(e) = engine.storage.touch_memory(record.id).await {
477 tracing::warn!(memory_id = %record.id, error = %e, "failed to update access timestamp");
478 }
479 }
480
481 if let Some(ref enc) = engine.encryption {
483 for (record, _) in &mut scored_memories {
484 match base64::engine::general_purpose::STANDARD.decode(&record.content) {
485 Ok(encrypted_bytes) => match enc.decrypt(&encrypted_bytes) {
486 Ok(decrypted) => match String::from_utf8(decrypted) {
487 Ok(plaintext) => record.content = plaintext,
488 Err(e) => {
489 tracing::error!(memory_id = %record.id, error = %e, "decrypted content is not valid UTF-8");
490 record.content = "[content unavailable: decryption error]".to_string();
491 }
492 },
493 Err(e) => {
494 tracing::error!(memory_id = %record.id, error = %e, "failed to decrypt memory content");
495 record.content = "[content unavailable: decryption error]".to_string();
496 }
497 },
498 Err(e) => {
499 tracing::error!(memory_id = %record.id, error = %e, "failed to decode encrypted content");
500 record.content = "[content unavailable: decryption error]".to_string();
501 }
502 }
503 }
504 }
505
506 let memories: Vec<ScoredMemory> = scored_memories
507 .into_iter()
508 .map(|(record, score)| {
509 let id = record.id;
510 let mut scored = ScoredMemory::from((record, score));
511 if let Some(breakdown) = breakdowns.remove(&id) {
512 scored.score_breakdown = Some(breakdown);
513 }
514 scored
515 })
516 .collect();
517
518 let now = chrono::Utc::now().to_rfc3339();
520 let event_content_hash = compute_content_hash(&request.query, &agent_id, &now);
521 let prev_event_hash = match engine.storage.get_latest_event_hash(&agent_id, None).await {
522 Ok(hash) => hash,
523 Err(e) => {
524 tracing::warn!(error = %e, "failed to get latest event hash, starting new chain segment");
525 None
526 }
527 };
528 let event_prev_hash = Some(crate::hash::compute_chain_hash(
529 &event_content_hash,
530 prev_event_hash.as_deref(),
531 ));
532 let mut event = AgentEvent {
533 id: Uuid::now_v7(),
534 agent_id: agent_id.clone(),
535 thread_id: None,
536 run_id: None,
537 parent_event_id: None,
538 event_type: EventType::MemoryRead,
539 payload: serde_json::json!({
540 "query": request.query,
541 "results": total,
542 "strategy": strategy,
543 }),
544 trace_id: None,
545 span_id: None,
546 model: None,
547 tokens_input: None,
548 tokens_output: None,
549 latency_ms: None,
550 cost_usd: None,
551 timestamp: now.clone(),
552 logical_clock: 0,
553 content_hash: event_content_hash,
554 prev_hash: event_prev_hash,
555 embedding: None,
556 };
557 if engine.embed_events
559 && let Ok(emb) = engine.embedding.embed(&event.payload.to_string()).await
560 {
561 event.embedding = Some(emb);
562 }
563 if let Err(e) = engine.storage.insert_event(&event).await {
564 tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
565 }
566
567 Ok(RecallResponse { memories, total })
568}
569
570async fn passes_filters(
571 record: &MemoryRecord,
572 request: &RecallRequest,
573 agent_id: &str,
574 engine: &MnemoEngine,
575) -> bool {
576 if request.as_of.is_none() && record.is_deleted() {
578 return false;
579 }
580
581 if let Some(ref expires_at) = record.expires_at
583 && let Ok(exp) = chrono::DateTime::parse_from_rfc3339(expires_at)
584 && exp < chrono::Utc::now()
585 {
586 return false;
587 }
588
589 if record.quarantined {
591 return false;
592 }
593
594 if let Some(ref s) = request.scope
596 && record.scope != *s
597 {
598 return false;
599 }
600
601 if let Some(ref mts) = request.memory_types {
603 if !mts.contains(&record.memory_type) {
604 return false;
605 }
606 } else if let Some(ref mt) = request.memory_type
607 && record.memory_type != *mt
608 {
609 return false;
610 }
611
612 if let Some(min_imp) = request.min_importance
614 && record.importance < min_imp
615 {
616 return false;
617 }
618
619 if let Some(ref req_tags) = request.tags
621 && !req_tags.iter().any(|t| record.tags.contains(t))
622 {
623 return false;
624 }
625
626 if let Some(ref tr) = request.temporal_range {
628 if let Some(ref after) = tr.after
629 && let (Ok(after_dt), Ok(record_dt)) = (
630 chrono::DateTime::parse_from_rfc3339(after),
631 chrono::DateTime::parse_from_rfc3339(&record.created_at),
632 )
633 && record_dt < after_dt
634 {
635 return false;
636 }
637 if let Some(ref before) = tr.before
638 && let (Ok(before_dt), Ok(record_dt)) = (
639 chrono::DateTime::parse_from_rfc3339(before),
640 chrono::DateTime::parse_from_rfc3339(&record.created_at),
641 )
642 && record_dt > before_dt
643 {
644 return false;
645 }
646 }
647
648 if let Some(ref as_of) = request.as_of {
650 if let (Ok(as_of_dt), Ok(record_dt)) = (
651 chrono::DateTime::parse_from_rfc3339(as_of),
652 chrono::DateTime::parse_from_rfc3339(&record.created_at),
653 ) && record_dt > as_of_dt
654 {
655 return false;
657 }
658 if let Some(ref deleted_at) = record.deleted_at
660 && let (Ok(del_dt), Ok(as_of_dt)) = (
661 chrono::DateTime::parse_from_rfc3339(deleted_at),
662 chrono::DateTime::parse_from_rfc3339(as_of),
663 )
664 && del_dt <= as_of_dt
665 {
666 return false;
667 }
668 }
669
670 match record.scope {
672 Scope::Public | Scope::Global => true,
673 Scope::Shared => {
674 record.agent_id == agent_id
675 || engine
676 .storage
677 .check_permission(
678 record.id,
679 agent_id,
680 crate::model::acl::Permission::Read,
681 )
682 .await
683 .unwrap_or_else(|e| {
684 tracing::warn!(memory_id = %record.id, error = %e, "permission check failed, denying access");
685 false
686 })
687 }
688 Scope::Private => record.agent_id == agent_id,
689 }
690}