rig/providers/mistral/
client.rs1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 http_client,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct MistralExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct MistralBuilder;
17
18type MistralApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
22
23impl Provider for MistralExt {
24 type Builder = MistralBuilder;
25 const VERIFY_PATH: &'static str = "/models";
26}
27
28impl<H> Capabilities<H> for MistralExt {
29 type Completion = Capable<super::CompletionModel<H>>;
30 type Embeddings = Capable<super::EmbeddingModel<H>>;
31
32 type Transcription = Capable<super::TranscriptionModel<H>>;
33 type ModelListing = Nothing;
34 #[cfg(feature = "image")]
35 type ImageGeneration = Nothing;
36
37 #[cfg(feature = "audio")]
38 type AudioGeneration = Nothing;
39}
40
41impl DebugExt for MistralExt {}
42
43impl ProviderBuilder for MistralBuilder {
44 type Extension<H>
45 = MistralExt
46 where
47 H: http_client::HttpClientExt;
48 type ApiKey = MistralApiKey;
49
50 const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
51
52 fn build<H>(
53 _builder: &client::ClientBuilder<Self, Self::ApiKey, H>,
54 ) -> http_client::Result<Self::Extension<H>>
55 where
56 H: http_client::HttpClientExt,
57 {
58 Ok(MistralExt)
59 }
60}
61
62impl ProviderClient for Client {
63 type Input = String;
64
65 fn from_env() -> Self
68 where
69 Self: Sized,
70 {
71 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
72 Self::new(&api_key).unwrap()
73 }
74
75 fn from_val(input: Self::Input) -> Self {
76 Self::new(&input).unwrap()
77 }
78}
79
80#[derive(Clone, Debug, Deserialize, Serialize)]
81pub struct Usage {
82 pub completion_tokens: usize,
83 pub prompt_tokens: usize,
84 pub total_tokens: usize,
85}
86
87impl std::fmt::Display for Usage {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 write!(
90 f,
91 "Prompt tokens: {} Total tokens: {}",
92 self.prompt_tokens, self.total_tokens
93 )
94 }
95}
96
97#[derive(Debug, Deserialize)]
98pub struct ApiErrorResponse {
99 pub(crate) message: String,
100}
101
102#[derive(Debug, Deserialize)]
103#[serde(untagged)]
104pub(crate) enum ApiResponse<T> {
105 Ok(T),
106 Err(ApiErrorResponse),
107}
108
109#[cfg(test)]
110mod tests {
111 #[test]
112 fn test_client_initialization() {
113 let _client =
114 crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
115 let _client_from_builder = crate::providers::mistral::Client::builder()
116 .api_key("dummy-key")
117 .build()
118 .expect("Client::builder() failed");
119 }
120}