noether_engine/llm/
mistral.rs1use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
12use crate::llm::{LlmConfig, LlmError, LlmProvider, Message, Role};
13use serde_json::{json, Value};
14
15const MISTRAL_API_BASE: &str = "https://api.mistral.ai/v1";
16
17pub struct MistralNativeProvider {
31 api_key: String,
32 client: reqwest::blocking::Client,
33}
34
35impl MistralNativeProvider {
36 pub fn new(api_key: impl Into<String>) -> Self {
37 let client = reqwest::blocking::Client::builder()
38 .timeout(std::time::Duration::from_secs(120))
39 .connect_timeout(std::time::Duration::from_secs(15))
40 .build()
41 .expect("failed to build reqwest client");
42 Self {
43 api_key: api_key.into(),
44 client,
45 }
46 }
47
48 pub fn from_env() -> Result<Self, String> {
50 let key = std::env::var("MISTRAL_API_KEY")
51 .map_err(|_| "MISTRAL_API_KEY is not set".to_string())?;
52 Ok(Self::new(key))
53 }
54}
55
56impl LlmProvider for MistralNativeProvider {
57 fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
58 let url = format!("{MISTRAL_API_BASE}/chat/completions");
59
60 let model = std::env::var("MISTRAL_MODEL")
63 .or_else(|_| std::env::var("VERTEX_AI_MODEL"))
64 .unwrap_or_else(|_| config.model.clone());
65
66 let model = model
68 .strip_prefix("mistralai/")
69 .map(|s| s.to_string())
70 .unwrap_or(model);
71
72 let msgs: Vec<Value> = messages
73 .iter()
74 .map(|m| {
75 let role = match m.role {
76 Role::System => "system",
77 Role::User => "user",
78 Role::Assistant => "assistant",
79 };
80 json!({"role": role, "content": m.content})
81 })
82 .collect();
83
84 let body = json!({
85 "model": model,
86 "messages": msgs,
87 "max_tokens": config.max_tokens,
88 "temperature": config.temperature,
89 "stream": false,
90 });
91
92 let resp = self
93 .client
94 .post(&url)
95 .bearer_auth(&self.api_key)
96 .json(&body)
97 .send()
98 .map_err(|e| LlmError::Http(e.to_string()))?;
99
100 let status = resp.status();
101 let text = resp.text().map_err(|e| LlmError::Http(e.to_string()))?;
102
103 if !status.is_success() {
104 return Err(LlmError::Provider(format!(
105 "Mistral API HTTP {status}: {text}"
106 )));
107 }
108
109 let json: Value =
110 serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
111
112 json["choices"][0]["message"]["content"]
113 .as_str()
114 .map(|s| s.to_string())
115 .ok_or_else(|| LlmError::Parse(format!("unexpected Mistral response shape: {json}")))
116 }
117}
118
119pub struct MistralNativeEmbeddingProvider {
127 api_key: String,
128 model: String,
129 client: reqwest::blocking::Client,
130}
131
132impl MistralNativeEmbeddingProvider {
133 pub fn new(api_key: impl Into<String>) -> Self {
134 let client = reqwest::blocking::Client::builder()
135 .timeout(std::time::Duration::from_secs(30))
136 .connect_timeout(std::time::Duration::from_secs(15))
137 .build()
138 .expect("failed to build reqwest client");
139 Self {
140 api_key: api_key.into(),
141 model: std::env::var("MISTRAL_EMBEDDING_MODEL")
142 .unwrap_or_else(|_| "mistral-embed".into()),
143 client,
144 }
145 }
146
147 pub fn from_env() -> Result<Self, String> {
149 let key = std::env::var("MISTRAL_API_KEY")
150 .map_err(|_| "MISTRAL_API_KEY is not set".to_string())?;
151 Ok(Self::new(key))
152 }
153}
154
155impl EmbeddingProvider for MistralNativeEmbeddingProvider {
156 fn dimensions(&self) -> usize {
157 1024 }
159
160 fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
161 let mut batch = self.embed_batch(&[text])?;
163 batch
164 .pop()
165 .ok_or_else(|| EmbeddingError::Provider("empty response".into()))
166 }
167
168 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
170 if texts.is_empty() {
171 return Ok(Vec::new());
172 }
173
174 let url = format!("{MISTRAL_API_BASE}/embeddings");
175 let body = json!({
176 "model": self.model,
177 "input": texts,
178 "encoding_format": "float",
179 });
180
181 let resp = self
182 .client
183 .post(&url)
184 .bearer_auth(&self.api_key)
185 .json(&body)
186 .send()
187 .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
188
189 let status = resp.status();
190 let text = resp
191 .text()
192 .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
193
194 if !status.is_success() {
195 return Err(EmbeddingError::Provider(format!(
196 "Mistral embeddings HTTP {status}: {text}"
197 )));
198 }
199
200 let json: Value =
201 serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
202
203 let mut items: Vec<(usize, Embedding)> = json["data"]
206 .as_array()
207 .ok_or_else(|| EmbeddingError::Provider("missing 'data' field".into()))?
208 .iter()
209 .map(|item| {
210 let index = item["index"].as_u64().unwrap_or(0) as usize;
211 let vec: Embedding = item["embedding"]
212 .as_array()
213 .unwrap_or(&vec![])
214 .iter()
215 .filter_map(|v| v.as_f64().map(|f| f as f32))
216 .collect();
217 (index, vec)
218 })
219 .collect();
220
221 items.sort_by_key(|(idx, _)| *idx);
222 Ok(items.into_iter().map(|(_, v)| v).collect())
223 }
224}
225
226#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn from_env_errors_without_key() {
234 let saved = std::env::var("MISTRAL_API_KEY").ok();
236 std::env::remove_var("MISTRAL_API_KEY");
237 assert!(MistralNativeProvider::from_env().is_err());
238 assert!(MistralNativeEmbeddingProvider::from_env().is_err());
239 if let Some(k) = saved {
240 std::env::set_var("MISTRAL_API_KEY", k);
241 }
242 }
243
244 #[test]
245 fn strips_vendor_prefix() {
246 let model = "mistralai/mistral-small-latest".to_string();
248 let normalised = model
249 .strip_prefix("mistralai/")
250 .map(|s| s.to_string())
251 .unwrap_or(model);
252 assert_eq!(normalised, "mistral-small-latest");
253 }
254}