rig/providers/mistral/
client.rs

1use bytes::Bytes;
2use serde::{Deserialize, Serialize};
3
4use super::{
5    CompletionModel,
6    embedding::{EmbeddingModel, MISTRAL_EMBED},
7};
8use crate::{
9    client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError},
10    http_client::HttpClientExt,
11};
12use crate::{http_client, impl_conversion_traits};
13use std::fmt::Debug;
14
15const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
16
17pub struct ClientBuilder<'a, T = reqwest::Client> {
18    api_key: &'a str,
19    base_url: &'a str,
20    http_client: T,
21}
22
23impl<'a> ClientBuilder<'a, reqwest::Client> {
24    pub fn new(api_key: &'a str) -> Self {
25        Self {
26            api_key,
27            base_url: MISTRAL_API_BASE_URL,
28            http_client: Default::default(),
29        }
30    }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34    pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
35        ClientBuilder {
36            api_key: self.api_key,
37            base_url: self.base_url,
38            http_client,
39        }
40    }
41
42    pub fn base_url(mut self, base_url: &'a str) -> Self {
43        self.base_url = base_url;
44        self
45    }
46
47    pub fn build(self) -> Client<T> {
48        Client {
49            base_url: self.base_url.to_string(),
50            api_key: self.api_key.to_string(),
51            http_client: self.http_client,
52        }
53    }
54}
55
56#[derive(Clone)]
57pub struct Client<T = reqwest::Client> {
58    base_url: String,
59    api_key: String,
60    http_client: T,
61}
62
63impl<T> std::fmt::Debug for Client<T>
64where
65    T: Debug,
66{
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("Client")
69            .field("base_url", &self.base_url)
70            .field("http_client", &self.http_client)
71            .field("api_key", &"<REDACTED>")
72            .finish()
73    }
74}
75
76impl Client<reqwest::Client> {
77    /// Create a new Mistral client. For more control, use the `builder` method.
78    ///
79    /// # Panics
80    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
81    pub fn new(api_key: &str) -> Self {
82        Self::builder(api_key).build()
83    }
84
85    /// Create a new Mistral client builder.
86    ///
87    /// # Example
88    /// ```
89    /// use rig::providers::mistral::{ClientBuilder, self};
90    ///
91    /// // Initialize the Mistral client
92    /// let mistral = Client::builder("your-mistral-api-key")
93    ///    .build()
94    /// ```
95    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
96        ClientBuilder::new(api_key)
97    }
98}
99
100impl<T> Client<T>
101where
102    T: HttpClientExt,
103{
104    pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
105        let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
106
107        http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
108    }
109
110    pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
111        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
112
113        http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
114    }
115
116    pub(crate) async fn send<Body, R>(
117        &self,
118        req: http_client::Request<Body>,
119    ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
120    where
121        Body: Into<Bytes> + Send,
122        R: From<Bytes> + Send + 'static,
123    {
124        self.http_client.send(req).await
125    }
126}
127
128impl ProviderClient for Client<reqwest::Client> {
129    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
130    /// Panics if the environment variable is not set.
131    fn from_env() -> Self
132    where
133        Self: Sized,
134    {
135        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
136        Self::new(&api_key)
137    }
138
139    fn from_val(input: crate::client::ProviderValue) -> Self {
140        let crate::client::ProviderValue::Simple(api_key) = input else {
141            panic!("Incorrect provider value type")
142        };
143        Self::new(&api_key)
144    }
145}
146
147impl CompletionClient for Client<reqwest::Client> {
148    type CompletionModel = CompletionModel<reqwest::Client>;
149
150    /// Create a completion model with the given name.
151    ///
152    /// # Example
153    /// ```
154    /// use rig::providers::mistral::{Client, self};
155    ///
156    /// // Initialize the Mistral client
157    /// let mistral = Client::new("your-mistral-api-key");
158    ///
159    /// let codestral = mistral.completion_model(mistral::CODESTRAL);
160    /// ```
161    fn completion_model(&self, model: &str) -> Self::CompletionModel {
162        CompletionModel::new(self.clone(), model)
163    }
164}
165
166impl EmbeddingsClient for Client<reqwest::Client> {
167    type EmbeddingModel = EmbeddingModel<reqwest::Client>;
168
169    /// Create an embedding model with the given name.
170    /// Note: default embedding dimension of 0 will be used if model is not known.
171    ///
172    /// # Example
173    /// ```
174    /// use rig::providers::mistral::{Client, self};
175    ///
176    /// // Initialize mistral client
177    /// let mistral = Client::new("your-mistral-api-key");
178    ///
179    /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED);
180    /// ```
181    fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
182        let ndims = match model {
183            MISTRAL_EMBED => 1024,
184            _ => 0,
185        };
186        EmbeddingModel::new(self.clone(), model, ndims)
187    }
188
189    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
190        EmbeddingModel::new(self.clone(), model, ndims)
191    }
192}
193
194impl VerifyClient for Client<reqwest::Client> {
195    #[cfg_attr(feature = "worker", worker::send)]
196    async fn verify(&self) -> Result<(), VerifyError> {
197        let req = self
198            .get("/models")?
199            .body(http_client::NoBody)
200            .map_err(|e| VerifyError::HttpError(e.into()))?;
201
202        let response = HttpClientExt::send(&self.http_client, req).await?;
203
204        match response.status() {
205            reqwest::StatusCode::OK => Ok(()),
206            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
207            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
208                let text = http_client::text(response).await?;
209                Err(VerifyError::ProviderError(text))
210            }
211            _ => {
212                // TODO: implement equivalent with `http` crate `StatusCode` type
213                //response.error_for_status()?;
214                Ok(())
215            }
216        }
217    }
218}
219
220impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client<T>);
221
222#[derive(Clone, Debug, Deserialize, Serialize)]
223pub struct Usage {
224    pub completion_tokens: usize,
225    pub prompt_tokens: usize,
226    pub total_tokens: usize,
227}
228
229impl std::fmt::Display for Usage {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        write!(
232            f,
233            "Prompt tokens: {} Total tokens: {}",
234            self.prompt_tokens, self.total_tokens
235        )
236    }
237}
238
239#[derive(Debug, Deserialize)]
240pub struct ApiErrorResponse {
241    pub(crate) message: String,
242}
243
244#[derive(Debug, Deserialize)]
245#[serde(untagged)]
246pub(crate) enum ApiResponse<T> {
247    Ok(T),
248    Err(ApiErrorResponse),
249}