Skip to main content

devsper_memory/
index.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4
5/// Simple TF-IDF style embedding index for semantic search.
6/// In production, replace with sentence-transformer embeddings.
7/// Stores term frequency vectors per document.
8pub struct EmbeddingIndex {
9    /// doc_id → term frequencies
10    documents: Arc<RwLock<HashMap<String, HashMap<String, f32>>>>,
11    /// term → document frequency (how many docs contain it)
12    doc_freq: Arc<RwLock<HashMap<String, usize>>>,
13}
14
15impl EmbeddingIndex {
16    pub fn new() -> Self {
17        Self {
18            documents: Arc::new(RwLock::new(HashMap::new())),
19            doc_freq: Arc::new(RwLock::new(HashMap::new())),
20        }
21    }
22
23    /// Index a document by its text content
24    pub async fn index(&self, doc_id: impl Into<String>, text: &str) {
25        let doc_id = doc_id.into();
26        let tf = term_frequencies(text);
27
28        let mut doc_freq = self.doc_freq.write().await;
29        for term in tf.keys() {
30            *doc_freq.entry(term.clone()).or_insert(0) += 1;
31        }
32        drop(doc_freq);
33
34        self.documents.write().await.insert(doc_id, tf);
35    }
36
37    /// Remove a document from the index
38    pub async fn remove(&self, doc_id: &str) {
39        let mut docs = self.documents.write().await;
40        if let Some(tf) = docs.remove(doc_id) {
41            let mut df = self.doc_freq.write().await;
42            for term in tf.keys() {
43                if let Some(count) = df.get_mut(term) {
44                    *count = count.saturating_sub(1);
45                    if *count == 0 {
46                        df.remove(term);
47                    }
48                }
49            }
50        }
51    }
52
53    /// Search for the top-k most relevant documents using TF-IDF cosine similarity
54    pub async fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
55        let query_tf = term_frequencies(query);
56        let docs = self.documents.read().await;
57        let df = self.doc_freq.read().await;
58        let n_docs = docs.len().max(1) as f32;
59
60        let mut scores: Vec<(String, f32)> = docs
61            .iter()
62            .map(|(doc_id, doc_tf)| {
63                let score = cosine_tfidf(&query_tf, doc_tf, &df, n_docs);
64                (doc_id.clone(), score)
65            })
66            .filter(|(_, s)| *s > 0.0)
67            .collect();
68
69        scores.sort_by(|a, b| {
70            b.1.partial_cmp(&a.1)
71                .unwrap_or(std::cmp::Ordering::Equal)
72        });
73        scores.truncate(top_k);
74        scores
75    }
76}
77
78impl Default for EmbeddingIndex {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84fn term_frequencies(text: &str) -> HashMap<String, f32> {
85    let mut counts: HashMap<String, f32> = HashMap::new();
86    let total: f32 = text.split_whitespace().count() as f32;
87    for word in text.split_whitespace() {
88        let term = word
89            .to_lowercase()
90            .trim_matches(|c: char| !c.is_alphanumeric())
91            .to_string();
92        if !term.is_empty() {
93            *counts.entry(term).or_insert(0.0) += 1.0 / total.max(1.0);
94        }
95    }
96    counts
97}
98
99fn cosine_tfidf(
100    query_tf: &HashMap<String, f32>,
101    doc_tf: &HashMap<String, f32>,
102    df: &HashMap<String, usize>,
103    n_docs: f32,
104) -> f32 {
105    let mut dot = 0.0f32;
106    let mut query_norm = 0.0f32;
107    let mut doc_norm = 0.0f32;
108
109    for (term, q_tf) in query_tf {
110        let idf =
111            ((n_docs + 1.0) / (df.get(term).copied().unwrap_or(0) as f32 + 1.0)).ln() + 1.0;
112        let q_tfidf = q_tf * idf;
113        query_norm += q_tfidf * q_tfidf;
114
115        if let Some(d_tf) = doc_tf.get(term) {
116            let d_tfidf = d_tf * idf;
117            dot += q_tfidf * d_tfidf;
118        }
119    }
120
121    for (term, d_tf) in doc_tf {
122        let idf =
123            ((n_docs + 1.0) / (df.get(term).copied().unwrap_or(0) as f32 + 1.0)).ln() + 1.0;
124        doc_norm += (d_tf * idf) * (d_tf * idf);
125    }
126
127    let denom = query_norm.sqrt() * doc_norm.sqrt();
128    if denom == 0.0 {
129        0.0
130    } else {
131        dot / denom
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[tokio::test]
140    async fn index_and_search() {
141        let idx = EmbeddingIndex::new();
142        idx.index("doc1", "cats are fluffy animals that meow").await;
143        idx.index("doc2", "dogs are loyal animals that bark").await;
144        idx.index("doc3", "the weather is sunny today nice").await;
145
146        let results = idx.search("fluffy cats", 2).await;
147        assert!(!results.is_empty());
148        assert_eq!(results[0].0, "doc1");
149    }
150
151    #[tokio::test]
152    async fn remove_from_index() {
153        let idx = EmbeddingIndex::new();
154        idx.index("doc1", "cats meow loudly").await;
155        idx.remove("doc1").await;
156
157        let results = idx.search("cats", 5).await;
158        assert!(results.is_empty());
159    }
160
161    #[tokio::test]
162    async fn empty_index_returns_empty() {
163        let idx = EmbeddingIndex::new();
164        let results = idx.search("anything", 5).await;
165        assert!(results.is_empty());
166    }
167}