rig_core/providers/mistral/
client.rs1#[cfg(any(feature = "image", feature = "audio"))]
2use crate::client::Nothing;
3use crate::{
4 client::{
5 self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
6 ProviderClient,
7 },
8 http_client,
9 providers::mistral::MistralModelLister,
10};
11use serde::{Deserialize, Serialize};
12use std::fmt::Debug;
13
14const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
15
16#[derive(Debug, Default, Clone, Copy)]
17pub struct MistralExt;
18#[derive(Debug, Default, Clone, Copy)]
19pub struct MistralBuilder;
20
21type MistralApiKey = BearerAuth;
22
23pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
24pub type ClientBuilder<H = crate::markers::Missing> =
25 client::ClientBuilder<MistralBuilder, String, H>;
26
27impl Provider for MistralExt {
28 type Builder = MistralBuilder;
29 const VERIFY_PATH: &'static str = "/models";
30}
31
32impl<H> Capabilities<H> for MistralExt {
33 type Completion = Capable<super::CompletionModel<H>>;
34 type Embeddings = Capable<super::EmbeddingModel<H>>;
35
36 type Transcription = Capable<super::TranscriptionModel<H>>;
37 type ModelListing = Capable<MistralModelLister<H>>;
38 #[cfg(feature = "image")]
39 type ImageGeneration = Nothing;
40
41 #[cfg(feature = "audio")]
42 type AudioGeneration = Nothing;
43}
44
45impl DebugExt for MistralExt {}
46
47impl ProviderBuilder for MistralBuilder {
48 type Extension<H>
49 = MistralExt
50 where
51 H: http_client::HttpClientExt;
52 type ApiKey = MistralApiKey;
53
54 const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
55
56 fn build<H>(
57 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
58 ) -> http_client::Result<Self::Extension<H>>
59 where
60 H: http_client::HttpClientExt,
61 {
62 Ok(MistralExt)
63 }
64}
65
66impl ProviderClient for Client {
67 type Input = String;
68 type Error = crate::client::ProviderClientError;
69
70 fn from_env() -> Result<Self, Self::Error>
72 where
73 Self: Sized,
74 {
75 let api_key = crate::client::required_env_var("MISTRAL_API_KEY")?;
76 Self::new(&api_key).map_err(Into::into)
77 }
78
79 fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
80 Self::new(&input).map_err(Into::into)
81 }
82}
83
84#[derive(Clone, Debug, Default, Deserialize, Serialize)]
90pub struct PromptTokensDetails {
91 #[serde(default)]
93 pub cached_tokens: u64,
94}
95
96#[derive(Clone, Debug, Deserialize, Serialize)]
103pub struct Usage {
104 pub completion_tokens: usize,
105 pub prompt_tokens: usize,
106 pub total_tokens: usize,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub prompt_audio_seconds: Option<u64>,
110 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub num_cached_tokens: Option<u64>,
115 #[serde(
117 default,
118 alias = "prompt_token_details",
119 skip_serializing_if = "Option::is_none"
120 )]
121 pub prompt_tokens_details: Option<PromptTokensDetails>,
122}
123
124impl Usage {
125 pub fn cached_tokens(&self) -> u64 {
129 self.prompt_tokens_details
130 .as_ref()
131 .map(|d| d.cached_tokens)
132 .or(self.num_cached_tokens)
133 .unwrap_or(0)
134 }
135}
136
137impl std::fmt::Display for Usage {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 write!(
140 f,
141 "Prompt tokens: {} Total tokens: {}",
142 self.prompt_tokens, self.total_tokens
143 )
144 }
145}
146
147#[derive(Debug, Deserialize)]
148pub struct ApiErrorResponse {
149 pub(crate) message: String,
150}
151
152#[derive(Debug, Deserialize)]
153#[serde(untagged)]
154pub(crate) enum ApiResponse<T> {
155 Ok(T),
156 Err(ApiErrorResponse),
157}
158
159#[cfg(test)]
160mod tests {
161 #[test]
162 fn test_client_initialization() {
163 let _client =
164 crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
165 let _client_from_builder = crate::providers::mistral::Client::builder()
166 .api_key("dummy-key")
167 .build()
168 .expect("Client::builder() failed");
169 }
170}