noether_engine/index/
embedding.rs1#![warn(clippy::unwrap_used)]
2#![cfg_attr(test, allow(clippy::unwrap_used))]
3
4use sha2::{Digest, Sha256};
5
6pub type Embedding = Vec<f32>;
7
8#[derive(Debug, thiserror::Error)]
9pub enum EmbeddingError {
10 #[error("embedding provider error: {0}")]
11 Provider(String),
12}
13
14pub trait EmbeddingProvider: Send + Sync {
16 fn dimensions(&self) -> usize;
17 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError>;
18
19 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
20 texts.iter().map(|t| self.embed(t)).collect()
21 }
22}
23
24pub struct MockEmbeddingProvider {
30 dimensions: usize,
31}
32
33impl MockEmbeddingProvider {
34 pub fn new(dimensions: usize) -> Self {
35 Self { dimensions }
36 }
37}
38
39impl EmbeddingProvider for MockEmbeddingProvider {
40 fn dimensions(&self) -> usize {
41 self.dimensions
42 }
43
44 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
45 let mut bytes = Vec::with_capacity(self.dimensions);
47 let mut current = Sha256::digest(text.as_bytes()).to_vec();
48 while bytes.len() < self.dimensions {
49 for &b in ¤t {
50 if bytes.len() >= self.dimensions {
51 break;
52 }
53 bytes.push(b);
54 }
55 current = Sha256::digest(¤t).to_vec();
57 }
58
59 let mut vec: Vec<f32> = bytes[..self.dimensions]
60 .iter()
61 .map(|&b| (b as f32 / 127.5) - 1.0)
62 .collect();
63
64 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
66 if norm > 0.0 {
67 for v in &mut vec {
68 *v /= norm;
69 }
70 }
71
72 Ok(vec)
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn mock_produces_correct_dimensions() {
82 let provider = MockEmbeddingProvider::new(64);
83 let emb = provider.embed("hello").unwrap();
84 assert_eq!(emb.len(), 64);
85 }
86
87 #[test]
88 fn mock_is_deterministic() {
89 let provider = MockEmbeddingProvider::new(32);
90 let e1 = provider.embed("hello world").unwrap();
91 let e2 = provider.embed("hello world").unwrap();
92 assert_eq!(e1, e2);
93 }
94
95 #[test]
96 fn mock_different_text_different_embedding() {
97 let provider = MockEmbeddingProvider::new(32);
98 let e1 = provider.embed("hello").unwrap();
99 let e2 = provider.embed("world").unwrap();
100 assert_ne!(e1, e2);
101 }
102
103 #[test]
104 fn mock_embeddings_are_normalized() {
105 let provider = MockEmbeddingProvider::new(128);
106 let emb = provider.embed("test text").unwrap();
107 let norm: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
108 assert!((norm - 1.0).abs() < 1e-5, "norm should be ~1.0, got {norm}");
109 }
110
111 #[test]
112 fn mock_batch_matches_individual() {
113 let provider = MockEmbeddingProvider::new(32);
114 let batch = provider.embed_batch(&["a", "b", "c"]).unwrap();
115 let individual: Vec<Embedding> = ["a", "b", "c"]
116 .iter()
117 .map(|t| provider.embed(t).unwrap())
118 .collect();
119 assert_eq!(batch, individual);
120 }
121}