Skip to main content

rig/providers/mistral/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    http_client,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct MistralExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct MistralBuilder;
17
18type MistralApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
22
23impl Provider for MistralExt {
24    type Builder = MistralBuilder;
25
26    const VERIFY_PATH: &'static str = "/models";
27
28    fn build<H>(
29        _: &client::ClientBuilder<Self::Builder, MistralApiKey, H>,
30    ) -> http_client::Result<Self> {
31        Ok(Self)
32    }
33}
34
35impl<H> Capabilities<H> for MistralExt {
36    type Completion = Capable<super::CompletionModel<H>>;
37    type Embeddings = Capable<super::EmbeddingModel<H>>;
38
39    type Transcription = Nothing;
40    type ModelListing = Nothing;
41    #[cfg(feature = "image")]
42    type ImageGeneration = Nothing;
43
44    #[cfg(feature = "audio")]
45    type AudioGeneration = Nothing;
46}
47
48impl DebugExt for MistralExt {}
49
50impl ProviderBuilder for MistralBuilder {
51    type Output = MistralExt;
52    type ApiKey = MistralApiKey;
53
54    const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
55}
56
57impl ProviderClient for Client {
58    type Input = String;
59
60    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
61    /// Panics if the environment variable is not set.
62    fn from_env() -> Self
63    where
64        Self: Sized,
65    {
66        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
67        Self::new(&api_key).unwrap()
68    }
69
70    fn from_val(input: Self::Input) -> Self {
71        Self::new(&input).unwrap()
72    }
73}
74
75#[derive(Clone, Debug, Deserialize, Serialize)]
76pub struct Usage {
77    pub completion_tokens: usize,
78    pub prompt_tokens: usize,
79    pub total_tokens: usize,
80}
81
82impl std::fmt::Display for Usage {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        write!(
85            f,
86            "Prompt tokens: {} Total tokens: {}",
87            self.prompt_tokens, self.total_tokens
88        )
89    }
90}
91
92#[derive(Debug, Deserialize)]
93pub struct ApiErrorResponse {
94    pub(crate) message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99pub(crate) enum ApiResponse<T> {
100    Ok(T),
101    Err(ApiErrorResponse),
102}
103
104#[cfg(test)]
105mod tests {
106    #[test]
107    fn test_client_initialization() {
108        let _client: crate::providers::mistral::Client =
109            crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
110        let _client_from_builder: crate::providers::mistral::Client =
111            crate::providers::mistral::Client::builder()
112                .api_key("dummy-key")
113                .build()
114                .expect("Client::builder() failed");
115    }
116}