engram/embedding/
tfidf.rs1use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8
9use crate::embedding::Embedder;
10use crate::error::Result;
11
12pub struct TfIdfEmbedder {
14 dimensions: usize,
15}
16
17impl TfIdfEmbedder {
18 pub fn new(dimensions: usize) -> Self {
19 Self { dimensions }
20 }
21
22 fn tokenize(text: &str) -> Vec<String> {
24 text.to_lowercase()
25 .split(|c: char| !c.is_alphanumeric())
26 .filter(|s| s.len() > 1)
27 .map(String::from)
28 .collect()
29 }
30
31 fn hash_token(token: &str, dimensions: usize) -> usize {
33 let mut hasher = std::collections::hash_map::DefaultHasher::new();
34 token.hash(&mut hasher);
35 (hasher.finish() as usize) % dimensions
36 }
37
38 fn hash_bigram(token1: &str, token2: &str, dimensions: usize) -> usize {
40 let mut hasher = std::collections::hash_map::DefaultHasher::new();
41 token1.hash(&mut hasher);
42 "_".hash(&mut hasher);
43 token2.hash(&mut hasher);
44 (hasher.finish() as usize) % dimensions
45 }
46
47 fn hash_sign(token: &str) -> f32 {
49 let mut hasher = std::collections::hash_map::DefaultHasher::new();
50 token.hash(&mut hasher);
51 "_sign".hash(&mut hasher);
52 if hasher.finish().is_multiple_of(2) {
53 1.0
54 } else {
55 -1.0
56 }
57 }
58
59 fn hash_bigram_sign(token1: &str, token2: &str) -> f32 {
61 let mut hasher = std::collections::hash_map::DefaultHasher::new();
62 token1.hash(&mut hasher);
63 "_".hash(&mut hasher);
64 token2.hash(&mut hasher);
65 "_sign".hash(&mut hasher);
66 if hasher.finish().is_multiple_of(2) {
67 1.0
68 } else {
69 -1.0
70 }
71 }
72}
73
74impl Embedder for TfIdfEmbedder {
75 fn embed(&self, text: &str) -> Result<Vec<f32>> {
76 let tokens = Self::tokenize(text);
77 let mut embedding = vec![0.0_f32; self.dimensions];
78
79 if tokens.is_empty() {
80 return Ok(embedding);
81 }
82
83 let doc_len = tokens.len() as f32;
84
85 for window in tokens.windows(2) {
88 let idx = Self::hash_bigram(&window[0], &window[1], self.dimensions);
89 let sign = Self::hash_bigram_sign(&window[0], &window[1]);
90 embedding[idx] += 0.5 * sign; }
92
93 let mut tf: HashMap<String, f32> = HashMap::new();
96 for token in tokens {
97 *tf.entry(token).or_insert(0.0) += 1.0;
98 }
99
100 for (token, count) in tf {
102 let tf_score = (1.0 + count / doc_len).ln();
104
105 let idf_score = 1.0 + (token.len() as f32 * 0.1);
107
108 let weight = tf_score * idf_score;
109 let idx = Self::hash_token(&token, self.dimensions);
110 let sign = Self::hash_sign(&token);
111
112 embedding[idx] += weight * sign;
113 }
114
115 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
117 if norm > 0.0 {
118 for x in &mut embedding {
119 *x /= norm;
120 }
121 }
122
123 Ok(embedding)
124 }
125
126 fn dimensions(&self) -> usize {
127 self.dimensions
128 }
129
130 fn model_name(&self) -> &str {
131 "tfidf"
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use crate::embedding::cosine_similarity;
139
140 #[test]
141 fn test_tfidf_basic() {
142 let embedder = TfIdfEmbedder::new(384);
143
144 let e1 = embedder.embed("hello world").unwrap();
145 let e2 = embedder.embed("hello world").unwrap();
146
147 assert_eq!(e1, e2);
149 }
150
151 #[test]
152 fn test_tfidf_similarity() {
153 let embedder = TfIdfEmbedder::new(384);
154
155 let e1 = embedder
156 .embed("the quick brown fox jumps over the lazy dog")
157 .unwrap();
158 let e2 = embedder
159 .embed("a fast brown fox leaps over a sleepy dog")
160 .unwrap();
161 let e3 = embedder
162 .embed("quantum physics and thermodynamics")
163 .unwrap();
164
165 let sim_similar = cosine_similarity(&e1, &e2);
167 let sim_different = cosine_similarity(&e1, &e3);
168
169 assert!(
170 sim_similar > sim_different,
171 "Similar sentences should have higher similarity"
172 );
173 }
174
175 #[test]
176 fn test_tfidf_empty() {
177 let embedder = TfIdfEmbedder::new(384);
178 let e = embedder.embed("").unwrap();
179 assert_eq!(e.len(), 384);
180 assert!(e.iter().all(|&x| x == 0.0));
181 }
182
183 #[test]
184 fn test_tfidf_normalized() {
185 let embedder = TfIdfEmbedder::new(384);
186 let e = embedder
187 .embed("this is a test sentence with multiple words")
188 .unwrap();
189
190 let norm: f32 = e.iter().map(|x| x * x).sum::<f32>().sqrt();
191 assert!(
192 (norm - 1.0).abs() < 0.001,
193 "Embedding should be L2 normalized"
194 );
195 }
196}