Skip to main content

argentor_memory/
embedding.rs

1use argentor_core::{ArgentorError, ArgentorResult};
2use async_trait::async_trait;
3use std::collections::HashMap;
4
5/// Trait for computing text embeddings (vector representations).
6#[async_trait]
7pub trait EmbeddingProvider: Send + Sync {
8    /// Compute embedding vector for a single text.
9    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>>;
10
11    /// Compute embeddings for a batch of texts.
12    async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
13        let mut results = Vec::with_capacity(texts.len());
14        for text in texts {
15            results.push(self.embed(text).await?);
16        }
17        Ok(results)
18    }
19
20    /// Dimension of the embedding vectors produced by this provider.
21    fn dimension(&self) -> usize;
22}
23
24/// Local bag-of-words embedding for MVP (no external API needed).
25/// Uses TF-based sparse-to-dense mapping with a fixed dimension.
26/// Good enough for basic semantic search; replace with OpenAI/Cohere embeddings in production.
27pub struct LocalEmbedding {
28    dimension: usize,
29}
30
31impl LocalEmbedding {
32    /// Create a new local embedding generator with the given vector dimension.
33    pub fn new(dimension: usize) -> Self {
34        Self { dimension }
35    }
36}
37
38impl Default for LocalEmbedding {
39    fn default() -> Self {
40        Self::new(256)
41    }
42}
43
44#[async_trait]
45impl EmbeddingProvider for LocalEmbedding {
46    async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
47        if text.is_empty() {
48            return Err(ArgentorError::Agent("Cannot embed empty text".to_string()));
49        }
50
51        // Simple bag-of-words hashing to a fixed-size vector
52        let mut vector = vec![0.0f32; self.dimension];
53
54        let lowered = text.to_lowercase();
55        let words: Vec<&str> = lowered
56            .split(|c: char| !c.is_alphanumeric())
57            .filter(|w| !w.is_empty() && w.len() > 1)
58            .collect();
59
60        // Count word frequencies
61        let mut freq: HashMap<&str, f32> = HashMap::new();
62        for word in &words {
63            *freq.entry(word).or_insert(0.0) += 1.0;
64        }
65
66        let total = words.len() as f32;
67        if total == 0.0 {
68            return Ok(vector);
69        }
70
71        // Hash each word to vector dimensions and add TF weight
72        for (word, count) in &freq {
73            let tf = count / total;
74            // Use multiple hash positions per word for better distribution
75            let hash1 = simple_hash(word.as_bytes()) as usize;
76            let hash2 = simple_hash(&[word.as_bytes(), &[1u8]].concat()) as usize;
77            let hash3 = simple_hash(&[word.as_bytes(), &[2u8]].concat()) as usize;
78
79            vector[hash1 % self.dimension] += tf;
80            vector[hash2 % self.dimension] += tf * 0.7;
81            vector[hash3 % self.dimension] += tf * 0.5;
82        }
83
84        // L2 normalize
85        let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
86        if norm > 0.0 {
87            for v in &mut vector {
88                *v /= norm;
89            }
90        }
91
92        Ok(vector)
93    }
94
95    fn dimension(&self) -> usize {
96        self.dimension
97    }
98}
99
100/// Simple deterministic hash function (FNV-1a).
101fn simple_hash(data: &[u8]) -> u32 {
102    let mut hash: u32 = 2166136261;
103    for &byte in data {
104        hash ^= byte as u32;
105        hash = hash.wrapping_mul(16777619);
106    }
107    hash
108}
109
110#[cfg(test)]
111#[allow(clippy::unwrap_used, clippy::expect_used)]
112mod tests {
113    use super::*;
114
115    #[tokio::test]
116    async fn test_local_embedding_dimension() {
117        let emb = LocalEmbedding::new(128);
118        assert_eq!(emb.dimension(), 128);
119        let vec = emb.embed("hello world").await.unwrap();
120        assert_eq!(vec.len(), 128);
121    }
122
123    #[tokio::test]
124    async fn test_local_embedding_normalized() {
125        let emb = LocalEmbedding::default();
126        let vec = emb.embed("the quick brown fox jumps").await.unwrap();
127        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
128        assert!((norm - 1.0).abs() < 0.01);
129    }
130
131    #[tokio::test]
132    async fn test_local_embedding_similar_texts() {
133        let emb = LocalEmbedding::default();
134        let v1 = emb.embed("rust programming language").await.unwrap();
135        let v2 = emb.embed("rust programming systems").await.unwrap();
136        let v3 = emb.embed("cooking recipes for dinner").await.unwrap();
137
138        let sim_12 = cosine_similarity(&v1, &v2);
139        let sim_13 = cosine_similarity(&v1, &v3);
140
141        // Similar texts should have higher similarity
142        assert!(
143            sim_12 > sim_13,
144            "sim(rust-rust)={sim_12} should be > sim(rust-cooking)={sim_13}"
145        );
146    }
147
148    #[tokio::test]
149    async fn test_local_embedding_empty() {
150        let emb = LocalEmbedding::default();
151        assert!(emb.embed("").await.is_err());
152    }
153
154    #[tokio::test]
155    async fn test_local_embedding_deterministic() {
156        let emb = LocalEmbedding::default();
157        let v1 = emb.embed("test input").await.unwrap();
158        let v2 = emb.embed("test input").await.unwrap();
159        assert_eq!(v1, v2);
160    }
161
162    #[tokio::test]
163    async fn test_embed_batch() {
164        let emb = LocalEmbedding::default();
165        let vecs = emb.embed_batch(&["hello", "world"]).await.unwrap();
166        assert_eq!(vecs.len(), 2);
167        assert_eq!(vecs[0].len(), 256);
168    }
169
170    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
171        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
172        let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
173        let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
174        if na == 0.0 || nb == 0.0 {
175            0.0
176        } else {
177            dot / (na * nb)
178        }
179    }
180}