Skip to main content

rig_core/providers/mistral/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    http_client,
7    providers::mistral::MistralModelLister,
8};
9use serde::{Deserialize, Serialize};
10use std::fmt::Debug;
11
12const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct MistralExt;
16#[derive(Debug, Default, Clone, Copy)]
17pub struct MistralBuilder;
18
19type MistralApiKey = BearerAuth;
20
21pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
22pub type ClientBuilder<H = crate::markers::Missing> =
23    client::ClientBuilder<MistralBuilder, String, H>;
24
25impl Provider for MistralExt {
26    type Builder = MistralBuilder;
27    const VERIFY_PATH: &'static str = "/models";
28}
29
30impl<H> Capabilities<H> for MistralExt {
31    type Completion = Capable<super::CompletionModel<H>>;
32    type Embeddings = Capable<super::EmbeddingModel<H>>;
33
34    type Transcription = Capable<super::TranscriptionModel<H>>;
35    type ModelListing = Capable<MistralModelLister<H>>;
36    #[cfg(feature = "image")]
37    type ImageGeneration = Nothing;
38
39    #[cfg(feature = "audio")]
40    type AudioGeneration = Nothing;
41    type Rerank = Nothing;
42}
43
44impl DebugExt for MistralExt {}
45
46impl ProviderBuilder for MistralBuilder {
47    type Extension<H>
48        = MistralExt
49    where
50        H: http_client::HttpClientExt;
51    type ApiKey = MistralApiKey;
52
53    const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
54
55    fn build<H>(
56        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
57    ) -> http_client::Result<Self::Extension<H>>
58    where
59        H: http_client::HttpClientExt,
60    {
61        Ok(MistralExt)
62    }
63}
64
65impl ProviderClient for Client {
66    type Input = String;
67    type Error = crate::client::ProviderClientError;
68
69    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
70    fn from_env() -> Result<Self, Self::Error>
71    where
72        Self: Sized,
73    {
74        let api_key = crate::client::required_env_var("MISTRAL_API_KEY")?;
75        Self::new(&api_key).map_err(Into::into)
76    }
77
78    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
79        Self::new(&input).map_err(Into::into)
80    }
81}
82
83/// In-depth details on prompt tokens.
84///
85/// Mirrors Mistral's `PromptTokensDetails` schema. The Mistral API also exposes
86/// the same shape under the singular field name `prompt_token_details`; the
87/// `Usage` field accepts either form via `serde(alias = ...)`.
88#[derive(Clone, Debug, Default, Deserialize, Serialize)]
89pub struct PromptTokensDetails {
90    /// Number of tokens served from the prompt cache.
91    #[serde(default)]
92    pub cached_tokens: u64,
93}
94
95/// Token usage returned by Mistral's chat completions and embeddings endpoints.
96///
97/// See <https://docs.mistral.ai/api/> (`UsageInfo` schema). The three counts are
98/// always present; the remaining fields are populated by Mistral on a best-effort
99/// basis (e.g. cached-token information appears once a prompt is large enough to
100/// be cached).
101#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct Usage {
103    pub completion_tokens: usize,
104    pub prompt_tokens: usize,
105    pub total_tokens: usize,
106    /// Duration in seconds of audio tokens in the prompt (audio-input models only).
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub prompt_audio_seconds: Option<u64>,
109    /// Total cached prompt tokens reported at the top level. Some Mistral
110    /// responses populate this in addition to (or instead of)
111    /// `prompt_tokens_details.cached_tokens`.
112    #[serde(default, skip_serializing_if = "Option::is_none")]
113    pub num_cached_tokens: Option<u64>,
114    /// In-depth breakdown of prompt token usage (currently only cached tokens).
115    #[serde(
116        default,
117        alias = "prompt_token_details",
118        skip_serializing_if = "Option::is_none"
119    )]
120    pub prompt_tokens_details: Option<PromptTokensDetails>,
121}
122
123impl Usage {
124    /// Returns the number of cached prompt tokens, preferring the structured
125    /// `prompt_tokens_details.cached_tokens` field and falling back to the
126    /// top-level `num_cached_tokens`. Returns 0 when neither is present.
127    pub fn cached_tokens(&self) -> u64 {
128        self.prompt_tokens_details
129            .as_ref()
130            .map(|d| d.cached_tokens)
131            .or(self.num_cached_tokens)
132            .unwrap_or(0)
133    }
134}
135
136impl std::fmt::Display for Usage {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(
139            f,
140            "Prompt tokens: {} Total tokens: {}",
141            self.prompt_tokens, self.total_tokens
142        )
143    }
144}
145
146#[derive(Debug, Deserialize)]
147pub struct ApiErrorResponse {
148    pub(crate) message: String,
149}
150
151#[derive(Debug, Deserialize)]
152#[serde(untagged)]
153pub(crate) enum ApiResponse<T> {
154    Ok(T),
155    Err(ApiErrorResponse),
156}
157
158#[cfg(test)]
159mod tests {
160    #[test]
161    fn test_client_initialization() {
162        let _client =
163            crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
164        let _client_from_builder = crate::providers::mistral::Client::builder()
165            .api_key("dummy-key")
166            .build()
167            .expect("Client::builder() failed");
168    }
169}