rig/providers/mistral/
client.rs1use serde::{Deserialize, Serialize};
2
3use super::{
4 CompletionModel,
5 embedding::{EmbeddingModel, MISTRAL_EMBED},
6};
7use crate::client::{ClientBuilderError, CompletionClient, EmbeddingsClient, ProviderClient};
8use crate::impl_conversion_traits;
9
10const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
11
12pub struct ClientBuilder<'a> {
13 api_key: &'a str,
14 base_url: &'a str,
15 http_client: Option<reqwest::Client>,
16}
17
18impl<'a> ClientBuilder<'a> {
19 pub fn new(api_key: &'a str) -> Self {
20 Self {
21 api_key,
22 base_url: MISTRAL_API_BASE_URL,
23 http_client: None,
24 }
25 }
26
27 pub fn base_url(mut self, base_url: &'a str) -> Self {
28 self.base_url = base_url;
29 self
30 }
31
32 pub fn custom_client(mut self, client: reqwest::Client) -> Self {
33 self.http_client = Some(client);
34 self
35 }
36
37 pub fn build(self) -> Result<Client, ClientBuilderError> {
38 let http_client = if let Some(http_client) = self.http_client {
39 http_client
40 } else {
41 reqwest::Client::builder().build()?
42 };
43
44 Ok(Client {
45 base_url: self.base_url.to_string(),
46 api_key: self.api_key.to_string(),
47 http_client,
48 })
49 }
50}
51
52#[derive(Clone)]
53pub struct Client {
54 base_url: String,
55 api_key: String,
56 http_client: reqwest::Client,
57}
58
59impl std::fmt::Debug for Client {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("Client")
62 .field("base_url", &self.base_url)
63 .field("http_client", &self.http_client)
64 .field("api_key", &"<REDACTED>")
65 .finish()
66 }
67}
68
69impl Client {
70 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
81 ClientBuilder::new(api_key)
82 }
83
84 pub fn new(api_key: &str) -> Self {
89 Self::builder(api_key)
90 .build()
91 .expect("Mistral client should build")
92 }
93
94 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
95 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
96 self.http_client.post(url).bearer_auth(&self.api_key)
97 }
98}
99
100impl ProviderClient for Client {
101 fn from_env() -> Self
104 where
105 Self: Sized,
106 {
107 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
108 Self::new(&api_key)
109 }
110
111 fn from_val(input: crate::client::ProviderValue) -> Self {
112 let crate::client::ProviderValue::Simple(api_key) = input else {
113 panic!("Incorrect provider value type")
114 };
115 Self::new(&api_key)
116 }
117}
118
119impl CompletionClient for Client {
120 type CompletionModel = CompletionModel;
121
122 fn completion_model(&self, model: &str) -> Self::CompletionModel {
134 CompletionModel::new(self.clone(), model)
135 }
136}
137
138impl EmbeddingsClient for Client {
139 type EmbeddingModel = EmbeddingModel;
140
141 fn embedding_model(&self, model: &str) -> EmbeddingModel {
154 let ndims = match model {
155 MISTRAL_EMBED => 1024,
156 _ => 0,
157 };
158 EmbeddingModel::new(self.clone(), model, ndims)
159 }
160
161 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
162 EmbeddingModel::new(self.clone(), model, ndims)
163 }
164}
165
166impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
167
168#[derive(Clone, Debug, Deserialize, Serialize)]
169pub struct Usage {
170 pub completion_tokens: usize,
171 pub prompt_tokens: usize,
172 pub total_tokens: usize,
173}
174
175impl std::fmt::Display for Usage {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 write!(
178 f,
179 "Prompt tokens: {} Total tokens: {}",
180 self.prompt_tokens, self.total_tokens
181 )
182 }
183}
184
185#[derive(Debug, Deserialize)]
186pub struct ApiErrorResponse {
187 pub(crate) message: String,
188}
189
190#[derive(Debug, Deserialize)]
191#[serde(untagged)]
192pub(crate) enum ApiResponse<T> {
193 Ok(T),
194 Err(ApiErrorResponse),
195}