memvid_cli/
mistral_embeddings.rs

1//! Mistral AI Embeddings Provider
2//!
3//! This module provides an embedding implementation that uses
4//! Mistral AI's embeddings API for generating high-quality embeddings.
5//!
6//! ## Environment Variables
7//! - `MISTRAL_API_KEY`: Required API key for Mistral AI
8//! - `MISTRAL_EMBEDDING_MODEL`: Optional model override (default: mistral-embed)
9//!
10//! ## Features
11//! - Uses mistral-embed model (1024 dimensions)
12//! - Efficient batch processing
13//! - Thread-safe for concurrent use
14
15use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21/// Mistral embeddings API endpoint
22const MISTRAL_EMBEDDINGS_URL: &str = "https://api.mistral.ai/v1/embeddings";
23
24/// Default embedding model
25const DEFAULT_MODEL: &str = "mistral-embed";
26
27/// Maximum texts per batch (Mistral supports batching)
28const MAX_BATCH_SIZE: usize = 100;
29
30/// Request timeout
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33/// Maximum characters for embedding text to avoid token limits.
34/// mistral-embed has an 8192 token limit, using conservative estimate.
35const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
36
37/// Truncate text to MAX_EMBEDDING_TEXT_LEN to avoid token limit errors.
38fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
39    if text.len() <= MAX_EMBEDDING_TEXT_LEN {
40        std::borrow::Cow::Borrowed(text)
41    } else {
42        let end = text[..MAX_EMBEDDING_TEXT_LEN]
43            .char_indices()
44            .rev()
45            .next()
46            .map(|(i, c)| i + c.len_utf8())
47            .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
48        warn!(
49            "Truncating embedding text from {} to {} chars to avoid token limit",
50            text.len(),
51            end
52        );
53        std::borrow::Cow::Owned(text[..end].to_string())
54    }
55}
56
57/// Mistral embedding request payload
58#[derive(Debug, Serialize)]
59struct MistralEmbeddingRequest<'a> {
60    model: &'a str,
61    input: Vec<&'a str>,
62}
63
64/// Mistral embedding response
65#[derive(Debug, Deserialize)]
66struct MistralEmbeddingResponse {
67    data: Vec<MistralEmbeddingData>,
68    model: String,
69    usage: MistralUsage,
70}
71
72#[derive(Debug, Deserialize)]
73struct MistralEmbeddingData {
74    embedding: Vec<f32>,
75    index: usize,
76}
77
78#[derive(Debug, Deserialize)]
79struct MistralUsage {
80    prompt_tokens: usize,
81    total_tokens: usize,
82}
83
84/// Mistral error response
85#[derive(Debug, Deserialize)]
86struct MistralErrorResponse {
87    message: String,
88    #[serde(rename = "type")]
89    error_type: Option<String>,
90}
91
92/// Mistral Embedding Provider
93#[derive(Clone)]
94pub struct MistralEmbeddingProvider {
95    api_key: String,
96    model: String,
97    client: Client,
98}
99
100impl std::fmt::Debug for MistralEmbeddingProvider {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        f.debug_struct("MistralEmbeddingProvider")
103            .field("model", &self.model)
104            .finish()
105    }
106}
107
108impl MistralEmbeddingProvider {
109    /// Embedding dimension for mistral-embed
110    pub const DIMENSION: usize = 1024;
111
112    /// Create a new Mistral embedding provider
113    pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
114        if api_key.is_empty() {
115            bail!("Mistral API key cannot be empty");
116        }
117
118        let client = crate::http::blocking_client(REQUEST_TIMEOUT)
119            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
120
121        let model = model.unwrap_or(DEFAULT_MODEL).to_string();
122
123        Ok(Self {
124            api_key,
125            model,
126            client,
127        })
128    }
129
130    /// Create provider from environment variables
131    pub fn from_env() -> Result<Self> {
132        let api_key = std::env::var("MISTRAL_API_KEY")
133            .map_err(|_| anyhow!("MISTRAL_API_KEY environment variable not set"))?;
134
135        let model = std::env::var("MISTRAL_EMBEDDING_MODEL").ok();
136        Self::new(api_key, model.as_deref())
137    }
138
139    /// Get model name
140    pub fn model(&self) -> &str {
141        &self.model
142    }
143
144    /// Get provider kind
145    pub fn kind(&self) -> &'static str {
146        "mistral"
147    }
148
149    /// Get embedding dimension
150    pub fn dimension(&self) -> usize {
151        Self::DIMENSION
152    }
153
154    /// Embed a single text
155    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
156        let text = truncate_for_embedding(text);
157        self.embed_with_retry(&[&text], 3)
158            .map(|mut v| v.pop().unwrap_or_default())
159    }
160
161    /// Embed multiple texts in batch
162    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
163        if texts.is_empty() {
164            return Ok(Vec::new());
165        }
166
167        // Truncate all texts first
168        let truncated: Vec<std::borrow::Cow<'_, str>> =
169            texts.iter().map(|t| truncate_for_embedding(t)).collect();
170        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
171
172        let mut all_embeddings = Vec::with_capacity(texts.len());
173
174        // Process in batches
175        for chunk in truncated_refs.chunks(MAX_BATCH_SIZE) {
176            let embeddings = self.embed_with_retry(chunk, 3)?;
177            all_embeddings.extend(embeddings);
178        }
179
180        Ok(all_embeddings)
181    }
182
183    /// Embed texts with retry logic
184    fn embed_with_retry(&self, texts: &[&str], max_retries: usize) -> Result<Vec<Vec<f32>>> {
185        let request = MistralEmbeddingRequest {
186            model: &self.model,
187            input: texts.to_vec(),
188        };
189
190        let mut last_error = None;
191
192        for attempt in 0..max_retries {
193            let response = self
194                .client
195                .post(MISTRAL_EMBEDDINGS_URL)
196                .header("Authorization", format!("Bearer {}", self.api_key))
197                .header("Content-Type", "application/json")
198                .json(&request)
199                .send();
200
201            match response {
202                Ok(resp) => {
203                    let status = resp.status();
204                    let body = resp.text().unwrap_or_default();
205
206                    if status.is_success() {
207                        let embed_response: MistralEmbeddingResponse = serde_json::from_str(&body)
208                            .map_err(|e| anyhow!("Failed to parse Mistral response: {}", e))?;
209
210                        debug!(
211                            "Mistral embeddings: {} texts, {} tokens, model={}",
212                            texts.len(),
213                            embed_response.usage.total_tokens,
214                            embed_response.model
215                        );
216
217                        // Sort by index and extract embeddings
218                        let mut data = embed_response.data;
219                        data.sort_by_key(|d| d.index);
220
221                        let embeddings: Vec<Vec<f32>> =
222                            data.into_iter().map(|d| d.embedding).collect();
223
224                        // Validate dimensions
225                        if let Some(first) = embeddings.first() {
226                            if first.len() != Self::DIMENSION {
227                                warn!(
228                                    "Mistral returned dimension {} but expected {}",
229                                    first.len(),
230                                    Self::DIMENSION
231                                );
232                            }
233                        }
234
235                        return Ok(embeddings);
236                    }
237
238                    // Check for rate limiting
239                    if status.as_u16() == 429 {
240                        let backoff = Duration::from_millis(500 * (1 << attempt));
241                        warn!(
242                            "Rate limited by Mistral, retrying in {:?} (attempt {}/{})",
243                            backoff,
244                            attempt + 1,
245                            max_retries
246                        );
247                        std::thread::sleep(backoff);
248                        last_error = Some(anyhow!("Rate limited"));
249                        continue;
250                    }
251
252                    // Try to parse error response
253                    if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&body)
254                    {
255                        return Err(anyhow!(
256                            "Mistral API error: {}",
257                            error_response.message
258                        ));
259                    }
260
261                    return Err(anyhow!(
262                        "Mistral API request failed with status {}: {}",
263                        status,
264                        body
265                    ));
266                }
267                Err(e) => {
268                    if attempt < max_retries - 1 {
269                        let backoff = Duration::from_millis(500 * (1 << attempt));
270                        warn!(
271                            "Mistral request failed, retrying in {:?} (attempt {}/{}): {}",
272                            backoff,
273                            attempt + 1,
274                            max_retries,
275                            e
276                        );
277                        std::thread::sleep(backoff);
278                        last_error = Some(anyhow!("Request failed: {}", e));
279                        continue;
280                    }
281                    return Err(anyhow!("Mistral API request failed: {}", e));
282                }
283            }
284        }
285
286        Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
287    }
288}
289
290/// Helper to create a Mistral provider or return None
291pub fn try_mistral_provider() -> Option<MistralEmbeddingProvider> {
292    match MistralEmbeddingProvider::from_env() {
293        Ok(provider) => {
294            info!("Mistral embedding provider available");
295            Some(provider)
296        }
297        Err(e) => {
298            debug!("Mistral provider not available: {}", e);
299            None
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    #[test]
309    fn test_empty_api_key() {
310        let result = MistralEmbeddingProvider::new(String::new(), None);
311        assert!(result.is_err());
312    }
313
314    #[test]
315    fn test_dimension() {
316        let provider = MistralEmbeddingProvider::new("test-key".to_string(), None).unwrap();
317        assert_eq!(provider.dimension(), 1024);
318    }
319
320    #[test]
321    #[ignore] // Requires valid API key
322    fn test_real_embedding() {
323        let provider = MistralEmbeddingProvider::from_env().expect("MISTRAL_API_KEY must be set");
324        let embedding = provider.embed_text("Hello, world!").expect("embed");
325        assert!(!embedding.is_empty());
326        assert_eq!(embedding.len(), 1024);
327    }
328}