fabryk_vector/
embedding.rs1use async_trait::async_trait;
12use fabryk_core::Result;
13
14#[async_trait]
25pub trait EmbeddingProvider: Send + Sync {
26 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
28
29 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
34 let mut results = Vec::with_capacity(texts.len());
35 for text in texts {
36 results.push(self.embed(text).await?);
37 }
38 Ok(results)
39 }
40
41 fn dimension(&self) -> usize;
43
44 fn name(&self) -> &str;
46}
47
48pub struct MockEmbeddingProvider {
54 dimension: usize,
55}
56
57impl MockEmbeddingProvider {
58 pub fn new(dimension: usize) -> Self {
60 Self { dimension }
61 }
62
63 fn deterministic_embedding(&self, text: &str) -> Vec<f32> {
65 let mut embedding = vec![0.0f32; self.dimension];
66 let bytes = text.as_bytes();
67
68 for (i, val) in embedding.iter_mut().enumerate() {
69 let byte_idx = i % bytes.len().max(1);
71 let byte_val = if bytes.is_empty() {
72 0u8
73 } else {
74 bytes[byte_idx]
75 };
76 *val = ((byte_val as f32 + i as f32) % 256.0) / 256.0;
77 }
78
79 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
81 if norm > 0.0 {
82 for val in &mut embedding {
83 *val /= norm;
84 }
85 }
86
87 embedding
88 }
89}
90
91#[async_trait]
92impl EmbeddingProvider for MockEmbeddingProvider {
93 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
94 Ok(self.deterministic_embedding(text))
95 }
96
97 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
98 Ok(texts
99 .iter()
100 .map(|t| self.deterministic_embedding(t))
101 .collect())
102 }
103
104 fn dimension(&self) -> usize {
105 self.dimension
106 }
107
108 fn name(&self) -> &str {
109 "mock"
110 }
111}
112
113#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn test_mock_provider_creation() {
123 let provider = MockEmbeddingProvider::new(384);
124 assert_eq!(provider.dimension(), 384);
125 assert_eq!(provider.name(), "mock");
126 }
127
128 #[tokio::test]
129 async fn test_mock_embed_single() {
130 let provider = MockEmbeddingProvider::new(8);
131 let embedding = provider.embed("hello world").await.unwrap();
132
133 assert_eq!(embedding.len(), 8);
134
135 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
137 assert!((norm - 1.0).abs() < 1e-5);
138 }
139
140 #[tokio::test]
141 async fn test_mock_embed_deterministic() {
142 let provider = MockEmbeddingProvider::new(16);
143 let e1 = provider.embed("same text").await.unwrap();
144 let e2 = provider.embed("same text").await.unwrap();
145
146 assert_eq!(e1, e2);
147 }
148
149 #[tokio::test]
150 async fn test_mock_embed_different_texts() {
151 let provider = MockEmbeddingProvider::new(16);
152 let e1 = provider.embed("text one").await.unwrap();
153 let e2 = provider.embed("text two").await.unwrap();
154
155 assert_ne!(e1, e2);
156 }
157
158 #[tokio::test]
159 async fn test_mock_embed_batch() {
160 let provider = MockEmbeddingProvider::new(8);
161 let texts = vec!["hello", "world", "test"];
162 let embeddings = provider.embed_batch(&texts).await.unwrap();
163
164 assert_eq!(embeddings.len(), 3);
165 for emb in &embeddings {
166 assert_eq!(emb.len(), 8);
167 }
168 }
169
170 #[tokio::test]
171 async fn test_mock_embed_empty_text() {
172 let provider = MockEmbeddingProvider::new(4);
173 let embedding = provider.embed("").await.unwrap();
174
175 assert_eq!(embedding.len(), 4);
176 }
179
180 #[tokio::test]
181 async fn test_mock_embed_batch_empty() {
182 let provider = MockEmbeddingProvider::new(4);
183 let texts: Vec<&str> = vec![];
184 let embeddings = provider.embed_batch(&texts).await.unwrap();
185
186 assert!(embeddings.is_empty());
187 }
188
189 #[test]
190 fn test_trait_object_safety() {
191 fn _assert_object_safe(_: &dyn EmbeddingProvider) {}
193 }
194}