Skip to main content

cognee_embedding/
mock.rs

1#![allow(
2    clippy::unwrap_used,
3    clippy::expect_used,
4    reason = "mock infrastructure — panics are acceptable"
5)]
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use sha2::{Digest, Sha256};
10use std::sync::{Arc, Mutex};
11
12use crate::engine::EmbeddingEngine;
13use crate::error::{EmbeddingError, EmbeddingResult};
14
15/// Controls how [`MockEmbeddingEngine`] produces vector components.
16#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
17#[serde(rename_all = "lowercase")]
18pub enum MockVectorMode {
19    /// Every component is `0.0` (default; preserves legacy test behavior).
20    #[default]
21    Zero,
22    /// Components are derived deterministically from `sha256(text)`, mirroring
23    /// the Python benchmark mock so that the same text always yields the same
24    /// vector and similar text yields stable neighbors.
25    Deterministic,
26}
27
28/// A mock embedding engine.
29///
30/// Useful for testing pipeline stages that depend on an `EmbeddingEngine`
31/// without requiring a real model or network connection. By default it returns
32/// zero vectors ([`MockVectorMode::Zero`]); in [`MockVectorMode::Deterministic`]
33/// it derives content-stable vectors from `sha256(text)`.
34pub struct MockEmbeddingEngine {
35    dimensions: usize,
36    batch_size: usize,
37    mode: MockVectorMode,
38    /// When `Some(n)`, the `n+1`-th call to `embed` (and every subsequent call)
39    /// returns an `EmbeddingError::InferenceError`. `set_failure_after(0)` causes
40    /// the very first call to fail.
41    failure_after: Arc<Mutex<Option<usize>>>,
42    /// Number of `embed` invocations observed.
43    call_count: Arc<Mutex<usize>>,
44}
45
46impl MockEmbeddingEngine {
47    /// Create a mock engine with the given output dimensionality and a default batch size of 100.
48    ///
49    /// Defaults to [`MockVectorMode::Zero`].
50    pub fn new(dimensions: usize) -> Self {
51        Self {
52            dimensions,
53            batch_size: 100,
54            mode: MockVectorMode::Zero,
55            failure_after: Arc::new(Mutex::new(None)),
56            call_count: Arc::new(Mutex::new(0)),
57        }
58    }
59
60    /// Create a mock engine with explicit dimensionality and batch size.
61    ///
62    /// Defaults to [`MockVectorMode::Zero`].
63    pub fn with_batch_size(dimensions: usize, batch_size: usize) -> Self {
64        Self {
65            dimensions,
66            batch_size,
67            mode: MockVectorMode::Zero,
68            failure_after: Arc::new(Mutex::new(None)),
69            call_count: Arc::new(Mutex::new(0)),
70        }
71    }
72
73    /// Create a mock engine that produces deterministic, content-stable vectors
74    /// derived from `sha256(text)` (see [`MockVectorMode::Deterministic`]).
75    pub fn deterministic(dimensions: usize) -> Self {
76        Self {
77            dimensions,
78            batch_size: 100,
79            mode: MockVectorMode::Deterministic,
80            failure_after: Arc::new(Mutex::new(None)),
81            call_count: Arc::new(Mutex::new(0)),
82        }
83    }
84
85    /// Override the vector-generation mode, consuming and returning `self`.
86    pub fn with_mode(mut self, mode: MockVectorMode) -> Self {
87        self.mode = mode;
88        self
89    }
90
91    /// Compute a single deterministic vector from `text`, mirroring the Python
92    /// benchmark mock: `sha256(text)` digest, little-endian `f32` windows of the
93    /// digest, scaled by `1e38` and clamped to `[-1.0, 1.0]`.
94    ///
95    /// Non-finite values (NaN/inf) produced by `f32::from_le_bytes` are mapped to
96    /// `0.0` so downstream cosine math stays well-defined.
97    fn deterministic_vector(&self, text: &str) -> Vec<f32> {
98        let digest = Sha256::digest(text.as_bytes());
99        let len = digest.len();
100        let mut vec = Vec::with_capacity(self.dimensions);
101        for i in 0..self.dimensions {
102            let offset = (i * 4) % len;
103            // Right-pad with 0x00 to a full 4-byte window (matches Python's
104            // `ljust(4, b"\x00")`); `offset` is always a multiple of 4 < len, so
105            // this only matters defensively.
106            let mut chunk = [0u8; 4];
107            let end = (offset + 4).min(len);
108            chunk[..end - offset].copy_from_slice(&digest[offset..end]);
109            let raw = f32::from_le_bytes(chunk);
110            let scaled = raw / 1e38_f32;
111            let val = if scaled.is_finite() {
112                scaled.clamp(-1.0, 1.0)
113            } else {
114                0.0
115            };
116            vec.push(val);
117        }
118        vec
119    }
120
121    /// Configure the engine to fail after `n` successful `embed` calls.
122    ///
123    /// With `n = 0`, the very first call fails. With `n = 3`, the first three
124    /// calls succeed and the fourth and beyond fail.
125    pub fn set_failure_after(&self, n: usize) {
126        let mut slot = self.failure_after.lock().unwrap(); // lock poison is unrecoverable
127        *slot = Some(n);
128    }
129}
130
131#[async_trait]
132impl EmbeddingEngine for MockEmbeddingEngine {
133    async fn embed(&self, texts: &[&str]) -> EmbeddingResult<Vec<Vec<f32>>> {
134        // Track call count and optionally inject failure.
135        let count_after = {
136            let mut count = self.call_count.lock().unwrap(); // lock poison is unrecoverable
137            *count += 1;
138            *count
139        };
140        let failure_threshold = {
141            let slot = self.failure_after.lock().unwrap(); // lock poison is unrecoverable
142            *slot
143        };
144        if let Some(n) = failure_threshold
145            && count_after > n
146        {
147            return Err(EmbeddingError::InferenceError(format!(
148                "MockEmbeddingEngine: injected failure after {n} successful call(s)"
149            )));
150        }
151        match self.mode {
152            MockVectorMode::Zero => Ok(vec![vec![0.0_f32; self.dimensions]; texts.len()]),
153            MockVectorMode::Deterministic => {
154                Ok(texts.iter().map(|t| self.deterministic_vector(t)).collect())
155            }
156        }
157    }
158
159    fn dimension(&self) -> usize {
160        self.dimensions
161    }
162
163    fn batch_size(&self) -> usize {
164        self.batch_size
165    }
166
167    fn max_sequence_length(&self) -> usize {
168        usize::MAX
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[tokio::test]
177    async fn test_embed_returns_correct_count() {
178        let engine = MockEmbeddingEngine::new(384);
179        let texts = vec!["hello", "world", "foo"];
180        let embeddings = engine
181            .embed(&texts)
182            .await
183            .expect("embed must not fail for mock engine");
184        assert_eq!(embeddings.len(), texts.len());
185    }
186
187    #[tokio::test]
188    async fn test_embed_returns_correct_dimensions() {
189        let engine = MockEmbeddingEngine::new(512);
190        let texts = vec!["some text"];
191        let embeddings = engine
192            .embed(&texts)
193            .await
194            .expect("embed must not fail for mock engine");
195        assert_eq!(embeddings[0].len(), 512);
196    }
197
198    #[tokio::test]
199    async fn test_embed_returns_zero_vectors() {
200        let engine = MockEmbeddingEngine::new(128);
201        let texts = vec!["a", "b"];
202        let embeddings = engine
203            .embed(&texts)
204            .await
205            .expect("embed must not fail for mock engine");
206        for vec in &embeddings {
207            for &val in vec {
208                assert_eq!(val, 0.0_f32);
209            }
210        }
211    }
212
213    #[tokio::test]
214    async fn test_embed_empty_input() {
215        let engine = MockEmbeddingEngine::new(384);
216        let texts: Vec<&str> = vec![];
217        let embeddings = engine
218            .embed(&texts)
219            .await
220            .expect("embed must not fail for mock engine");
221        assert_eq!(embeddings.len(), 0);
222    }
223
224    #[test]
225    fn test_dimension() {
226        let engine = MockEmbeddingEngine::new(256);
227        assert_eq!(engine.dimension(), 256);
228    }
229
230    #[test]
231    fn test_batch_size_default() {
232        let engine = MockEmbeddingEngine::new(384);
233        assert_eq!(engine.batch_size(), 100);
234    }
235
236    #[test]
237    fn test_with_batch_size() {
238        let engine = MockEmbeddingEngine::with_batch_size(384, 50);
239        assert_eq!(engine.batch_size(), 50);
240        assert_eq!(engine.dimension(), 384);
241    }
242
243    #[test]
244    fn test_max_sequence_length() {
245        let engine = MockEmbeddingEngine::new(384);
246        assert_eq!(engine.max_sequence_length(), usize::MAX);
247    }
248
249    #[tokio::test]
250    async fn test_deterministic_same_input_identical() {
251        let engine = MockEmbeddingEngine::deterministic(384);
252        let a = engine
253            .embed(&["hello world"])
254            .await
255            .expect("embed must not fail for mock engine");
256        let b = engine
257            .embed(&["hello world"])
258            .await
259            .expect("embed must not fail for mock engine");
260        assert_eq!(a, b);
261    }
262
263    #[tokio::test]
264    async fn test_deterministic_different_inputs_differ() {
265        let engine = MockEmbeddingEngine::deterministic(384);
266        let out = engine
267            .embed(&["hello world", "goodbye world"])
268            .await
269            .expect("embed must not fail for mock engine");
270        assert_ne!(out[0], out[1]);
271    }
272
273    #[tokio::test]
274    async fn test_deterministic_finite_and_clamped() {
275        let engine = MockEmbeddingEngine::deterministic(512);
276        let out = engine
277            .embed(&["some representative text"])
278            .await
279            .expect("embed must not fail for mock engine");
280        assert_eq!(out[0].len(), 512);
281        for &val in &out[0] {
282            assert!(val.is_finite(), "component must be finite, got {val}");
283            assert!(
284                (-1.0..=1.0).contains(&val),
285                "component {val} out of [-1, 1]"
286            );
287        }
288    }
289
290    #[tokio::test]
291    async fn test_deterministic_dimensionality() {
292        let engine = MockEmbeddingEngine::deterministic(128);
293        let out = engine
294            .embed(&["abc"])
295            .await
296            .expect("embed must not fail for mock engine");
297        assert_eq!(out[0].len(), 128);
298        assert_eq!(engine.dimension(), 128);
299    }
300
301    #[tokio::test]
302    async fn test_with_mode_selects_deterministic() {
303        let engine = MockEmbeddingEngine::new(64).with_mode(MockVectorMode::Deterministic);
304        let out = engine
305            .embed(&["x"])
306            .await
307            .expect("embed must not fail for mock engine");
308        // Deterministic vectors are not all-zero for typical inputs.
309        assert!(out[0].iter().any(|&v| v != 0.0));
310    }
311
312    #[tokio::test]
313    async fn test_zero_mode_still_returns_zeros() {
314        // Regression guard: default mode must remain zero vectors.
315        let engine = MockEmbeddingEngine::new(128);
316        let out = engine
317            .embed(&["a", "b"])
318            .await
319            .expect("embed must not fail for mock engine");
320        for vec in &out {
321            for &val in vec {
322                assert_eq!(val, 0.0_f32);
323            }
324        }
325    }
326}