rig/providers/mistral/
client.rs

1use crate::{
2    client::{
3        self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4        ProviderClient,
5    },
6    http_client,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct MistralExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct MistralBuilder;
17
18type MistralApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
22
23impl Provider for MistralExt {
24    type Builder = MistralBuilder;
25
26    const VERIFY_PATH: &'static str = "/models";
27
28    fn build<H>(
29        _: &client::ClientBuilder<Self::Builder, MistralApiKey, H>,
30    ) -> http_client::Result<Self> {
31        Ok(Self)
32    }
33}
34
35impl<H> Capabilities<H> for MistralExt {
36    type Completion = Capable<super::CompletionModel<H>>;
37    type Embeddings = Capable<super::EmbeddingModel<H>>;
38
39    type Transcription = Nothing;
40    #[cfg(feature = "image")]
41    type ImageGeneration = Nothing;
42
43    #[cfg(feature = "audio")]
44    type AudioGeneration = Nothing;
45}
46
47impl DebugExt for MistralExt {}
48
49impl ProviderBuilder for MistralBuilder {
50    type Output = MistralExt;
51    type ApiKey = MistralApiKey;
52
53    const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
54}
55
56impl ProviderClient for Client {
57    type Input = String;
58
59    /// Create a new Mistral client from the `MISTRAL_API_KEY` environment variable.
60    /// Panics if the environment variable is not set.
61    fn from_env() -> Self
62    where
63        Self: Sized,
64    {
65        let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
66        Self::new(&api_key).unwrap()
67    }
68
69    fn from_val(input: Self::Input) -> Self {
70        Self::new(&input).unwrap()
71    }
72}
73
74#[derive(Clone, Debug, Deserialize, Serialize)]
75pub struct Usage {
76    pub completion_tokens: usize,
77    pub prompt_tokens: usize,
78    pub total_tokens: usize,
79}
80
81impl std::fmt::Display for Usage {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(
84            f,
85            "Prompt tokens: {} Total tokens: {}",
86            self.prompt_tokens, self.total_tokens
87        )
88    }
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ApiErrorResponse {
93    pub(crate) message: String,
94}
95
96#[derive(Debug, Deserialize)]
97#[serde(untagged)]
98pub(crate) enum ApiResponse<T> {
99    Ok(T),
100    Err(ApiErrorResponse),
101}