do_memory_core/embeddings/
provider.rs1use anyhow::Result;
4use async_trait::async_trait;
5
6#[derive(Debug, Clone)]
8pub struct EmbeddingResult {
9 pub embedding: Vec<f32>,
11 pub token_count: Option<usize>,
13 pub model: String,
15 pub generation_time_ms: Option<u64>,
17}
18
19impl EmbeddingResult {
20 #[must_use]
22 pub fn new(embedding: Vec<f32>, model: String) -> Self {
23 Self {
24 embedding,
25 token_count: None,
26 model,
27 generation_time_ms: None,
28 }
29 }
30
31 #[must_use]
33 pub fn detailed(
34 embedding: Vec<f32>,
35 model: String,
36 token_count: usize,
37 generation_time_ms: u64,
38 ) -> Self {
39 Self {
40 embedding,
41 token_count: Some(token_count),
42 model,
43 generation_time_ms: Some(generation_time_ms),
44 }
45 }
46}
47
48#[async_trait]
50pub trait EmbeddingProvider: Send + Sync {
51 async fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
59
60 async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
71 let mut embeddings = Vec::with_capacity(texts.len());
72 for text in texts {
73 let embedding = self.embed_text(text).await?;
74 embeddings.push(embedding);
75 }
76 Ok(embeddings)
77 }
78
79 async fn similarity(&self, text1: &str, text2: &str) -> Result<f32> {
88 let embedding1 = self.embed_text(text1).await?;
89 let embedding2 = self.embed_text(text2).await?;
90 Ok(crate::embeddings::similarity::cosine_similarity(
91 &embedding1,
92 &embedding2,
93 ))
94 }
95
96 fn embedding_dimension(&self) -> usize;
98
99 fn model_name(&self) -> &str;
101
102 async fn is_available(&self) -> bool {
104 self.embed_text("test").await.is_ok()
106 }
107
108 async fn warmup(&self) -> Result<()> {
110 self.embed_text("warmup test").await?;
112 Ok(())
113 }
114
115 fn metadata(&self) -> serde_json::Value {
117 serde_json::json!({
118 "model": self.model_name(),
119 "dimension": self.embedding_dimension()
120 })
121 }
122}
123
124pub mod utils {
126 use anyhow::Result;
127
128 pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
130 let magnitude = (vector.iter().map(|x| x * x).sum::<f32>()).sqrt();
131 if magnitude > 0.0 {
132 for x in &mut vector {
133 *x /= magnitude;
134 }
135 }
136 vector
137 }
138
139 #[allow(dead_code)]
141 pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<()> {
142 if embedding.len() != expected {
143 anyhow::bail!(
144 "Embedding dimension mismatch: got {}, expected {}",
145 embedding.len(),
146 expected
147 );
148 }
149 Ok(())
150 }
151
152 #[allow(dead_code)]
155 pub fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
156 if text.len() <= max_chars {
157 return vec![text.to_string()];
158 }
159
160 let mut chunks = Vec::new();
161 let words: Vec<&str> = text.split_whitespace().collect();
162 let mut current_chunk = String::new();
163
164 for word in words {
165 if current_chunk.len() + word.len() + 1 > max_chars && !current_chunk.is_empty() {
166 chunks.push(current_chunk.trim().to_string());
167 current_chunk = word.to_string();
168 } else {
169 if !current_chunk.is_empty() {
170 current_chunk.push(' ');
171 }
172 current_chunk.push_str(word);
173 }
174 }
175
176 if !current_chunk.is_empty() {
177 chunks.push(current_chunk.trim().to_string());
178 }
179
180 chunks
181 }
182
183 #[allow(dead_code)]
186 pub fn average_embeddings(embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
187 if embeddings.is_empty() {
188 anyhow::bail!("Cannot average empty embeddings list");
189 }
190
191 let dimension = embeddings[0].len();
192 let mut result = vec![0.0; dimension];
193
194 for embedding in embeddings {
195 if embedding.len() != dimension {
196 anyhow::bail!("Inconsistent embedding dimensions");
197 }
198 for (i, &value) in embedding.iter().enumerate() {
199 result[i] += value;
200 }
201 }
202
203 let count = embeddings.len() as f32;
204 for value in &mut result {
205 *value /= count;
206 }
207
208 Ok(normalize_vector(result))
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn test_normalize_vector() {
218 let vector = vec![3.0, 4.0]; let normalized = utils::normalize_vector(vector);
220
221 assert!((normalized[0] - 0.6).abs() < 0.001);
223 assert!((normalized[1] - 0.8).abs() < 0.001);
224
225 let magnitude = (normalized.iter().map(|x| x * x).sum::<f32>()).sqrt();
227 assert!((magnitude - 1.0).abs() < 0.001);
228 }
229
230 #[test]
231 fn test_chunk_text() {
232 let text =
233 "This is a long text that needs to be chunked into smaller pieces for processing";
234 let chunks = utils::chunk_text(text, 25);
235
236 assert!(chunks.len() > 1);
237 for chunk in &chunks {
238 assert!(chunk.len() <= 25);
239 }
240
241 let rejoined = chunks.join(" ");
243 let original_words: Vec<&str> = text.split_whitespace().collect();
244 let rejoined_words: Vec<&str> = rejoined.split_whitespace().collect();
245 assert_eq!(original_words, rejoined_words);
246 }
247
248 #[test]
249 fn test_average_embeddings() {
250 let embeddings = vec![
251 vec![1.0, 2.0, 3.0],
252 vec![2.0, 4.0, 6.0],
253 vec![3.0, 6.0, 9.0],
254 ];
255
256 let averaged = utils::average_embeddings(&embeddings)
257 .expect("average_embeddings should succeed with valid embedding vectors");
258
259 let expected_magnitude = (4.0 + 16.0 + 36.0_f32).sqrt(); let expected = [
263 2.0 / expected_magnitude,
264 4.0 / expected_magnitude,
265 6.0 / expected_magnitude,
266 ];
267
268 for (actual, expected) in averaged.iter().zip(expected.iter()) {
269 assert!((actual - expected).abs() < 0.001);
270 }
271 }
272
273 #[test]
274 fn test_validate_dimension() {
275 let embedding = vec![1.0, 2.0, 3.0];
276
277 assert!(utils::validate_dimension(&embedding, 3).is_ok());
278 assert!(utils::validate_dimension(&embedding, 4).is_err());
279 }
280}