1#![allow(
2 clippy::unwrap_used,
3 clippy::expect_used,
4 reason = "mock infrastructure — panics are acceptable"
5)]
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::sync::{Arc, Mutex};
11
12use crate::engine::EmbeddingEngine;
13use crate::error::{EmbeddingError, EmbeddingResult};
14
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum MockVectorMode {
19 #[default]
21 Zero,
22 Deterministic,
26}
27
28pub struct MockEmbeddingEngine {
35 dimensions: usize,
36 batch_size: usize,
37 mode: MockVectorMode,
38 failure_after: Arc<Mutex<Option<usize>>>,
42 call_count: Arc<Mutex<usize>>,
44}
45
46impl MockEmbeddingEngine {
47 pub fn new(dimensions: usize) -> Self {
51 Self {
52 dimensions,
53 batch_size: 100,
54 mode: MockVectorMode::Zero,
55 failure_after: Arc::new(Mutex::new(None)),
56 call_count: Arc::new(Mutex::new(0)),
57 }
58 }
59
60 pub fn with_batch_size(dimensions: usize, batch_size: usize) -> Self {
64 Self {
65 dimensions,
66 batch_size,
67 mode: MockVectorMode::Zero,
68 failure_after: Arc::new(Mutex::new(None)),
69 call_count: Arc::new(Mutex::new(0)),
70 }
71 }
72
73 pub fn deterministic(dimensions: usize) -> Self {
76 Self {
77 dimensions,
78 batch_size: 100,
79 mode: MockVectorMode::Deterministic,
80 failure_after: Arc::new(Mutex::new(None)),
81 call_count: Arc::new(Mutex::new(0)),
82 }
83 }
84
85 pub fn with_mode(mut self, mode: MockVectorMode) -> Self {
87 self.mode = mode;
88 self
89 }
90
91 fn deterministic_vector(&self, text: &str) -> Vec<f32> {
98 let digest = Sha256::digest(text.as_bytes());
99 let len = digest.len();
100 let mut vec = Vec::with_capacity(self.dimensions);
101 for i in 0..self.dimensions {
102 let offset = (i * 4) % len;
103 let mut chunk = [0u8; 4];
107 let end = (offset + 4).min(len);
108 chunk[..end - offset].copy_from_slice(&digest[offset..end]);
109 let raw = f32::from_le_bytes(chunk);
110 let scaled = raw / 1e38_f32;
111 let val = if scaled.is_finite() {
112 scaled.clamp(-1.0, 1.0)
113 } else {
114 0.0
115 };
116 vec.push(val);
117 }
118 vec
119 }
120
121 pub fn set_failure_after(&self, n: usize) {
126 let mut slot = self.failure_after.lock().unwrap(); *slot = Some(n);
128 }
129}
130
131#[async_trait]
132impl EmbeddingEngine for MockEmbeddingEngine {
133 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
134 let count_after = {
136 let mut count = self.call_count.lock().unwrap(); *count += 1;
138 *count
139 };
140 let failure_threshold = {
141 let slot = self.failure_after.lock().unwrap(); *slot
143 };
144 if let Some(n) = failure_threshold
145 && count_after > n
146 {
147 return Err(EmbeddingError::InferenceError(format!(
148 "MockEmbeddingEngine: injected failure after {n} successful call(s)"
149 )));
150 }
151 match self.mode {
152 MockVectorMode::Zero => Ok(vec![vec![0.0_f32; self.dimensions]; texts.len()]),
153 MockVectorMode::Deterministic => {
154 Ok(texts.iter().map(|t| self.deterministic_vector(t)).collect())
155 }
156 }
157 }
158
159 fn dimension(&self) -> usize {
160 self.dimensions
161 }
162
163 fn batch_size(&self) -> usize {
164 self.batch_size
165 }
166
167 fn max_sequence_length(&self) -> usize {
168 usize::MAX
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[tokio::test]
177 async fn test_embed_returns_correct_count() {
178 let engine = MockEmbeddingEngine::new(384);
179 let texts = vec!["hello", "world", "foo"];
180 let embeddings = engine
181 .embed(&texts)
182 .await
183 .expect("embed must not fail for mock engine");
184 assert_eq!(embeddings.len(), texts.len());
185 }
186
187 #[tokio::test]
188 async fn test_embed_returns_correct_dimensions() {
189 let engine = MockEmbeddingEngine::new(512);
190 let texts = vec!["some text"];
191 let embeddings = engine
192 .embed(&texts)
193 .await
194 .expect("embed must not fail for mock engine");
195 assert_eq!(embeddings[0].len(), 512);
196 }
197
198 #[tokio::test]
199 async fn test_embed_returns_zero_vectors() {
200 let engine = MockEmbeddingEngine::new(128);
201 let texts = vec!["a", "b"];
202 let embeddings = engine
203 .embed(&texts)
204 .await
205 .expect("embed must not fail for mock engine");
206 for vec in &embeddings {
207 for &val in vec {
208 assert_eq!(val, 0.0_f32);
209 }
210 }
211 }
212
213 #[tokio::test]
214 async fn test_embed_empty_input() {
215 let engine = MockEmbeddingEngine::new(384);
216 let texts: Vec<&str> = vec![];
217 let embeddings = engine
218 .embed(&texts)
219 .await
220 .expect("embed must not fail for mock engine");
221 assert_eq!(embeddings.len(), 0);
222 }
223
224 #[test]
225 fn test_dimension() {
226 let engine = MockEmbeddingEngine::new(256);
227 assert_eq!(engine.dimension(), 256);
228 }
229
230 #[test]
231 fn test_batch_size_default() {
232 let engine = MockEmbeddingEngine::new(384);
233 assert_eq!(engine.batch_size(), 100);
234 }
235
236 #[test]
237 fn test_with_batch_size() {
238 let engine = MockEmbeddingEngine::with_batch_size(384, 50);
239 assert_eq!(engine.batch_size(), 50);
240 assert_eq!(engine.dimension(), 384);
241 }
242
243 #[test]
244 fn test_max_sequence_length() {
245 let engine = MockEmbeddingEngine::new(384);
246 assert_eq!(engine.max_sequence_length(), usize::MAX);
247 }
248
249 #[tokio::test]
250 async fn test_deterministic_same_input_identical() {
251 let engine = MockEmbeddingEngine::deterministic(384);
252 let a = engine
253 .embed(&["hello world"])
254 .await
255 .expect("embed must not fail for mock engine");
256 let b = engine
257 .embed(&["hello world"])
258 .await
259 .expect("embed must not fail for mock engine");
260 assert_eq!(a, b);
261 }
262
263 #[tokio::test]
264 async fn test_deterministic_different_inputs_differ() {
265 let engine = MockEmbeddingEngine::deterministic(384);
266 let out = engine
267 .embed(&["hello world", "goodbye world"])
268 .await
269 .expect("embed must not fail for mock engine");
270 assert_ne!(out[0], out[1]);
271 }
272
273 #[tokio::test]
274 async fn test_deterministic_finite_and_clamped() {
275 let engine = MockEmbeddingEngine::deterministic(512);
276 let out = engine
277 .embed(&["some representative text"])
278 .await
279 .expect("embed must not fail for mock engine");
280 assert_eq!(out[0].len(), 512);
281 for &val in &out[0] {
282 assert!(val.is_finite(), "component must be finite, got {val}");
283 assert!(
284 (-1.0..=1.0).contains(&val),
285 "component {val} out of [-1, 1]"
286 );
287 }
288 }
289
290 #[tokio::test]
291 async fn test_deterministic_dimensionality() {
292 let engine = MockEmbeddingEngine::deterministic(128);
293 let out = engine
294 .embed(&["abc"])
295 .await
296 .expect("embed must not fail for mock engine");
297 assert_eq!(out[0].len(), 128);
298 assert_eq!(engine.dimension(), 128);
299 }
300
301 #[tokio::test]
302 async fn test_with_mode_selects_deterministic() {
303 let engine = MockEmbeddingEngine::new(64).with_mode(MockVectorMode::Deterministic);
304 let out = engine
305 .embed(&["x"])
306 .await
307 .expect("embed must not fail for mock engine");
308 assert!(out[0].iter().any(|&v| v != 0.0));
310 }
311
312 #[tokio::test]
313 async fn test_zero_mode_still_returns_zeros() {
314 let engine = MockEmbeddingEngine::new(128);
316 let out = engine
317 .embed(&["a", "b"])
318 .await
319 .expect("embed must not fail for mock engine");
320 for vec in &out {
321 for &val in vec {
322 assert_eq!(val, 0.0_f32);
323 }
324 }
325 }
326}