argentor_memory/
embedding.rs1use argentor_core::{ArgentorError, ArgentorResult};
2use async_trait::async_trait;
3use std::collections::HashMap;
4
5#[async_trait]
7pub trait EmbeddingProvider: Send + Sync {
8 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>>;
10
11 async fn embed_batch(&self, texts: &[&str]) -> ArgentorResult<Vec<Vec<f32>>> {
13 let mut results = Vec::with_capacity(texts.len());
14 for text in texts {
15 results.push(self.embed(text).await?);
16 }
17 Ok(results)
18 }
19
20 fn dimension(&self) -> usize;
22}
23
24pub struct LocalEmbedding {
28 dimension: usize,
29}
30
31impl LocalEmbedding {
32 pub fn new(dimension: usize) -> Self {
34 Self { dimension }
35 }
36}
37
38impl Default for LocalEmbedding {
39 fn default() -> Self {
40 Self::new(256)
41 }
42}
43
44#[async_trait]
45impl EmbeddingProvider for LocalEmbedding {
46 async fn embed(&self, text: &str) -> ArgentorResult<Vec<f32>> {
47 if text.is_empty() {
48 return Err(ArgentorError::Agent("Cannot embed empty text".to_string()));
49 }
50
51 let mut vector = vec![0.0f32; self.dimension];
53
54 let lowered = text.to_lowercase();
55 let words: Vec<&str> = lowered
56 .split(|c: char| !c.is_alphanumeric())
57 .filter(|w| !w.is_empty() && w.len() > 1)
58 .collect();
59
60 let mut freq: HashMap<&str, f32> = HashMap::new();
62 for word in &words {
63 *freq.entry(word).or_insert(0.0) += 1.0;
64 }
65
66 let total = words.len() as f32;
67 if total == 0.0 {
68 return Ok(vector);
69 }
70
71 for (word, count) in &freq {
73 let tf = count / total;
74 let hash1 = simple_hash(word.as_bytes()) as usize;
76 let hash2 = simple_hash(&[word.as_bytes(), &[1u8]].concat()) as usize;
77 let hash3 = simple_hash(&[word.as_bytes(), &[2u8]].concat()) as usize;
78
79 vector[hash1 % self.dimension] += tf;
80 vector[hash2 % self.dimension] += tf * 0.7;
81 vector[hash3 % self.dimension] += tf * 0.5;
82 }
83
84 let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
86 if norm > 0.0 {
87 for v in &mut vector {
88 *v /= norm;
89 }
90 }
91
92 Ok(vector)
93 }
94
95 fn dimension(&self) -> usize {
96 self.dimension
97 }
98}
99
100fn simple_hash(data: &[u8]) -> u32 {
102 let mut hash: u32 = 2166136261;
103 for &byte in data {
104 hash ^= byte as u32;
105 hash = hash.wrapping_mul(16777619);
106 }
107 hash
108}
109
110#[cfg(test)]
111#[allow(clippy::unwrap_used, clippy::expect_used)]
112mod tests {
113 use super::*;
114
115 #[tokio::test]
116 async fn test_local_embedding_dimension() {
117 let emb = LocalEmbedding::new(128);
118 assert_eq!(emb.dimension(), 128);
119 let vec = emb.embed("hello world").await.unwrap();
120 assert_eq!(vec.len(), 128);
121 }
122
123 #[tokio::test]
124 async fn test_local_embedding_normalized() {
125 let emb = LocalEmbedding::default();
126 let vec = emb.embed("the quick brown fox jumps").await.unwrap();
127 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
128 assert!((norm - 1.0).abs() < 0.01);
129 }
130
131 #[tokio::test]
132 async fn test_local_embedding_similar_texts() {
133 let emb = LocalEmbedding::default();
134 let v1 = emb.embed("rust programming language").await.unwrap();
135 let v2 = emb.embed("rust programming systems").await.unwrap();
136 let v3 = emb.embed("cooking recipes for dinner").await.unwrap();
137
138 let sim_12 = cosine_similarity(&v1, &v2);
139 let sim_13 = cosine_similarity(&v1, &v3);
140
141 assert!(
143 sim_12 > sim_13,
144 "sim(rust-rust)={sim_12} should be > sim(rust-cooking)={sim_13}"
145 );
146 }
147
148 #[tokio::test]
149 async fn test_local_embedding_empty() {
150 let emb = LocalEmbedding::default();
151 assert!(emb.embed("").await.is_err());
152 }
153
154 #[tokio::test]
155 async fn test_local_embedding_deterministic() {
156 let emb = LocalEmbedding::default();
157 let v1 = emb.embed("test input").await.unwrap();
158 let v2 = emb.embed("test input").await.unwrap();
159 assert_eq!(v1, v2);
160 }
161
162 #[tokio::test]
163 async fn test_embed_batch() {
164 let emb = LocalEmbedding::default();
165 let vecs = emb.embed_batch(&["hello", "world"]).await.unwrap();
166 assert_eq!(vecs.len(), 2);
167 assert_eq!(vecs[0].len(), 256);
168 }
169
170 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
171 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
172 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
173 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
174 if na == 0.0 || nb == 0.0 {
175 0.0
176 } else {
177 dot / (na * nb)
178 }
179 }
180}