do_memory_core/embeddings/
provider.rs1use anyhow::Result;
4use async_trait::async_trait;
5
6#[derive(Debug, Clone)]
8pub struct EmbeddingResult {
9 pub embedding: Vec<f32>,
11 pub token_count: Option<usize>,
13 pub model: String,
15 pub generation_time_ms: Option<u64>,
17}
18
19impl EmbeddingResult {
20 #[must_use]
22 pub fn new(embedding: Vec<f32>, model: String) -> Self {
23 Self {
24 embedding,
25 token_count: None,
26 model,
27 generation_time_ms: None,
28 }
29 }
30
31 #[must_use]
33 pub fn detailed(
34 embedding: Vec<f32>,
35 model: String,
36 token_count: usize,
37 generation_time_ms: u64,
38 ) -> Self {
39 Self {
40 embedding,
41 token_count: Some(token_count),
42 model,
43 generation_time_ms: Some(generation_time_ms),
44 }
45 }
46}
47
48#[async_trait]
50pub trait EmbeddingProvider: Send + Sync {
51 async fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
59
60 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
71 let mut embeddings = Vec::with_capacity(texts.len());
72 for text in texts {
73 let embedding = self.embed_text(text).await?;
74 embeddings.push(embedding);
75 }
76 Ok(embeddings)
77 }
78
79 async fn similarity(&self, text1: &str, text2: &str) -> Result<f32> {
88 let embedding1 = self.embed_text(text1).await?;
89 let embedding2 = self.embed_text(text2).await?;
90 Ok(crate::embeddings::similarity::cosine_similarity(
91 &embedding1,
92 &embedding2,
93 ))
94 }
95
96 fn embedding_dimension(&self) -> usize;
98
99 fn model_name(&self) -> &str;
101
102 async fn is_available(&self) -> bool {
104 self.embed_text("test").await.is_ok()
106 }
107
108 async fn warmup(&self) -> Result<()> {
110 self.embed_text("warmup test").await?;
112 Ok(())
113 }
114
115 fn metadata(&self) -> serde_json::Value {
117 serde_json::json!({
118 "model": self.model_name(),
119 "dimension": self.embedding_dimension()
120 })
121 }
122}
123
124pub mod utils {
126 use anyhow::Result;
127
128 #[must_use]
130 pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
131 let magnitude = (vector.iter().map(|x| x * x).sum::<f32>()).sqrt();
132 if magnitude > 0.0 {
133 for x in &mut vector {
134 *x /= magnitude;
135 }
136 }
137 vector
138 }
139
140 #[allow(dead_code)] pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<()> {
143 if embedding.len() != expected {
144 anyhow::bail!(
145 "Embedding dimension mismatch: got {}, expected {}",
146 embedding.len(),
147 expected
148 );
149 }
150 Ok(())
151 }
152
153 #[allow(dead_code)] pub fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
157 if text.len() <= max_chars {
158 return vec![text.to_string()];
159 }
160
161 let mut chunks = Vec::new();
162 let words: Vec<&str> = text.split_whitespace().collect();
163 let mut current_chunk = String::new();
164
165 for word in words {
166 if current_chunk.len() + word.len() + 1 > max_chars && !current_chunk.is_empty() {
167 chunks.push(current_chunk.trim().to_string());
168 current_chunk = word.to_string();
169 } else {
170 if !current_chunk.is_empty() {
171 current_chunk.push(' ');
172 }
173 current_chunk.push_str(word);
174 }
175 }
176
177 if !current_chunk.is_empty() {
178 chunks.push(current_chunk.trim().to_string());
179 }
180
181 chunks
182 }
183
184 #[allow(dead_code)] pub fn average_embeddings(embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
188 if embeddings.is_empty() {
189 anyhow::bail!("Cannot average empty embeddings list");
190 }
191
192 let dimension = embeddings[0].len();
193 let mut result = vec![0.0; dimension];
194
195 for embedding in embeddings {
196 if embedding.len() != dimension {
197 anyhow::bail!("Inconsistent embedding dimensions");
198 }
199 for (i, &value) in embedding.iter().enumerate() {
200 result[i] += value;
201 }
202 }
203
204 let count = embeddings.len() as f32;
205 for value in &mut result {
206 *value /= count;
207 }
208
209 Ok(normalize_vector(result))
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_normalize_vector() {
219 let vector = vec![3.0, 4.0]; let normalized = utils::normalize_vector(vector);
221
222 assert!((normalized[0] - 0.6).abs() < 0.001);
224 assert!((normalized[1] - 0.8).abs() < 0.001);
225
226 let magnitude = (normalized.iter().map(|x| x * x).sum::<f32>()).sqrt();
228 assert!((magnitude - 1.0).abs() < 0.001);
229 }
230
231 #[test]
232 fn test_chunk_text() {
233 let text =
234 "This is a long text that needs to be chunked into smaller pieces for processing";
235 let chunks = utils::chunk_text(text, 25);
236
237 assert!(chunks.len() > 1);
238 for chunk in &chunks {
239 assert!(chunk.len() <= 25);
240 }
241
242 let rejoined = chunks.join(" ");
244 let original_words: Vec<&str> = text.split_whitespace().collect();
245 let rejoined_words: Vec<&str> = rejoined.split_whitespace().collect();
246 assert_eq!(original_words, rejoined_words);
247 }
248
249 #[test]
250 fn test_average_embeddings() {
251 let embeddings = vec![
252 vec![1.0, 2.0, 3.0],
253 vec![2.0, 4.0, 6.0],
254 vec![3.0, 6.0, 9.0],
255 ];
256
257 let averaged = utils::average_embeddings(&embeddings)
258 .expect("average_embeddings should succeed with valid embedding vectors");
259
260 let expected_magnitude = (4.0 + 16.0 + 36.0_f32).sqrt(); let expected = [
264 2.0 / expected_magnitude,
265 4.0 / expected_magnitude,
266 6.0 / expected_magnitude,
267 ];
268
269 for (actual, expected) in averaged.iter().zip(expected.iter()) {
270 assert!((actual - expected).abs() < 0.001);
271 }
272 }
273
274 #[test]
275 fn test_validate_dimension() {
276 let embedding = vec![1.0, 2.0, 3.0];
277
278 assert!(utils::validate_dimension(&embedding, 3).is_ok());
279 assert!(utils::validate_dimension(&embedding, 4).is_err());
280 }
281}