Skip to main content

memvid_cli/
mistral_embeddings.rs

1//! Mistral AI Embeddings Provider
2//!
3//! This module provides an embedding implementation that uses
4//! Mistral AI's embeddings API for generating high-quality embeddings.
5//!
6//! ## Environment Variables
7//! - `MISTRAL_API_KEY`: Required API key for Mistral AI
8//! - `MISTRAL_EMBEDDING_MODEL`: Optional model override (default: mistral-embed)
9//!
10//! ## Features
11//! - Uses mistral-embed model (1024 dimensions)
12//! - Efficient batch processing
13//! - Thread-safe for concurrent use
14
15use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21/// Mistral embeddings API endpoint
22const MISTRAL_EMBEDDINGS_URL: &str = "https://api.mistral.ai/v1/embeddings";
23
24/// Default embedding model
25const DEFAULT_MODEL: &str = "mistral-embed";
26
27/// Maximum texts per batch (Mistral supports batching)
28const MAX_BATCH_SIZE: usize = 100;
29
30/// Request timeout
31const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33/// Maximum characters for embedding text to avoid token limits.
34/// mistral-embed has an 8192 token limit, using conservative estimate.
35const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
36
37/// Truncate text to MAX_EMBEDDING_TEXT_LEN to avoid token limit errors.
38fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
39    if text.len() <= MAX_EMBEDDING_TEXT_LEN {
40        std::borrow::Cow::Borrowed(text)
41    } else {
42        let end = text[..MAX_EMBEDDING_TEXT_LEN]
43            .char_indices()
44            .rev()
45            .next()
46            .map(|(i, c)| i + c.len_utf8())
47            .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
48        warn!(
49            "Truncating embedding text from {} to {} chars to avoid token limit",
50            text.len(),
51            end
52        );
53        std::borrow::Cow::Owned(text[..end].to_string())
54    }
55}
56
57/// Mistral embedding request payload
58#[derive(Debug, Serialize)]
59struct MistralEmbeddingRequest<'a> {
60    model: &'a str,
61    input: Vec<&'a str>,
62}
63
64/// Mistral embedding response
65#[derive(Debug, Deserialize)]
66struct MistralEmbeddingResponse {
67    data: Vec<MistralEmbeddingData>,
68    model: String,
69    usage: MistralUsage,
70}
71
72#[derive(Debug, Deserialize)]
73struct MistralEmbeddingData {
74    embedding: Vec<f32>,
75    index: usize,
76}
77
78#[derive(Debug, Deserialize)]
79#[allow(dead_code)]
80struct MistralUsage {
81    prompt_tokens: usize,
82    total_tokens: usize,
83}
84
85/// Mistral error response
86#[derive(Debug, Deserialize)]
87#[allow(dead_code)]
88struct MistralErrorResponse {
89    message: String,
90    #[serde(rename = "type")]
91    error_type: Option<String>,
92}
93
94/// Mistral Embedding Provider
95#[derive(Clone)]
96pub struct MistralEmbeddingProvider {
97    api_key: String,
98    model: String,
99    client: Client,
100}
101
102impl std::fmt::Debug for MistralEmbeddingProvider {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("MistralEmbeddingProvider")
105            .field("model", &self.model)
106            .finish()
107    }
108}
109
110impl MistralEmbeddingProvider {
111    /// Embedding dimension for mistral-embed
112    pub const DIMENSION: usize = 1024;
113
114    /// Create a new Mistral embedding provider
115    pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
116        if api_key.is_empty() {
117            bail!("Mistral API key cannot be empty");
118        }
119
120        let client = crate::http::blocking_client(REQUEST_TIMEOUT)
121            .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
122
123        let model = model.unwrap_or(DEFAULT_MODEL).to_string();
124
125        Ok(Self {
126            api_key,
127            model,
128            client,
129        })
130    }
131
132    /// Create provider from environment variables
133    pub fn from_env() -> Result<Self> {
134        let api_key = std::env::var("MISTRAL_API_KEY")
135            .map_err(|_| anyhow!("MISTRAL_API_KEY environment variable not set"))?;
136
137        let model = std::env::var("MISTRAL_EMBEDDING_MODEL").ok();
138        Self::new(api_key, model.as_deref())
139    }
140
141    /// Get model name
142    pub fn model(&self) -> &str {
143        &self.model
144    }
145
146    /// Get provider kind
147    pub fn kind(&self) -> &'static str {
148        "mistral"
149    }
150
151    /// Get embedding dimension
152    pub fn dimension(&self) -> usize {
153        Self::DIMENSION
154    }
155
156    /// Embed a single text
157    pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
158        let text = truncate_for_embedding(text);
159        self.embed_with_retry(&[&text], 3)
160            .map(|mut v| v.pop().unwrap_or_default())
161    }
162
163    /// Embed multiple texts in batch
164    pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
165        if texts.is_empty() {
166            return Ok(Vec::new());
167        }
168
169        // Truncate all texts first
170        let truncated: Vec<std::borrow::Cow<'_, str>> =
171            texts.iter().map(|t| truncate_for_embedding(t)).collect();
172        let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
173
174        let mut all_embeddings = Vec::with_capacity(texts.len());
175
176        // Process in batches
177        for chunk in truncated_refs.chunks(MAX_BATCH_SIZE) {
178            let embeddings = self.embed_with_retry(chunk, 3)?;
179            all_embeddings.extend(embeddings);
180        }
181
182        Ok(all_embeddings)
183    }
184
185    /// Embed texts with retry logic
186    fn embed_with_retry(&self, texts: &[&str], max_retries: usize) -> Result<Vec<Vec<f32>>> {
187        let request = MistralEmbeddingRequest {
188            model: &self.model,
189            input: texts.to_vec(),
190        };
191
192        let mut last_error = None;
193
194        for attempt in 0..max_retries {
195            let response = self
196                .client
197                .post(MISTRAL_EMBEDDINGS_URL)
198                .header("Authorization", format!("Bearer {}", self.api_key))
199                .header("Content-Type", "application/json")
200                .json(&request)
201                .send();
202
203            match response {
204                Ok(resp) => {
205                    let status = resp.status();
206                    let body = resp.text().unwrap_or_default();
207
208                    if status.is_success() {
209                        let embed_response: MistralEmbeddingResponse = serde_json::from_str(&body)
210                            .map_err(|e| anyhow!("Failed to parse Mistral response: {}", e))?;
211
212                        debug!(
213                            "Mistral embeddings: {} texts, {} tokens, model={}",
214                            texts.len(),
215                            embed_response.usage.total_tokens,
216                            embed_response.model
217                        );
218
219                        // Sort by index and extract embeddings
220                        let mut data = embed_response.data;
221                        data.sort_by_key(|d| d.index);
222
223                        let embeddings: Vec<Vec<f32>> =
224                            data.into_iter().map(|d| d.embedding).collect();
225
226                        // Validate dimensions
227                        if let Some(first) = embeddings.first() {
228                            if first.len() != Self::DIMENSION {
229                                warn!(
230                                    "Mistral returned dimension {} but expected {}",
231                                    first.len(),
232                                    Self::DIMENSION
233                                );
234                            }
235                        }
236
237                        return Ok(embeddings);
238                    }
239
240                    // Check for rate limiting
241                    if status.as_u16() == 429 {
242                        let backoff = Duration::from_millis(500 * (1 << attempt));
243                        warn!(
244                            "Rate limited by Mistral, retrying in {:?} (attempt {}/{})",
245                            backoff,
246                            attempt + 1,
247                            max_retries
248                        );
249                        std::thread::sleep(backoff);
250                        last_error = Some(anyhow!("Rate limited"));
251                        continue;
252                    }
253
254                    // Try to parse error response
255                    if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&body)
256                    {
257                        return Err(anyhow!("Mistral API error: {}", error_response.message));
258                    }
259
260                    return Err(anyhow!(
261                        "Mistral API request failed with status {}: {}",
262                        status,
263                        body
264                    ));
265                }
266                Err(e) => {
267                    if attempt < max_retries - 1 {
268                        let backoff = Duration::from_millis(500 * (1 << attempt));
269                        warn!(
270                            "Mistral request failed, retrying in {:?} (attempt {}/{}): {}",
271                            backoff,
272                            attempt + 1,
273                            max_retries,
274                            e
275                        );
276                        std::thread::sleep(backoff);
277                        last_error = Some(anyhow!("Request failed: {}", e));
278                        continue;
279                    }
280                    return Err(anyhow!("Mistral API request failed: {}", e));
281                }
282            }
283        }
284
285        Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
286    }
287}
288
289/// Helper to create a Mistral provider or return None
290pub fn try_mistral_provider() -> Option<MistralEmbeddingProvider> {
291    match MistralEmbeddingProvider::from_env() {
292        Ok(provider) => {
293            info!("Mistral embedding provider available");
294            Some(provider)
295        }
296        Err(e) => {
297            debug!("Mistral provider not available: {}", e);
298            None
299        }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_empty_api_key() {
309        let result = MistralEmbeddingProvider::new(String::new(), None);
310        assert!(result.is_err());
311    }
312
313    #[test]
314    fn test_dimension() {
315        let provider = MistralEmbeddingProvider::new("test-key".to_string(), None).unwrap();
316        assert_eq!(provider.dimension(), 1024);
317    }
318
319    #[test]
320    #[ignore] // Requires valid API key
321    fn test_real_embedding() {
322        let provider = MistralEmbeddingProvider::from_env().expect("MISTRAL_API_KEY must be set");
323        let embedding = provider.embed_text("Hello, world!").expect("embed");
324        assert!(!embedding.is_empty());
325        assert_eq!(embedding.len(), 1024);
326    }
327}