Skip to main content

rig/providers/mistral/
client.rs

1#[cfg(any(feature = "image", feature = "audio"))]
2use crate::client::Nothing;
3use crate::{
4    client::{
5        self, BearerAuth, Capabilities, Capable, DebugExt, Provider, ProviderBuilder,
6        ProviderClient,
7    },
8    http_client,
9    providers::mistral::MistralModelLister,
10};
11use serde::{Deserialize, Serialize};
12use std::fmt::Debug;
13
14const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
15
16#[derive(Debug, Default, Clone, Copy)]
17pub struct MistralExt;
18#[derive(Debug, Default, Clone, Copy)]
19pub struct MistralBuilder;
20
21type MistralApiKey = BearerAuth;
22
23pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
24pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
25
26impl Provider for MistralExt {
27    type Builder = MistralBuilder;
28    const VERIFY_PATH: &'static str = "/models";
29}
30
31impl<H> Capabilities<H> for MistralExt {
32    type Completion = Capable<super::CompletionModel<H>>;
33    type Embeddings = Capable<super::EmbeddingModel<H>>;
34
35    type Transcription = Capable<super::TranscriptionModel<H>>;
36    type ModelListing = Capable<MistralModelLister<H>>;
37    #[cfg(feature = "image")]
38    type ImageGeneration = Nothing;
39
40    #[cfg(feature = "audio")]
41    type AudioGeneration = Nothing;
42}
43
44impl DebugExt for MistralExt {}
45
46impl ProviderBuilder for MistralBuilder {
47    type Extension<H>
48        = MistralExt
49    where
50        H: http_client::HttpClientExt;
51    type ApiKey = MistralApiKey;
52
53    const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
54
55    fn build<H>(
56        _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
57    ) -> http_client::Result<Self::Extension<H>>
58    where
59        H: http_client::HttpClientExt,
60    {
61        Ok(MistralExt)
62    }
63}
64
65impl ProviderClient for Client {
66    type Input = String;
67    type Error = crate::client::ProviderClientError;
68
69    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
70    fn from_env() -> Result<Self, Self::Error>
71    where
72        Self: Sized,
73    {
74        let api_key = crate::client::required_env_var("MISTRAL_API_KEY")?;
75        Self::new(&api_key).map_err(Into::into)
76    }
77
78    fn from_val(input: Self::Input) -> Result<Self, Self::Error> {
79        Self::new(&input).map_err(Into::into)
80    }
81}
82
83#[derive(Clone, Debug, Deserialize, Serialize)]
84pub struct Usage {
85    pub completion_tokens: usize,
86    pub prompt_tokens: usize,
87    pub total_tokens: usize,
88}
89
90impl std::fmt::Display for Usage {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        write!(
93            f,
94            "Prompt tokens: {} Total tokens: {}",
95            self.prompt_tokens, self.total_tokens
96        )
97    }
98}
99
100#[derive(Debug, Deserialize)]
101pub struct ApiErrorResponse {
102    pub(crate) message: String,
103}
104
105#[derive(Debug, Deserialize)]
106#[serde(untagged)]
107pub(crate) enum ApiResponse<T> {
108    Ok(T),
109    Err(ApiErrorResponse),
110}
111
112#[cfg(test)]
113mod tests {
114    #[test]
115    fn test_client_initialization() {
116        let _client =
117            crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
118        let _client_from_builder = crate::providers::mistral::Client::builder()
119            .api_key("dummy-key")
120            .build()
121            .expect("Client::builder() failed");
122    }
123}