rig/providers/mistral/
client.rs

1use serde::Deserialize;
2
3use super::{
4    CompletionModel,
5    embedding::{EmbeddingModel, MISTRAL_EMBED},
6};
7use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
8use crate::impl_conversion_traits;
9
10const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
11
12#[derive(Clone)]
13pub struct Client {
14    base_url: String,
15    api_key: String,
16    http_client: reqwest::Client,
17}
18
19impl std::fmt::Debug for Client {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        f.debug_struct("Client")
22            .field("base_url", &self.base_url)
23            .field("http_client", &self.http_client)
24            .field("api_key", &"<REDACTED>")
25            .finish()
26    }
27}
28
29impl Client {
30    pub fn new(api_key: &str) -> Self {
31        Self::from_url(api_key, MISTRAL_API_BASE_URL)
32    }
33
34    pub fn from_url(api_key: &str, base_url: &str) -> Self {
35        Self {
36            base_url: base_url.to_string(),
37            api_key: api_key.to_string(),
38            http_client: reqwest::Client::builder()
39                .build()
40                .expect("Mistral reqwest client should build"),
41        }
42    }
43
44    /// Use your own `reqwest::Client`.
45    /// The required headers will be automatically attached upon trying to make a request.
46    pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
47        self.http_client = client;
48
49        self
50    }
51
52    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
53        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
54        self.http_client.post(url).bearer_auth(&self.api_key)
55    }
56}
57
58impl ProviderClient for Client {
59    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
60    /// Panics if the environment variable is not set.
61    fn from_env() -> Self
62    where
63        Self: Sized,
64    {
65        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
66        Self::new(&api_key)
67    }
68
69    fn from_val(input: crate::client::ProviderValue) -> Self {
70        let crate::client::ProviderValue::Simple(api_key) = input else {
71            panic!("Incorrect provider value type")
72        };
73        Self::new(&api_key)
74    }
75}
76
77impl CompletionClient for Client {
78    type CompletionModel = CompletionModel;
79
80    /// Create a completion model with the given name.
81    ///
82    /// # Example
83    /// ```
84    /// use rig::providers::mistral::{Client, self};
85    ///
86    /// // Initialize the Mistral client
87    /// let mistral = Client::new("your-mistral-api-key");
88    ///
89    /// let codestral = mistral.completion_model(mistral::CODESTRAL);
90    /// ```
91    fn completion_model(&self, model: &str) -> Self::CompletionModel {
92        CompletionModel::new(self.clone(), model)
93    }
94}
95
96impl EmbeddingsClient for Client {
97    type EmbeddingModel = EmbeddingModel;
98
99    /// Create an embedding model with the given name.
100    /// Note: default embedding dimension of 0 will be used if model is not known.
101    ///
102    /// # Example
103    /// ```
104    /// use rig::providers::mistral::{Client, self};
105    ///
106    /// // Initialize mistral client
107    /// let mistral = Client::new("your-mistral-api-key");
108    ///
109    /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED);
110    /// ```
111    fn embedding_model(&self, model: &str) -> EmbeddingModel {
112        let ndims = match model {
113            MISTRAL_EMBED => 1024,
114            _ => 0,
115        };
116        EmbeddingModel::new(self.clone(), model, ndims)
117    }
118
119    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
120        EmbeddingModel::new(self.clone(), model, ndims)
121    }
122}
123
124impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
125
126#[derive(Clone, Debug, Deserialize)]
127pub struct Usage {
128    pub completion_tokens: usize,
129    pub prompt_tokens: usize,
130    pub total_tokens: usize,
131}
132
133impl std::fmt::Display for Usage {
134    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135        write!(
136            f,
137            "Prompt tokens: {} Total tokens: {}",
138            self.prompt_tokens, self.total_tokens
139        )
140    }
141}
142
143#[derive(Debug, Deserialize)]
144pub struct ApiErrorResponse {
145    pub(crate) message: String,
146}
147
148#[derive(Debug, Deserialize)]
149#[serde(untagged)]
150pub(crate) enum ApiResponse<T> {
151    Ok(T),
152    Err(ApiErrorResponse),
153}