mentedb_embedding/
hash_provider.rs1use mentedb_core::error::MenteResult;
7
8use crate::provider::{AsyncEmbeddingProvider, EmbeddingProvider};
9
10pub struct HashEmbeddingProvider {
15 dimensions: usize,
16 model_name: String,
17}
18
19impl HashEmbeddingProvider {
20 pub fn new(dimensions: usize) -> Self {
22 Self {
23 dimensions,
24 model_name: format!("hash-embedding-{dimensions}d"),
25 }
26 }
27
28 pub fn default_384() -> Self {
30 Self::new(384)
31 }
32
33 fn hash_dimension(text: &str, dim: usize) -> f32 {
35 let mut hash: u64 = 0xcbf29ce484222325;
37 let prime: u64 = 0x100000001b3;
38
39 for byte in dim.to_le_bytes() {
41 hash ^= byte as u64;
42 hash = hash.wrapping_mul(prime);
43 }
44
45 for byte in text.as_bytes() {
47 hash ^= *byte as u64;
48 hash = hash.wrapping_mul(prime);
49 }
50
51 (((hash as f64) / (u64::MAX as f64)) * 2.0 - 1.0) as f32
53 }
54
55 fn compute_embedding(&self, text: &str) -> Vec<f32> {
57 let mut embedding: Vec<f32> = (0..self.dimensions)
58 .map(|dim| Self::hash_dimension(text, dim))
59 .collect();
60
61 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
63 if norm > 0.0 {
64 for val in &mut embedding {
65 *val /= norm;
66 }
67 }
68
69 embedding
70 }
71}
72
73impl EmbeddingProvider for HashEmbeddingProvider {
74 fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
75 Ok(self.compute_embedding(text))
76 }
77
78 fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
79 Ok(texts.iter().map(|t| self.compute_embedding(t)).collect())
80 }
81
82 fn dimensions(&self) -> usize {
83 self.dimensions
84 }
85
86 fn model_name(&self) -> &str {
87 &self.model_name
88 }
89}
90
91impl AsyncEmbeddingProvider for HashEmbeddingProvider {
92 async fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
93 Ok(self.compute_embedding(text))
94 }
95
96 async fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
97 Ok(texts.iter().map(|t| self.compute_embedding(t)).collect())
98 }
99
100 fn dimensions(&self) -> usize {
101 self.dimensions
102 }
103
104 fn model_name(&self) -> &str {
105 &self.model_name
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 #[test]
114 fn test_deterministic() {
115 let provider = HashEmbeddingProvider::default_384();
116 let e1 = EmbeddingProvider::embed(&provider, "hello world").unwrap();
117 let e2 = EmbeddingProvider::embed(&provider, "hello world").unwrap();
118 assert_eq!(e1, e2);
119 }
120
121 #[test]
122 fn test_correct_dimensions() {
123 let provider = HashEmbeddingProvider::new(128);
124 let emb = EmbeddingProvider::embed(&provider, "test").unwrap();
125 assert_eq!(emb.len(), 128);
126 }
127
128 #[test]
129 fn test_normalized() {
130 let provider = HashEmbeddingProvider::default_384();
131 let emb = EmbeddingProvider::embed(&provider, "test normalization").unwrap();
132 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
133 assert!((norm - 1.0).abs() < 1e-5);
134 }
135}