Skip to main content

rig_core/providers/mistral/
client.rs

1#[cfg(any(feature = "image", feature = "audio"))]
2use crate::client::Nothing;
3use crate::{
4    client::{
5        self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
6        ProviderClient,
7    },
8    http_client,
9    providers::mistral::MistralModelLister,
10};
11use serde::{Deserialize, Serialize};
12use std::fmt::Debug;
13
14const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
15
16#[derive(Debug, Default, Clone, Copy)]
17pub struct MistralExt;
18#[derive(Debug, Default, Clone, Copy)]
19pub struct MistralBuilder;
20
21type MistralApiKey = BearerAuth;
22
23pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
24pub type ClientBuilder<H = crate::markers::Missing> =
25    client::ClientBuilder<MistralBuilder, String, H>;
26
27impl Provider for MistralExt {
28    type Builder = MistralBuilder;
29    const VERIFY_PATH: &'static str = "/models";
30}
31
32impl<H> Capabilities<H> for MistralExt {
33    type Completion = Capable<super::CompletionModel<H>>;
34    type Embeddings = Capable<super::EmbeddingModel<H>>;
35
36    type Transcription = Capable<super::TranscriptionModel<H>>;
37    type ModelListing = Capable<MistralModelLister<H>>;
38    #[cfg(feature = "image")]
39    type ImageGeneration = Nothing;
40
41    #[cfg(feature = "audio")]
42    type AudioGeneration = Nothing;
43}
44
45impl DebugExt for MistralExt {}
46
47impl ProviderBuilder for MistralBuilder {
48    type Extension<H>
49        = MistralExt
50    where
51        H: http_client::HttpClientExt;
52    type ApiKey = MistralApiKey;
53
54    const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
55
56    fn build<H>(
57        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
58    ) -> http_client::Result<Self::Extension<H>>
59    where
60        H: http_client::HttpClientExt,
61    {
62        Ok(MistralExt)
63    }
64}
65
66impl ProviderClient for Client {
67    type Input = String;
68    type Error = crate::client::ProviderClientError;
69
70    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
71    fn from_env() -> Result<Self, Self::Error>
72    where
73        Self: Sized,
74    {
75        let api_key = crate::client::required_env_var("MISTRAL_API_KEY")?;
76        Self::new(&api_key).map_err(Into::into)
77    }
78
79    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
80        Self::new(&input).map_err(Into::into)
81    }
82}
83
84/// In-depth details on prompt tokens.
85///
86/// Mirrors Mistral's `PromptTokensDetails` schema. The Mistral API also exposes
87/// the same shape under the singular field name `prompt_token_details`; the
88/// `Usage` field accepts either form via `serde(alias = ...)`.
89#[derive(Clone, Debug, Default, Deserialize, Serialize)]
90pub struct PromptTokensDetails {
91    /// Number of tokens served from the prompt cache.
92    #[serde(default)]
93    pub cached_tokens: u64,
94}
95
96/// Token usage returned by Mistral's chat completions and embeddings endpoints.
97///
98/// See <https://docs.mistral.ai/api/> (`UsageInfo` schema). The three counts are
99/// always present; the remaining fields are populated by Mistral on a best-effort
100/// basis (e.g. cached-token information appears once a prompt is large enough to
101/// be cached).
102#[derive(Clone, Debug, Deserialize, Serialize)]
103pub struct Usage {
104    pub completion_tokens: usize,
105    pub prompt_tokens: usize,
106    pub total_tokens: usize,
107    /// Duration in seconds of audio tokens in the prompt (audio-input models only).
108    #[serde(default, skip_serializing_if = "Option::is_none")]
109    pub prompt_audio_seconds: Option<u64>,
110    /// Total cached prompt tokens reported at the top level. Some Mistral
111    /// responses populate this in addition to (or instead of)
112    /// `prompt_tokens_details.cached_tokens`.
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub num_cached_tokens: Option<u64>,
115    /// In-depth breakdown of prompt token usage (currently only cached tokens).
116    #[serde(
117        default,
118        alias = "prompt_token_details",
119        skip_serializing_if = "Option::is_none"
120    )]
121    pub prompt_tokens_details: Option<PromptTokensDetails>,
122}
123
124impl Usage {
125    /// Returns the number of cached prompt tokens, preferring the structured
126    /// `prompt_tokens_details.cached_tokens` field and falling back to the
127    /// top-level `num_cached_tokens`. Returns 0 when neither is present.
128    pub fn cached_tokens(&self) -> u64 {
129        self.prompt_tokens_details
130            .as_ref()
131            .map(|d| d.cached_tokens)
132            .or(self.num_cached_tokens)
133            .unwrap_or(0)
134    }
135}
136
137impl std::fmt::Display for Usage {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        write!(
140            f,
141            "Prompt tokens: {} Total tokens: {}",
142            self.prompt_tokens, self.total_tokens
143        )
144    }
145}
146
147#[derive(Debug, Deserialize)]
148pub struct ApiErrorResponse {
149    pub(crate) message: String,
150}
151
152#[derive(Debug, Deserialize)]
153#[serde(untagged)]
154pub(crate) enum ApiResponse<T> {
155    Ok(T),
156    Err(ApiErrorResponse),
157}
158
159#[cfg(test)]
160mod tests {
161    #[test]
162    fn test_client_initialization() {
163        let _client =
164            crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
165        let _client_from_builder = crate::providers::mistral::Client::builder()
166            .api_key("dummy-key")
167            .build()
168            .expect("Client::builder() failed");
169    }
170}