memvid_cli/
mistral_embeddings.rs1use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21const MISTRAL_EMBEDDINGS_URL: &str = "https://api.mistral.ai/v1/embeddings";
23
24const DEFAULT_MODEL: &str = "mistral-embed";
26
27const MAX_BATCH_SIZE: usize = 100;
29
30const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
36
37fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
39 if text.len() <= MAX_EMBEDDING_TEXT_LEN {
40 std::borrow::Cow::Borrowed(text)
41 } else {
42 let end = text[..MAX_EMBEDDING_TEXT_LEN]
43 .char_indices()
44 .rev()
45 .next()
46 .map(|(i, c)| i + c.len_utf8())
47 .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
48 warn!(
49 "Truncating embedding text from {} to {} chars to avoid token limit",
50 text.len(),
51 end
52 );
53 std::borrow::Cow::Owned(text[..end].to_string())
54 }
55}
56
57#[derive(Debug, Serialize)]
59struct MistralEmbeddingRequest<'a> {
60 model: &'a str,
61 input: Vec<&'a str>,
62}
63
64#[derive(Debug, Deserialize)]
66struct MistralEmbeddingResponse {
67 data: Vec<MistralEmbeddingData>,
68 model: String,
69 usage: MistralUsage,
70}
71
72#[derive(Debug, Deserialize)]
73struct MistralEmbeddingData {
74 embedding: Vec<f32>,
75 index: usize,
76}
77
78#[derive(Debug, Deserialize)]
79struct MistralUsage {
80 prompt_tokens: usize,
81 total_tokens: usize,
82}
83
84#[derive(Debug, Deserialize)]
86struct MistralErrorResponse {
87 message: String,
88 #[serde(rename = "type")]
89 error_type: Option<String>,
90}
91
92#[derive(Clone)]
94pub struct MistralEmbeddingProvider {
95 api_key: String,
96 model: String,
97 client: Client,
98}
99
100impl std::fmt::Debug for MistralEmbeddingProvider {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("MistralEmbeddingProvider")
103 .field("model", &self.model)
104 .finish()
105 }
106}
107
108impl MistralEmbeddingProvider {
109 pub const DIMENSION: usize = 1024;
111
112 pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
114 if api_key.is_empty() {
115 bail!("Mistral API key cannot be empty");
116 }
117
118 let client = crate::http::blocking_client(REQUEST_TIMEOUT)
119 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
120
121 let model = model.unwrap_or(DEFAULT_MODEL).to_string();
122
123 Ok(Self {
124 api_key,
125 model,
126 client,
127 })
128 }
129
130 pub fn from_env() -> Result<Self> {
132 let api_key = std::env::var("MISTRAL_API_KEY")
133 .map_err(|_| anyhow!("MISTRAL_API_KEY environment variable not set"))?;
134
135 let model = std::env::var("MISTRAL_EMBEDDING_MODEL").ok();
136 Self::new(api_key, model.as_deref())
137 }
138
139 pub fn model(&self) -> &str {
141 &self.model
142 }
143
144 pub fn kind(&self) -> &'static str {
146 "mistral"
147 }
148
149 pub fn dimension(&self) -> usize {
151 Self::DIMENSION
152 }
153
154 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
156 let text = truncate_for_embedding(text);
157 self.embed_with_retry(&[&text], 3)
158 .map(|mut v| v.pop().unwrap_or_default())
159 }
160
161 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
163 if texts.is_empty() {
164 return Ok(Vec::new());
165 }
166
167 let truncated: Vec<std::borrow::Cow<'_, str>> =
169 texts.iter().map(|t| truncate_for_embedding(t)).collect();
170 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
171
172 let mut all_embeddings = Vec::with_capacity(texts.len());
173
174 for chunk in truncated_refs.chunks(MAX_BATCH_SIZE) {
176 let embeddings = self.embed_with_retry(chunk, 3)?;
177 all_embeddings.extend(embeddings);
178 }
179
180 Ok(all_embeddings)
181 }
182
183 fn embed_with_retry(&self, texts: &[&str], max_retries: usize) -> Result<Vec<Vec<f32>>> {
185 let request = MistralEmbeddingRequest {
186 model: &self.model,
187 input: texts.to_vec(),
188 };
189
190 let mut last_error = None;
191
192 for attempt in 0..max_retries {
193 let response = self
194 .client
195 .post(MISTRAL_EMBEDDINGS_URL)
196 .header("Authorization", format!("Bearer {}", self.api_key))
197 .header("Content-Type", "application/json")
198 .json(&request)
199 .send();
200
201 match response {
202 Ok(resp) => {
203 let status = resp.status();
204 let body = resp.text().unwrap_or_default();
205
206 if status.is_success() {
207 let embed_response: MistralEmbeddingResponse = serde_json::from_str(&body)
208 .map_err(|e| anyhow!("Failed to parse Mistral response: {}", e))?;
209
210 debug!(
211 "Mistral embeddings: {} texts, {} tokens, model={}",
212 texts.len(),
213 embed_response.usage.total_tokens,
214 embed_response.model
215 );
216
217 let mut data = embed_response.data;
219 data.sort_by_key(|d| d.index);
220
221 let embeddings: Vec<Vec<f32>> =
222 data.into_iter().map(|d| d.embedding).collect();
223
224 if let Some(first) = embeddings.first() {
226 if first.len() != Self::DIMENSION {
227 warn!(
228 "Mistral returned dimension {} but expected {}",
229 first.len(),
230 Self::DIMENSION
231 );
232 }
233 }
234
235 return Ok(embeddings);
236 }
237
238 if status.as_u16() == 429 {
240 let backoff = Duration::from_millis(500 * (1 << attempt));
241 warn!(
242 "Rate limited by Mistral, retrying in {:?} (attempt {}/{})",
243 backoff,
244 attempt + 1,
245 max_retries
246 );
247 std::thread::sleep(backoff);
248 last_error = Some(anyhow!("Rate limited"));
249 continue;
250 }
251
252 if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&body)
254 {
255 return Err(anyhow!(
256 "Mistral API error: {}",
257 error_response.message
258 ));
259 }
260
261 return Err(anyhow!(
262 "Mistral API request failed with status {}: {}",
263 status,
264 body
265 ));
266 }
267 Err(e) => {
268 if attempt < max_retries - 1 {
269 let backoff = Duration::from_millis(500 * (1 << attempt));
270 warn!(
271 "Mistral request failed, retrying in {:?} (attempt {}/{}): {}",
272 backoff,
273 attempt + 1,
274 max_retries,
275 e
276 );
277 std::thread::sleep(backoff);
278 last_error = Some(anyhow!("Request failed: {}", e));
279 continue;
280 }
281 return Err(anyhow!("Mistral API request failed: {}", e));
282 }
283 }
284 }
285
286 Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
287 }
288}
289
290pub fn try_mistral_provider() -> Option<MistralEmbeddingProvider> {
292 match MistralEmbeddingProvider::from_env() {
293 Ok(provider) => {
294 info!("Mistral embedding provider available");
295 Some(provider)
296 }
297 Err(e) => {
298 debug!("Mistral provider not available: {}", e);
299 None
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_empty_api_key() {
310 let result = MistralEmbeddingProvider::new(String::new(), None);
311 assert!(result.is_err());
312 }
313
314 #[test]
315 fn test_dimension() {
316 let provider = MistralEmbeddingProvider::new("test-key".to_string(), None).unwrap();
317 assert_eq!(provider.dimension(), 1024);
318 }
319
320 #[test]
321 #[ignore] fn test_real_embedding() {
323 let provider = MistralEmbeddingProvider::from_env().expect("MISTRAL_API_KEY must be set");
324 let embedding = provider.embed_text("Hello, world!").expect("embed");
325 assert!(!embedding.is_empty());
326 assert_eq!(embedding.len(), 1024);
327 }
328}