Skip to main content

deepstrike_core/memory/
semantic.rs

1use std::collections::{HashSet, VecDeque};
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5
6use crate::types::error::Result;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct MemoryEntry {
10    pub text: String,
11    pub score: f64,
12    pub metadata: serde_json::Value,
13}
14
15pub trait SemanticMemory: Send + Sync {
16    fn query(&self, text: &str, top_k: usize) -> Result<Vec<MemoryEntry>>;
17    fn store(&self, entry: MemoryEntry) -> Result<()>;
18}
19
20const MAX_ENTRIES: usize = 10_000;
21
22pub struct InMemorySemanticStore {
23    entries: Mutex<VecDeque<MemoryEntry>>,
24}
25
26impl InMemorySemanticStore {
27    pub fn new() -> Self {
28        Self {
29            entries: Mutex::new(VecDeque::new()),
30        }
31    }
32
33    fn jaccard(a: &str, b: &str) -> f64 {
34        let sa: HashSet<&str> = a.split_whitespace().collect();
35        let sb: HashSet<&str> = b.split_whitespace().collect();
36        let inter = sa.intersection(&sb).count();
37        let union = sa.union(&sb).count();
38        if union == 0 {
39            0.0
40        } else {
41            inter as f64 / union as f64
42        }
43    }
44}
45
46impl Default for InMemorySemanticStore {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SemanticMemory for InMemorySemanticStore {
53    fn store(&self, entry: MemoryEntry) -> Result<()> {
54        let mut entries = self.entries.lock().unwrap();
55        if entries.len() >= MAX_ENTRIES {
56            entries.pop_front();
57        }
58        entries.push_back(entry);
59        Ok(())
60    }
61
62    fn query(&self, text: &str, top_k: usize) -> Result<Vec<MemoryEntry>> {
63        let entries = self.entries.lock().unwrap();
64        let mut scored: Vec<(f64, &MemoryEntry)> = entries
65            .iter()
66            .map(|e| (Self::jaccard(text, &e.text), e))
67            .collect();
68        scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
69        Ok(scored
70            .into_iter()
71            .take(top_k)
72            .map(|(score, e)| MemoryEntry { score, ..e.clone() })
73            .collect())
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn query_returns_top_k_by_jaccard() {
83        let store = InMemorySemanticStore::new();
84        store
85            .store(MemoryEntry {
86                text: "foo bar baz".into(),
87                score: 0.0,
88                metadata: serde_json::Value::Null,
89            })
90            .unwrap();
91        store
92            .store(MemoryEntry {
93                text: "hello world".into(),
94                score: 0.0,
95                metadata: serde_json::Value::Null,
96            })
97            .unwrap();
98        let results = store.query("foo bar", 1).unwrap();
99        assert_eq!(results.len(), 1);
100        assert!(results[0].text.contains("foo"));
101    }
102}