1#![allow(
2 clippy::unwrap_used,
3 clippy::expect_used,
4 reason = "mock infrastructure — panics are acceptable"
5)]
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::sync::{Arc, Mutex};
11
12use crate::engine::EmbeddingEngine;
13use crate::error::{EmbeddingError, EmbeddingResult};
14
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum MockVectorMode {
19 #[default]
21 Zero,
22 Deterministic,
26}
27
28pub struct MockEmbeddingEngine {
35 dimensions: usize,
36 batch_size: usize,
37 mode: MockVectorMode,
38 failure_after: Arc<Mutex<Option<usize>>>,
42 call_count: Arc<Mutex<usize>>,
44 text_count: Arc<Mutex<usize>>,
46}
47
48impl MockEmbeddingEngine {
49 pub fn new(dimensions: usize) -> Self {
53 Self {
54 dimensions,
55 batch_size: 100,
56 mode: MockVectorMode::Zero,
57 failure_after: Arc::new(Mutex::new(None)),
58 call_count: Arc::new(Mutex::new(0)),
59 text_count: Arc::new(Mutex::new(0)),
60 }
61 }
62
63 pub fn with_batch_size(dimensions: usize, batch_size: usize) -> Self {
67 Self {
68 dimensions,
69 batch_size,
70 mode: MockVectorMode::Zero,
71 failure_after: Arc::new(Mutex::new(None)),
72 call_count: Arc::new(Mutex::new(0)),
73 text_count: Arc::new(Mutex::new(0)),
74 }
75 }
76
77 pub fn deterministic(dimensions: usize) -> Self {
80 Self {
81 dimensions,
82 batch_size: 100,
83 mode: MockVectorMode::Deterministic,
84 failure_after: Arc::new(Mutex::new(None)),
85 call_count: Arc::new(Mutex::new(0)),
86 text_count: Arc::new(Mutex::new(0)),
87 }
88 }
89
90 pub fn with_mode(mut self, mode: MockVectorMode) -> Self {
92 self.mode = mode;
93 self
94 }
95
96 fn deterministic_vector(&self, text: &str) -> Vec<f32> {
103 let digest = Sha256::digest(text.as_bytes());
104 let len = digest.len();
105 let mut vec = Vec::with_capacity(self.dimensions);
106 for i in 0..self.dimensions {
107 let offset = (i * 4) % len;
108 let mut chunk = [0u8; 4];
112 let end = (offset + 4).min(len);
113 chunk[..end - offset].copy_from_slice(&digest[offset..end]);
114 let raw = f32::from_le_bytes(chunk);
115 let scaled = raw / 1e38_f32;
116 let val = if scaled.is_finite() {
117 scaled.clamp(-1.0, 1.0)
118 } else {
119 0.0
120 };
121 vec.push(val);
122 }
123 vec
124 }
125
126 pub fn set_failure_after(&self, n: usize) {
131 let mut slot = self.failure_after.lock().unwrap(); *slot = Some(n);
133 }
134
135 pub fn call_count(&self) -> usize {
137 *self.call_count.lock().unwrap() }
139
140 pub fn embedded_text_count(&self) -> usize {
142 *self.text_count.lock().unwrap() }
144}
145
146#[async_trait]
147impl EmbeddingEngine for MockEmbeddingEngine {
148 async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
149 let count_after = {
151 let mut count = self.call_count.lock().unwrap(); *count += 1;
153 *count
154 };
155 {
156 let mut texts_seen = self.text_count.lock().unwrap(); *texts_seen += texts.len();
158 }
159 let failure_threshold = {
160 let slot = self.failure_after.lock().unwrap(); *slot
162 };
163 if let Some(n) = failure_threshold
164 && count_after > n
165 {
166 return Err(EmbeddingError::InferenceError(format!(
167 "MockEmbeddingEngine: injected failure after {n} successful call(s)"
168 )));
169 }
170 match self.mode {
171 MockVectorMode::Zero => Ok(vec![vec![0.0_f32; self.dimensions]; texts.len()]),
172 MockVectorMode::Deterministic => {
173 Ok(texts.iter().map(|t| self.deterministic_vector(t)).collect())
174 }
175 }
176 }
177
178 fn dimension(&self) -> usize {
179 self.dimensions
180 }
181
182 fn batch_size(&self) -> usize {
183 self.batch_size
184 }
185
186 fn max_sequence_length(&self) -> usize {
187 usize::MAX
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[tokio::test]
196 async fn test_embed_returns_correct_count() {
197 let engine = MockEmbeddingEngine::new(384);
198 let texts = vec!["hello", "world", "foo"];
199 let embeddings = engine
200 .embed(&texts)
201 .await
202 .expect("embed must not fail for mock engine");
203 assert_eq!(embeddings.len(), texts.len());
204 }
205
206 #[tokio::test]
207 async fn test_embed_returns_correct_dimensions() {
208 let engine = MockEmbeddingEngine::new(512);
209 let texts = vec!["some text"];
210 let embeddings = engine
211 .embed(&texts)
212 .await
213 .expect("embed must not fail for mock engine");
214 assert_eq!(embeddings[0].len(), 512);
215 }
216
217 #[tokio::test]
218 async fn test_embed_returns_zero_vectors() {
219 let engine = MockEmbeddingEngine::new(128);
220 let texts = vec!["a", "b"];
221 let embeddings = engine
222 .embed(&texts)
223 .await
224 .expect("embed must not fail for mock engine");
225 for vec in &embeddings {
226 for &val in vec {
227 assert_eq!(val, 0.0_f32);
228 }
229 }
230 }
231
232 #[tokio::test]
233 async fn test_embed_empty_input() {
234 let engine = MockEmbeddingEngine::new(384);
235 let texts: Vec<&str> = vec![];
236 let embeddings = engine
237 .embed(&texts)
238 .await
239 .expect("embed must not fail for mock engine");
240 assert_eq!(embeddings.len(), 0);
241 }
242
243 #[test]
244 fn test_dimension() {
245 let engine = MockEmbeddingEngine::new(256);
246 assert_eq!(engine.dimension(), 256);
247 }
248
249 #[test]
250 fn test_batch_size_default() {
251 let engine = MockEmbeddingEngine::new(384);
252 assert_eq!(engine.batch_size(), 100);
253 }
254
255 #[test]
256 fn test_with_batch_size() {
257 let engine = MockEmbeddingEngine::with_batch_size(384, 50);
258 assert_eq!(engine.batch_size(), 50);
259 assert_eq!(engine.dimension(), 384);
260 }
261
262 #[test]
263 fn test_max_sequence_length() {
264 let engine = MockEmbeddingEngine::new(384);
265 assert_eq!(engine.max_sequence_length(), usize::MAX);
266 }
267
268 #[tokio::test]
269 async fn test_deterministic_same_input_identical() {
270 let engine = MockEmbeddingEngine::deterministic(384);
271 let a = engine
272 .embed(&["hello world"])
273 .await
274 .expect("embed must not fail for mock engine");
275 let b = engine
276 .embed(&["hello world"])
277 .await
278 .expect("embed must not fail for mock engine");
279 assert_eq!(a, b);
280 }
281
282 #[tokio::test]
283 async fn test_deterministic_different_inputs_differ() {
284 let engine = MockEmbeddingEngine::deterministic(384);
285 let out = engine
286 .embed(&["hello world", "goodbye world"])
287 .await
288 .expect("embed must not fail for mock engine");
289 assert_ne!(out[0], out[1]);
290 }
291
292 #[tokio::test]
293 async fn test_deterministic_finite_and_clamped() {
294 let engine = MockEmbeddingEngine::deterministic(512);
295 let out = engine
296 .embed(&["some representative text"])
297 .await
298 .expect("embed must not fail for mock engine");
299 assert_eq!(out[0].len(), 512);
300 for &val in &out[0] {
301 assert!(val.is_finite(), "component must be finite, got {val}");
302 assert!(
303 (-1.0..=1.0).contains(&val),
304 "component {val} out of [-1, 1]"
305 );
306 }
307 }
308
309 #[tokio::test]
310 async fn test_deterministic_dimensionality() {
311 let engine = MockEmbeddingEngine::deterministic(128);
312 let out = engine
313 .embed(&["abc"])
314 .await
315 .expect("embed must not fail for mock engine");
316 assert_eq!(out[0].len(), 128);
317 assert_eq!(engine.dimension(), 128);
318 }
319
320 #[tokio::test]
321 async fn test_with_mode_selects_deterministic() {
322 let engine = MockEmbeddingEngine::new(64).with_mode(MockVectorMode::Deterministic);
323 let out = engine
324 .embed(&["x"])
325 .await
326 .expect("embed must not fail for mock engine");
327 assert!(out[0].iter().any(|&v| v != 0.0));
329 }
330
331 #[tokio::test]
332 async fn test_zero_mode_still_returns_zeros() {
333 let engine = MockEmbeddingEngine::new(128);
335 let out = engine
336 .embed(&["a", "b"])
337 .await
338 .expect("embed must not fail for mock engine");
339 for vec in &out {
340 for &val in vec {
341 assert_eq!(val, 0.0_f32);
342 }
343 }
344 }
345}