Skip to main content

noether_engine/llm/
mistral.rs

1//! Mistral AI native API provider.
2//!
3//! Calls `api.mistral.ai` directly — no Google Cloud required.
4//! Auth: `MISTRAL_API_KEY` environment variable.
5//!
6//! This is the preferred provider for the European deployment stack:
7//! - Mistral AI is headquartered in Paris.
8//! - Data stays within the EU (Mistral's infrastructure is EU-based).
9//! - No dependency on any US cloud provider.
10
11use crate::index::embedding::{Embedding, EmbeddingError, EmbeddingProvider};
12use crate::llm::{LlmConfig, LlmError, LlmProvider, Message, Role};
13use serde_json::{json, Value};
14
15const MISTRAL_API_BASE: &str = "https://api.mistral.ai/v1";
16
17// ── LLM provider ────────────────────────────────────────────────────────────
18
19/// Calls `api.mistral.ai/v1/chat/completions` with an API key.
20///
21/// Supports all Mistral chat models:
22/// - `mistral-small-latest` — fastest, cheapest  (€0.10/1M tokens in)
23/// - `mistral-medium-latest` — balanced
24/// - `mistral-large-latest` — most capable
25/// - `codestral-latest` — code specialist
26///
27/// Set `MISTRAL_API_KEY` to your API key from console.mistral.ai.
28/// Override model with `VERTEX_AI_MODEL` (name reused for compatibility) or
29/// the new `MISTRAL_MODEL` env var.
30pub struct MistralNativeProvider {
31    api_key: String,
32    client: reqwest::blocking::Client,
33}
34
35impl MistralNativeProvider {
36    pub fn new(api_key: impl Into<String>) -> Self {
37        let client = reqwest::blocking::Client::builder()
38            .timeout(std::time::Duration::from_secs(120))
39            .connect_timeout(std::time::Duration::from_secs(15))
40            .build()
41            .expect("failed to build reqwest client");
42        Self {
43            api_key: api_key.into(),
44            client,
45        }
46    }
47
48    /// Construct from environment. Returns `Err` if `MISTRAL_API_KEY` is not set.
49    pub fn from_env() -> Result<Self, String> {
50        let key = std::env::var("MISTRAL_API_KEY")
51            .map_err(|_| "MISTRAL_API_KEY is not set".to_string())?;
52        Ok(Self::new(key))
53    }
54}
55
56impl LlmProvider for MistralNativeProvider {
57    fn complete(&self, messages: &[Message], config: &LlmConfig) -> Result<String, LlmError> {
58        let url = format!("{MISTRAL_API_BASE}/chat/completions");
59
60        // Model resolution: prefer MISTRAL_MODEL, fall back to VERTEX_AI_MODEL (compat),
61        // then the LlmConfig model value.
62        let model = std::env::var("MISTRAL_MODEL")
63            .or_else(|_| std::env::var("VERTEX_AI_MODEL"))
64            .unwrap_or_else(|_| config.model.clone());
65
66        // Normalise model name: strip vendor prefix if present (e.g. "mistralai/mistral-small")
67        let model = model
68            .strip_prefix("mistralai/")
69            .map(|s| s.to_string())
70            .unwrap_or(model);
71
72        let msgs: Vec<Value> = messages
73            .iter()
74            .map(|m| {
75                let role = match m.role {
76                    Role::System => "system",
77                    Role::User => "user",
78                    Role::Assistant => "assistant",
79                };
80                json!({"role": role, "content": m.content})
81            })
82            .collect();
83
84        let body = json!({
85            "model": model,
86            "messages": msgs,
87            "max_tokens": config.max_tokens,
88            "temperature": config.temperature,
89            "stream": false,
90        });
91
92        let resp = self
93            .client
94            .post(&url)
95            .bearer_auth(&self.api_key)
96            .json(&body)
97            .send()
98            .map_err(|e| LlmError::Http(e.to_string()))?;
99
100        let status = resp.status();
101        let text = resp.text().map_err(|e| LlmError::Http(e.to_string()))?;
102
103        if !status.is_success() {
104            return Err(LlmError::Provider(format!(
105                "Mistral API HTTP {status}: {text}"
106            )));
107        }
108
109        let json: Value =
110            serde_json::from_str(&text).map_err(|e| LlmError::Parse(e.to_string()))?;
111
112        json["choices"][0]["message"]["content"]
113            .as_str()
114            .map(|s| s.to_string())
115            .ok_or_else(|| LlmError::Parse(format!("unexpected Mistral response shape: {json}")))
116    }
117}
118
119// ── Embedding provider ───────────────────────────────────────────────────────
120
121/// Calls `api.mistral.ai/v1/embeddings` using the `mistral-embed` model.
122///
123/// - Dimension: 1024
124/// - Context window: 8192 tokens
125/// - EU-hosted, GDPR-compliant
126pub struct MistralNativeEmbeddingProvider {
127    api_key: String,
128    model: String,
129    client: reqwest::blocking::Client,
130}
131
132impl MistralNativeEmbeddingProvider {
133    pub fn new(api_key: impl Into<String>) -> Self {
134        let client = reqwest::blocking::Client::builder()
135            .timeout(std::time::Duration::from_secs(30))
136            .connect_timeout(std::time::Duration::from_secs(15))
137            .build()
138            .expect("failed to build reqwest client");
139        Self {
140            api_key: api_key.into(),
141            model: std::env::var("MISTRAL_EMBEDDING_MODEL")
142                .unwrap_or_else(|_| "mistral-embed".into()),
143            client,
144        }
145    }
146
147    /// Construct from environment. Returns `Err` if `MISTRAL_API_KEY` is not set.
148    pub fn from_env() -> Result<Self, String> {
149        let key = std::env::var("MISTRAL_API_KEY")
150            .map_err(|_| "MISTRAL_API_KEY is not set".to_string())?;
151        Ok(Self::new(key))
152    }
153}
154
155impl EmbeddingProvider for MistralNativeEmbeddingProvider {
156    fn dimensions(&self) -> usize {
157        1024 // mistral-embed fixed dimension
158    }
159
160    fn embed(&self, text: &str) -> Result<Embedding, EmbeddingError> {
161        // Delegate to batch — avoids duplicating the HTTP/parse logic.
162        let mut batch = self.embed_batch(&[text])?;
163        batch
164            .pop()
165            .ok_or_else(|| EmbeddingError::Provider("empty response".into()))
166    }
167
168    /// Override the default batch implementation to call the API once for all texts.
169    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>, EmbeddingError> {
170        if texts.is_empty() {
171            return Ok(Vec::new());
172        }
173
174        let url = format!("{MISTRAL_API_BASE}/embeddings");
175        let body = json!({
176            "model": self.model,
177            "input": texts,
178            "encoding_format": "float",
179        });
180
181        let resp = self
182            .client
183            .post(&url)
184            .bearer_auth(&self.api_key)
185            .json(&body)
186            .send()
187            .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
188
189        let status = resp.status();
190        let text = resp
191            .text()
192            .map_err(|e| EmbeddingError::Provider(e.to_string()))?;
193
194        if !status.is_success() {
195            return Err(EmbeddingError::Provider(format!(
196                "Mistral embeddings HTTP {status}: {text}"
197            )));
198        }
199
200        let json: Value =
201            serde_json::from_str(&text).map_err(|e| EmbeddingError::Provider(e.to_string()))?;
202
203        // Response: { "data": [{ "embedding": [...], "index": 0 }, ...] }
204        // Sort by index to preserve input order.
205        let mut items: Vec<(usize, Embedding)> = json["data"]
206            .as_array()
207            .ok_or_else(|| EmbeddingError::Provider("missing 'data' field".into()))?
208            .iter()
209            .map(|item| {
210                let index = item["index"].as_u64().unwrap_or(0) as usize;
211                let vec: Embedding = item["embedding"]
212                    .as_array()
213                    .unwrap_or(&vec![])
214                    .iter()
215                    .filter_map(|v| v.as_f64().map(|f| f as f32))
216                    .collect();
217                (index, vec)
218            })
219            .collect();
220
221        items.sort_by_key(|(idx, _)| *idx);
222        Ok(items.into_iter().map(|(_, v)| v).collect())
223    }
224}
225
226// ── Tests ────────────────────────────────────────────────────────────────────
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn from_env_errors_without_key() {
234        // Temporarily unset the key (save + restore to avoid side effects in parallel tests).
235        let saved = std::env::var("MISTRAL_API_KEY").ok();
236        std::env::remove_var("MISTRAL_API_KEY");
237        assert!(MistralNativeProvider::from_env().is_err());
238        assert!(MistralNativeEmbeddingProvider::from_env().is_err());
239        if let Some(k) = saved {
240            std::env::set_var("MISTRAL_API_KEY", k);
241        }
242    }
243
244    #[test]
245    fn strips_vendor_prefix() {
246        // The model name normalisation is purely internal; verify it via a manual check.
247        let model = "mistralai/mistral-small-latest".to_string();
248        let normalised = model
249            .strip_prefix("mistralai/")
250            .map(|s| s.to_string())
251            .unwrap_or(model);
252        assert_eq!(normalised, "mistral-small-latest");
253    }
254}