Skip to main content

noether_engine/index/
embedding.rs

1use sha2::{Digest, Sha256};
2
3pub type Embedding = Vec<f32>;
4
5#[derive(Debug, thiserror::Error)]
6pub enum EmbeddingError {
7    #[error("embedding provider error: {0}")]
8    Provider(String),
9}
10
11/// Trait for generating vector embeddings from text.
12pub trait EmbeddingProvider: Send + Sync {
13    fn dimensions(&self) -> usize;
14    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError>;
15
16    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
17        texts.iter().map(|t| self.embed(t)).collect()
18    }
19}
20
21/// Deterministic mock embedding provider using SHA-256 hashing.
22///
23/// Produces normalized vectors where identical text always yields identical
24/// embeddings. Different text yields uncorrelated embeddings. No semantic
25/// similarity — purely structural; use a real provider for semantic quality.
26pub struct MockEmbeddingProvider {
27    dimensions: usize,
28}
29
30impl MockEmbeddingProvider {
31    pub fn new(dimensions: usize) -> Self {
32        Self { dimensions }
33    }
34}
35
36impl EmbeddingProvider for MockEmbeddingProvider {
37    fn dimensions(&self) -> usize {
38        self.dimensions
39    }
40
41    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
42        // Generate deterministic bytes by iteratively hashing
43        let mut bytes = Vec::with_capacity(self.dimensions);
44        let mut current = Sha256::digest(text.as_bytes()).to_vec();
45        while bytes.len() < self.dimensions {
46            for &b in &current {
47                if bytes.len() >= self.dimensions {
48                    break;
49                }
50                bytes.push(b);
51            }
52            // Hash the current hash to get more bytes
53            current = Sha256::digest(&current).to_vec();
54        }
55
56        let mut vec: Vec<f32> = bytes[..self.dimensions]
57            .iter()
58            .map(|&b| (b as f32 / 127.5) - 1.0)
59            .collect();
60
61        // L2-normalize
62        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
63        if norm > 0.0 {
64            for v in &mut vec {
65                *v /= norm;
66            }
67        }
68
69        Ok(vec)
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::*;
76
77    #[test]
78    fn mock_produces_correct_dimensions() {
79        let provider = MockEmbeddingProvider::new(64);
80        let emb = provider.embed("hello").unwrap();
81        assert_eq!(emb.len(), 64);
82    }
83
84    #[test]
85    fn mock_is_deterministic() {
86        let provider = MockEmbeddingProvider::new(32);
87        let e1 = provider.embed("hello world").unwrap();
88        let e2 = provider.embed("hello world").unwrap();
89        assert_eq!(e1, e2);
90    }
91
92    #[test]
93    fn mock_different_text_different_embedding() {
94        let provider = MockEmbeddingProvider::new(32);
95        let e1 = provider.embed("hello").unwrap();
96        let e2 = provider.embed("world").unwrap();
97        assert_ne!(e1, e2);
98    }
99
100    #[test]
101    fn mock_embeddings_are_normalized() {
102        let provider = MockEmbeddingProvider::new(128);
103        let emb = provider.embed("test text").unwrap();
104        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
105        assert!((norm - 1.0).abs() < 1e-5, "norm should be ~1.0, got {norm}");
106    }
107
108    #[test]
109    fn mock_batch_matches_individual() {
110        let provider = MockEmbeddingProvider::new(32);
111        let batch = provider.embed_batch(&["a", "b", "c"]).unwrap();
112        let individual: Vec<Embedding> = ["a", "b", "c"]
113            .iter()
114            .map(|t| provider.embed(t).unwrap())
115            .collect();
116        assert_eq!(batch, individual);
117    }
118}