context_mcp/
embeddings.rs1use crate::error::Result;
7use async_trait::async_trait;
8
9#[async_trait]
11pub trait EmbeddingGenerator: Send + Sync {
12 async fn generate(&self, text: &str) -> Result<Vec<f32>>;
14
15 fn dimension(&self) -> usize;
17}
18
19pub struct MockEmbeddingGenerator {
21 dimension: usize,
22}
23
24impl MockEmbeddingGenerator {
25 pub fn new(dimension: usize) -> Self {
26 Self { dimension }
27 }
28}
29
30#[async_trait]
31impl EmbeddingGenerator for MockEmbeddingGenerator {
32 async fn generate(&self, text: &str) -> Result<Vec<f32>> {
33 use std::collections::hash_map::DefaultHasher;
34 use std::hash::{Hash, Hasher};
35
36 let mut hasher = DefaultHasher::new();
38 text.hash(&mut hasher);
39 let hash = hasher.finish();
40
41 let mut embedding = Vec::with_capacity(self.dimension);
42 for i in 0..self.dimension {
43 let value = ((hash.wrapping_mul(i as u64 + 1)) as f32) / (u64::MAX as f32);
44 embedding.push(value);
45 }
46
47 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
49 if norm > 0.0 {
50 for val in embedding.iter_mut() {
51 *val /= norm;
52 }
53 }
54
55 Ok(embedding)
56 }
57
58 fn dimension(&self) -> usize {
59 self.dimension
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[tokio::test]
68 async fn test_mock_embedding_deterministic() {
69 let generator = MockEmbeddingGenerator::new(384);
70
71 let emb1 = generator.generate("test text").await.unwrap();
72 let emb2 = generator.generate("test text").await.unwrap();
73
74 assert_eq!(emb1.len(), 384);
75 assert_eq!(emb1, emb2); }
77
78 #[tokio::test]
79 async fn test_mock_embedding_normalized() {
80 let generator = MockEmbeddingGenerator::new(384);
81 let embedding = generator.generate("test").await.unwrap();
82
83 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
84 assert!((norm - 1.0).abs() < 0.001); }
86}