Skip to main content

noether_engine/index/
embedding.rs

1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4use sha2::{Digest, Sha256};
5
6pub type Embedding = Vec<f32>;
7
8#[derive(Debug, thiserror::Error)]
9pub enum EmbeddingError {
10    #[error("embedding provider error: {0}")]
11    Provider(String),
12}
13
14/// Trait for generating vector embeddings from text.
15pub trait EmbeddingProvider: Send + Sync {
16    fn dimensions(&self) -> usize;
17    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError>;
18
19    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
20        texts.iter().map(|t| self.embed(t)).collect()
21    }
22}
23
24/// Deterministic mock embedding provider using SHA-256 hashing.
25///
26/// Produces normalized vectors where identical text always yields identical
27/// embeddings. Different text yields uncorrelated embeddings. No semantic
28/// similarity — purely structural; use a real provider for semantic quality.
29pub struct MockEmbeddingProvider {
30    dimensions: usize,
31}
32
33impl MockEmbeddingProvider {
34    pub fn new(dimensions: usize) -> Self {
35        Self { dimensions }
36    }
37}
38
39impl EmbeddingProvider for MockEmbeddingProvider {
40    fn dimensions(&self) -> usize {
41        self.dimensions
42    }
43
44    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
45        // Generate deterministic bytes by iteratively hashing
46        let mut bytes = Vec::with_capacity(self.dimensions);
47        let mut current = Sha256::digest(text.as_bytes()).to_vec();
48        while bytes.len() < self.dimensions {
49            for &b in &current {
50                if bytes.len() >= self.dimensions {
51                    break;
52                }
53                bytes.push(b);
54            }
55            // Hash the current hash to get more bytes
56            current = Sha256::digest(&current).to_vec();
57        }
58
59        let mut vec: Vec<f32> = bytes[..self.dimensions]
60            .iter()
61            .map(|&b| (b as f32 / 127.5) - 1.0)
62            .collect();
63
64        // L2-normalize
65        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
66        if norm > 0.0 {
67            for v in &mut vec {
68                *v /= norm;
69            }
70        }
71
72        Ok(vec)
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn mock_produces_correct_dimensions() {
82        let provider = MockEmbeddingProvider::new(64);
83        let emb = provider.embed("hello").unwrap();
84        assert_eq!(emb.len(), 64);
85    }
86
87    #[test]
88    fn mock_is_deterministic() {
89        let provider = MockEmbeddingProvider::new(32);
90        let e1 = provider.embed("hello world").unwrap();
91        let e2 = provider.embed("hello world").unwrap();
92        assert_eq!(e1, e2);
93    }
94
95    #[test]
96    fn mock_different_text_different_embedding() {
97        let provider = MockEmbeddingProvider::new(32);
98        let e1 = provider.embed("hello").unwrap();
99        let e2 = provider.embed("world").unwrap();
100        assert_ne!(e1, e2);
101    }
102
103    #[test]
104    fn mock_embeddings_are_normalized() {
105        let provider = MockEmbeddingProvider::new(128);
106        let emb = provider.embed("test text").unwrap();
107        let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
108        assert!((norm - 1.0).abs() < 1e-5, "norm should be ~1.0, got {norm}");
109    }
110
111    #[test]
112    fn mock_batch_matches_individual() {
113        let provider = MockEmbeddingProvider::new(32);
114        let batch = provider.embed_batch(&["a", "b", "c"]).unwrap();
115        let individual: Vec<Embedding> = ["a", "b", "c"]
116            .iter()
117            .map(|t| provider.embed(t).unwrap())
118            .collect();
119        assert_eq!(batch, individual);
120    }
121}