Skip to main content

phago_embeddings/
simple.rs

1//! Simple hash-based embedder (no external dependencies).
2//!
3//! This embedder uses a hash-based approach to create fixed-dimension vectors.
4//! While not as semantically rich as neural embeddings, it provides a fast
5//! baseline that works without any ML models.
6
7use crate::{Embedder, EmbeddingError, EmbeddingResult};
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10
11/// Simple hash-based embedder.
12///
13/// Creates embeddings by hashing words into a fixed-dimension space.
14/// Uses multiple hash functions for better distribution.
15///
16/// # Example
17///
18/// ```rust
19/// use phago_embeddings::SimpleEmbedder;
20/// use phago_embeddings::Embedder;
21///
22/// let embedder = SimpleEmbedder::new(128);
23/// let vec = embedder.embed("hello world").unwrap();
24/// assert_eq!(vec.len(), 128);
25/// ```
26pub struct SimpleEmbedder {
27    dimension: usize,
28    num_hashes: usize,
29}
30
31impl SimpleEmbedder {
32    /// Create a new simple embedder with specified dimension.
33    pub fn new(dimension: usize) -> Self {
34        Self {
35            dimension,
36            num_hashes: 4, // Multiple hashes for better distribution
37        }
38    }
39
40    /// Create with default dimension (256).
41    pub fn default_dimension() -> Self {
42        Self::new(256)
43    }
44
45    /// Tokenize text into words.
46    fn tokenize(&self, text: &str) -> Vec<String> {
47        text.to_lowercase()
48            .split(|c: char| !c.is_alphanumeric())
49            .filter(|s| s.len() > 1)
50            .map(|s| s.to_string())
51            .collect()
52    }
53
54    /// Hash a word with a seed to get an index.
55    fn hash_with_seed(&self, word: &str, seed: u64) -> usize {
56        let mut hasher = DefaultHasher::new();
57        seed.hash(&mut hasher);
58        word.hash(&mut hasher);
59        (hasher.finish() as usize) % self.dimension
60    }
61
62    /// Hash a word with a seed to get a sign (+1 or -1).
63    fn sign_hash(&self, word: &str, seed: u64) -> f32 {
64        let mut hasher = DefaultHasher::new();
65        (seed + 1000).hash(&mut hasher);
66        word.hash(&mut hasher);
67        if hasher.finish() % 2 == 0 { 1.0 } else { -1.0 }
68    }
69}
70
71impl Default for SimpleEmbedder {
72    fn default() -> Self {
73        Self::default_dimension()
74    }
75}
76
77impl Embedder for SimpleEmbedder {
78    fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
79        if text.is_empty() {
80            return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
81        }
82
83        let tokens = self.tokenize(text);
84        if tokens.is_empty() {
85            // Return zero vector for text with no valid tokens
86            return Ok(vec![0.0; self.dimension]);
87        }
88
89        let mut vector = vec![0.0f32; self.dimension];
90
91        // Use multiple hash functions for each token
92        for token in &tokens {
93            for seed in 0..self.num_hashes as u64 {
94                let idx = self.hash_with_seed(token, seed);
95                let sign = self.sign_hash(token, seed);
96                vector[idx] += sign;
97            }
98        }
99
100        // Normalize by token count and number of hashes
101        let scale = 1.0 / ((tokens.len() * self.num_hashes) as f32).sqrt();
102        for v in &mut vector {
103            *v *= scale;
104        }
105
106        // L2 normalize
107        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
108        if norm > 0.0 {
109            for v in &mut vector {
110                *v /= norm;
111            }
112        }
113
114        Ok(vector)
115    }
116
117    fn dimension(&self) -> usize {
118        self.dimension
119    }
120
121    fn model_name(&self) -> &str {
122        "simple-hash"
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_simple_embedder() {
132        let embedder = SimpleEmbedder::new(128);
133
134        let v1 = embedder.embed("hello world").unwrap();
135        let v2 = embedder.embed("hello world").unwrap();
136        let v3 = embedder.embed("goodbye universe").unwrap();
137
138        assert_eq!(v1.len(), 128);
139
140        // Same text should produce same embedding
141        let sim_same = embedder.similarity(&v1, &v2).unwrap();
142        assert!((sim_same - 1.0).abs() < 0.001);
143
144        // Different text should produce different embedding
145        let sim_diff = embedder.similarity(&v1, &v3).unwrap();
146        assert!(sim_diff < 0.9);
147    }
148
149    #[test]
150    fn test_similar_texts() {
151        let embedder = SimpleEmbedder::new(256);
152
153        let v1 = embedder.embed("cell membrane transport").unwrap();
154        let v2 = embedder.embed("membrane cell transport proteins").unwrap();
155        let v3 = embedder.embed("quantum computing algorithms").unwrap();
156
157        let sim_related = embedder.similarity(&v1, &v2).unwrap();
158        let sim_unrelated = embedder.similarity(&v1, &v3).unwrap();
159
160        // Related texts should have higher similarity
161        assert!(sim_related > sim_unrelated);
162    }
163}