noether_engine/index/
embedding.rs1use sha2::{Digest, Sha256};
2
3pub type Embedding = Vec<f32>;
4
5#[derive(Debug, thiserror::Error)]
6pub enum EmbeddingError {
7 #[error("embedding provider error: {0}")]
8 Provider(String),
9}
10
11pub trait EmbeddingProvider: Send + Sync {
13 fn dimensions(&self) -> usize;
14 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError>;
15
16 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
17 texts.iter().map(|t| self.embed(t)).collect()
18 }
19}
20
21pub struct MockEmbeddingProvider {
27 dimensions: usize,
28}
29
30impl MockEmbeddingProvider {
31 pub fn new(dimensions: usize) -> Self {
32 Self { dimensions }
33 }
34}
35
36impl EmbeddingProvider for MockEmbeddingProvider {
37 fn dimensions(&self) -> usize {
38 self.dimensions
39 }
40
41 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
42 let mut bytes = Vec::with_capacity(self.dimensions);
44 let mut current = Sha256::digest(text.as_bytes()).to_vec();
45 while bytes.len() < self.dimensions {
46 for &b in ¤t {
47 if bytes.len() >= self.dimensions {
48 break;
49 }
50 bytes.push(b);
51 }
52 current = Sha256::digest(¤t).to_vec();
54 }
55
56 let mut vec: Vec<f32> = bytes[..self.dimensions]
57 .iter()
58 .map(|&b| (b as f32 / 127.5) - 1.0)
59 .collect();
60
61 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
63 if norm > 0.0 {
64 for v in &mut vec {
65 *v /= norm;
66 }
67 }
68
69 Ok(vec)
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76
77 #[test]
78 fn mock_produces_correct_dimensions() {
79 let provider = MockEmbeddingProvider::new(64);
80 let emb = provider.embed("hello").unwrap();
81 assert_eq!(emb.len(), 64);
82 }
83
84 #[test]
85 fn mock_is_deterministic() {
86 let provider = MockEmbeddingProvider::new(32);
87 let e1 = provider.embed("hello world").unwrap();
88 let e2 = provider.embed("hello world").unwrap();
89 assert_eq!(e1, e2);
90 }
91
92 #[test]
93 fn mock_different_text_different_embedding() {
94 let provider = MockEmbeddingProvider::new(32);
95 let e1 = provider.embed("hello").unwrap();
96 let e2 = provider.embed("world").unwrap();
97 assert_ne!(e1, e2);
98 }
99
100 #[test]
101 fn mock_embeddings_are_normalized() {
102 let provider = MockEmbeddingProvider::new(128);
103 let emb = provider.embed("test text").unwrap();
104 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
105 assert!((norm - 1.0).abs() < 1e-5, "norm should be ~1.0, got {norm}");
106 }
107
108 #[test]
109 fn mock_batch_matches_individual() {
110 let provider = MockEmbeddingProvider::new(32);
111 let batch = provider.embed_batch(&["a", "b", "c"]).unwrap();
112 let individual: Vec<Embedding> = ["a", "b", "c"]
113 .iter()
114 .map(|t| provider.embed(t).unwrap())
115 .collect();
116 assert_eq!(batch, individual);
117 }
118}