phago_embeddings/
simple.rs1use crate::{Embedder, EmbeddingError, EmbeddingResult};
8use std::collections::hash_map::DefaultHasher;
9use std::hash::{Hash, Hasher};
10
11pub struct SimpleEmbedder {
27 dimension: usize,
28 num_hashes: usize,
29}
30
31impl SimpleEmbedder {
32 pub fn new(dimension: usize) -> Self {
34 Self {
35 dimension,
36 num_hashes: 4, }
38 }
39
40 pub fn default_dimension() -> Self {
42 Self::new(256)
43 }
44
45 fn tokenize(&self, text: &str) -> Vec<String> {
47 text.to_lowercase()
48 .split(|c: char| !c.is_alphanumeric())
49 .filter(|s| s.len() > 1)
50 .map(|s| s.to_string())
51 .collect()
52 }
53
54 fn hash_with_seed(&self, word: &str, seed: u64) -> usize {
56 let mut hasher = DefaultHasher::new();
57 seed.hash(&mut hasher);
58 word.hash(&mut hasher);
59 (hasher.finish() as usize) % self.dimension
60 }
61
62 fn sign_hash(&self, word: &str, seed: u64) -> f32 {
64 let mut hasher = DefaultHasher::new();
65 (seed + 1000).hash(&mut hasher);
66 word.hash(&mut hasher);
67 if hasher.finish() % 2 == 0 { 1.0 } else { -1.0 }
68 }
69}
70
71impl Default for SimpleEmbedder {
72 fn default() -> Self {
73 Self::default_dimension()
74 }
75}
76
77impl Embedder for SimpleEmbedder {
78 fn embed(&self, text: &str) -> EmbeddingResult<Vec<f32>> {
79 if text.is_empty() {
80 return Err(EmbeddingError::InvalidInput("Empty text".to_string()));
81 }
82
83 let tokens = self.tokenize(text);
84 if tokens.is_empty() {
85 return Ok(vec![0.0; self.dimension]);
87 }
88
89 let mut vector = vec![0.0f32; self.dimension];
90
91 for token in &tokens {
93 for seed in 0..self.num_hashes as u64 {
94 let idx = self.hash_with_seed(token, seed);
95 let sign = self.sign_hash(token, seed);
96 vector[idx] += sign;
97 }
98 }
99
100 let scale = 1.0 / ((tokens.len() * self.num_hashes) as f32).sqrt();
102 for v in &mut vector {
103 *v *= scale;
104 }
105
106 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
108 if norm > 0.0 {
109 for v in &mut vector {
110 *v /= norm;
111 }
112 }
113
114 Ok(vector)
115 }
116
117 fn dimension(&self) -> usize {
118 self.dimension
119 }
120
121 fn model_name(&self) -> &str {
122 "simple-hash"
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[test]
131 fn test_simple_embedder() {
132 let embedder = SimpleEmbedder::new(128);
133
134 let v1 = embedder.embed("hello world").unwrap();
135 let v2 = embedder.embed("hello world").unwrap();
136 let v3 = embedder.embed("goodbye universe").unwrap();
137
138 assert_eq!(v1.len(), 128);
139
140 let sim_same = embedder.similarity(&v1, &v2).unwrap();
142 assert!((sim_same - 1.0).abs() < 0.001);
143
144 let sim_diff = embedder.similarity(&v1, &v3).unwrap();
146 assert!(sim_diff < 0.9);
147 }
148
149 #[test]
150 fn test_similar_texts() {
151 let embedder = SimpleEmbedder::new(256);
152
153 let v1 = embedder.embed("cell membrane transport").unwrap();
154 let v2 = embedder.embed("membrane cell transport proteins").unwrap();
155 let v3 = embedder.embed("quantum computing algorithms").unwrap();
156
157 let sim_related = embedder.similarity(&v1, &v2).unwrap();
158 let sim_unrelated = embedder.similarity(&v1, &v3).unwrap();
159
160 assert!(sim_related > sim_unrelated);
162 }
163}