Skip to main content

do_memory_core/embeddings/
provider.rs

1//! Embedding provider trait and common functionality
2
3use anyhow::Result;
4use async_trait::async_trait;
5
6/// Result from embedding generation
7#[derive(Debug, Clone)]
8pub struct EmbeddingResult {
9    /// The generated embedding vector
10    pub embedding: Vec<f32>,
11    /// Number of tokens processed
12    pub token_count: Option<usize>,
13    /// Model used for generation
14    pub model: String,
15    /// Generation time in milliseconds
16    pub generation_time_ms: Option<u64>,
17}
18
19impl EmbeddingResult {
20    /// Create a simple embedding result
21    #[must_use]
22    pub fn new(embedding: Vec<f32>, model: String) -> Self {
23        Self {
24            embedding,
25            token_count: None,
26            model,
27            generation_time_ms: None,
28        }
29    }
30
31    /// Create a detailed embedding result
32    #[must_use]
33    pub fn detailed(
34        embedding: Vec<f32>,
35        model: String,
36        token_count: usize,
37        generation_time_ms: u64,
38    ) -> Self {
39        Self {
40            embedding,
41            token_count: Some(token_count),
42            model,
43            generation_time_ms: Some(generation_time_ms),
44        }
45    }
46}
47
48/// Trait for embedding providers that convert text to vectors
49#[async_trait]
50pub trait EmbeddingProvider: Send + Sync {
51    /// Generate embedding for a single text
52    ///
53    /// # Arguments
54    /// * `text` - Input text to embed
55    ///
56    /// # Returns
57    /// Vector representation of the text
58    async fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
59
60    /// Generate embeddings for multiple texts in batch
61    ///
62    /// More efficient than calling `embed_text` multiple times.
63    /// Default implementation calls `embed_text` for each text.
64    ///
65    /// # Arguments
66    /// * `texts` - Batch of texts to embed
67    ///
68    /// # Returns
69    /// Vector of embeddings in the same order as input texts
70    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
71        let mut embeddings = Vec::with_capacity(texts.len());
72        for text in texts {
73            let embedding = self.embed_text(text).await?;
74            embeddings.push(embedding);
75        }
76        Ok(embeddings)
77    }
78
79    /// Calculate semantic similarity between two texts
80    ///
81    /// # Arguments
82    /// * `text1` - First text
83    /// * `text2` - Second text
84    ///
85    /// # Returns
86    /// Similarity score between 0.0 and 1.0 (higher = more similar)
87    async fn similarity(&self, text1: &str, text2: &str) -> Result<f32> {
88        let embedding1 = self.embed_text(text1).await?;
89        let embedding2 = self.embed_text(text2).await?;
90        Ok(crate::embeddings::similarity::cosine_similarity(
91            &embedding1,
92            &embedding2,
93        ))
94    }
95
96    /// Get the embedding dimension for this provider
97    fn embedding_dimension(&self) -> usize;
98
99    /// Get the model name/identifier
100    fn model_name(&self) -> &str;
101
102    /// Check if the provider is available/configured
103    async fn is_available(&self) -> bool {
104        // Default implementation tries to embed a simple test
105        self.embed_text("test").await.is_ok()
106    }
107
108    /// Warm up the provider (load models, test connections, etc.)
109    async fn warmup(&self) -> Result<()> {
110        // Default implementation does a simple test embedding
111        self.embed_text("warmup test").await?;
112        Ok(())
113    }
114
115    /// Get provider-specific metadata
116    fn metadata(&self) -> serde_json::Value {
117        serde_json::json!({
118            "model": self.model_name(),
119            "dimension": self.embedding_dimension()
120        })
121    }
122}
123
124/// Utility functions for embedding providers
125pub mod utils {
126    use anyhow::Result;
127
128    /// Normalize a vector to unit length
129    pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
130        let magnitude = (vector.iter().map(|x| x * x).sum::<f32>()).sqrt();
131        if magnitude > 0.0 {
132            for x in &mut vector {
133                *x /= magnitude;
134            }
135        }
136        vector
137    }
138
139    /// Validate embedding dimension matches expected
140    #[allow(dead_code)]
141    pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<()> {
142        if embedding.len() != expected {
143            anyhow::bail!(
144                "Embedding dimension mismatch: got {}, expected {}",
145                embedding.len(),
146                expected
147            );
148        }
149        Ok(())
150    }
151
152    /// Chunk text into smaller pieces for embedding
153    /// Useful for long texts that exceed model token limits
154    #[allow(dead_code)]
155    pub fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
156        if text.len() <= max_chars {
157            return vec![text.to_string()];
158        }
159
160        let mut chunks = Vec::new();
161        let words: Vec<&str> = text.split_whitespace().collect();
162        let mut current_chunk = String::new();
163
164        for word in words {
165            if current_chunk.len() + word.len() + 1 > max_chars && !current_chunk.is_empty() {
166                chunks.push(current_chunk.trim().to_string());
167                current_chunk = word.to_string();
168            } else {
169                if !current_chunk.is_empty() {
170                    current_chunk.push(' ');
171                }
172                current_chunk.push_str(word);
173            }
174        }
175
176        if !current_chunk.is_empty() {
177            chunks.push(current_chunk.trim().to_string());
178        }
179
180        chunks
181    }
182
183    /// Average multiple embeddings into a single embedding
184    /// Useful for combining embeddings from chunked text
185    #[allow(dead_code)]
186    pub fn average_embeddings(embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
187        if embeddings.is_empty() {
188            anyhow::bail!("Cannot average empty embeddings list");
189        }
190
191        let dimension = embeddings[0].len();
192        let mut result = vec![0.0; dimension];
193
194        for embedding in embeddings {
195            if embedding.len() != dimension {
196                anyhow::bail!("Inconsistent embedding dimensions");
197            }
198            for (i, &value) in embedding.iter().enumerate() {
199                result[i] += value;
200            }
201        }
202
203        let count = embeddings.len() as f32;
204        for value in &mut result {
205            *value /= count;
206        }
207
208        Ok(normalize_vector(result))
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn test_normalize_vector() {
218        let vector = vec![3.0, 4.0]; // Magnitude = 5.0
219        let normalized = utils::normalize_vector(vector);
220
221        // Should be [0.6, 0.8]
222        assert!((normalized[0] - 0.6).abs() < 0.001);
223        assert!((normalized[1] - 0.8).abs() < 0.001);
224
225        // Check unit magnitude
226        let magnitude = (normalized.iter().map(|x| x * x).sum::<f32>()).sqrt();
227        assert!((magnitude - 1.0).abs() < 0.001);
228    }
229
230    #[test]
231    fn test_chunk_text() {
232        let text =
233            "This is a long text that needs to be chunked into smaller pieces for processing";
234        let chunks = utils::chunk_text(text, 25);
235
236        assert!(chunks.len() > 1);
237        for chunk in &chunks {
238            assert!(chunk.len() <= 25);
239        }
240
241        // All chunks should contain parts of the original text
242        let rejoined = chunks.join(" ");
243        let original_words: Vec<&str> = text.split_whitespace().collect();
244        let rejoined_words: Vec<&str> = rejoined.split_whitespace().collect();
245        assert_eq!(original_words, rejoined_words);
246    }
247
248    #[test]
249    fn test_average_embeddings() {
250        let embeddings = vec![
251            vec![1.0, 2.0, 3.0],
252            vec![2.0, 4.0, 6.0],
253            vec![3.0, 6.0, 9.0],
254        ];
255
256        let averaged = utils::average_embeddings(&embeddings)
257            .expect("average_embeddings should succeed with valid embedding vectors");
258
259        // Average before normalization would be [2.0, 4.0, 6.0]
260        // After normalization, it should be a unit vector in that direction
261        let expected_magnitude = (4.0 + 16.0 + 36.0_f32).sqrt(); // sqrt(56) ≈ 7.48
262        let expected = [
263            2.0 / expected_magnitude,
264            4.0 / expected_magnitude,
265            6.0 / expected_magnitude,
266        ];
267
268        for (actual, expected) in averaged.iter().zip(expected.iter()) {
269            assert!((actual - expected).abs() < 0.001);
270        }
271    }
272
273    #[test]
274    fn test_validate_dimension() {
275        let embedding = vec![1.0, 2.0, 3.0];
276
277        assert!(utils::validate_dimension(&embedding, 3).is_ok());
278        assert!(utils::validate_dimension(&embedding, 4).is_err());
279    }
280}