context_mcp/
embeddings.rs

1//! Mock embedding generation traits and implementations
2//!
3//! This module provides trait definitions for embedding generation that will
4//! be replaced with real implementations (candle, embeddenator-vsa) in future PRs.
5
6use crate::error::Result;
7use async_trait::async_trait;
8
9/// Trait for generating embeddings from text
10#[async_trait]
11pub trait EmbeddingGenerator: Send + Sync {
12    /// Generate an embedding vector from text
13    async fn generate(&self, text: &str) -> Result<Vec<f32>>;
14
15    /// Get the dimension of embeddings produced by this generator
16    fn dimension(&self) -> usize;
17}
18
19/// Mock embedding generator for testing and development
20pub struct MockEmbeddingGenerator {
21    dimension: usize,
22}
23
24impl MockEmbeddingGenerator {
25    pub fn new(dimension: usize) -> Self {
26        Self { dimension }
27    }
28}
29
30#[async_trait]
31impl EmbeddingGenerator for MockEmbeddingGenerator {
32    async fn generate(&self, text: &str) -> Result<Vec<f32>> {
33        use std::collections::hash_map::DefaultHasher;
34        use std::hash::{Hash, Hasher};
35
36        // Generate deterministic embedding from text hash
37        let mut hasher = DefaultHasher::new();
38        text.hash(&mut hasher);
39        let hash = hasher.finish();
40
41        let mut embedding = Vec::with_capacity(self.dimension);
42        for i in 0..self.dimension {
43            let value = ((hash.wrapping_mul(i as u64 + 1)) as f32) / (u64::MAX as f32);
44            embedding.push(value);
45        }
46
47        // Normalize the vector
48        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
49        if norm > 0.0 {
50            for val in embedding.iter_mut() {
51                *val /= norm;
52            }
53        }
54
55        Ok(embedding)
56    }
57
58    fn dimension(&self) -> usize {
59        self.dimension
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[tokio::test]
68    async fn test_mock_embedding_deterministic() {
69        let generator = MockEmbeddingGenerator::new(384);
70
71        let emb1 = generator.generate("test text").await.unwrap();
72        let emb2 = generator.generate("test text").await.unwrap();
73
74        assert_eq!(emb1.len(), 384);
75        assert_eq!(emb1, emb2); // Should be deterministic
76    }
77
78    #[tokio::test]
79    async fn test_mock_embedding_normalized() {
80        let generator = MockEmbeddingGenerator::new(384);
81        let embedding = generator.generate("test").await.unwrap();
82
83        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
84        assert!((norm - 1.0).abs() < 0.001); // Should be unit length
85    }
86}