rig/providers/mistral/
client.rs1use serde::Deserialize;
2
3use super::{
4 embedding::{EmbeddingModel, MISTRAL_EMBED},
5 CompletionModel,
6};
7use crate::client::{CompletionClient, EmbeddingsClient, ProviderClient};
8use crate::impl_conversion_traits;
9
10const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
11
12#[derive(Debug, Clone)]
13pub struct Client {
14 base_url: String,
15 http_client: reqwest::Client,
16}
17
18impl Client {
19 pub fn new(api_key: &str) -> Self {
20 Self::from_url(api_key, MISTRAL_API_BASE_URL)
21 }
22
23 pub fn from_url(api_key: &str, base_url: &str) -> Self {
24 Self {
25 base_url: base_url.to_string(),
26 http_client: reqwest::Client::builder()
27 .default_headers({
28 let mut headers = reqwest::header::HeaderMap::new();
29 headers.insert(
30 "Authorization",
31 format!("Bearer {api_key}")
32 .parse()
33 .expect("Bearer token should parse"),
34 );
35 headers
36 })
37 .build()
38 .expect("Mistral reqwest client should build"),
39 }
40 }
41
42 pub(crate) fn post(&self, path: &str) -> reqwest::RequestBuilder {
43 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
44 self.http_client.post(url)
45 }
46}
47
48impl ProviderClient for Client {
49 fn from_env() -> Self
52 where
53 Self: Sized,
54 {
55 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
56 Self::new(&api_key)
57 }
58}
59
60impl CompletionClient for Client {
61 type CompletionModel = CompletionModel;
62
63 fn completion_model(&self, model: &str) -> Self::CompletionModel {
75 CompletionModel::new(self.clone(), model)
76 }
77}
78
79impl EmbeddingsClient for Client {
80 type EmbeddingModel = EmbeddingModel;
81
82 fn embedding_model(&self, model: &str) -> EmbeddingModel {
95 let ndims = match model {
96 MISTRAL_EMBED => 1024,
97 _ => 0,
98 };
99 EmbeddingModel::new(self.clone(), model, ndims)
100 }
101
102 fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> Self::EmbeddingModel {
103 EmbeddingModel::new(self.clone(), model, ndims)
104 }
105}
106
107impl_conversion_traits!(AsTranscription, AsAudioGeneration, AsImageGeneration for Client);
108
109#[derive(Clone, Debug, Deserialize)]
110pub struct Usage {
111 pub completion_tokens: usize,
112 pub prompt_tokens: usize,
113 pub total_tokens: usize,
114}
115
116impl std::fmt::Display for Usage {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(
119 f,
120 "Prompt tokens: {} Total tokens: {}",
121 self.prompt_tokens, self.total_tokens
122 )
123 }
124}
125
126#[derive(Debug, Deserialize)]
127pub struct ApiErrorResponse {
128 pub(crate) message: String,
129}
130
131#[derive(Debug, Deserialize)]
132#[serde(untagged)]
133pub(crate) enum ApiResponse<T> {
134 Ok(T),
135 Err(ApiErrorResponse),
136}