Skip to main content

cognee_embedding/
mock.rs

1#![allow(
2    clippy::unwrap_used,
3    clippy::expect_used,
4    reason = "mock infrastructure — panics are acceptable"
5)]
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::sync::{Arc, Mutex};
11
12use crate::engine::EmbeddingEngine;
13use crate::error::{EmbeddingError, EmbeddingResult};
14
15/// Controls how [`MockEmbeddingEngine`] produces vector components.
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum MockVectorMode {
19    /// Every component is `0.0` (default; preserves legacy test behavior).
20    #[default]
21    Zero,
22    /// Components are derived deterministically from `sha256(text)`, mirroring
23    /// the Python benchmark mock so that the same text always yields the same
24    /// vector and similar text yields stable neighbors.
25    Deterministic,
26}
27
28/// A mock embedding engine.
29///
30/// Useful for testing pipeline stages that depend on an `EmbeddingEngine`
31/// without requiring a real model or network connection. By default it returns
32/// zero vectors ([`MockVectorMode::Zero`]); in [`MockVectorMode::Deterministic`]
33/// it derives content-stable vectors from `sha256(text)`.
34pub struct MockEmbeddingEngine {
35    dimensions: usize,
36    batch_size: usize,
37    mode: MockVectorMode,
38    /// When `Some(n)`, the `n+1`-th call to `embed` (and every subsequent call)
39    /// returns an `EmbeddingError::InferenceError`. `set_failure_after(0)` causes
40    /// the very first call to fail.
41    failure_after: Arc<Mutex<Option<usize>>>,
42    /// Number of `embed` invocations observed.
43    call_count: Arc<Mutex<usize>>,
44    /// Total number of texts passed across all `embed` invocations.
45    text_count: Arc<Mutex<usize>>,
46}
47
48impl MockEmbeddingEngine {
49    /// Create a mock engine with the given output dimensionality and a default batch size of 100.
50    ///
51    /// Defaults to [`MockVectorMode::Zero`].
52    pub fn new(dimensions: usize) -> Self {
53        Self {
54            dimensions,
55            batch_size: 100,
56            mode: MockVectorMode::Zero,
57            failure_after: Arc::new(Mutex::new(None)),
58            call_count: Arc::new(Mutex::new(0)),
59            text_count: Arc::new(Mutex::new(0)),
60        }
61    }
62
63    /// Create a mock engine with explicit dimensionality and batch size.
64    ///
65    /// Defaults to [`MockVectorMode::Zero`].
66    pub fn with_batch_size(dimensions: usize, batch_size: usize) -> Self {
67        Self {
68            dimensions,
69            batch_size,
70            mode: MockVectorMode::Zero,
71            failure_after: Arc::new(Mutex::new(None)),
72            call_count: Arc::new(Mutex::new(0)),
73            text_count: Arc::new(Mutex::new(0)),
74        }
75    }
76
77    /// Create a mock engine that produces deterministic, content-stable vectors
78    /// derived from `sha256(text)` (see [`MockVectorMode::Deterministic`]).
79    pub fn deterministic(dimensions: usize) -> Self {
80        Self {
81            dimensions,
82            batch_size: 100,
83            mode: MockVectorMode::Deterministic,
84            failure_after: Arc::new(Mutex::new(None)),
85            call_count: Arc::new(Mutex::new(0)),
86            text_count: Arc::new(Mutex::new(0)),
87        }
88    }
89
90    /// Override the vector-generation mode, consuming and returning `self`.
91    pub fn with_mode(mut self, mode: MockVectorMode) -> Self {
92        self.mode = mode;
93        self
94    }
95
96    /// Compute a single deterministic vector from `text`, mirroring the Python
97    /// benchmark mock: `sha256(text)` digest, little-endian `f32` windows of the
98    /// digest, scaled by `1e38` and clamped to `[-1.0, 1.0]`.
99    ///
100    /// Non-finite values (NaN/inf) produced by `f32::from_le_bytes` are mapped to
101    /// `0.0` so downstream cosine math stays well-defined.
102    fn deterministic_vector(&self, text: &str) -> Vec<f32> {
103        let digest = Sha256::digest(text.as_bytes());
104        let len = digest.len();
105        let mut vec = Vec::with_capacity(self.dimensions);
106        for i in 0..self.dimensions {
107            let offset = (i * 4) % len;
108            // Right-pad with 0x00 to a full 4-byte window (matches Python's
109            // `ljust(4, b"\x00")`); `offset` is always a multiple of 4 < len, so
110            // this only matters defensively.
111            let mut chunk = [0u8; 4];
112            let end = (offset + 4).min(len);
113            chunk[..end - offset].copy_from_slice(&digest[offset..end]);
114            let raw = f32::from_le_bytes(chunk);
115            let scaled = raw / 1e38_f32;
116            let val = if scaled.is_finite() {
117                scaled.clamp(-1.0, 1.0)
118            } else {
119                0.0
120            };
121            vec.push(val);
122        }
123        vec
124    }
125
126    /// Configure the engine to fail after `n` successful `embed` calls.
127    ///
128    /// With `n = 0`, the very first call fails. With `n = 3`, the first three
129    /// calls succeed and the fourth and beyond fail.
130    pub fn set_failure_after(&self, n: usize) {
131        let mut slot = self.failure_after.lock().unwrap(); // lock poison is unrecoverable
132        *slot = Some(n);
133    }
134
135    /// Number of `embed` invocations observed so far (one per batched call).
136    pub fn call_count(&self) -> usize {
137        *self.call_count.lock().unwrap() // lock poison is unrecoverable
138    }
139
140    /// Total number of texts embedded across all `embed` invocations.
141    pub fn embedded_text_count(&self) -> usize {
142        *self.text_count.lock().unwrap() // lock poison is unrecoverable
143    }
144}
145
146#[async_trait]
147impl EmbeddingEngine for MockEmbeddingEngine {
148    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
149        // Track call count and optionally inject failure.
150        let count_after = {
151            let mut count = self.call_count.lock().unwrap(); // lock poison is unrecoverable
152            *count += 1;
153            *count
154        };
155        {
156            let mut texts_seen = self.text_count.lock().unwrap(); // lock poison is unrecoverable
157            *texts_seen += texts.len();
158        }
159        let failure_threshold = {
160            let slot = self.failure_after.lock().unwrap(); // lock poison is unrecoverable
161            *slot
162        };
163        if let Some(n) = failure_threshold
164            && count_after > n
165        {
166            return Err(EmbeddingError::InferenceError(format!(
167                "MockEmbeddingEngine: injected failure after {n} successful call(s)"
168            )));
169        }
170        match self.mode {
171            MockVectorMode::Zero => Ok(vec![vec![0.0_f32; self.dimensions]; texts.len()]),
172            MockVectorMode::Deterministic => {
173                Ok(texts.iter().map(|t| self.deterministic_vector(t)).collect())
174            }
175        }
176    }
177
178    fn dimension(&self) -> usize {
179        self.dimensions
180    }
181
182    fn batch_size(&self) -> usize {
183        self.batch_size
184    }
185
186    fn max_sequence_length(&self) -> usize {
187        usize::MAX
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[tokio::test]
196    async fn test_embed_returns_correct_count() {
197        let engine = MockEmbeddingEngine::new(384);
198        let texts = vec!["hello", "world", "foo"];
199        let embeddings = engine
200            .embed(&texts)
201            .await
202            .expect("embed must not fail for mock engine");
203        assert_eq!(embeddings.len(), texts.len());
204    }
205
206    #[tokio::test]
207    async fn test_embed_returns_correct_dimensions() {
208        let engine = MockEmbeddingEngine::new(512);
209        let texts = vec!["some text"];
210        let embeddings = engine
211            .embed(&texts)
212            .await
213            .expect("embed must not fail for mock engine");
214        assert_eq!(embeddings[0].len(), 512);
215    }
216
217    #[tokio::test]
218    async fn test_embed_returns_zero_vectors() {
219        let engine = MockEmbeddingEngine::new(128);
220        let texts = vec!["a", "b"];
221        let embeddings = engine
222            .embed(&texts)
223            .await
224            .expect("embed must not fail for mock engine");
225        for vec in &embeddings {
226            for &val in vec {
227                assert_eq!(val, 0.0_f32);
228            }
229        }
230    }
231
232    #[tokio::test]
233    async fn test_embed_empty_input() {
234        let engine = MockEmbeddingEngine::new(384);
235        let texts: Vec<&str> = vec![];
236        let embeddings = engine
237            .embed(&texts)
238            .await
239            .expect("embed must not fail for mock engine");
240        assert_eq!(embeddings.len(), 0);
241    }
242
243    #[test]
244    fn test_dimension() {
245        let engine = MockEmbeddingEngine::new(256);
246        assert_eq!(engine.dimension(), 256);
247    }
248
249    #[test]
250    fn test_batch_size_default() {
251        let engine = MockEmbeddingEngine::new(384);
252        assert_eq!(engine.batch_size(), 100);
253    }
254
255    #[test]
256    fn test_with_batch_size() {
257        let engine = MockEmbeddingEngine::with_batch_size(384, 50);
258        assert_eq!(engine.batch_size(), 50);
259        assert_eq!(engine.dimension(), 384);
260    }
261
262    #[test]
263    fn test_max_sequence_length() {
264        let engine = MockEmbeddingEngine::new(384);
265        assert_eq!(engine.max_sequence_length(), usize::MAX);
266    }
267
268    #[tokio::test]
269    async fn test_deterministic_same_input_identical() {
270        let engine = MockEmbeddingEngine::deterministic(384);
271        let a = engine
272            .embed(&["hello world"])
273            .await
274            .expect("embed must not fail for mock engine");
275        let b = engine
276            .embed(&["hello world"])
277            .await
278            .expect("embed must not fail for mock engine");
279        assert_eq!(a, b);
280    }
281
282    #[tokio::test]
283    async fn test_deterministic_different_inputs_differ() {
284        let engine = MockEmbeddingEngine::deterministic(384);
285        let out = engine
286            .embed(&["hello world", "goodbye world"])
287            .await
288            .expect("embed must not fail for mock engine");
289        assert_ne!(out[0], out[1]);
290    }
291
292    #[tokio::test]
293    async fn test_deterministic_finite_and_clamped() {
294        let engine = MockEmbeddingEngine::deterministic(512);
295        let out = engine
296            .embed(&["some representative text"])
297            .await
298            .expect("embed must not fail for mock engine");
299        assert_eq!(out[0].len(), 512);
300        for &val in &out[0] {
301            assert!(val.is_finite(), "component must be finite, got {val}");
302            assert!(
303                (-1.0..=1.0).contains(&val),
304                "component {val} out of [-1, 1]"
305            );
306        }
307    }
308
309    #[tokio::test]
310    async fn test_deterministic_dimensionality() {
311        let engine = MockEmbeddingEngine::deterministic(128);
312        let out = engine
313            .embed(&["abc"])
314            .await
315            .expect("embed must not fail for mock engine");
316        assert_eq!(out[0].len(), 128);
317        assert_eq!(engine.dimension(), 128);
318    }
319
320    #[tokio::test]
321    async fn test_with_mode_selects_deterministic() {
322        let engine = MockEmbeddingEngine::new(64).with_mode(MockVectorMode::Deterministic);
323        let out = engine
324            .embed(&["x"])
325            .await
326            .expect("embed must not fail for mock engine");
327        // Deterministic vectors are not all-zero for typical inputs.
328        assert!(out[0].iter().any(|&v| v != 0.0));
329    }
330
331    #[tokio::test]
332    async fn test_zero_mode_still_returns_zeros() {
333        // Regression guard: default mode must remain zero vectors.
334        let engine = MockEmbeddingEngine::new(128);
335        let out = engine
336            .embed(&["a", "b"])
337            .await
338            .expect("embed must not fail for mock engine");
339        for vec in &out {
340            for &val in vec {
341                assert_eq!(val, 0.0_f32);
342            }
343        }
344    }
345}