Skip to main content

do_memory_core/embeddings/
provider.rs

1//! Embedding provider trait and common functionality
2
3use anyhow::Result;
4use async_trait::async_trait;
5
6/// Result from embedding generation
7#[derive(Debug, Clone)]
8pub struct EmbeddingResult {
9    /// The generated embedding vector
10    pub embedding: Vec<f32>,
11    /// Number of tokens processed
12    pub token_count: Option<usize>,
13    /// Model used for generation
14    pub model: String,
15    /// Generation time in milliseconds
16    pub generation_time_ms: Option<u64>,
17}
18
19impl EmbeddingResult {
20    /// Create a simple embedding result
21    #[must_use]
22    pub fn new(embedding: Vec<f32>, model: String) -> Self {
23        Self {
24            embedding,
25            token_count: None,
26            model,
27            generation_time_ms: None,
28        }
29    }
30
31    /// Create a detailed embedding result
32    #[must_use]
33    pub fn detailed(
34        embedding: Vec<f32>,
35        model: String,
36        token_count: usize,
37        generation_time_ms: u64,
38    ) -> Self {
39        Self {
40            embedding,
41            token_count: Some(token_count),
42            model,
43            generation_time_ms: Some(generation_time_ms),
44        }
45    }
46}
47
48/// Trait for embedding providers that convert text to vectors
49#[async_trait]
50pub trait EmbeddingProvider: Send + Sync {
51    /// Generate embedding for a single text
52    ///
53    /// # Arguments
54    /// * `text` - Input text to embed
55    ///
56    /// # Returns
57    /// Vector representation of the text
58    async fn embed_text(&self, text: &str) -> Result<Vec<f32>>;
59
60    /// Generate embeddings for multiple texts in batch
61    ///
62    /// More efficient than calling `embed_text` multiple times.
63    /// Default implementation calls `embed_text` for each text.
64    ///
65    /// # Arguments
66    /// * `texts` - Batch of texts to embed
67    ///
68    /// # Returns
69    /// Vector of embeddings in the same order as input texts
70    async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
71        let mut embeddings = Vec::with_capacity(texts.len());
72        for text in texts {
73            let embedding = self.embed_text(text).await?;
74            embeddings.push(embedding);
75        }
76        Ok(embeddings)
77    }
78
79    /// Calculate semantic similarity between two texts
80    ///
81    /// # Arguments
82    /// * `text1` - First text
83    /// * `text2` - Second text
84    ///
85    /// # Returns
86    /// Similarity score between 0.0 and 1.0 (higher = more similar)
87    async fn similarity(&self, text1: &str, text2: &str) -> Result<f32> {
88        let embedding1 = self.embed_text(text1).await?;
89        let embedding2 = self.embed_text(text2).await?;
90        Ok(crate::embeddings::similarity::cosine_similarity(
91            &embedding1,
92            &embedding2,
93        ))
94    }
95
96    /// Get the embedding dimension for this provider
97    fn embedding_dimension(&self) -> usize;
98
99    /// Get the model name/identifier
100    fn model_name(&self) -> &str;
101
102    /// Check if the provider is available/configured
103    async fn is_available(&self) -> bool {
104        // Default implementation tries to embed a simple test
105        self.embed_text("test").await.is_ok()
106    }
107
108    /// Warm up the provider (load models, test connections, etc.)
109    async fn warmup(&self) -> Result<()> {
110        // Default implementation does a simple test embedding
111        self.embed_text("warmup test").await?;
112        Ok(())
113    }
114
115    /// Get provider-specific metadata
116    fn metadata(&self) -> serde_json::Value {
117        serde_json::json!({
118            "model": self.model_name(),
119            "dimension": self.embedding_dimension()
120        })
121    }
122}
123
124/// Utility functions for embedding providers
125pub mod utils {
126    use anyhow::Result;
127
128    /// Normalize a vector to unit length
129    #[must_use]
130    pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
131        let magnitude = (vector.iter().map(|x| x * x).sum::<f32>()).sqrt();
132        if magnitude > 0.0 {
133            for x in &mut vector {
134                *x /= magnitude;
135            }
136        }
137        vector
138    }
139
140    /// Validate embedding dimension matches expected
141    #[allow(dead_code)] // Utility function kept for future use
142    pub fn validate_dimension(embedding: &[f32], expected: usize) -> Result<()> {
143        if embedding.len() != expected {
144            anyhow::bail!(
145                "Embedding dimension mismatch: got {}, expected {}",
146                embedding.len(),
147                expected
148            );
149        }
150        Ok(())
151    }
152
153    /// Chunk text into smaller pieces for embedding
154    /// Useful for long texts that exceed model token limits
155    #[allow(dead_code)] // Utility function kept for future use
156    pub fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
157        if text.len() <= max_chars {
158            return vec![text.to_string()];
159        }
160
161        let mut chunks = Vec::new();
162        let words: Vec<&str> = text.split_whitespace().collect();
163        let mut current_chunk = String::new();
164
165        for word in words {
166            if current_chunk.len() + word.len() + 1 > max_chars && !current_chunk.is_empty() {
167                chunks.push(current_chunk.trim().to_string());
168                current_chunk = word.to_string();
169            } else {
170                if !current_chunk.is_empty() {
171                    current_chunk.push(' ');
172                }
173                current_chunk.push_str(word);
174            }
175        }
176
177        if !current_chunk.is_empty() {
178            chunks.push(current_chunk.trim().to_string());
179        }
180
181        chunks
182    }
183
184    /// Average multiple embeddings into a single embedding
185    /// Useful for combining embeddings from chunked text
186    #[allow(dead_code)] // Utility function kept for future use
187    pub fn average_embeddings(embeddings: &[Vec<f32>]) -> Result<Vec<f32>> {
188        if embeddings.is_empty() {
189            anyhow::bail!("Cannot average empty embeddings list");
190        }
191
192        let dimension = embeddings[0].len();
193        let mut result = vec![0.0; dimension];
194
195        for embedding in embeddings {
196            if embedding.len() != dimension {
197                anyhow::bail!("Inconsistent embedding dimensions");
198            }
199            for (i, &value) in embedding.iter().enumerate() {
200                result[i] += value;
201            }
202        }
203
204        let count = embeddings.len() as f32;
205        for value in &mut result {
206            *value /= count;
207        }
208
209        Ok(normalize_vector(result))
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_normalize_vector() {
219        let vector = vec![3.0, 4.0]; // Magnitude = 5.0
220        let normalized = utils::normalize_vector(vector);
221
222        // Should be [0.6, 0.8]
223        assert!((normalized[0] - 0.6).abs() < 0.001);
224        assert!((normalized[1] - 0.8).abs() < 0.001);
225
226        // Check unit magnitude
227        let magnitude = (normalized.iter().map(|x| x * x).sum::<f32>()).sqrt();
228        assert!((magnitude - 1.0).abs() < 0.001);
229    }
230
231    #[test]
232    fn test_chunk_text() {
233        let text =
234            "This is a long text that needs to be chunked into smaller pieces for processing";
235        let chunks = utils::chunk_text(text, 25);
236
237        assert!(chunks.len() > 1);
238        for chunk in &chunks {
239            assert!(chunk.len() <= 25);
240        }
241
242        // All chunks should contain parts of the original text
243        let rejoined = chunks.join(" ");
244        let original_words: Vec<&str> = text.split_whitespace().collect();
245        let rejoined_words: Vec<&str> = rejoined.split_whitespace().collect();
246        assert_eq!(original_words, rejoined_words);
247    }
248
249    #[test]
250    fn test_average_embeddings() {
251        let embeddings = vec![
252            vec![1.0, 2.0, 3.0],
253            vec![2.0, 4.0, 6.0],
254            vec![3.0, 6.0, 9.0],
255        ];
256
257        let averaged = utils::average_embeddings(&embeddings)
258            .expect("average_embeddings should succeed with valid embedding vectors");
259
260        // Average before normalization would be [2.0, 4.0, 6.0]
261        // After normalization, it should be a unit vector in that direction
262        let expected_magnitude = (4.0 + 16.0 + 36.0_f32).sqrt(); // sqrt(56) ≈ 7.48
263        let expected = [
264            2.0 / expected_magnitude,
265            4.0 / expected_magnitude,
266            6.0 / expected_magnitude,
267        ];
268
269        for (actual, expected) in averaged.iter().zip(expected.iter()) {
270            assert!((actual - expected).abs() < 0.001);
271        }
272    }
273
274    #[test]
275    fn test_validate_dimension() {
276        let embedding = vec![1.0, 2.0, 3.0];
277
278        assert!(utils::validate_dimension(&embedding, 3).is_ok());
279        assert!(utils::validate_dimension(&embedding, 4).is_err());
280    }
281}