memvid_cli/
mistral_embeddings.rs1use anyhow::{anyhow, bail, Result};
16use reqwest::blocking::Client;
17use serde::{Deserialize, Serialize};
18use std::time::Duration;
19use tracing::{debug, info, warn};
20
21const MISTRAL_EMBEDDINGS_URL: &str = "https://api.mistral.ai/v1/embeddings";
23
24const DEFAULT_MODEL: &str = "mistral-embed";
26
27const MAX_BATCH_SIZE: usize = 100;
29
30const REQUEST_TIMEOUT: Duration = Duration::from_secs(60);
32
33const MAX_EMBEDDING_TEXT_LEN: usize = 20_000;
36
37fn truncate_for_embedding(text: &str) -> std::borrow::Cow<'_, str> {
39 if text.len() <= MAX_EMBEDDING_TEXT_LEN {
40 std::borrow::Cow::Borrowed(text)
41 } else {
42 let end = text[..MAX_EMBEDDING_TEXT_LEN]
43 .char_indices()
44 .rev()
45 .next()
46 .map(|(i, c)| i + c.len_utf8())
47 .unwrap_or(MAX_EMBEDDING_TEXT_LEN);
48 warn!(
49 "Truncating embedding text from {} to {} chars to avoid token limit",
50 text.len(),
51 end
52 );
53 std::borrow::Cow::Owned(text[..end].to_string())
54 }
55}
56
57#[derive(Debug, Serialize)]
59struct MistralEmbeddingRequest<'a> {
60 model: &'a str,
61 input: Vec<&'a str>,
62}
63
64#[derive(Debug, Deserialize)]
66struct MistralEmbeddingResponse {
67 data: Vec<MistralEmbeddingData>,
68 model: String,
69 usage: MistralUsage,
70}
71
72#[derive(Debug, Deserialize)]
73struct MistralEmbeddingData {
74 embedding: Vec<f32>,
75 index: usize,
76}
77
78#[derive(Debug, Deserialize)]
79#[allow(dead_code)]
80struct MistralUsage {
81 prompt_tokens: usize,
82 total_tokens: usize,
83}
84
85#[derive(Debug, Deserialize)]
87#[allow(dead_code)]
88struct MistralErrorResponse {
89 message: String,
90 #[serde(rename = "type")]
91 error_type: Option<String>,
92}
93
94#[derive(Clone)]
96pub struct MistralEmbeddingProvider {
97 api_key: String,
98 model: String,
99 client: Client,
100}
101
102impl std::fmt::Debug for MistralEmbeddingProvider {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("MistralEmbeddingProvider")
105 .field("model", &self.model)
106 .finish()
107 }
108}
109
110impl MistralEmbeddingProvider {
111 pub const DIMENSION: usize = 1024;
113
114 pub fn new(api_key: String, model: Option<&str>) -> Result<Self> {
116 if api_key.is_empty() {
117 bail!("Mistral API key cannot be empty");
118 }
119
120 let client = crate::http::blocking_client(REQUEST_TIMEOUT)
121 .map_err(|e| anyhow!("Failed to create HTTP client: {}", e))?;
122
123 let model = model.unwrap_or(DEFAULT_MODEL).to_string();
124
125 Ok(Self {
126 api_key,
127 model,
128 client,
129 })
130 }
131
132 pub fn from_env() -> Result<Self> {
134 let api_key = std::env::var("MISTRAL_API_KEY")
135 .map_err(|_| anyhow!("MISTRAL_API_KEY environment variable not set"))?;
136
137 let model = std::env::var("MISTRAL_EMBEDDING_MODEL").ok();
138 Self::new(api_key, model.as_deref())
139 }
140
141 pub fn model(&self) -> &str {
143 &self.model
144 }
145
146 pub fn kind(&self) -> &'static str {
148 "mistral"
149 }
150
151 pub fn dimension(&self) -> usize {
153 Self::DIMENSION
154 }
155
156 pub fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
158 let text = truncate_for_embedding(text);
159 self.embed_with_retry(&[&text], 3)
160 .map(|mut v| v.pop().unwrap_or_default())
161 }
162
163 pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
165 if texts.is_empty() {
166 return Ok(Vec::new());
167 }
168
169 let truncated: Vec<std::borrow::Cow<'_, str>> =
171 texts.iter().map(|t| truncate_for_embedding(t)).collect();
172 let truncated_refs: Vec<&str> = truncated.iter().map(|c| c.as_ref()).collect();
173
174 let mut all_embeddings = Vec::with_capacity(texts.len());
175
176 for chunk in truncated_refs.chunks(MAX_BATCH_SIZE) {
178 let embeddings = self.embed_with_retry(chunk, 3)?;
179 all_embeddings.extend(embeddings);
180 }
181
182 Ok(all_embeddings)
183 }
184
185 fn embed_with_retry(&self, texts: &[&str], max_retries: usize) -> Result<Vec<Vec<f32>>> {
187 let request = MistralEmbeddingRequest {
188 model: &self.model,
189 input: texts.to_vec(),
190 };
191
192 let mut last_error = None;
193
194 for attempt in 0..max_retries {
195 let response = self
196 .client
197 .post(MISTRAL_EMBEDDINGS_URL)
198 .header("Authorization", format!("Bearer {}", self.api_key))
199 .header("Content-Type", "application/json")
200 .json(&request)
201 .send();
202
203 match response {
204 Ok(resp) => {
205 let status = resp.status();
206 let body = resp.text().unwrap_or_default();
207
208 if status.is_success() {
209 let embed_response: MistralEmbeddingResponse = serde_json::from_str(&body)
210 .map_err(|e| anyhow!("Failed to parse Mistral response: {}", e))?;
211
212 debug!(
213 "Mistral embeddings: {} texts, {} tokens, model={}",
214 texts.len(),
215 embed_response.usage.total_tokens,
216 embed_response.model
217 );
218
219 let mut data = embed_response.data;
221 data.sort_by_key(|d| d.index);
222
223 let embeddings: Vec<Vec<f32>> =
224 data.into_iter().map(|d| d.embedding).collect();
225
226 if let Some(first) = embeddings.first() {
228 if first.len() != Self::DIMENSION {
229 warn!(
230 "Mistral returned dimension {} but expected {}",
231 first.len(),
232 Self::DIMENSION
233 );
234 }
235 }
236
237 return Ok(embeddings);
238 }
239
240 if status.as_u16() == 429 {
242 let backoff = Duration::from_millis(500 * (1 << attempt));
243 warn!(
244 "Rate limited by Mistral, retrying in {:?} (attempt {}/{})",
245 backoff,
246 attempt + 1,
247 max_retries
248 );
249 std::thread::sleep(backoff);
250 last_error = Some(anyhow!("Rate limited"));
251 continue;
252 }
253
254 if let Ok(error_response) = serde_json::from_str::<MistralErrorResponse>(&body)
256 {
257 return Err(anyhow!(
258 "Mistral API error: {}",
259 error_response.message
260 ));
261 }
262
263 return Err(anyhow!(
264 "Mistral API request failed with status {}: {}",
265 status,
266 body
267 ));
268 }
269 Err(e) => {
270 if attempt < max_retries - 1 {
271 let backoff = Duration::from_millis(500 * (1 << attempt));
272 warn!(
273 "Mistral request failed, retrying in {:?} (attempt {}/{}): {}",
274 backoff,
275 attempt + 1,
276 max_retries,
277 e
278 );
279 std::thread::sleep(backoff);
280 last_error = Some(anyhow!("Request failed: {}", e));
281 continue;
282 }
283 return Err(anyhow!("Mistral API request failed: {}", e));
284 }
285 }
286 }
287
288 Err(last_error.unwrap_or_else(|| anyhow!("Failed to embed after {} retries", max_retries)))
289 }
290}
291
292pub fn try_mistral_provider() -> Option<MistralEmbeddingProvider> {
294 match MistralEmbeddingProvider::from_env() {
295 Ok(provider) => {
296 info!("Mistral embedding provider available");
297 Some(provider)
298 }
299 Err(e) => {
300 debug!("Mistral provider not available: {}", e);
301 None
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_empty_api_key() {
312 let result = MistralEmbeddingProvider::new(String::new(), None);
313 assert!(result.is_err());
314 }
315
316 #[test]
317 fn test_dimension() {
318 let provider = MistralEmbeddingProvider::new("test-key".to_string(), None).unwrap();
319 assert_eq!(provider.dimension(), 1024);
320 }
321
322 #[test]
323 #[ignore] fn test_real_embedding() {
325 let provider = MistralEmbeddingProvider::from_env().expect("MISTRAL_API_KEY must be set");
326 let embedding = provider.embed_text("Hello, world!").expect("embed");
327 assert!(!embedding.is_empty());
328 assert_eq!(embedding.len(), 1024);
329 }
330}