1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4
5pub struct EmbeddingIndex {
9 documents: Arc<RwLock<HashMap<String, HashMap<String, f32>>>>,
11 doc_freq: Arc<RwLock<HashMap<String, usize>>>,
13}
14
15impl EmbeddingIndex {
16 pub fn new() -> Self {
17 Self {
18 documents: Arc::new(RwLock::new(HashMap::new())),
19 doc_freq: Arc::new(RwLock::new(HashMap::new())),
20 }
21 }
22
23 pub async fn index(&self, doc_id: impl Into<String>, text: &str) {
25 let doc_id = doc_id.into();
26 let tf = term_frequencies(text);
27
28 let mut doc_freq = self.doc_freq.write().await;
29 for term in tf.keys() {
30 *doc_freq.entry(term.clone()).or_insert(0) += 1;
31 }
32 drop(doc_freq);
33
34 self.documents.write().await.insert(doc_id, tf);
35 }
36
37 pub async fn remove(&self, doc_id: &str) {
39 let mut docs = self.documents.write().await;
40 if let Some(tf) = docs.remove(doc_id) {
41 let mut df = self.doc_freq.write().await;
42 for term in tf.keys() {
43 if let Some(count) = df.get_mut(term) {
44 *count = count.saturating_sub(1);
45 if *count == 0 {
46 df.remove(term);
47 }
48 }
49 }
50 }
51 }
52
53 pub async fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32)> {
55 let query_tf = term_frequencies(query);
56 let docs = self.documents.read().await;
57 let df = self.doc_freq.read().await;
58 let n_docs = docs.len().max(1) as f32;
59
60 let mut scores: Vec<(String, f32)> = docs
61 .iter()
62 .map(|(doc_id, doc_tf)| {
63 let score = cosine_tfidf(&query_tf, doc_tf, &df, n_docs);
64 (doc_id.clone(), score)
65 })
66 .filter(|(_, s)| *s > 0.0)
67 .collect();
68
69 scores.sort_by(|a, b| {
70 b.1.partial_cmp(&a.1)
71 .unwrap_or(std::cmp::Ordering::Equal)
72 });
73 scores.truncate(top_k);
74 scores
75 }
76}
77
78impl Default for EmbeddingIndex {
79 fn default() -> Self {
80 Self::new()
81 }
82}
83
84fn term_frequencies(text: &str) -> HashMap<String, f32> {
85 let mut counts: HashMap<String, f32> = HashMap::new();
86 let total: f32 = text.split_whitespace().count() as f32;
87 for word in text.split_whitespace() {
88 let term = word
89 .to_lowercase()
90 .trim_matches(|c: char| !c.is_alphanumeric())
91 .to_string();
92 if !term.is_empty() {
93 *counts.entry(term).or_insert(0.0) += 1.0 / total.max(1.0);
94 }
95 }
96 counts
97}
98
99fn cosine_tfidf(
100 query_tf: &HashMap<String, f32>,
101 doc_tf: &HashMap<String, f32>,
102 df: &HashMap<String, usize>,
103 n_docs: f32,
104) -> f32 {
105 let mut dot = 0.0f32;
106 let mut query_norm = 0.0f32;
107 let mut doc_norm = 0.0f32;
108
109 for (term, q_tf) in query_tf {
110 let idf =
111 ((n_docs + 1.0) / (df.get(term).copied().unwrap_or(0) as f32 + 1.0)).ln() + 1.0;
112 let q_tfidf = q_tf * idf;
113 query_norm += q_tfidf * q_tfidf;
114
115 if let Some(d_tf) = doc_tf.get(term) {
116 let d_tfidf = d_tf * idf;
117 dot += q_tfidf * d_tfidf;
118 }
119 }
120
121 for (term, d_tf) in doc_tf {
122 let idf =
123 ((n_docs + 1.0) / (df.get(term).copied().unwrap_or(0) as f32 + 1.0)).ln() + 1.0;
124 doc_norm += (d_tf * idf) * (d_tf * idf);
125 }
126
127 let denom = query_norm.sqrt() * doc_norm.sqrt();
128 if denom == 0.0 {
129 0.0
130 } else {
131 dot / denom
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[tokio::test]
140 async fn index_and_search() {
141 let idx = EmbeddingIndex::new();
142 idx.index("doc1", "cats are fluffy animals that meow").await;
143 idx.index("doc2", "dogs are loyal animals that bark").await;
144 idx.index("doc3", "the weather is sunny today nice").await;
145
146 let results = idx.search("fluffy cats", 2).await;
147 assert!(!results.is_empty());
148 assert_eq!(results[0].0, "doc1");
149 }
150
151 #[tokio::test]
152 async fn remove_from_index() {
153 let idx = EmbeddingIndex::new();
154 idx.index("doc1", "cats meow loudly").await;
155 idx.remove("doc1").await;
156
157 let results = idx.search("cats", 5).await;
158 assert!(results.is_empty());
159 }
160
161 #[tokio::test]
162 async fn empty_index_returns_empty() {
163 let idx = EmbeddingIndex::new();
164 let results = idx.search("anything", 5).await;
165 assert!(results.is_empty());
166 }
167}