rig/providers/mistral/
client.rs

1use serde::Deserialize;
2
3use super::{
4    embedding::{EmbeddingModel, MISTRAL_EMBED},
5    CompletionModel,
6};
7use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
8use crate::impl_conversion_traits;
9
10const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
11
12#[derive(Debug, Clone)]
13pub struct Client {
14    base_url: String,
15    http_client: reqwest::Client,
16}
17
18impl Client {
19    pub fn new(api_key: &str) -> Self {
20        Self::from_url(api_key, MISTRAL_API_BASE_URL)
21    }
22
23    pub fn from_url(api_key: &str, base_url: &str) -> Self {
24        Self {
25            base_url: base_url.to_string(),
26            http_client: reqwest::Client::builder()
27                .default_headers({
28                    let mut headers = reqwest::header::HeaderMap::new();
29                    headers.insert(
30                        "Authorization",
31                        format!("Bearer {api_key}")
32                            .parse()
33                            .expect("Bearer token should parse"),
34                    );
35                    headers
36                })
37                .build()
38                .expect("Mistral reqwest client should build"),
39        }
40    }
41
42    pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
43        let url = format!("{}/{}", self.base_url, path).replace("//", "/");
44        self.http_client.post(url)
45    }
46}
47
48impl ProviderClient for Client {
49    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
50    /// Panics if the environment variable is not set.
51    fn from_env() -> Self
52    where
53        Self: Sized,
54    {
55        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
56        Self::new(&api_key)
57    }
58}
59
60impl CompletionClient for Client {
61    type CompletionModel = CompletionModel;
62
63    /// Create a completion model with the given name.
64    ///
65    /// # Example
66    /// ```
67    /// use rig::providers::mistral::{Client, self};
68    ///
69    /// // Initialize the Mistral client
70    /// let mistral = Client::new("your-mistral-api-key");
71    ///
72    /// let codestral = mistral.completion_model(mistral::CODESTRAL);
73    /// ```
74    fn completion_model(&self, model: &str) -> Self::CompletionModel {
75        CompletionModel::new(self.clone(), model)
76    }
77}
78
79impl EmbeddingsClient for Client {
80    type EmbeddingModel = EmbeddingModel;
81
82    /// Create an embedding model with the given name.
83    /// Note: default embedding dimension of 0 will be used if model is not known.
84    ///
85    /// # Example
86    /// ```
87    /// use rig::providers::mistral::{Client, self};
88    ///
89    /// // Initialize mistral client
90    /// let mistral = Client::new("your-mistral-api-key");
91    ///
92    /// let embedding_model = mistral.embedding_model(mistral::MISTRAL_EMBED);
93    /// ```
94    fn embedding_model(&self, model: &str) -> EmbeddingModel {
95        let ndims = match model {
96            MISTRAL_EMBED => 1024,
97            _ => 0,
98        };
99        EmbeddingModel::new(self.clone(), model, ndims)
100    }
101
102    fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
103        EmbeddingModel::new(self.clone(), model, ndims)
104    }
105}
106
107impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
108
109#[derive(Clone, Debug, Deserialize)]
110pub struct Usage {
111    pub completion_tokens: usize,
112    pub prompt_tokens: usize,
113    pub total_tokens: usize,
114}
115
116impl std::fmt::Display for Usage {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(
119            f,
120            "Prompt tokens: {} Total tokens: {}",
121            self.prompt_tokens, self.total_tokens
122        )
123    }
124}
125
126#[derive(Debug, Deserialize)]
127pub struct ApiErrorResponse {
128    pub(crate) message: String,
129}
130
131#[derive(Debug, Deserialize)]
132#[serde(untagged)]
133pub(crate) enum ApiResponse<T> {
134    Ok(T),
135    Err(ApiErrorResponse),
136}