1use std::collections::HashMap;
2
3use chrono::Utc;
4
5use crate::embed::{cosine_similarity, Embedder};
6use crate::memory::{MemoryKind, MemoryRecord};
7use crate::store::MemoryStore;
8
9const VECTOR_SIMILARITY_THRESHOLD: f32 = 0.3;
11
12const RRF_K: f64 = 60.0;
14
15#[derive(Debug, Clone, serde::Serialize)]
17pub struct RecallResult {
18 pub memory: MemoryRecord,
19 pub score: f64,
20}
21
22const FACT_DECAY_LAMBDA_PER_DAY: f64 = 0.02;
27const EPISODE_DECAY_LAMBDA_PER_DAY: f64 = 0.06;
28
29const FACT_TOUCH_BOOST: f64 = 0.10;
31const EPISODE_TOUCH_BOOST: f64 = 0.20;
32
33const CANDIDATE_MULTIPLIER: usize = 10;
35const MIN_CANDIDATES: usize = 50;
36
37const SPREAD_FACTOR: f64 = 0.15;
39
40const RECENCY_HALF_LIFE_HOURS: f64 = 168.0;
43
44const RECENCY_FLOOR: f64 = 0.3;
46
47pub fn recall(
61 store: &MemoryStore,
62 query: &str,
63 embedder: &dyn Embedder,
64 limit: usize,
65) -> Result<Vec<RecallResult>, RecallError> {
66 recall_with_tag_filter(store, query, embedder, limit, None)
67}
68
69pub fn recall_with_tag_filter(
72 store: &MemoryStore,
73 query: &str,
74 embedder: &dyn Embedder,
75 limit: usize,
76 tag_filter: Option<&str>,
77) -> Result<Vec<RecallResult>, RecallError> {
78 recall_with_tag_filter_ns(store, query, embedder, limit, tag_filter, "default")
79}
80
81pub fn recall_with_tag_filter_ns(
82 store: &MemoryStore,
83 query: &str,
84 embedder: &dyn Embedder,
85 limit: usize,
86 tag_filter: Option<&str>,
87 namespace: &str,
88) -> Result<Vec<RecallResult>, RecallError> {
89 let mut all_memories = store.all_memories_with_text_ns(namespace).map_err(RecallError::Db)?;
90
91 if let Some(tag) = tag_filter {
93 let tag_lower = tag.to_lowercase();
94 all_memories.retain(|(mem, _)| {
95 mem.tags.iter().any(|t| t.to_lowercase() == tag_lower)
96 });
97 }
98
99 if all_memories.is_empty() {
100 return Ok(vec![]);
101 }
102
103 let now = Utc::now();
104
105 let max_access = all_memories
107 .iter()
108 .map(|(m, _)| m.access_count)
109 .max()
110 .unwrap_or(0);
111
112 let bm25_ranked = bm25_search(query, &all_memories);
114
115 let query_embedding = embedder
117 .embed_one(query)
118 .map_err(|e| RecallError::Embedding(e.to_string()))?;
119 let vector_ranked = vector_search(&query_embedding, &all_memories);
120
121 let fused = rrf(&bm25_ranked, &vector_ranked);
123
124 let candidate_count = (limit.saturating_mul(CANDIDATE_MULTIPLIER)).max(MIN_CANDIDATES);
127 let candidates = fused.into_iter().take(candidate_count);
128
129 let mut results: Vec<RecallResult> = candidates
131 .map(|(idx, rrf_score)| {
132 let mem = &all_memories[idx].0;
133 let decayed_strength = effective_strength(mem, now);
134 let recency = recency_boost(mem, now);
135 let access = access_weight(mem, max_access);
136 RecallResult {
137 memory: mem.clone(),
138 score: rrf_score * decayed_strength * recency * access,
139 }
140 })
141 .collect();
142
143 spread_activation(&mut results, SPREAD_FACTOR);
147
148 temporal_cooccurrence_boost(&mut results);
153
154 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
155 results.truncate(limit);
156
157 for result in &results {
159 let mem = &result.memory;
160 let decayed = effective_strength(mem, now);
161 let boosted = (decayed + touch_boost(mem)).min(1.0);
162 store
163 .touch_memory_with_strength(mem.id, boosted, now)
164 .map_err(RecallError::Db)?;
165 }
166
167 Ok(results)
168}
169
170fn recency_boost(mem: &MemoryRecord, now: chrono::DateTime<Utc>) -> f64 {
176 let hours_ago = (now - mem.created_at).num_seconds().max(0) as f64 / 3600.0;
177 let raw = 1.0 / (1.0 + (hours_ago / RECENCY_HALF_LIFE_HOURS).powf(0.8));
178 raw.max(RECENCY_FLOOR)
179}
180
181fn access_weight(mem: &MemoryRecord, max_access: i64) -> f64 {
187 if max_access <= 0 {
188 return 1.0;
189 }
190 let norm = (mem.access_count as f64 + 1.0).log2() / (max_access as f64 + 1.0).log2();
191 1.0 + norm }
193
194fn spread_activation(results: &mut Vec<RecallResult>, factor: f64) {
201 let mut entity_index: HashMap<String, Vec<usize>> = HashMap::new();
203 for (i, r) in results.iter().enumerate() {
204 if let MemoryKind::Fact(f) = &r.memory.kind {
205 let subj = f.subject.to_lowercase();
206 let obj = f.object.to_lowercase();
207 entity_index.entry(subj).or_default().push(i);
208 entity_index.entry(obj).or_default().push(i);
209 }
210 }
211
212 let mut boosts: HashMap<usize, f64> = HashMap::new();
214 for (i, r) in results.iter().enumerate() {
215 if let MemoryKind::Fact(f) = &r.memory.kind {
216 let entities = [f.subject.to_lowercase(), f.object.to_lowercase()];
217 for entity in &entities {
218 if let Some(neighbors) = entity_index.get(entity) {
219 for &ni in neighbors {
220 if ni != i {
221 *boosts.entry(ni).or_insert(0.0) += r.score * factor;
222 }
223 }
224 }
225 }
226 }
227 }
228
229 for (idx, boost) in boosts {
231 if idx < results.len() {
232 results[idx].score += boost;
233 }
234 }
235}
236
237fn temporal_cooccurrence_boost(results: &mut Vec<RecallResult>) {
241 if results.len() < 2 {
242 return;
243 }
244
245 let mut sorted_indices: Vec<usize> = (0..results.len()).collect();
247 sorted_indices.sort_by(|&a, &b| {
248 results[b]
249 .score
250 .partial_cmp(&results[a].score)
251 .unwrap_or(std::cmp::Ordering::Equal)
252 });
253 let anchor_count = sorted_indices.len().min(5);
254 let anchors: Vec<(usize, f64, chrono::DateTime<Utc>)> = sorted_indices[..anchor_count]
255 .iter()
256 .map(|&i| (i, results[i].score, results[i].memory.created_at))
257 .collect();
258
259 let mut boosts: HashMap<usize, f64> = HashMap::new();
260 for (ai, a_score, a_time) in &anchors {
261 for (j, r) in results.iter().enumerate() {
262 if j == *ai {
263 continue;
264 }
265 let gap_minutes = (*a_time - r.memory.created_at)
266 .num_minutes()
267 .unsigned_abs() as f64;
268 if gap_minutes < 30.0 {
269 let proximity = 0.1 * (1.0 - gap_minutes / 30.0);
270 *boosts.entry(j).or_insert(0.0) += a_score * proximity;
271 }
272 }
273 }
274
275 for (idx, boost) in boosts {
276 if idx < results.len() {
277 results[idx].score += boost;
278 }
279 }
280}
281
282fn kind_decay_lambda_per_day(mem: &MemoryRecord) -> f64 {
283 match &mem.kind {
284 MemoryKind::Fact(_) => FACT_DECAY_LAMBDA_PER_DAY,
285 MemoryKind::Episode(_) => EPISODE_DECAY_LAMBDA_PER_DAY,
286 }
287}
288
289fn touch_boost(mem: &MemoryRecord) -> f64 {
290 match &mem.kind {
291 MemoryKind::Fact(_) => FACT_TOUCH_BOOST,
292 MemoryKind::Episode(_) => EPISODE_TOUCH_BOOST,
293 }
294}
295
296fn effective_strength(mem: &MemoryRecord, now: chrono::DateTime<Utc>) -> f64 {
297 let elapsed_secs = (now - mem.last_accessed_at).num_seconds().max(0) as f64;
298 let elapsed_days = elapsed_secs / 86_400.0;
299 let lambda = kind_decay_lambda_per_day(mem);
300 let effective_lambda = lambda / (1.0 + mem.importance);
303 (mem.strength * (-effective_lambda * elapsed_days).exp()).clamp(0.0, 1.0)
304}
305
306fn bm25_search(query: &str, memories: &[(MemoryRecord, String)]) -> Vec<(usize, f32)> {
307 use bm25::{Document, Language, SearchEngineBuilder};
308
309 let documents: Vec<Document<usize>> = memories
310 .iter()
311 .enumerate()
312 .map(|(i, (_, text))| Document {
313 id: i,
314 contents: text.clone(),
315 })
316 .collect();
317
318 let engine: bm25::SearchEngine<usize> =
319 SearchEngineBuilder::with_documents(Language::English, documents)
320 .b(0.5)
321 .build();
322
323 engine
324 .search(query, memories.len())
325 .into_iter()
326 .map(|r| (r.document.id, r.score))
327 .collect()
328}
329
330fn vector_search(query_emb: &[f32], memories: &[(MemoryRecord, String)]) -> Vec<(usize, f32)> {
331 let mut scored: Vec<(usize, f32)> = memories
332 .iter()
333 .enumerate()
334 .filter_map(|(i, (mem, _))| {
335 let emb = mem.embedding.as_ref()?;
336 let sim = cosine_similarity(query_emb, emb);
337 if sim > VECTOR_SIMILARITY_THRESHOLD {
338 Some((i, sim))
339 } else {
340 None
341 }
342 })
343 .collect();
344
345 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
346 scored
347}
348
349fn rrf(list_a: &[(usize, f32)], list_b: &[(usize, f32)]) -> Vec<(usize, f64)> {
350 let mut scores: HashMap<usize, f64> = HashMap::new();
351
352 for (rank, &(idx, _)) in list_a.iter().enumerate() {
353 *scores.entry(idx).or_insert(0.0) += 1.0 / (RRF_K + rank as f64 + 1.0);
354 }
355 for (rank, &(idx, _)) in list_b.iter().enumerate() {
356 *scores.entry(idx).or_insert(0.0) += 1.0 / (RRF_K + rank as f64 + 1.0);
357 }
358
359 let mut results: Vec<(usize, f64)> = scores.into_iter().collect();
360 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
361 results
362}
363
364#[derive(Debug, thiserror::Error)]
365pub enum RecallError {
366 #[error("database error: {0}")]
367 Db(rusqlite::Error),
368 #[error("embedding error: {0}")]
369 Embedding(String),
370}
371
372impl From<rusqlite::Error> for RecallError {
373 fn from(e: rusqlite::Error) -> Self {
374 RecallError::Db(e)
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use crate::embed::{EmbedError, Embedding};
382
383 struct MockEmbedder;
384
385 impl Embedder for MockEmbedder {
386 fn embed(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
387 Ok(texts
388 .iter()
389 .map(|t| {
390 if t.contains("alpha") {
391 vec![1.0, 0.0]
392 } else {
393 vec![0.0, 1.0]
394 }
395 })
396 .collect())
397 }
398
399 fn dimension(&self) -> usize {
400 2
401 }
402 }
403
404 #[test]
405 fn effective_strength_decays_by_kind() {
406 let store = MemoryStore::open_in_memory().unwrap();
407 let fact_id = store.remember_fact("Jared", "builds", "Gen", Some(&[1.0, 0.0])).unwrap();
408 let ep_id = store
409 .remember_episode("alpha project context", Some(&[1.0, 0.0]))
410 .unwrap();
411
412 let old_time = (Utc::now() - chrono::Duration::days(10)).to_rfc3339();
413 store
414 .conn()
415 .execute(
416 "UPDATE memories SET last_accessed_at = ?1 WHERE id IN (?2, ?3)",
417 rusqlite::params![old_time, fact_id, ep_id],
418 )
419 .unwrap();
420
421 let fact = store.get_memory(fact_id).unwrap().unwrap();
422 let episode = store.get_memory(ep_id).unwrap().unwrap();
423
424 let sf = effective_strength(&fact, Utc::now());
425 let se = effective_strength(&episode, Utc::now());
426 assert!(sf > se, "facts should decay slower than episodes");
427 }
428
429 #[test]
432 fn recency_boost_favors_recent_over_old() {
433 let store = MemoryStore::open_in_memory().unwrap();
434 let now = Utc::now();
435
436 let recent_id = store
438 .remember_episode("alpha project is great", Some(&[1.0, 0.0]))
439 .unwrap();
440 let old_id = store
441 .remember_episode("alpha project is great", Some(&[1.0, 0.0]))
442 .unwrap();
443
444 let old_time = (now - chrono::Duration::days(30)).to_rfc3339();
446 store
447 .conn()
448 .execute(
449 "UPDATE memories SET created_at = ?1, last_accessed_at = ?1 WHERE id = ?2",
450 rusqlite::params![old_time, old_id],
451 )
452 .unwrap();
453
454 let results = recall(&store, "alpha project", &MockEmbedder, 2).unwrap();
455 assert_eq!(results.len(), 2);
456 assert_eq!(results[0].memory.id, recent_id);
458 assert!(results[0].score > results[1].score);
459 }
460
461 #[test]
462 fn recency_boost_has_floor_old_memories_still_appear() {
463 let store = MemoryStore::open_in_memory().unwrap();
464 let now = Utc::now();
465
466 let id = store
468 .remember_episode("alpha ancient knowledge", Some(&[1.0, 0.0]))
469 .unwrap();
470
471 let ancient_time = (now - chrono::Duration::days(365)).to_rfc3339();
472 store
473 .conn()
474 .execute(
475 "UPDATE memories SET created_at = ?1, last_accessed_at = ?1 WHERE id = ?2",
476 rusqlite::params![ancient_time, id],
477 )
478 .unwrap();
479
480 let mem = store.get_memory(id).unwrap().unwrap();
481 let boost = recency_boost(&mem, now);
482 assert!(boost >= RECENCY_FLOOR, "recency boost {} should be >= floor {}", boost, RECENCY_FLOOR);
483 }
484
485 #[test]
488 fn access_weight_boosts_frequently_recalled_memories() {
489 let store = MemoryStore::open_in_memory().unwrap();
490
491 let hot_id = store
493 .remember_episode("alpha hot memory", Some(&[1.0, 0.0]))
494 .unwrap();
495 let cold_id = store
496 .remember_episode("alpha cold memory", Some(&[1.0, 0.0]))
497 .unwrap();
498
499 store
501 .conn()
502 .execute(
503 "UPDATE memories SET access_count = 20 WHERE id = ?1",
504 rusqlite::params![hot_id],
505 )
506 .unwrap();
507
508 let hot = store.get_memory(hot_id).unwrap().unwrap();
509 let cold = store.get_memory(cold_id).unwrap().unwrap();
510
511 let hot_w = access_weight(&hot, 20);
512 let cold_w = access_weight(&cold, 20);
513
514 assert!(hot_w > cold_w, "hot ({}) should weigh more than cold ({})", hot_w, cold_w);
515 assert!(hot_w >= 1.0 && hot_w <= 2.0, "access weight should be in [1.0, 2.0], got {}", hot_w);
516 assert!(cold_w >= 1.0, "cold access weight should be >= 1.0, got {}", cold_w);
517 }
518
519 #[test]
520 fn access_weight_is_bounded() {
521 let store = MemoryStore::open_in_memory().unwrap();
522
523 let id = store
524 .remember_episode("alpha bounded test", Some(&[1.0, 0.0]))
525 .unwrap();
526
527 store
529 .conn()
530 .execute(
531 "UPDATE memories SET access_count = 1000 WHERE id = ?1",
532 rusqlite::params![id],
533 )
534 .unwrap();
535
536 let mem = store.get_memory(id).unwrap().unwrap();
537 let w = access_weight(&mem, 1000);
538 assert!(w <= 2.0, "access weight should never exceed 2.0, got {}", w);
539 }
540
541 #[test]
544 fn spreading_activation_boosts_related_facts() {
545 let mut results = vec![
548 RecallResult {
549 memory: make_fact_record(1, "Jared", "has_pet", "Tortellini"),
550 score: 1.0,
551 },
552 RecallResult {
553 memory: make_fact_record(2, "Tortellini", "is_a", "dog"),
554 score: 0.1, },
556 RecallResult {
557 memory: make_fact_record(3, "Abby", "likes", "cats"),
558 score: 0.1, },
560 ];
561
562 let original_related = results[1].score;
563 let original_unrelated = results[2].score;
564
565 spread_activation(&mut results, SPREAD_FACTOR);
566
567 assert!(
568 results[1].score > original_related,
569 "related fact should be boosted: {} > {}",
570 results[1].score,
571 original_related
572 );
573 assert_eq!(
574 results[2].score, original_unrelated,
575 "unrelated fact should not be boosted"
576 );
577 }
578
579 #[test]
580 fn spreading_activation_is_bidirectional() {
581 let mut results = vec![
583 RecallResult {
584 memory: make_fact_record(1, "Jared", "works_at", "Microsoft"),
585 score: 0.8,
586 },
587 RecallResult {
588 memory: make_fact_record(2, "Microsoft", "located_in", "Seattle"),
589 score: 0.3,
590 },
591 ];
592
593 let score_a_before = results[0].score;
594 let score_b_before = results[1].score;
595
596 spread_activation(&mut results, SPREAD_FACTOR);
597
598 assert!(results[1].score > score_b_before);
600 assert!(results[0].score > score_a_before);
602 }
603
604 #[test]
605 fn spreading_activation_does_not_self_boost() {
606 let mut results = vec![
607 RecallResult {
608 memory: make_fact_record(1, "Jared", "builds", "Gen"),
609 score: 1.0,
610 },
611 ];
612
613 spread_activation(&mut results, SPREAD_FACTOR);
614 assert!((results[0].score - 1.0).abs() < f64::EPSILON);
616 }
617
618 #[test]
621 fn temporal_cooccurrence_boosts_same_session_memories() {
622 let now = Utc::now();
623
624 let mut results = vec![
625 RecallResult {
626 memory: make_timed_episode(1, "alpha anchor memory", now),
627 score: 1.0,
628 },
629 RecallResult {
630 memory: make_timed_episode(2, "alpha nearby memory", now - chrono::Duration::minutes(5)),
631 score: 0.2,
632 },
633 RecallResult {
634 memory: make_timed_episode(3, "alpha distant memory", now - chrono::Duration::hours(3)),
635 score: 0.2,
636 },
637 ];
638
639 let nearby_before = results[1].score;
640 let distant_before = results[2].score;
641
642 temporal_cooccurrence_boost(&mut results);
643
644 assert!(
645 results[1].score > nearby_before,
646 "nearby memory should be boosted: {} > {}",
647 results[1].score,
648 nearby_before
649 );
650 assert_eq!(
651 results[2].score, distant_before,
652 "distant memory (>30min) should not be boosted"
653 );
654 }
655
656 #[test]
657 fn temporal_cooccurrence_scales_with_proximity() {
658 let now = Utc::now();
659
660 let mut results = vec![
661 RecallResult {
662 memory: make_timed_episode(1, "alpha anchor", now),
663 score: 1.0,
664 },
665 RecallResult {
666 memory: make_timed_episode(2, "alpha very close", now - chrono::Duration::minutes(2)),
667 score: 0.1,
668 },
669 RecallResult {
670 memory: make_timed_episode(3, "alpha further", now - chrono::Duration::minutes(25)),
671 score: 0.1,
672 },
673 ];
674
675 temporal_cooccurrence_boost(&mut results);
676
677 assert!(
679 results[1].score > results[2].score,
680 "closer memory ({}) should score higher than further one ({})",
681 results[1].score,
682 results[2].score
683 );
684 }
685
686 #[test]
689 fn full_recall_pipeline_ranks_recent_accessed_related_higher() {
690 let store = MemoryStore::open_in_memory().unwrap();
691 let now = Utc::now();
692
693 store.remember_fact("Jared", "has_pet", "Tortellini", Some(&[1.0, 0.0])).unwrap();
695 store.remember_fact("Tortellini", "is_a", "dog", Some(&[1.0, 0.0])).unwrap();
696
697 let old_id = store.remember_fact("weather", "is", "sunny", Some(&[0.5, 0.5])).unwrap();
699 let old_time = (now - chrono::Duration::days(60)).to_rfc3339();
700 store
701 .conn()
702 .execute(
703 "UPDATE memories SET created_at = ?1, last_accessed_at = ?1 WHERE id = ?2",
704 rusqlite::params![old_time, old_id],
705 )
706 .unwrap();
707
708 let results = recall(&store, "alpha", &MockEmbedder, 10).unwrap();
709
710 if results.len() >= 3 {
713 let weather_pos = results.iter().position(|r| r.memory.id == old_id);
714 if let Some(pos) = weather_pos {
715 assert!(pos >= 2, "old unrelated memory should rank below related recent ones, was at position {}", pos);
716 }
717 }
718 }
719
720 fn make_fact_record(id: i64, subj: &str, rel: &str, obj: &str) -> MemoryRecord {
723 MemoryRecord {
724 id,
725 kind: MemoryKind::Fact(crate::memory::Fact {
726 subject: subj.to_string(),
727 relation: rel.to_string(),
728 object: obj.to_string(),
729 }),
730 strength: 1.0,
731 created_at: Utc::now(),
732 last_accessed_at: Utc::now(),
733 access_count: 0,
734 embedding: None,
735 tags: vec![],
736 source: None,
737 session_id: None,
738 channel: None,
739 importance: 0.5,
740 namespace: "default".to_string(),
741 checksum: None,
742 }
743 }
744
745 fn make_timed_episode(id: i64, text: &str, time: chrono::DateTime<Utc>) -> MemoryRecord {
746 MemoryRecord {
747 id,
748 kind: MemoryKind::Episode(crate::memory::Episode {
749 text: text.to_string(),
750 }),
751 strength: 1.0,
752 created_at: time,
753 last_accessed_at: time,
754 access_count: 0,
755 embedding: None,
756 tags: vec![],
757 source: None,
758 session_id: None,
759 channel: None,
760 importance: 0.5,
761 namespace: "default".to_string(),
762 checksum: None,
763 }
764 }
765
766 #[test]
769 fn recall_touch_applies_decay_then_reinforcement() {
770 let store = MemoryStore::open_in_memory().unwrap();
771 let id = store
772 .remember_episode("alpha memory to recall", Some(&[1.0, 0.0]))
773 .unwrap();
774
775 let old_time = (Utc::now() - chrono::Duration::days(30)).to_rfc3339();
776 store
777 .conn()
778 .execute(
779 "UPDATE memories SET strength = 1.0, last_accessed_at = ?1 WHERE id = ?2",
780 rusqlite::params![old_time, id],
781 )
782 .unwrap();
783
784 let results = recall(&store, "alpha", &MockEmbedder, 1).unwrap();
785 assert_eq!(results.len(), 1);
786
787 let after = store.get_memory(id).unwrap().unwrap();
788 assert!(after.access_count >= 1);
789 assert!(after.strength < 1.0);
791 assert!(after.strength > 0.2);
793 }
794
795 #[test]
798 fn recall_with_tag_filter_returns_only_tagged_memories() {
799 let store = MemoryStore::open_in_memory().unwrap();
800
801 store.remember_fact_with_tags("Jared", "likes", "alpha", Some(&[1.0, 0.0]), &["preference".to_string()]).unwrap();
803 store.remember_fact_with_tags("Jared", "uses", "alpha", Some(&[1.0, 0.0]), &["technical".to_string()]).unwrap();
804 store.remember_fact("weather", "is", "alpha", Some(&[1.0, 0.0])).unwrap();
805
806 let all_results = recall(&store, "alpha", &MockEmbedder, 10).unwrap();
808 assert_eq!(all_results.len(), 3);
809
810 let filtered = recall_with_tag_filter(&store, "alpha", &MockEmbedder, 10, Some("preference")).unwrap();
812 assert_eq!(filtered.len(), 1);
813 assert_eq!(filtered[0].memory.tags, vec!["preference"]);
814
815 let filtered = recall_with_tag_filter(&store, "alpha", &MockEmbedder, 10, Some("technical")).unwrap();
817 assert_eq!(filtered.len(), 1);
818 assert_eq!(filtered[0].memory.tags, vec!["technical"]);
819 }
820
821 #[test]
822 fn recall_with_tag_filter_is_case_insensitive() {
823 let store = MemoryStore::open_in_memory().unwrap();
824 store.remember_fact_with_tags("Jared", "likes", "alpha", Some(&[1.0, 0.0]), &["Preference".to_string()]).unwrap();
825
826 let results = recall_with_tag_filter(&store, "alpha", &MockEmbedder, 10, Some("preference")).unwrap();
827 assert_eq!(results.len(), 1);
828 }
829
830 #[test]
831 fn recall_with_no_tag_filter_returns_all() {
832 let store = MemoryStore::open_in_memory().unwrap();
833 store.remember_fact_with_tags("Jared", "likes", "alpha", Some(&[1.0, 0.0]), &["preference".to_string()]).unwrap();
834 store.remember_fact("weather", "is", "alpha", Some(&[1.0, 0.0])).unwrap();
835
836 let results = recall_with_tag_filter(&store, "alpha", &MockEmbedder, 10, None).unwrap();
837 assert_eq!(results.len(), 2);
838 }
839}