mem0_rust/embeddings/
mock.rs

1//! Mock embedder for testing.
2//!
3//! Uses hash-based embeddings that are deterministic but not semantic.
4
5use async_trait::async_trait;
6use std::collections::hash_map::DefaultHasher;
7use std::hash::{Hash, Hasher};
8
9use super::traits::Embedder;
10use crate::errors::EmbeddingError;
11
12/// Hash-based mock embedder for testing
13pub struct MockEmbedder {
14    dimensions: usize,
15}
16
17impl MockEmbedder {
18    /// Create a new mock embedder with the specified dimensions
19    pub fn new(dimensions: usize) -> Self {
20        Self {
21            dimensions: dimensions.max(1),
22        }
23    }
24}
25
26#[async_trait]
27impl Embedder for MockEmbedder {
28    async fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
29        let mut vector = vec![0.0f32; self.dimensions];
30
31        for token in text.split_whitespace() {
32            let mut hasher = DefaultHasher::new();
33            token.to_lowercase().hash(&mut hasher);
34            let hash = hasher.finish();
35            let idx = (hash as usize) % self.dimensions;
36            let sign = if hash & 1 == 0 { 1.0 } else { -1.0 };
37            let magnitude = 1.0 + ((hash >> 1) as f32 / u64::MAX as f32);
38            vector[idx] += sign * magnitude;
39        }
40
41        // Normalize
42        let norm: f32 = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
43        if norm > 0.0 {
44            for value in &mut vector {
45                *value /= norm;
46            }
47        }
48
49        Ok(vector)
50    }
51
52    fn dimensions(&self) -> usize {
53        self.dimensions
54    }
55
56    fn model_name(&self) -> &str {
57        "mock-hash-embedder"
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[tokio::test]
66    async fn test_mock_embedder() {
67        let embedder = MockEmbedder::new(128);
68        let embedding = embedder.embed("Hello world").await.unwrap();
69        assert_eq!(embedding.len(), 128);
70
71        // Check normalization
72        let norm: f32 = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
73        assert!((norm - 1.0).abs() < 0.01);
74    }
75
76    #[tokio::test]
77    async fn test_deterministic() {
78        let embedder = MockEmbedder::new(64);
79        let e1 = embedder.embed("test").await.unwrap();
80        let e2 = embedder.embed("test").await.unwrap();
81        assert_eq!(e1, e2);
82    }
83}