rig/providers/mistral/
client.rs

1use bytes::Bytes;
2use serde::{Deserialize, Serialize};
3
4use super::{
5    CompletionModel,
6    embedding::{EmbeddingModel, MISTRAL_EMBED},
7};
8use crate::{
9    client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError},
10    http_client::HttpClientExt,
11};
12use crate::{http_client, impl_conversion_traits};
13use std::fmt::Debug;
14
15const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
16
17pub struct ClientBuilder<'a, T = reqwest::Client> {
18    api_key: &'a str,
19    base_url: &'a str,
20    http_client: T,
21}
22
23impl<'a, T> ClientBuilder<'a, T>
24where
25    T: Default,
26{
27    pub fn new(api_key: &'a str) -> Self {
28        Self {
29            api_key,
30            base_url: MISTRAL_API_BASE_URL,
31            http_client: Default::default(),
32        }
33    }
34}
35
36impl<'a, T> ClientBuilder<'a, T> {
37    pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
38        Self {
39            api_key,
40            base_url: MISTRAL_API_BASE_URL,
41            http_client,
42        }
43    }
44
45    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
46        ClientBuilder {
47            api_key: self.api_key,
48            base_url: self.base_url,
49            http_client,
50        }
51    }
52
53    pub fn base_url(mut self, base_url: &'a str) -> Self {
54        self.base_url = base_url;
55        self
56    }
57
58    pub fn build(self) -> Client<T> {
59        Client {
60            base_url: self.base_url.to_string(),
61            api_key: self.api_key.to_string(),
62            http_client: self.http_client,
63        }
64    }
65}
66
67#[derive(Clone)]
68pub struct Client<T = reqwest::Client> {
69    base_url: String,
70    api_key: String,
71    http_client: T,
72}
73
74impl<T> std::fmt::Debug for Client<T>
75where
76    T: Debug,
77{
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("Client")
80            .field("base_url", &self.base_url)
81            .field("http_client", &self.http_client)
82            .field("api_key", &"<REDACTED>")
83            .finish()
84    }
85}
86
87impl Client<reqwest::Client> {
88    /// Create a new Mistral client. For more control, use the `builder` method.
89    ///
90    /// # Panics
91    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
92    pub fn new(api_key: &str) -> Self {
93        Self::builder(api_key).build()
94    }
95
96    /// Create a new Mistral client builder.
97    ///
98    /// # Example
99    /// ```
100    /// use rig::providers::mistral::{ClientBuilder, self};
101    ///
102    /// // Initialize the Mistral client
103    /// let mistral = Client::builder("your-mistral-api-key")
104    ///    .build()
105    /// ```
106    pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
107        ClientBuilder::new(api_key)
108    }
109
110    pub fn from_env() -> Self {
111        <Self as ProviderClient>::from_env()
112    }
113}
114
115impl<T> Client<T>
116where
117    T: HttpClientExt,
118{
119    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
120        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
121
122        http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
123    }
124
125    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
126        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
127
128        http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
129    }
130
131    pub(crate) async fn send<Body, R>(
132        &self,
133        req: http_client::Request<Body>,
134    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
135    where
136        Body: Into<Bytes> + Send,
137        R: From<Bytes> + Send + 'static,
138    {
139        self.http_client.send(req).await
140    }
141}
142
143impl<T> ProviderClient for Client<T>
144where
145    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
146{
147    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
148    /// Panics if the environment variable is not set.
149    fn from_env() -> Self
150    where
151        Self: Sized,
152    {
153        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
154        ClientBuilder::<T>::new(&api_key).build()
155    }
156
157    fn from_val(input: crate::client::ProviderValue) -> Self {
158        let crate::client::ProviderValue::Simple(api_key) = input else {
159            panic!("Incorrect provider value type")
160        };
161        ClientBuilder::<T>::new(&api_key).build()
162    }
163}
164
165impl<T> CompletionClient for Client<T>
166where
167    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
168{
169    type CompletionModel = CompletionModel<T>;
170
171    /// Create a completion model with the given name.
172    ///
173    /// # Example
174    /// ```
175    /// use rig::providers::mistral::{Client, self};
176    ///
177    /// // Initialize the Mistral client
178    /// let mistral = Client::new("your-mistral-api-key");
179    ///
180    /// let codestral = mistral.completion_model(mistral::CODESTRAL);
181    /// ```
182    fn completion_model(&self, model: &str) -> Self::CompletionModel {
183        CompletionModel::new(self.clone(), model)
184    }
185}
186
187impl<T> EmbeddingsClient for Client<T>
188where
189    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
190{
191    type EmbeddingModel = EmbeddingModel<T>;
192
193    /// Create an embedding model with the given name.
194    /// Note: default embedding dimension of 0 will be used if model is not known.
195    ///
196    /// # Example
197    /// ```
198    /// use rig::providers::mistral::{Client, self};
199    ///
200    /// // Initialize mistral client
201    /// let mistral = Client::new("your-mistral-api-key");
202    ///
203    /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED);
204    /// ```
205    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
206        let ndims = match model {
207            MISTRAL_EMBED => 1024,
208            _ => 0,
209        };
210        EmbeddingModel::new(self.clone(), model, ndims)
211    }
212
213    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
214        EmbeddingModel::new(self.clone(), model, ndims)
215    }
216}
217
218impl<T> VerifyClient for Client<T>
219where
220    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
221{
222    #[cfg_attr(feature = "worker", worker::send)]
223    async fn verify(&self) -> Result<(), VerifyError> {
224        let req = self
225            .get("/models")?
226            .body(http_client::NoBody)
227            .map_err(|e| VerifyError::HttpError(e.into()))?;
228
229        let response = HttpClientExt::send(&self.http_client, req).await?;
230
231        match response.status() {
232            reqwest::StatusCode::OK => Ok(()),
233            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
234            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
235                let text = http_client::text(response).await?;
236                Err(VerifyError::ProviderError(text))
237            }
238            _ => {
239                // TODO: implement equivalent with `http` crate `StatusCode` type
240                //response.error_for_status()?;
241                Ok(())
242            }
243        }
244    }
245}
246
247impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client<T>);
248
249#[derive(Clone, Debug, Deserialize, Serialize)]
250pub struct Usage {
251    pub completion_tokens: usize,
252    pub prompt_tokens: usize,
253    pub total_tokens: usize,
254}
255
256impl std::fmt::Display for Usage {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        write!(
259            f,
260            "Prompt tokens: {} Total tokens: {}",
261            self.prompt_tokens, self.total_tokens
262        )
263    }
264}
265
266#[derive(Debug, Deserialize)]
267pub struct ApiErrorResponse {
268    pub(crate) message: String,
269}
270
271#[derive(Debug, Deserialize)]
272#[serde(untagged)]
273pub(crate) enum ApiResponse<T> {
274    Ok(T),
275    Err(ApiErrorResponse),
276}