rig/providers/mistral/
client.rs

1use serde::{Deserialize, Serialize};
2
3use super::{
4    CompletionModel,
5    embedding::{EmbeddingModel, MISTRAL_EMBED},
6};
7use crate::client::{
8    ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient,
9    VerifyError,
10};
11use crate::impl_conversion_traits;
12
13const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
14
15pub struct ClientBuilder<'a> {
16    api_key: &'a str,
17    base_url: &'a str,
18    http_client: Option<reqwest::Client>,
19}
20
21impl<'a> ClientBuilder<'a> {
22    pub fn new(api_key: &'a str) -> Self {
23        Self {
24            api_key,
25            base_url: MISTRAL_API_BASE_URL,
26            http_client: None,
27        }
28    }
29
30    pub fn base_url(mut self, base_url: &'a str) -> Self {
31        self.base_url = base_url;
32        self
33    }
34
35    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
36        self.http_client = Some(client);
37        self
38    }
39
40    pub fn build(self) -> Result<Client, ClientBuilderError> {
41        let http_client = if let Some(http_client) = self.http_client {
42            http_client
43        } else {
44            reqwest::Client::builder().build()?
45        };
46
47        Ok(Client {
48            base_url: self.base_url.to_string(),
49            api_key: self.api_key.to_string(),
50            http_client,
51        })
52    }
53}
54
55#[derive(Clone)]
56pub struct Client {
57    base_url: String,
58    api_key: String,
59    http_client: reqwest::Client,
60}
61
62impl std::fmt::Debug for Client {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        f.debug_struct("Client")
65            .field("base_url", &self.base_url)
66            .field("http_client", &self.http_client)
67            .field("api_key", &"<REDACTED>")
68            .finish()
69    }
70}
71
72impl Client {
73    /// Create a new Mistral client builder.
74    ///
75    /// # Example
76    /// ```
77    /// use rig::providers::mistral::{ClientBuilder, self};
78    ///
79    /// // Initialize the Mistral client
80    /// let mistral = Client::builder("your-mistral-api-key")
81    ///    .build()
82    /// ```
83    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
84        ClientBuilder::new(api_key)
85    }
86
87    /// Create a new Mistral client. For more control, use the `builder` method.
88    ///
89    /// # Panics
90    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
91    pub fn new(api_key: &str) -> Self {
92        Self::builder(api_key)
93            .build()
94            .expect("Mistral client should build")
95    }
96
97    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
98        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
99        self.http_client.post(url).bearer_auth(&self.api_key)
100    }
101
102    pub(crate) fn get(&self, path: &str) -> reqwest::RequestBuilder {
103        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
104        self.http_client.get(url).bearer_auth(&self.api_key)
105    }
106}
107
108impl ProviderClient for Client {
109    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
110    /// Panics if the environment variable is not set.
111    fn from_env() -> Self
112    where
113        Self: Sized,
114    {
115        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
116        Self::new(&api_key)
117    }
118
119    fn from_val(input: crate::client::ProviderValue) -> Self {
120        let crate::client::ProviderValue::Simple(api_key) = input else {
121            panic!("Incorrect provider value type")
122        };
123        Self::new(&api_key)
124    }
125}
126
127impl CompletionClient for Client {
128    type CompletionModel = CompletionModel;
129
130    /// Create a completion model with the given name.
131    ///
132    /// # Example
133    /// ```
134    /// use rig::providers::mistral::{Client, self};
135    ///
136    /// // Initialize the Mistral client
137    /// let mistral = Client::new("your-mistral-api-key");
138    ///
139    /// let codestral = mistral.completion_model(mistral::CODESTRAL);
140    /// ```
141    fn completion_model(&self, model: &str) -> Self::CompletionModel {
142        CompletionModel::new(self.clone(), model)
143    }
144}
145
146impl EmbeddingsClient for Client {
147    type EmbeddingModel = EmbeddingModel;
148
149    /// Create an embedding model with the given name.
150    /// Note: default embedding dimension of 0 will be used if model is not known.
151    ///
152    /// # Example
153    /// ```
154    /// use rig::providers::mistral::{Client, self};
155    ///
156    /// // Initialize mistral client
157    /// let mistral = Client::new("your-mistral-api-key");
158    ///
159    /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED);
160    /// ```
161    fn embedding_model(&self, model: &str) -> EmbeddingModel {
162        let ndims = match model {
163            MISTRAL_EMBED => 1024,
164            _ => 0,
165        };
166        EmbeddingModel::new(self.clone(), model, ndims)
167    }
168
169    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
170        EmbeddingModel::new(self.clone(), model, ndims)
171    }
172}
173
174impl VerifyClient for Client {
175    #[cfg_attr(feature = "worker", worker::send)]
176    async fn verify(&self) -> Result<(), VerifyError> {
177        let response = self.get("/models").send().await?;
178        match response.status() {
179            reqwest::StatusCode::OK => Ok(()),
180            reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
181            reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
182                Err(VerifyError::ProviderError(response.text().await?))
183            }
184            _ => {
185                response.error_for_status()?;
186                Ok(())
187            }
188        }
189    }
190}
191
192impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
193
194#[derive(Clone, Debug, Deserialize, Serialize)]
195pub struct Usage {
196    pub completion_tokens: usize,
197    pub prompt_tokens: usize,
198    pub total_tokens: usize,
199}
200
201impl std::fmt::Display for Usage {
202    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
203        write!(
204            f,
205            "Prompt tokens: {} Total tokens: {}",
206            self.prompt_tokens, self.total_tokens
207        )
208    }
209}
210
211#[derive(Debug, Deserialize)]
212pub struct ApiErrorResponse {
213    pub(crate) message: String,
214}
215
216#[derive(Debug, Deserialize)]
217#[serde(untagged)]
218pub(crate) enum ApiResponse<T> {
219    Ok(T),
220    Err(ApiErrorResponse),
221}