1use std::collections::HashMap;
7use std::sync::Arc;
8
9use llmsdk_provider::ProviderError;
10use llmsdk_provider_utils::http::HttpClient;
11
12use crate::chat::MistralChatModel;
13use crate::embedding::MistralEmbeddingModel;
14use crate::{API_KEY_ENV_VAR, DEFAULT_BASE_URL};
15
16#[derive(Debug, Clone)]
20pub struct Mistral {
21 inner: Arc<Inner>,
22}
23
24pub type GenerateIdFn = dyn Fn() -> String + Send + Sync;
31
32pub(crate) struct Inner {
33 pub(crate) base_url: String,
34 pub(crate) headers: HashMap<String, Option<String>>,
35 pub(crate) http: HttpClient,
36 pub(crate) generate_id: Option<Arc<GenerateIdFn>>,
37}
38
39impl std::fmt::Debug for Inner {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("Inner")
42 .field("base_url", &self.base_url)
43 .field("headers", &self.headers)
44 .field("http", &self.http)
45 .field("generate_id", &self.generate_id.is_some())
46 .finish()
47 }
48}
49
50impl Mistral {
51 #[must_use]
53 pub fn builder() -> MistralBuilder {
54 MistralBuilder::default()
55 }
56
57 pub fn from_env() -> Result<Self, ProviderError> {
64 Self::builder().build()
65 }
66
67 #[must_use]
71 pub fn chat(&self, model_id: impl Into<String>) -> MistralChatModel {
72 MistralChatModel::new(Arc::clone(&self.inner), model_id.into())
73 }
74
75 #[must_use]
77 pub fn language_model(&self, model_id: impl Into<String>) -> MistralChatModel {
78 self.chat(model_id)
79 }
80
81 #[must_use]
86 pub fn embedding(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
87 MistralEmbeddingModel::new(Arc::clone(&self.inner), model_id.into())
88 }
89
90 #[must_use]
92 pub fn embedding_model(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
93 self.embedding(model_id)
94 }
95
96 #[must_use]
98 pub fn text_embedding(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
99 self.embedding(model_id)
100 }
101
102 #[must_use]
104 pub fn text_embedding_model(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
105 self.embedding(model_id)
106 }
107}
108
109#[derive(Default, Clone)]
113pub struct MistralBuilder {
114 api_key: Option<String>,
115 base_url: Option<String>,
116 extra_headers: HashMap<String, Option<String>>,
117 http: Option<HttpClient>,
118 generate_id: Option<Arc<GenerateIdFn>>,
119}
120
121impl std::fmt::Debug for MistralBuilder {
122 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123 f.debug_struct("MistralBuilder")
124 .field("api_key", &self.api_key.as_ref().map(|_| "***"))
125 .field("base_url", &self.base_url)
126 .field("extra_headers", &self.extra_headers)
127 .field("http", &self.http.is_some())
128 .field("generate_id", &self.generate_id.is_some())
129 .finish()
130 }
131}
132
133impl MistralBuilder {
134 #[must_use]
136 pub fn api_key(mut self, key: impl Into<String>) -> Self {
137 self.api_key = Some(key.into());
138 self
139 }
140
141 #[must_use]
143 pub fn base_url(mut self, url: impl Into<String>) -> Self {
144 self.base_url = Some(url.into());
145 self
146 }
147
148 #[must_use]
152 pub fn header(mut self, name: impl Into<String>, value: Option<String>) -> Self {
153 self.extra_headers.insert(name.into(), value);
154 self
155 }
156
157 #[must_use]
159 pub fn http_client(mut self, client: HttpClient) -> Self {
160 self.http = Some(client);
161 self
162 }
163
164 #[must_use]
172 pub fn generate_id<F>(mut self, f: F) -> Self
173 where
174 F: Fn() -> String + Send + Sync + 'static,
175 {
176 self.generate_id = Some(Arc::new(f));
177 self
178 }
179
180 pub fn build(self) -> Result<Mistral, ProviderError> {
189 let api_key = llmsdk_provider_utils::api_key::load_api_key(
190 &llmsdk_provider_utils::api_key::LoadApiKey {
191 api_key: self.api_key.as_deref(),
192 env_var: API_KEY_ENV_VAR,
193 description: "Mistral",
194 parameter_name: Some("api_key"),
195 },
196 )?;
197
198 let base_url = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned());
199
200 let mut headers = self.extra_headers;
201 headers.insert("authorization".into(), Some(format!("Bearer {api_key}")));
202
203 let http = match self.http {
204 Some(client) => client,
205 None => HttpClient::new()?,
206 };
207
208 Ok(Mistral {
209 inner: Arc::new(Inner {
210 base_url,
211 headers,
212 http,
213 generate_id: self.generate_id,
214 }),
215 })
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn builder_with_explicit_key_succeeds() {
225 let m = Mistral::builder().api_key("test-key").build().expect("ok");
226 assert_eq!(m.inner.base_url, DEFAULT_BASE_URL);
227 assert!(
228 m.inner
229 .headers
230 .get("authorization")
231 .unwrap()
232 .as_ref()
233 .unwrap()
234 .starts_with("Bearer ")
235 );
236 }
237
238 #[test]
239 fn builder_custom_base_url() {
240 let m = Mistral::builder()
241 .api_key("k")
242 .base_url("https://proxy.example.com/v1")
243 .build()
244 .expect("ok");
245 assert_eq!(m.inner.base_url, "https://proxy.example.com/v1");
246 }
247
248 #[test]
249 fn builder_generate_id_is_stored() {
250 let m = Mistral::builder()
253 .api_key("k")
254 .generate_id(|| "custom-id".to_owned())
255 .build()
256 .expect("ok");
257 let gen_fn = m.inner.generate_id.as_ref().expect("generate_id stored");
258 assert_eq!(gen_fn(), "custom-id");
259 }
260
261 #[test]
262 fn builder_custom_header() {
263 let m = Mistral::builder()
264 .api_key("k")
265 .header("x-feature", Some("y".into()))
266 .build()
267 .expect("ok");
268 assert_eq!(
269 m.inner.headers.get("x-feature").unwrap().as_deref(),
270 Some("y")
271 );
272 }
273}