1use bytes::Bytes;
2use serde::{Deserialize, Serialize};
3
4use super::{
5 CompletionModel,
6 embedding::{EmbeddingModel, MISTRAL_EMBED},
7};
8use crate::{
9 client::{CompletionClient, EmbeddingsClient, ProviderClient, VerifyClient, VerifyError},
10 http_client::HttpClientExt,
11};
12use crate::{http_client, impl_conversion_traits};
13use std::fmt::Debug;
14
15const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
16
17pub struct ClientBuilder<'a, T = reqwest::Client> {
18 api_key: &'a str,
19 base_url: &'a str,
20 http_client: T,
21}
22
23impl<'a> ClientBuilder<'a, reqwest::Client> {
24 pub fn new(api_key: &'a str) -> Self {
25 Self {
26 api_key,
27 base_url: MISTRAL_API_BASE_URL,
28 http_client: Default::default(),
29 }
30 }
31}
32
33impl<'a, T> ClientBuilder<'a, T> {
34 pub fn with_client<U>(self, http_client: U) -> ClientBuilder<'a, U> {
35 ClientBuilder {
36 api_key: self.api_key,
37 base_url: self.base_url,
38 http_client,
39 }
40 }
41
42 pub fn base_url(mut self, base_url: &'a str) -> Self {
43 self.base_url = base_url;
44 self
45 }
46
47 pub fn build(self) -> Client<T> {
48 Client {
49 base_url: self.base_url.to_string(),
50 api_key: self.api_key.to_string(),
51 http_client: self.http_client,
52 }
53 }
54}
55
56#[derive(Clone)]
57pub struct Client<T = reqwest::Client> {
58 base_url: String,
59 api_key: String,
60 http_client: T,
61}
62
63impl<T> std::fmt::Debug for Client<T>
64where
65 T: Debug,
66{
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("Client")
69 .field("base_url", &self.base_url)
70 .field("http_client", &self.http_client)
71 .field("api_key", &"<REDACTED>")
72 .finish()
73 }
74}
75
76impl Client<reqwest::Client> {
77 pub fn new(api_key: &str) -> Self {
82 Self::builder(api_key).build()
83 }
84
85 pub fn builder(api_key: &str) -> ClientBuilder<'_> {
96 ClientBuilder::new(api_key)
97 }
98}
99
100impl<T> Client<T>
101where
102 T: HttpClientExt,
103{
104 pub(crate) fn post(&self, path: &str) -> http_client::Result<http_client::Builder> {
105 let url = format!("{}/{}", self.base_url, path.trim_start_matches('/'));
106
107 http_client::with_bearer_auth(http_client::Request::post(url), &self.api_key)
108 }
109
110 pub(crate) fn get(&self, path: &str) -> http_client::Result<http_client::Builder> {
111 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
112
113 http_client::with_bearer_auth(http_client::Request::get(url), &self.api_key)
114 }
115
116 pub(crate) async fn send<Body, R>(
117 &self,
118 req: http_client::Request<Body>,
119 ) -> http_client::Result<http_client::Response<http_client::LazyBody<R>>>
120 where
121 Body: Into<Bytes> + Send,
122 R: From<Bytes> + Send + 'static,
123 {
124 self.http_client.send(req).await
125 }
126}
127
128impl ProviderClient for Client<reqwest::Client> {
129 fn from_env() -> Self
132 where
133 Self: Sized,
134 {
135 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
136 Self::new(&api_key)
137 }
138
139 fn from_val(input: crate::client::ProviderValue) -> Self {
140 let crate::client::ProviderValue::Simple(api_key) = input else {
141 panic!("Incorrect provider value type")
142 };
143 Self::new(&api_key)
144 }
145}
146
147impl CompletionClient for Client<reqwest::Client> {
148 type CompletionModel = CompletionModel<reqwest::Client>;
149
150 fn completion_model(&self, model: &str) -> Self::CompletionModel {
162 CompletionModel::new(self.clone(), model)
163 }
164}
165
166impl EmbeddingsClient for Client<reqwest::Client> {
167 type EmbeddingModel = EmbeddingModel<reqwest::Client>;
168
169 fn embedding_model(&self, model: &str) -> Self::EmbeddingModel {
182 let ndims = match model {
183 MISTRAL_EMBED => 1024,
184 _ => 0,
185 };
186 EmbeddingModel::new(self.clone(), model, ndims)
187 }
188
189 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
190 EmbeddingModel::new(self.clone(), model, ndims)
191 }
192}
193
194impl VerifyClient for Client<reqwest::Client> {
195 #[cfg_attr(feature = "worker", worker::send)]
196 async fn verify(&self) -> Result<(), VerifyError> {
197 let req = self
198 .get("/models")?
199 .body(http_client::NoBody)
200 .map_err(|e| VerifyError::HttpError(e.into()))?;
201
202 let response = HttpClientExt::send(&self.http_client, req).await?;
203
204 match response.status() {
205 reqwest::StatusCode::OK => Ok(()),
206 reqwest::StatusCode::UNAUTHORIZED => Err(VerifyError::InvalidAuthentication),
207 reqwest::StatusCode::INTERNAL_SERVER_ERROR => {
208 let text = http_client::text(response).await?;
209 Err(VerifyError::ProviderError(text))
210 }
211 _ => {
212 Ok(())
215 }
216 }
217 }
218}
219
220impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client<T>);
221
222#[derive(Clone, Debug, Deserialize, Serialize)]
223pub struct Usage {
224 pub completion_tokens: usize,
225 pub prompt_tokens: usize,
226 pub total_tokens: usize,
227}
228
229impl std::fmt::Display for Usage {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 write!(
232 f,
233 "Prompt tokens: {} Total tokens: {}",
234 self.prompt_tokens, self.total_tokens
235 )
236 }
237}
238
239#[derive(Debug, Deserialize)]
240pub struct ApiErrorResponse {
241 pub(crate) message: String,
242}
243
244#[derive(Debug, Deserialize)]
245#[serde(untagged)]
246pub(crate) enum ApiResponse<T> {
247 Ok(T),
248 Err(ApiErrorResponse),
249}