1use bytes::Bytes;
2use serde::{Deserialize, Serialize};
3
4use super::{
5 CompletionModel,
6 embedding::{EmbeddingModel, MISTRAL_EMBED},
7};
8use crate::{
9 client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError},
10 http_client::HttpClientExt,
11};
12use crate::{http_client, impl_conversion_traits};
13use std::fmt::Debug;
14
15const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
16
17pub struct ClientBuilder<'a, T = reqwest::Client> {
18 api_key: &'a str,
19 base_url: &'a str,
20 http_client: T,
21}
22
23impl<'a, T> ClientBuilder<'a, T>
24where
25 T: Default,
26{
27 pub fn new(api_key: &'a str) -> Self {
28 Self {
29 api_key,
30 base_url: MISTRAL_API_BASE_URL,
31 http_client: Default::default(),
32 }
33 }
34}
35
36impl<'a, T> ClientBuilder<'a, T> {
37 pub fn new_with_client(api_key: &'a str, http_client: T) -> Self {
38 Self {
39 api_key,
40 base_url: MISTRAL_API_BASE_URL,
41 http_client,
42 }
43 }
44
45 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
46 ClientBuilder {
47 api_key: self.api_key,
48 base_url: self.base_url,
49 http_client,
50 }
51 }
52
53 pub fn base_url(mut self, base_url: &'a str) -> Self {
54 self.base_url = base_url;
55 self
56 }
57
58 pub fn build(self) -> Client<T> {
59 Client {
60 base_url: self.base_url.to_string(),
61 api_key: self.api_key.to_string(),
62 http_client: self.http_client,
63 }
64 }
65}
66
67#[derive(Clone)]
68pub struct Client<T = reqwest::Client> {
69 base_url: String,
70 api_key: String,
71 http_client: T,
72}
73
74impl<T> std::fmt::Debug for Client<T>
75where
76 T: Debug,
77{
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("Client")
80 .field("base_url", &self.base_url)
81 .field("http_client", &self.http_client)
82 .field("api_key", &"<REDACTED>")
83 .finish()
84 }
85}
86
87impl Client<reqwest::Client> {
88 pub fn new(api_key: &str) -> Self {
93 Self::builder(api_key).build()
94 }
95
96 pub fn builder(api_key: &str) -> ClientBuilder<'_, reqwest::Client> {
107 ClientBuilder::new(api_key)
108 }
109
110 pub fn from_env() -> Self {
111 <Self as ProviderClient>::from_env()
112 }
113}
114
115impl<T> Client<T>
116where
117 T: HttpClientExt,
118{
119 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
120 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
121
122 http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
123 }
124
125 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
126 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
127
128 http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
129 }
130
131 pub(crate) async fn send<Body, R>(
132 &self,
133 req: http_client::Request<Body>,
134 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
135 where
136 Body: Into<Bytes> + Send,
137 R: From<Bytes> + Send + 'static,
138 {
139 self.http_client.send(req).await
140 }
141}
142
143impl<T> ProviderClient for Client<T>
144where
145 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
146{
147 fn from_env() -> Self
150 where
151 Self: Sized,
152 {
153 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
154 ClientBuilder::<T>::new(&api_key).build()
155 }
156
157 fn from_val(input: crate::client::ProviderValue) -> Self {
158 let crate::client::ProviderValue::Simple(api_key) = input else {
159 panic!("Incorrect provider value type")
160 };
161 ClientBuilder::<T>::new(&api_key).build()
162 }
163}
164
165impl<T> CompletionClient for Client<T>
166where
167 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
168{
169 type CompletionModel = CompletionModel<T>;
170
171 fn completion_model(&self, model: &str) -> Self::CompletionModel {
183 CompletionModel::new(self.clone(), model)
184 }
185}
186
187impl<T> EmbeddingsClient for Client<T>
188where
189 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
190{
191 type EmbeddingModel = EmbeddingModel<T>;
192
193 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
206 let ndims = match model {
207 MISTRAL_EMBED => 1024,
208 _ => 0,
209 };
210 EmbeddingModel::new(self.clone(), model, ndims)
211 }
212
213 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
214 EmbeddingModel::new(self.clone(), model, ndims)
215 }
216}
217
218impl<T> VerifyClient for Client<T>
219where
220 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
221{
222 #[cfg_attr(feature = "worker", worker::send)]
223 async fn verify(&self) -> Result<(), VerifyError> {
224 let req = self
225 .get("/models")?
226 .body(http_client::NoBody)
227 .map_err(|e| VerifyError::HttpError(e.into()))?;
228
229 let response = HttpClientExt::send(&self.http_client, req).await?;
230
231 match response.status() {
232 reqwest::StatusCode::OK => Ok(()),
233 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
234 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
235 let text = http_client::text(response).await?;
236 Err(VerifyError::ProviderError(text))
237 }
238 _ => {
239 Ok(())
242 }
243 }
244 }
245}
246
247impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client<T>);
248
249#[derive(Clone, Debug, Deserialize, Serialize)]
250pub struct Usage {
251 pub completion_tokens: usize,
252 pub prompt_tokens: usize,
253 pub total_tokens: usize,
254}
255
256impl std::fmt::Display for Usage {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 write!(
259 f,
260 "Prompt tokens: {} Total tokens: {}",
261 self.prompt_tokens, self.total_tokens
262 )
263 }
264}
265
266#[derive(Debug, Deserialize)]
267pub struct ApiErrorResponse {
268 pub(crate) message: String,
269}
270
271#[derive(Debug, Deserialize)]
272#[serde(untagged)]
273pub(crate) enum ApiResponse<T> {
274 Ok(T),
275 Err(ApiErrorResponse),
276}