rig/providers/mistral/
client.rs

1use serde::{Deserialize, Serialize};
2
3use super::{
4    CompletionModel,
5    embedding::{EmbeddingModel, MISTRAL_EMBED},
6};
7use crate::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient};
8use crate::impl_conversion_traits;
9
10const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
11
12pub struct ClientBuilder<'a> {
13    api_key: &'a str,
14    base_url: &'a str,
15    http_client: Option<reqwest::Client>,
16}
17
18impl<'a> ClientBuilder<'a> {
19    pub fn new(api_key: &'a str) -> Self {
20        Self {
21            api_key,
22            base_url: MISTRAL_API_BASE_URL,
23            http_client: None,
24        }
25    }
26
27    pub fn base_url(mut self, base_url: &'a str) -> Self {
28        self.base_url = base_url;
29        self
30    }
31
32    pub fn custom_client(mut self, client: reqwest::Client) -> Self {
33        self.http_client = Some(client);
34        self
35    }
36
37    pub fn build(self) -> Result<Client, ClientBuilderError> {
38        let http_client = if let Some(http_client) = self.http_client {
39            http_client
40        } else {
41            reqwest::Client::builder().build()?
42        };
43
44        Ok(Client {
45            base_url: self.base_url.to_string(),
46            api_key: self.api_key.to_string(),
47            http_client,
48        })
49    }
50}
51
52#[derive(Clone)]
53pub struct Client {
54    base_url: String,
55    api_key: String,
56    http_client: reqwest::Client,
57}
58
59impl std::fmt::Debug for Client {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("Client")
62            .field("base_url", &self.base_url)
63            .field("http_client", &self.http_client)
64            .field("api_key", &"<REDACTED>")
65            .finish()
66    }
67}
68
69impl Client {
70    /// Create a new Mistral client builder.
71    ///
72    /// # Example
73    /// ```
74    /// use rig::providers::mistral::{ClientBuilder, self};
75    ///
76    /// // Initialize the Mistral client
77    /// let mistral = Client::builder("your-mistral-api-key")
78    ///    .build()
79    /// ```
80    pub fn builder(api_key: &str) -> ClientBuilder<'_> {
81        ClientBuilder::new(api_key)
82    }
83
84    /// Create a new Mistral client. For more control, use the `builder` method.
85    ///
86    /// # Panics
87    /// - If the reqwest client cannot be built (if the TLS backend cannot be initialized).
88    pub fn new(api_key: &str) -> Self {
89        Self::builder(api_key)
90            .build()
91            .expect("Mistral client should build")
92    }
93
94    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
95        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
96        self.http_client.post(url).bearer_auth(&self.api_key)
97    }
98}
99
100impl ProviderClient for Client {
101    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
102    /// Panics if the environment variable is not set.
103    fn from_env() -> Self
104    where
105        Self: Sized,
106    {
107        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
108        Self::new(&api_key)
109    }
110
111    fn from_val(input: crate::client::ProviderValue) -> Self {
112        let crate::client::ProviderValue::Simple(api_key) = input else {
113            panic!("Incorrect provider value type")
114        };
115        Self::new(&api_key)
116    }
117}
118
119impl CompletionClient for Client {
120    type CompletionModel = CompletionModel;
121
122    /// Create a completion model with the given name.
123    ///
124    /// # Example
125    /// ```
126    /// use rig::providers::mistral::{Client, self};
127    ///
128    /// // Initialize the Mistral client
129    /// let mistral = Client::new("your-mistral-api-key");
130    ///
131    /// let codestral = mistral.completion_model(mistral::CODESTRAL);
132    /// ```
133    fn completion_model(&self, model: &str) -> Self::CompletionModel {
134        CompletionModel::new(self.clone(), model)
135    }
136}
137
138impl EmbeddingsClient for Client {
139    type EmbeddingModel = EmbeddingModel;
140
141    /// Create an embedding model with the given name.
142    /// Note: default embedding dimension of 0 will be used if model is not known.
143    ///
144    /// # Example
145    /// ```
146    /// use rig::providers::mistral::{Client, self};
147    ///
148    /// // Initialize mistral client
149    /// let mistral = Client::new("your-mistral-api-key");
150    ///
151    /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED);
152    /// ```
153    fn embedding_model(&self, model: &str) -> EmbeddingModel {
154        let ndims = match model {
155            MISTRAL_EMBED => 1024,
156            _ => 0,
157        };
158        EmbeddingModel::new(self.clone(), model, ndims)
159    }
160
161    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
162        EmbeddingModel::new(self.clone(), model, ndims)
163    }
164}
165
166impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
167
168#[derive(Clone, Debug, Deserialize, Serialize)]
169pub struct Usage {
170    pub completion_tokens: usize,
171    pub prompt_tokens: usize,
172    pub total_tokens: usize,
173}
174
175impl std::fmt::Display for Usage {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        write!(
178            f,
179            "Prompt tokens: {} Total tokens: {}",
180            self.prompt_tokens, self.total_tokens
181        )
182    }
183}
184
185#[derive(Debug, Deserialize)]
186pub struct ApiErrorResponse {
187    pub(crate) message: String,
188}
189
190#[derive(Debug, Deserialize)]
191#[serde(untagged)]
192pub(crate) enum ApiResponse<T> {
193    Ok(T),
194    Err(ApiErrorResponse),
195}