Skip to main content

llmsdk_mistral/
config.rs

1//! Provider configuration and entry point.
2//!
3//! Mirrors `@ai-sdk/mistral/src/mistral-provider.ts`.
4// Rust guideline compliant 2026-05-25
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use llmsdk_provider::ProviderError;
10use llmsdk_provider_utils::http::HttpClient;
11
12use crate::chat::MistralChatModel;
13use crate::embedding::MistralEmbeddingModel;
14use crate::{API_KEY_ENV_VAR, DEFAULT_BASE_URL};
15
16/// Mistral provider handle — entry point for model construction.
17///
18/// Cheap to clone; the underlying HTTP client and headers are shared.
19#[derive(Debug, Clone)]
20pub struct Mistral {
21    inner: Arc<Inner>,
22}
23
24/// User-supplied id generator for streaming reasoning blocks.
25///
26/// Mirrors `MistralProviderSettings.generateId` in upstream
27/// `mistral-provider.ts:77`. When set, the chat model invokes the callback
28/// each time it needs an identifier for a new reasoning block; otherwise it
29/// falls back to a deterministic in-process counter.
30pub type GenerateIdFn = dyn Fn() -> String + Send + Sync;
31
32pub(crate) struct Inner {
33    pub(crate) base_url: String,
34    pub(crate) headers: HashMap<String, Option<String>>,
35    pub(crate) http: HttpClient,
36    pub(crate) generate_id: Option<Arc<GenerateIdFn>>,
37}
38
39impl std::fmt::Debug for Inner {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        f.debug_struct("Inner")
42            .field("base_url", &self.base_url)
43            .field("headers", &self.headers)
44            .field("http", &self.http)
45            .field("generate_id", &self.generate_id.is_some())
46            .finish()
47    }
48}
49
50impl Mistral {
51    /// Open a [`MistralBuilder`].
52    #[must_use]
53    pub fn builder() -> MistralBuilder {
54        MistralBuilder::default()
55    }
56
57    /// Build with defaults: API key from `MISTRAL_API_KEY`, default base URL.
58    ///
59    /// # Errors
60    ///
61    /// Returns [`ProviderError::load_api_key`] when the env var is unset
62    /// or empty.
63    pub fn from_env() -> Result<Self, ProviderError> {
64        Self::builder().build()
65    }
66
67    /// Construct a Chat Completions model handle.
68    ///
69    /// `model_id` is the Mistral model name, e.g. `"mistral-large-latest"`.
70    #[must_use]
71    pub fn chat(&self, model_id: impl Into<String>) -> MistralChatModel {
72        MistralChatModel::new(Arc::clone(&self.inner), model_id.into())
73    }
74
75    /// Alias of [`Self::chat`] — mirrors ai-sdk's `provider.languageModel(id)`.
76    #[must_use]
77    pub fn language_model(&self, model_id: impl Into<String>) -> MistralChatModel {
78        self.chat(model_id)
79    }
80
81    /// Construct a text-embedding model handle.
82    ///
83    /// `model_id` is the Mistral embedding model name, e.g.
84    /// `"mistral-embed"` or `"codestral-embed"`.
85    #[must_use]
86    pub fn embedding(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
87        MistralEmbeddingModel::new(Arc::clone(&self.inner), model_id.into())
88    }
89
90    /// Alias of [`Self::embedding`] — mirrors ai-sdk's `provider.embeddingModel(id)`.
91    #[must_use]
92    pub fn embedding_model(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
93        self.embedding(model_id)
94    }
95
96    /// Deprecated alias of [`Self::embedding`] retained for ai-sdk parity.
97    #[must_use]
98    pub fn text_embedding(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
99        self.embedding(model_id)
100    }
101
102    /// Deprecated alias of [`Self::embedding_model`] retained for ai-sdk parity.
103    #[must_use]
104    pub fn text_embedding_model(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
105        self.embedding(model_id)
106    }
107}
108
109/// Builder for [`Mistral`].
110///
111/// All setters are optional; `build()` falls back to env / library defaults.
112#[derive(Default, Clone)]
113pub struct MistralBuilder {
114    api_key: Option<String>,
115    base_url: Option<String>,
116    extra_headers: HashMap<String, Option<String>>,
117    http: Option<HttpClient>,
118    generate_id: Option<Arc<GenerateIdFn>>,
119}
120
121impl std::fmt::Debug for MistralBuilder {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        f.debug_struct("MistralBuilder")
124            .field("api_key", &self.api_key.as_ref().map(|_| "***"))
125            .field("base_url", &self.base_url)
126            .field("extra_headers", &self.extra_headers)
127            .field("http", &self.http.is_some())
128            .field("generate_id", &self.generate_id.is_some())
129            .finish()
130    }
131}
132
133impl MistralBuilder {
134    /// Set the API key explicitly.
135    #[must_use]
136    pub fn api_key(mut self, key: impl Into<String>) -> Self {
137        self.api_key = Some(key.into());
138        self
139    }
140
141    /// Override the base URL (e.g. for a local proxy).
142    #[must_use]
143    pub fn base_url(mut self, url: impl Into<String>) -> Self {
144        self.base_url = Some(url.into());
145        self
146    }
147
148    /// Append or override a header.
149    ///
150    /// Passing `None` for `value` removes the header.
151    #[must_use]
152    pub fn header(mut self, name: impl Into<String>, value: Option<String>) -> Self {
153        self.extra_headers.insert(name.into(), value);
154        self
155    }
156
157    /// Inject a pre-configured HTTP client.
158    #[must_use]
159    pub fn http_client(mut self, client: HttpClient) -> Self {
160        self.http = Some(client);
161        self
162    }
163
164    /// Override the id generator used for streaming reasoning blocks.
165    ///
166    /// Mirrors `config.generateId` on the upstream `MistralChatLanguageModel`
167    /// (`mistral-provider.ts:77`). When unset, each new reasoning block
168    /// receives a deterministic `reasoning-N` id from an internal counter,
169    /// which is fine for tests and offline replay but does not collide-proof
170    /// against ids issued by other sessions or downstream consumers.
171    #[must_use]
172    pub fn generate_id<F>(mut self, f: F) -> Self
173    where
174        F: Fn() -> String + Send + Sync + 'static,
175    {
176        self.generate_id = Some(Arc::new(f));
177        self
178    }
179
180    /// Finalize the provider.
181    ///
182    /// # Errors
183    ///
184    /// - [`ProviderError::load_api_key`] when no explicit key is given and
185    ///   `MISTRAL_API_KEY` is unset / empty.
186    /// - [`ProviderError`] from [`HttpClient::new`] if the TLS stack fails
187    ///   to initialize (rare).
188    pub fn build(self) -> Result<Mistral, ProviderError> {
189        let api_key = llmsdk_provider_utils::api_key::load_api_key(
190            &llmsdk_provider_utils::api_key::LoadApiKey {
191                api_key: self.api_key.as_deref(),
192                env_var: API_KEY_ENV_VAR,
193                description: "Mistral",
194                parameter_name: Some("api_key"),
195            },
196        )?;
197
198        let base_url = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned());
199
200        let mut headers = self.extra_headers;
201        headers.insert("authorization".into(), Some(format!("Bearer {api_key}")));
202
203        let http = match self.http {
204            Some(client) => client,
205            None => HttpClient::new()?,
206        };
207
208        Ok(Mistral {
209            inner: Arc::new(Inner {
210                base_url,
211                headers,
212                http,
213                generate_id: self.generate_id,
214            }),
215        })
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn builder_with_explicit_key_succeeds() {
225        let m = Mistral::builder().api_key("test-key").build().expect("ok");
226        assert_eq!(m.inner.base_url, DEFAULT_BASE_URL);
227        assert!(
228            m.inner
229                .headers
230                .get("authorization")
231                .unwrap()
232                .as_ref()
233                .unwrap()
234                .starts_with("Bearer ")
235        );
236    }
237
238    #[test]
239    fn builder_custom_base_url() {
240        let m = Mistral::builder()
241            .api_key("k")
242            .base_url("https://proxy.example.com/v1")
243            .build()
244            .expect("ok");
245        assert_eq!(m.inner.base_url, "https://proxy.example.com/v1");
246    }
247
248    #[test]
249    fn builder_generate_id_is_stored() {
250        // Mirrors upstream `mistral-provider.ts` accepting `generateId?: () => string`
251        // and forwarding it onto the chat model config (lines 77, 108).
252        let m = Mistral::builder()
253            .api_key("k")
254            .generate_id(|| "custom-id".to_owned())
255            .build()
256            .expect("ok");
257        let gen_fn = m.inner.generate_id.as_ref().expect("generate_id stored");
258        assert_eq!(gen_fn(), "custom-id");
259    }
260
261    #[test]
262    fn builder_custom_header() {
263        let m = Mistral::builder()
264            .api_key("k")
265            .header("x-feature", Some("y".into()))
266            .build()
267            .expect("ok");
268        assert_eq!(
269            m.inner.headers.get("x-feature").unwrap().as_deref(),
270            Some("y")
271        );
272    }
273}