1use std::collections::{HashMap, VecDeque};
35use std::sync::Arc;
36use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
37
38use parking_lot::RwLock;
39use serde::{Deserialize, Serialize};
40use tracing::{debug, info};
41
42#[derive(Debug, Clone)]
44pub struct MemoryConfig {
45 pub short_term_capacity: usize,
47 pub short_term_ttl: Duration,
49 pub long_term_capacity: usize,
51 pub similarity_threshold: f32,
53 pub entity_capacity: usize,
55 pub max_attributes_per_entity: usize,
57}
58
59impl Default for MemoryConfig {
60 fn default() -> Self {
61 Self {
62 short_term_capacity: 100,
63 short_term_ttl: Duration::from_secs(3600), long_term_capacity: 10000,
65 similarity_threshold: 0.7,
66 entity_capacity: 1000,
67 max_attributes_per_entity: 50,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct MemoryItem {
75 pub id: String,
77 pub content: String,
79 pub created_at: u64,
81 pub last_accessed: u64,
83 pub access_count: u32,
85 pub importance: f32,
87 pub source: MemorySource,
89 pub metadata: HashMap<String, String>,
91}
92
93impl MemoryItem {
94 pub fn new(content: impl Into<String>, source: MemorySource) -> Self {
95 let now = SystemTime::now()
96 .duration_since(UNIX_EPOCH)
97 .unwrap_or_default()
98 .as_secs();
99
100 Self {
101 id: uuid::Uuid::new_v4().to_string(),
102 content: content.into(),
103 created_at: now,
104 last_accessed: now,
105 access_count: 0,
106 importance: 0.5,
107 source,
108 metadata: HashMap::new(),
109 }
110 }
111
112 pub fn with_importance(mut self, importance: f32) -> Self {
113 self.importance = importance.clamp(0.0, 1.0);
114 self
115 }
116
117 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
118 self.metadata.insert(key.into(), value.into());
119 self
120 }
121
122 fn touch(&mut self) {
123 self.last_accessed = SystemTime::now()
124 .duration_since(UNIX_EPOCH)
125 .unwrap_or_default()
126 .as_secs();
127 self.access_count += 1;
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
133pub enum MemorySource {
134 User,
136 Agent,
138 Tool,
140 External,
142 System,
144}
145
146pub struct ShortTermMemory {
152 items: RwLock<VecDeque<(MemoryItem, Instant)>>,
153 capacity: usize,
154 ttl: Duration,
155}
156
157impl ShortTermMemory {
158 pub fn new(capacity: usize, ttl: Duration) -> Self {
159 Self {
160 items: RwLock::new(VecDeque::with_capacity(capacity)),
161 capacity,
162 ttl,
163 }
164 }
165
166 pub fn add(&self, content: impl Into<String>, source: MemorySource) {
168 let item = MemoryItem::new(content, source);
169 self.add_item(item);
170 }
171
172 pub fn add_item(&self, item: MemoryItem) {
174 let mut items = self.items.write();
175
176 let now = Instant::now();
178 while let Some((_, created)) = items.front() {
179 if now.duration_since(*created) > self.ttl {
180 items.pop_front();
181 } else {
182 break;
183 }
184 }
185
186 while items.len() >= self.capacity {
188 items.pop_front();
189 }
190
191 items.push_back((item, now));
192 }
193
194 pub fn get_all(&self) -> Vec<MemoryItem> {
196 let items = self.items.read();
197 let now = Instant::now();
198
199 items
200 .iter()
201 .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
202 .map(|(item, _)| item.clone())
203 .collect()
204 }
205
206 pub fn get_recent(&self, n: usize) -> Vec<MemoryItem> {
208 let items = self.items.read();
209 let now = Instant::now();
210
211 items
212 .iter()
213 .rev()
214 .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
215 .take(n)
216 .map(|(item, _)| item.clone())
217 .collect()
218 }
219
220 pub fn search(&self, query: &str) -> Vec<MemoryItem> {
222 let query_lower = query.to_lowercase();
223 let items = self.items.read();
224 let now = Instant::now();
225
226 items
227 .iter()
228 .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
229 .filter(|(item, _)| item.content.to_lowercase().contains(&query_lower))
230 .map(|(item, _)| item.clone())
231 .collect()
232 }
233
234 pub fn clear(&self) {
236 self.items.write().clear();
237 }
238
239 pub fn len(&self) -> usize {
241 let items = self.items.read();
242 let now = Instant::now();
243 items
244 .iter()
245 .filter(|(_, created)| now.duration_since(*created) <= self.ttl)
246 .count()
247 }
248
249 pub fn is_empty(&self) -> bool {
250 self.len() == 0
251 }
252
253 pub fn as_context(&self, max_items: usize) -> String {
255 let items = self.get_recent(max_items);
256 items
257 .iter()
258 .map(|item| format!("[{}] {}", format_source(&item.source), item.content))
259 .collect::<Vec<_>>()
260 .join("\n")
261 }
262}
263
264fn format_source(source: &MemorySource) -> &'static str {
265 match source {
266 MemorySource::User => "User",
267 MemorySource::Agent => "Agent",
268 MemorySource::Tool => "Tool",
269 MemorySource::External => "External",
270 MemorySource::System => "System",
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct LongTermEntry {
281 pub item: MemoryItem,
282 pub embedding: Vec<f32>,
283}
284
285pub struct LongTermMemory {
287 entries: RwLock<Vec<LongTermEntry>>,
288 capacity: usize,
289 similarity_threshold: f32,
290}
291
292impl LongTermMemory {
293 pub fn new(capacity: usize, similarity_threshold: f32) -> Self {
294 Self {
295 entries: RwLock::new(Vec::with_capacity(capacity)),
296 capacity,
297 similarity_threshold,
298 }
299 }
300
301 pub fn store(
303 &self,
304 content: impl Into<String>,
305 embedding: Vec<f32>,
306 source: MemorySource,
307 ) -> String {
308 let item = MemoryItem::new(content, source);
309 self.store_item(item, embedding)
310 }
311
312 pub fn store_item(&self, item: MemoryItem, embedding: Vec<f32>) -> String {
314 let id = item.id.clone();
315 let entry = LongTermEntry { item, embedding };
316
317 let mut entries = self.entries.write();
318
319 while entries.len() >= self.capacity {
321 if let Some(idx) = entries
322 .iter()
323 .enumerate()
324 .min_by_key(|(_, e)| e.item.last_accessed)
325 .map(|(idx, _)| idx)
326 {
327 entries.remove(idx);
328 }
329 }
330
331 entries.push(entry);
332 debug!(id = %id, "Long-term memory stored");
333 id
334 }
335
336 pub fn search(&self, query_embedding: &[f32], limit: usize) -> Vec<(MemoryItem, f32)> {
338 let mut entries = self.entries.write();
339
340 let mut results: Vec<_> = entries
341 .iter_mut()
342 .map(|entry| {
343 let similarity = cosine_similarity(&entry.embedding, query_embedding);
344 (entry, similarity)
345 })
346 .filter(|(_, sim)| *sim >= self.similarity_threshold)
347 .collect();
348
349 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
351
352 results
354 .into_iter()
355 .take(limit)
356 .map(|(entry, sim)| {
357 entry.item.touch();
358 (entry.item.clone(), sim)
359 })
360 .collect()
361 }
362
363 pub fn get_all(&self) -> Vec<LongTermEntry> {
365 self.entries.read().clone()
366 }
367
368 pub fn remove(&self, id: &str) -> bool {
370 let mut entries = self.entries.write();
371 let initial_len = entries.len();
372 entries.retain(|e| e.item.id != id);
373 entries.len() < initial_len
374 }
375
376 pub fn clear(&self) {
378 self.entries.write().clear();
379 }
380
381 pub fn len(&self) -> usize {
382 self.entries.read().len()
383 }
384
385 pub fn is_empty(&self) -> bool {
386 self.entries.read().is_empty()
387 }
388
389 pub fn get_important(&self, min_importance: f32) -> Vec<MemoryItem> {
391 self.entries
392 .read()
393 .iter()
394 .filter(|e| e.item.importance >= min_importance)
395 .map(|e| e.item.clone())
396 .collect()
397 }
398}
399
400fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
401 if a.len() != b.len() || a.is_empty() {
402 return 0.0;
403 }
404
405 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
406 let magnitude_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
407 let magnitude_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
408
409 if magnitude_a == 0.0 || magnitude_b == 0.0 {
410 return 0.0;
411 }
412
413 dot_product / (magnitude_a * magnitude_b)
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct Entity {
423 pub name: String,
425 pub entity_type: String,
427 pub attributes: HashMap<String, String>,
429 pub created_at: u64,
431 pub updated_at: u64,
433 pub confidence: f32,
435 pub relations: Vec<(String, String)>, }
438
439impl Entity {
440 pub fn new(name: impl Into<String>, entity_type: impl Into<String>) -> Self {
441 let now = SystemTime::now()
442 .duration_since(UNIX_EPOCH)
443 .unwrap_or_default()
444 .as_secs();
445
446 Self {
447 name: name.into(),
448 entity_type: entity_type.into(),
449 attributes: HashMap::new(),
450 created_at: now,
451 updated_at: now,
452 confidence: 1.0,
453 relations: Vec::new(),
454 }
455 }
456
457 pub fn set_attribute(&mut self, key: impl Into<String>, value: impl Into<String>) {
458 self.attributes.insert(key.into(), value.into());
459 self.touch();
460 }
461
462 pub fn get_attribute(&self, key: &str) -> Option<&String> {
463 self.attributes.get(key)
464 }
465
466 pub fn add_relation(&mut self, relation_type: impl Into<String>, target: impl Into<String>) {
467 self.relations.push((relation_type.into(), target.into()));
468 self.touch();
469 }
470
471 fn touch(&mut self) {
472 self.updated_at = SystemTime::now()
473 .duration_since(UNIX_EPOCH)
474 .unwrap_or_default()
475 .as_secs();
476 }
477}
478
479pub struct EntityMemory {
481 entities: RwLock<HashMap<String, Entity>>,
482 capacity: usize,
483 max_attributes: usize,
484}
485
486impl EntityMemory {
487 pub fn new(capacity: usize, max_attributes: usize) -> Self {
488 Self {
489 entities: RwLock::new(HashMap::with_capacity(capacity)),
490 capacity,
491 max_attributes,
492 }
493 }
494
495 pub fn get_or_create(&self, name: &str, entity_type: &str) -> Entity {
497 let mut entities = self.entities.write();
498
499 if let Some(entity) = entities.get(name) {
500 return entity.clone();
501 }
502
503 while entities.len() >= self.capacity {
505 if let Some(oldest) = entities
507 .iter()
508 .min_by_key(|(_, e)| e.updated_at)
509 .map(|(k, _)| k.clone())
510 {
511 entities.remove(&oldest);
512 }
513 }
514
515 let entity = Entity::new(name, entity_type);
516 entities.insert(name.to_string(), entity.clone());
517 entity
518 }
519
520 pub fn update_attribute(&self, name: &str, key: &str, value: &str) -> bool {
522 let mut entities = self.entities.write();
523
524 if let Some(entity) = entities.get_mut(name) {
525 if entity.attributes.len() < self.max_attributes || entity.attributes.contains_key(key)
526 {
527 entity.set_attribute(key, value);
528 return true;
529 }
530 }
531 false
532 }
533
534 pub fn get(&self, name: &str) -> Option<Entity> {
536 self.entities.read().get(name).cloned()
537 }
538
539 pub fn exists(&self, name: &str) -> bool {
541 self.entities.read().contains_key(name)
542 }
543
544 pub fn add_relation(&self, source: &str, relation_type: &str, target: &str) -> bool {
546 let mut entities = self.entities.write();
547
548 if let Some(entity) = entities.get_mut(source) {
549 entity.add_relation(relation_type, target);
550 return true;
551 }
552 false
553 }
554
555 pub fn get_by_type(&self, entity_type: &str) -> Vec<Entity> {
557 self.entities
558 .read()
559 .values()
560 .filter(|e| e.entity_type == entity_type)
561 .cloned()
562 .collect()
563 }
564
565 pub fn search_by_attribute(&self, key: &str, value: &str) -> Vec<Entity> {
567 self.entities
568 .read()
569 .values()
570 .filter(|e| e.attributes.get(key).map(|v| v == value).unwrap_or(false))
571 .cloned()
572 .collect()
573 }
574
575 pub fn get_all(&self) -> Vec<Entity> {
577 self.entities.read().values().cloned().collect()
578 }
579
580 pub fn remove(&self, name: &str) -> bool {
582 self.entities.write().remove(name).is_some()
583 }
584
585 pub fn clear(&self) {
587 self.entities.write().clear();
588 }
589
590 pub fn len(&self) -> usize {
591 self.entities.read().len()
592 }
593
594 pub fn is_empty(&self) -> bool {
595 self.entities.read().is_empty()
596 }
597
598 pub fn entity_context(&self, name: &str) -> Option<String> {
600 self.get(name).map(|entity| {
601 let mut lines = vec![format!("{} ({})", entity.name, entity.entity_type)];
602
603 for (key, value) in &entity.attributes {
604 lines.push(format!(" - {}: {}", key, value));
605 }
606
607 for (relation, target) in &entity.relations {
608 lines.push(format!(" - {} -> {}", relation, target));
609 }
610
611 lines.join("\n")
612 })
613 }
614}
615
616pub struct MultiMemory {
622 pub short_term: Arc<ShortTermMemory>,
624 pub long_term: Arc<LongTermMemory>,
626 pub entity: Arc<EntityMemory>,
628 #[allow(dead_code)]
630 config: MemoryConfig,
631}
632
633impl MultiMemory {
634 pub fn new(config: MemoryConfig) -> Self {
635 Self {
636 short_term: Arc::new(ShortTermMemory::new(
637 config.short_term_capacity,
638 config.short_term_ttl,
639 )),
640 long_term: Arc::new(LongTermMemory::new(
641 config.long_term_capacity,
642 config.similarity_threshold,
643 )),
644 entity: Arc::new(EntityMemory::new(
645 config.entity_capacity,
646 config.max_attributes_per_entity,
647 )),
648 config,
649 }
650 }
651
652 pub fn default_config() -> Self {
654 Self::new(MemoryConfig::default())
655 }
656
657 pub fn recall(
659 &self,
660 query: &str,
661 query_embedding: Option<&[f32]>,
662 limit: usize,
663 ) -> RecallResult {
664 let mut result = RecallResult::default();
665
666 result.short_term = self.short_term.search(query);
668
669 if let Some(embedding) = query_embedding {
671 result.long_term = self.long_term.search(embedding, limit);
672 }
673
674 let words: Vec<&str> = query.split_whitespace().collect();
676 for word in words {
677 if let Some(entity) = self.entity.get(word) {
678 result.entities.push(entity);
679 }
680 }
681
682 result
683 }
684
685 pub fn build_context(
687 &self,
688 query: &str,
689 query_embedding: Option<&[f32]>,
690 max_short_term: usize,
691 max_long_term: usize,
692 ) -> String {
693 let mut sections = Vec::new();
694
695 let recent = self.short_term.get_recent(max_short_term);
697 if !recent.is_empty() {
698 let context = recent
699 .iter()
700 .map(|item| format!("- {}", item.content))
701 .collect::<Vec<_>>()
702 .join("\n");
703 sections.push(format!("## Recent Context\n{}", context));
704 }
705
706 if let Some(embedding) = query_embedding {
708 let long_term = self.long_term.search(embedding, max_long_term);
709 if !long_term.is_empty() {
710 let context = long_term
711 .iter()
712 .map(|(item, score)| format!("- {} (relevance: {:.2})", item.content, score))
713 .collect::<Vec<_>>()
714 .join("\n");
715 sections.push(format!("## Relevant Knowledge\n{}", context));
716 }
717 }
718
719 let words: Vec<&str> = query.split_whitespace().collect();
721 let mut entity_contexts = Vec::new();
722 for word in words {
723 if let Some(ctx) = self.entity.entity_context(word) {
724 entity_contexts.push(ctx);
725 }
726 }
727 if !entity_contexts.is_empty() {
728 sections.push(format!(
729 "## Known Entities\n{}",
730 entity_contexts.join("\n\n")
731 ));
732 }
733
734 sections.join("\n\n")
735 }
736
737 pub fn clear_all(&self) {
739 self.short_term.clear();
740 self.long_term.clear();
741 self.entity.clear();
742 info!("All memories cleared");
743 }
744
745 pub fn stats(&self) -> MemoryStats {
747 MemoryStats {
748 short_term_count: self.short_term.len(),
749 long_term_count: self.long_term.len(),
750 entity_count: self.entity.len(),
751 }
752 }
753}
754
755#[derive(Debug, Default)]
757pub struct RecallResult {
758 pub short_term: Vec<MemoryItem>,
760 pub long_term: Vec<(MemoryItem, f32)>,
762 pub entities: Vec<Entity>,
764}
765
766impl RecallResult {
767 pub fn is_empty(&self) -> bool {
768 self.short_term.is_empty() && self.long_term.is_empty() && self.entities.is_empty()
769 }
770
771 pub fn total_count(&self) -> usize {
772 self.short_term.len() + self.long_term.len() + self.entities.len()
773 }
774}
775
776#[derive(Debug, Clone, Serialize, Deserialize)]
778pub struct MemoryStats {
779 pub short_term_count: usize,
780 pub long_term_count: usize,
781 pub entity_count: usize,
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787
788 #[test]
789 fn test_short_term_memory_add_and_retrieve() {
790 let memory = ShortTermMemory::new(10, Duration::from_secs(3600));
791
792 memory.add("First message", MemorySource::User);
793 memory.add("Second message", MemorySource::Agent);
794
795 let all = memory.get_all();
796 assert_eq!(all.len(), 2);
797 }
798
799 #[test]
800 fn test_short_term_memory_capacity() {
801 let memory = ShortTermMemory::new(3, Duration::from_secs(3600));
802
803 memory.add("1", MemorySource::User);
804 memory.add("2", MemorySource::User);
805 memory.add("3", MemorySource::User);
806 memory.add("4", MemorySource::User);
807
808 let all = memory.get_all();
809 assert_eq!(all.len(), 3);
810 assert_eq!(all[0].content, "2");
811 }
812
813 #[test]
814 fn test_short_term_memory_search() {
815 let memory = ShortTermMemory::new(10, Duration::from_secs(3600));
816
817 memory.add("Hello world", MemorySource::User);
818 memory.add("Goodbye world", MemorySource::User);
819 memory.add("Something else", MemorySource::User);
820
821 let results = memory.search("world");
822 assert_eq!(results.len(), 2);
823 }
824
825 #[test]
826 fn test_long_term_memory_store_and_search() {
827 let memory = LongTermMemory::new(100, 0.5);
828
829 let embedding1 = vec![1.0, 0.0, 0.0];
830 let embedding2 = vec![0.0, 1.0, 0.0];
831 let embedding3 = vec![0.9, 0.1, 0.0]; memory.store("First fact", embedding1.clone(), MemorySource::External);
834 memory.store("Second fact", embedding2, MemorySource::External);
835 memory.store("Third fact", embedding3, MemorySource::External);
836
837 let results = memory.search(&embedding1, 2);
838 assert_eq!(results.len(), 2);
839 assert!(results[0].1 > results[1].1); }
841
842 #[test]
843 fn test_long_term_memory_threshold() {
844 let memory = LongTermMemory::new(100, 0.9); let embedding1 = vec![1.0, 0.0, 0.0];
847 let embedding2 = vec![0.0, 1.0, 0.0]; memory.store("First fact", embedding1.clone(), MemorySource::External);
850 memory.store("Second fact", embedding2, MemorySource::External);
851
852 let results = memory.search(&embedding1, 10);
853 assert_eq!(results.len(), 1); }
855
856 #[test]
857 fn test_entity_memory_create_and_update() {
858 let memory = EntityMemory::new(100, 50);
859
860 let entity = memory.get_or_create("Rust", "programming_language");
861 assert_eq!(entity.name, "Rust");
862
863 memory.update_attribute("Rust", "creator", "Mozilla");
864 memory.update_attribute("Rust", "year", "2010");
865
866 let updated = memory.get("Rust").unwrap();
867 assert_eq!(updated.attributes.get("creator").unwrap(), "Mozilla");
868 assert_eq!(updated.attributes.get("year").unwrap(), "2010");
869 }
870
871 #[test]
872 fn test_entity_memory_relations() {
873 let memory = EntityMemory::new(100, 50);
874
875 memory.get_or_create("Rust", "programming_language");
876 memory.get_or_create("Mozilla", "organization");
877
878 memory.add_relation("Rust", "created_by", "Mozilla");
879
880 let rust = memory.get("Rust").unwrap();
881 assert_eq!(rust.relations.len(), 1);
882 assert_eq!(
883 rust.relations[0],
884 ("created_by".to_string(), "Mozilla".to_string())
885 );
886 }
887
888 #[test]
889 fn test_entity_memory_search() {
890 let memory = EntityMemory::new(100, 50);
891
892 memory.get_or_create("Rust", "programming_language");
893 memory.get_or_create("Python", "programming_language");
894 memory.get_or_create("Mozilla", "organization");
895
896 memory.update_attribute("Rust", "type", "compiled");
897 memory.update_attribute("Python", "type", "interpreted");
898
899 let langs = memory.get_by_type("programming_language");
900 assert_eq!(langs.len(), 2);
901
902 let compiled = memory.search_by_attribute("type", "compiled");
903 assert_eq!(compiled.len(), 1);
904 assert_eq!(compiled[0].name, "Rust");
905 }
906
907 #[test]
908 fn test_multi_memory_recall() {
909 let memory = MultiMemory::default_config();
910
911 memory
913 .short_term
914 .add("User asked about Rust programming", MemorySource::User);
915
916 let embedding = vec![1.0, 0.0, 0.0];
918 memory.long_term.store(
919 "Rust is a systems programming language",
920 embedding.clone(),
921 MemorySource::External,
922 );
923
924 memory.entity.get_or_create("Rust", "programming_language");
926 memory.entity.update_attribute("Rust", "creator", "Mozilla");
927
928 let result = memory.recall("Rust", Some(&embedding), 10);
930
931 assert_eq!(result.short_term.len(), 1); assert_eq!(result.long_term.len(), 1); assert_eq!(result.entities.len(), 1); }
935
936 #[test]
937 fn test_multi_memory_build_context() {
938 let memory = MultiMemory::default_config();
939
940 memory
941 .short_term
942 .add("Previous message", MemorySource::User);
943
944 let embedding = vec![1.0, 0.0, 0.0];
945 memory
946 .long_term
947 .store("Relevant fact", embedding.clone(), MemorySource::External);
948
949 memory.entity.get_or_create("Test", "concept");
950 memory
951 .entity
952 .update_attribute("Test", "description", "A test entity");
953
954 let context = memory.build_context("Test query", Some(&embedding), 5, 5);
955
956 assert!(context.contains("Previous message"));
957 assert!(context.contains("Relevant fact"));
958 assert!(context.contains("Test"));
959 }
960
961 #[test]
962 fn test_memory_stats() {
963 let memory = MultiMemory::default_config();
964
965 memory.short_term.add("Message 1", MemorySource::User);
966 memory.short_term.add("Message 2", MemorySource::User);
967
968 memory
969 .long_term
970 .store("Fact 1", vec![1.0], MemorySource::External);
971
972 memory.entity.get_or_create("Entity1", "type");
973 memory.entity.get_or_create("Entity2", "type");
974 memory.entity.get_or_create("Entity3", "type");
975
976 let stats = memory.stats();
977 assert_eq!(stats.short_term_count, 2);
978 assert_eq!(stats.long_term_count, 1);
979 assert_eq!(stats.entity_count, 3);
980 }
981
982 #[test]
983 fn test_cosine_similarity() {
984 assert!((cosine_similarity(&[1.0, 0.0], &[1.0, 0.0]) - 1.0).abs() < 0.001);
986
987 assert!((cosine_similarity(&[1.0, 0.0], &[0.0, 1.0])).abs() < 0.001);
989
990 assert!((cosine_similarity(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 0.001);
992 }
993
994 #[test]
995 fn test_memory_item_builder() {
996 let item = MemoryItem::new("Test content", MemorySource::User)
997 .with_importance(0.8)
998 .with_metadata("key", "value");
999
1000 assert_eq!(item.content, "Test content");
1001 assert_eq!(item.importance, 0.8);
1002 assert_eq!(item.metadata.get("key").unwrap(), "value");
1003 }
1004}