Skip to main content

mentedb_embedding/
hash_provider.rs

1//! A deterministic, zero-dependency embedding provider for testing and development.
2//!
3//! Generates embeddings by hashing text content. Same text always produces the same
4//! embedding vector. NOT suitable for production similarity search.
5
6use mentedb_core::error::MenteResult;
7
8use crate::provider::{AsyncEmbeddingProvider, EmbeddingProvider};
9
10/// Deterministic hash-based embedding provider.
11///
12/// Useful for testing the full embedding pipeline without requiring an ML model.
13/// The same input text always produces the same embedding vector.
14pub struct HashEmbeddingProvider {
15    dimensions: usize,
16    model_name: String,
17}
18
19impl HashEmbeddingProvider {
20    /// Create a new hash embedding provider with the given dimensions.
21    pub fn new(dimensions: usize) -> Self {
22        Self {
23            dimensions,
24            model_name: format!("hash-embedding-{dimensions}d"),
25        }
26    }
27
28    /// Create a new hash embedding provider with the default 384 dimensions.
29    pub fn default_384() -> Self {
30        Self::new(384)
31    }
32
33    /// Hash text combined with a dimension index to produce a deterministic f32 value.
34    fn hash_dimension(text: &str, dim: usize) -> f32 {
35        // Simple FNV-1a-inspired hash combining text bytes with dimension index
36        let mut hash: u64 = 0xcbf29ce484222325;
37        let prime: u64 = 0x100000001b3;
38
39        // Mix in the dimension index
40        for byte in dim.to_le_bytes() {
41            hash ^= byte as u64;
42            hash = hash.wrapping_mul(prime);
43        }
44
45        // Mix in the text bytes
46        for byte in text.as_bytes() {
47            hash ^= *byte as u64;
48            hash = hash.wrapping_mul(prime);
49        }
50
51        // Convert to f32 in [-1.0, 1.0]
52        (((hash as f64) / (u64::MAX as f64)) * 2.0 - 1.0) as f32
53    }
54
55    /// Generate a normalized embedding for the given text.
56    fn compute_embedding(&self, text: &str) -> Vec<f32> {
57        let mut embedding: Vec<f32> = (0..self.dimensions)
58            .map(|dim| Self::hash_dimension(text, dim))
59            .collect();
60
61        // L2 normalize
62        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
63        if norm > 0.0 {
64            for val in &mut embedding {
65                *val /= norm;
66            }
67        }
68
69        embedding
70    }
71}
72
73impl EmbeddingProvider for HashEmbeddingProvider {
74    fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
75        Ok(self.compute_embedding(text))
76    }
77
78    fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
79        Ok(texts.iter().map(|t| self.compute_embedding(t)).collect())
80    }
81
82    fn dimensions(&self) -> usize {
83        self.dimensions
84    }
85
86    fn model_name(&self) -> &str {
87        &self.model_name
88    }
89}
90
91impl AsyncEmbeddingProvider for HashEmbeddingProvider {
92    async fn embed(&self, text: &str) -> MenteResult<Vec<f32>> {
93        Ok(self.compute_embedding(text))
94    }
95
96    async fn embed_batch(&self, texts: &[&str]) -> MenteResult<Vec<Vec<f32>>> {
97        Ok(texts.iter().map(|t| self.compute_embedding(t)).collect())
98    }
99
100    fn dimensions(&self) -> usize {
101        self.dimensions
102    }
103
104    fn model_name(&self) -> &str {
105        &self.model_name
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[test]
114    fn test_deterministic() {
115        let provider = HashEmbeddingProvider::default_384();
116        let e1 = EmbeddingProvider::embed(&provider, "hello world").unwrap();
117        let e2 = EmbeddingProvider::embed(&provider, "hello world").unwrap();
118        assert_eq!(e1, e2);
119    }
120
121    #[test]
122    fn test_correct_dimensions() {
123        let provider = HashEmbeddingProvider::new(128);
124        let emb = EmbeddingProvider::embed(&provider, "test").unwrap();
125        assert_eq!(emb.len(), 128);
126    }
127
128    #[test]
129    fn test_normalized() {
130        let provider = HashEmbeddingProvider::default_384();
131        let emb = EmbeddingProvider::embed(&provider, "test normalization").unwrap();
132        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
133        assert!((norm - 1.0).abs() < 1e-5);
134    }
135}