nt_memory/agentdb/
embeddings.rs

1//! Embedding generation for semantic search
2
3use serde::Serialize;
4use std::hash::{Hash, Hasher};
5
6/// Vector embedding
7pub type Embedding = Vec<f32>;
8
9/// Embedding provider trait
10#[async_trait::async_trait]
11pub trait EmbeddingProvider: Send + Sync {
12    /// Generate embedding for text
13    async fn embed(&self, text: &str) -> anyhow::Result<Embedding>;
14
15    /// Generate embeddings for batch of texts
16    async fn embed_batch(&self, texts: &[String]) -> anyhow::Result<Vec<Embedding>>;
17
18    /// Embedding dimension
19    fn dimension(&self) -> usize;
20}
21
22/// Simple deterministic embedding provider (for testing)
23pub struct DeterministicEmbedder {
24    dimension: usize,
25}
26
27impl DeterministicEmbedder {
28    pub fn new(dimension: usize) -> Self {
29        Self { dimension }
30    }
31
32    fn hash_to_embedding(&self, text: &str) -> Embedding {
33        use std::collections::hash_map::DefaultHasher;
34
35        let mut hasher = DefaultHasher::new();
36        text.hash(&mut hasher);
37        let hash = hasher.finish();
38
39        // Generate deterministic embedding from hash
40        let mut embedding = Vec::with_capacity(self.dimension);
41        let mut current_hash = hash;
42
43        for _ in 0..self.dimension {
44            // Use linear congruential generator for more values
45            current_hash = current_hash.wrapping_mul(1103515245).wrapping_add(12345);
46            let value = ((current_hash >> 16) & 0xFFFF) as f32 / 65535.0;
47            embedding.push(value * 2.0 - 1.0); // Normalize to [-1, 1]
48        }
49
50        // Normalize to unit vector
51        let magnitude: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
52        if magnitude > 0.0 {
53            embedding.iter_mut().for_each(|x| *x /= magnitude);
54        }
55
56        embedding
57    }
58}
59
60#[async_trait::async_trait]
61impl EmbeddingProvider for DeterministicEmbedder {
62    async fn embed(&self, text: &str) -> anyhow::Result<Embedding> {
63        Ok(self.hash_to_embedding(text))
64    }
65
66    async fn embed_batch(&self, texts: &[String]) -> anyhow::Result<Vec<Embedding>> {
67        Ok(texts
68            .iter()
69            .map(|text| self.hash_to_embedding(text))
70            .collect())
71    }
72
73    fn dimension(&self) -> usize {
74        self.dimension
75    }
76}
77
78/// Cosine similarity between embeddings
79pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
80    if a.len() != b.len() {
81        return 0.0;
82    }
83
84    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
85
86    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
87    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
88
89    if mag_a > 0.0 && mag_b > 0.0 {
90        dot_product / (mag_a * mag_b)
91    } else {
92        0.0
93    }
94}
95
96/// Euclidean distance between embeddings
97pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
98    if a.len() != b.len() {
99        return f32::MAX;
100    }
101
102    a.iter()
103        .zip(b.iter())
104        .map(|(x, y)| (x - y).powi(2))
105        .sum::<f32>()
106        .sqrt()
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[tokio::test]
114    async fn test_deterministic_embedder() {
115        let embedder = DeterministicEmbedder::new(384);
116
117        let text = "test string";
118        let embedding1 = embedder.embed(text).await.unwrap();
119        let embedding2 = embedder.embed(text).await.unwrap();
120
121        // Should be deterministic
122        assert_eq!(embedding1, embedding2);
123        assert_eq!(embedding1.len(), 384);
124    }
125
126    #[tokio::test]
127    async fn test_batch_embedding() {
128        let embedder = DeterministicEmbedder::new(128);
129
130        let texts = vec!["hello".to_string(), "world".to_string()];
131        let embeddings = embedder.embed_batch(&texts).await.unwrap();
132
133        assert_eq!(embeddings.len(), 2);
134        assert_eq!(embeddings[0].len(), 128);
135        assert_eq!(embeddings[1].len(), 128);
136    }
137
138    #[test]
139    fn test_cosine_similarity() {
140        let a = vec![1.0, 0.0, 0.0];
141        let b = vec![1.0, 0.0, 0.0];
142        let c = vec![0.0, 1.0, 0.0];
143
144        // Identical vectors
145        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
146
147        // Orthogonal vectors
148        assert!(cosine_similarity(&a, &c).abs() < 0.001);
149    }
150
151    #[test]
152    fn test_euclidean_distance() {
153        let a = vec![0.0, 0.0];
154        let b = vec![3.0, 4.0];
155
156        // Distance should be 5.0
157        assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
158    }
159}