rig_core/providers/mistral/
client.rs1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 http_client,
7 providers::mistral::MistralModelLister,
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11
12const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct MistralExt;
16#[derive(Debug, Default, Clone, Copy)]
17pub struct MistralBuilder;
18
19type MistralApiKey = BearerAuth;
20
21pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
22pub type ClientBuilder<H = crate::markers::Missing> =
23 client::ClientBuilder<MistralBuilder, String, H>;
24
25impl Provider for MistralExt {
26 type Builder = MistralBuilder;
27 const VERIFY_PATH: &'static str = "/models";
28}
29
30impl<H> Capabilities<H> for MistralExt {
31 type Completion = Capable<super::CompletionModel<H>>;
32 type Embeddings = Capable<super::EmbeddingModel<H>>;
33
34 type Transcription = Capable<super::TranscriptionModel<H>>;
35 type ModelListing = Capable<MistralModelLister<H>>;
36 #[cfg(feature = "image")]
37 type ImageGeneration = Nothing;
38
39 #[cfg(feature = "audio")]
40 type AudioGeneration = Nothing;
41 type Rerank = Nothing;
42}
43
44impl DebugExt for MistralExt {}
45
46impl ProviderBuilder for MistralBuilder {
47 type Extension<H>
48 = MistralExt
49 where
50 H: http_client::HttpClientExt;
51 type ApiKey = MistralApiKey;
52
53 const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
54
55 fn build<H>(
56 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
57 ) -> http_client::Result<Self::Extension<H>>
58 where
59 H: http_client::HttpClientExt,
60 {
61 Ok(MistralExt)
62 }
63}
64
65impl ProviderClient for Client {
66 type Input = String;
67 type Error = crate::client::ProviderClientError;
68
69 fn from_env() -> Result<Self, Self::Error>
71 where
72 Self: Sized,
73 {
74 let api_key = crate::client::required_env_var("MISTRAL_API_KEY")?;
75 Self::new(&api_key).map_err(Into::into)
76 }
77
78 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
79 Self::new(&input).map_err(Into::into)
80 }
81}
82
83#[derive(Clone, Debug, Default, Deserialize, Serialize)]
89pub struct PromptTokensDetails {
90 #[serde(default)]
92 pub cached_tokens: u64,
93}
94
95#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct Usage {
103 pub completion_tokens: usize,
104 pub prompt_tokens: usize,
105 pub total_tokens: usize,
106 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub prompt_audio_seconds: Option<u64>,
109 #[serde(default, skip_serializing_if = "Option::is_none")]
113 pub num_cached_tokens: Option<u64>,
114 #[serde(
116 default,
117 alias = "prompt_token_details",
118 skip_serializing_if = "Option::is_none"
119 )]
120 pub prompt_tokens_details: Option<PromptTokensDetails>,
121}
122
123impl Usage {
124 pub fn cached_tokens(&self) -> u64 {
128 self.prompt_tokens_details
129 .as_ref()
130 .map(|d| d.cached_tokens)
131 .or(self.num_cached_tokens)
132 .unwrap_or(0)
133 }
134}
135
136impl std::fmt::Display for Usage {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 write!(
139 f,
140 "Prompt tokens: {} Total tokens: {}",
141 self.prompt_tokens, self.total_tokens
142 )
143 }
144}
145
146#[derive(Debug, Deserialize)]
147pub struct ApiErrorResponse {
148 pub(crate) message: String,
149}
150
151#[derive(Debug, Deserialize)]
152#[serde(untagged)]
153pub(crate) enum ApiResponse<T> {
154 Ok(T),
155 Err(ApiErrorResponse),
156}
157
158#[cfg(test)]
159mod tests {
160 #[test]
161 fn test_client_initialization() {
162 let _client =
163 crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
164 let _client_from_builder = crate::providers::mistral::Client::builder()
165 .api_key("dummy-key")
166 .build()
167 .expect("Client::builder() failed");
168 }
169}