rig/providers/mistral/
client.rs1use serde::Deserialize;
2
3use super::{
4 CompletionModel,
5 embedding::{EmbeddingModel, MISTRAL_EMBED},
6};
7use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
8use crate::impl_conversion_traits;
9
10const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
11
12#[derive(Clone)]
13pub struct Client {
14 base_url: String,
15 api_key: String,
16 http_client: reqwest::Client,
17}
18
19impl std::fmt::Debug for Client {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 f.debug_struct("Client")
22 .field("base_url", &self.base_url)
23 .field("http_client", &self.http_client)
24 .field("api_key", &"<REDACTED>")
25 .finish()
26 }
27}
28
29impl Client {
30 pub fn new(api_key: &str) -> Self {
31 Self::from_url(api_key, MISTRAL_API_BASE_URL)
32 }
33
34 pub fn from_url(api_key: &str, base_url: &str) -> Self {
35 Self {
36 base_url: base_url.to_string(),
37 api_key: api_key.to_string(),
38 http_client: reqwest::Client::builder()
39 .build()
40 .expect("Mistral reqwest client should build"),
41 }
42 }
43
44 pub fn with_custom_client(mut self, client: reqwest::Client) -> Self {
47 self.http_client = client;
48
49 self
50 }
51
52 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
53 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
54 self.http_client.post(url).bearer_auth(&self.api_key)
55 }
56}
57
58impl ProviderClient for Client {
59 fn from_env() -> Self
62 where
63 Self: Sized,
64 {
65 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
66 Self::new(&api_key)
67 }
68
69 fn from_val(input: crate::client::ProviderValue) -> Self {
70 let crate::client::ProviderValue::Simple(api_key) = input else {
71 panic!("Incorrect provider value type")
72 };
73 Self::new(&api_key)
74 }
75}
76
77impl CompletionClient for Client {
78 type CompletionModel = CompletionModel;
79
80 fn completion_model(&self, model: &str) -> Self::CompletionModel {
92 CompletionModel::new(self.clone(), model)
93 }
94}
95
96impl EmbeddingsClient for Client {
97 type EmbeddingModel = EmbeddingModel;
98
99 fn embedding_model(&self, model: &str) -> EmbeddingModel {
112 let ndims = match model {
113 MISTRAL_EMBED => 1024,
114 _ => 0,
115 };
116 EmbeddingModel::new(self.clone(), model, ndims)
117 }
118
119 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
120 EmbeddingModel::new(self.clone(), model, ndims)
121 }
122}
123
124impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
125
126#[derive(Clone, Debug, Deserialize)]
127pub struct Usage {
128 pub completion_tokens: usize,
129 pub prompt_tokens: usize,
130 pub total_tokens: usize,
131}
132
133impl std::fmt::Display for Usage {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 write!(
136 f,
137 "Prompt tokens: {} Total tokens: {}",
138 self.prompt_tokens, self.total_tokens
139 )
140 }
141}
142
143#[derive(Debug, Deserialize)]
144pub struct ApiErrorResponse {
145 pub(crate) message: String,
146}
147
148#[derive(Debug, Deserialize)]
149#[serde(untagged)]
150pub(crate) enum ApiResponse<T> {
151 Ok(T),
152 Err(ApiErrorResponse),
153}