Skip to main content

engram/embedding/
tfidf.rs

1//! TF-IDF based embedding fallback
2//!
3//! Simple, fast, no external dependencies. Good for testing and
4//! environments where API calls aren't possible.
5
6use std::collections::HashMap;
7use std::hash::{Hash, Hasher};
8
9use crate::embedding::Embedder;
10use crate::error::Result;
11
12/// TF-IDF based embedder using hashing trick
13pub struct TfIdfEmbedder {
14    dimensions: usize,
15}
16
17impl TfIdfEmbedder {
18    pub fn new(dimensions: usize) -> Self {
19        Self { dimensions }
20    }
21
22    /// Tokenize text into lowercase words
23    fn tokenize(text: &str) -> Vec<String> {
24        text.to_lowercase()
25            .split(|c: char| !c.is_alphanumeric())
26            .filter(|s| s.len() > 1)
27            .map(String::from)
28            .collect()
29    }
30
31    /// Hash a token to a dimension index
32    fn hash_token(token: &str, dimensions: usize) -> usize {
33        let mut hasher = std::collections::hash_map::DefaultHasher::new();
34        token.hash(&mut hasher);
35        (hasher.finish() as usize) % dimensions
36    }
37
38    /// Hash a bigram (token1 + "_" + token2) directly to a dimension index
39    fn hash_bigram(token1: &str, token2: &str, dimensions: usize) -> usize {
40        let mut hasher = std::collections::hash_map::DefaultHasher::new();
41        token1.hash(&mut hasher);
42        "_".hash(&mut hasher);
43        token2.hash(&mut hasher);
44        (hasher.finish() as usize) % dimensions
45    }
46
47    /// Get sign for feature hashing (reduces collision impact)
48    fn hash_sign(token: &str) -> f32 {
49        let mut hasher = std::collections::hash_map::DefaultHasher::new();
50        token.hash(&mut hasher);
51        "_sign".hash(&mut hasher);
52        if hasher.finish().is_multiple_of(2) {
53            1.0
54        } else {
55            -1.0
56        }
57    }
58
59    /// Get sign for bigram feature hashing
60    fn hash_bigram_sign(token1: &str, token2: &str) -> f32 {
61        let mut hasher = std::collections::hash_map::DefaultHasher::new();
62        token1.hash(&mut hasher);
63        "_".hash(&mut hasher);
64        token2.hash(&mut hasher);
65        "_sign".hash(&mut hasher);
66        if hasher.finish().is_multiple_of(2) {
67            1.0
68        } else {
69            -1.0
70        }
71    }
72}
73
74impl Embedder for TfIdfEmbedder {
75    fn embed(&self, text: &str) -> Result<Vec<f32>> {
76        let tokens = Self::tokenize(text);
77        let mut embedding = vec![0.0_f32; self.dimensions];
78
79        if tokens.is_empty() {
80            return Ok(embedding);
81        }
82
83        let doc_len = tokens.len() as f32;
84
85        // Also add bigrams for better semantic capture
86        // Process bigrams first using references to avoid cloning/moving strings yet
87        for window in tokens.windows(2) {
88            let idx = Self::hash_bigram(&window[0], &window[1], self.dimensions);
89            let sign = Self::hash_bigram_sign(&window[0], &window[1]);
90            embedding[idx] += 0.5 * sign; // Bigrams weighted less
91        }
92
93        // Count term frequencies
94        // Consume tokens to avoid cloning strings when inserting into HashMap
95        let mut tf: HashMap<String, f32> = HashMap::new();
96        for token in tokens {
97            *tf.entry(token).or_insert(0.0) += 1.0;
98        }
99
100        // Apply TF-IDF-like weighting with feature hashing
101        for (token, count) in tf {
102            // TF: log(1 + count/doc_len)
103            let tf_score = (1.0 + count / doc_len).ln();
104
105            // IDF approximation based on token length (longer = rarer)
106            let idf_score = 1.0 + (token.len() as f32 * 0.1);
107
108            let weight = tf_score * idf_score;
109            let idx = Self::hash_token(&token, self.dimensions);
110            let sign = Self::hash_sign(&token);
111
112            embedding[idx] += weight * sign;
113        }
114
115        // L2 normalize
116        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
117        if norm > 0.0 {
118            for x in &mut embedding {
119                *x /= norm;
120            }
121        }
122
123        Ok(embedding)
124    }
125
126    fn dimensions(&self) -> usize {
127        self.dimensions
128    }
129
130    fn model_name(&self) -> &str {
131        "tfidf"
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::embedding::cosine_similarity;
139
140    #[test]
141    fn test_tfidf_basic() {
142        let embedder = TfIdfEmbedder::new(384);
143
144        let e1 = embedder.embed("hello world").unwrap();
145        let e2 = embedder.embed("hello world").unwrap();
146
147        // Same text should produce identical embeddings
148        assert_eq!(e1, e2);
149    }
150
151    #[test]
152    fn test_tfidf_similarity() {
153        let embedder = TfIdfEmbedder::new(384);
154
155        let e1 = embedder
156            .embed("the quick brown fox jumps over the lazy dog")
157            .unwrap();
158        let e2 = embedder
159            .embed("a fast brown fox leaps over a sleepy dog")
160            .unwrap();
161        let e3 = embedder
162            .embed("quantum physics and thermodynamics")
163            .unwrap();
164
165        // Similar sentences should have higher similarity
166        let sim_similar = cosine_similarity(&e1, &e2);
167        let sim_different = cosine_similarity(&e1, &e3);
168
169        assert!(
170            sim_similar > sim_different,
171            "Similar sentences should have higher similarity"
172        );
173    }
174
175    #[test]
176    fn test_tfidf_empty() {
177        let embedder = TfIdfEmbedder::new(384);
178        let e = embedder.embed("").unwrap();
179        assert_eq!(e.len(), 384);
180        assert!(e.iter().all(|&x| x == 0.0));
181    }
182
183    #[test]
184    fn test_tfidf_normalized() {
185        let embedder = TfIdfEmbedder::new(384);
186        let e = embedder
187            .embed("this is a test sentence with multiple words")
188            .unwrap();
189
190        let norm: f32 = e.iter().map(|x| x * x).sum::<f32>().sqrt();
191        assert!(
192            (norm - 1.0).abs() < 0.001,
193            "Embedding should be L2 normalized"
194        );
195    }
196}