Skip to main content

rig/providers/mistral/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    http_client,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct MistralExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct MistralBuilder;
17
18type MistralApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
22
23impl Provider for MistralExt {
24    type Builder = MistralBuilder;
25    const VERIFY_PATH: &'static str = "/models";
26}
27
28impl<H> Capabilities<H> for MistralExt {
29    type Completion = Capable<super::CompletionModel<H>>;
30    type Embeddings = Capable<super::EmbeddingModel<H>>;
31
32    type Transcription = Capable<super::TranscriptionModel<H>>;
33    type ModelListing = Nothing;
34    #[cfg(feature = "image")]
35    type ImageGeneration = Nothing;
36
37    #[cfg(feature = "audio")]
38    type AudioGeneration = Nothing;
39}
40
41impl DebugExt for MistralExt {}
42
43impl ProviderBuilder for MistralBuilder {
44    type Extension<H>
45        = MistralExt
46    where
47        H: http_client::HttpClientExt;
48    type ApiKey = MistralApiKey;
49
50    const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
51
52    fn build<H>(
53        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
54    ) -> http_client::Result<Self::Extension<H>>
55    where
56        H: http_client::HttpClientExt,
57    {
58        Ok(MistralExt)
59    }
60}
61
62impl ProviderClient for Client {
63    type Input = String;
64
65    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
66    /// Panics if the environment variable is not set.
67    fn from_env() -> Self
68    where
69        Self: Sized,
70    {
71        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
72        Self::new(&api_key).unwrap()
73    }
74
75    fn from_val(input: Self::Input) -> Self {
76        Self::new(&input).unwrap()
77    }
78}
79
80#[derive(Clone, Debug, Deserialize, Serialize)]
81pub struct Usage {
82    pub completion_tokens: usize,
83    pub prompt_tokens: usize,
84    pub total_tokens: usize,
85}
86
87impl std::fmt::Display for Usage {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        write!(
90            f,
91            "Prompt tokens: {} Total tokens: {}",
92            self.prompt_tokens, self.total_tokens
93        )
94    }
95}
96
97#[derive(Debug, Deserialize)]
98pub struct ApiErrorResponse {
99    pub(crate) message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104pub(crate) enum ApiResponse<T> {
105    Ok(T),
106    Err(ApiErrorResponse),
107}
108
109#[cfg(test)]
110mod tests {
111    #[test]
112    fn test_client_initialization() {
113        let _client =
114            crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
115        let _client_from_builder = crate::providers::mistral::Client::builder()
116            .api_key("dummy-key")
117            .build()
118            .expect("Client::builder() failed");
119    }
120}