rig/providers/mistral/
client.rs1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 http_client,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct MistralExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct MistralBuilder;
17
18type MistralApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
22
23impl Provider for MistralExt {
24 type Builder = MistralBuilder;
25
26 const VERIFY_PATH: &'static str = "/models";
27
28 fn build<H>(
29 _: &client::ClientBuilder<Self::Builder, MistralApiKey, H>,
30 ) -> http_client::Result<Self> {
31 Ok(Self)
32 }
33}
34
35impl<H> Capabilities<H> for MistralExt {
36 type Completion = Capable<super::CompletionModel<H>>;
37 type Embeddings = Capable<super::EmbeddingModel<H>>;
38
39 type Transcription = Nothing;
40 type ModelListing = Nothing;
41 #[cfg(feature = "image")]
42 type ImageGeneration = Nothing;
43
44 #[cfg(feature = "audio")]
45 type AudioGeneration = Nothing;
46}
47
48impl DebugExt for MistralExt {}
49
50impl ProviderBuilder for MistralBuilder {
51 type Output = MistralExt;
52 type ApiKey = MistralApiKey;
53
54 const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
55}
56
57impl ProviderClient for Client {
58 type Input = String;
59
60 fn from_env() -> Self
63 where
64 Self: Sized,
65 {
66 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
67 Self::new(&api_key).unwrap()
68 }
69
70 fn from_val(input: Self::Input) -> Self {
71 Self::new(&input).unwrap()
72 }
73}
74
75#[derive(Clone, Debug, Deserialize, Serialize)]
76pub struct Usage {
77 pub completion_tokens: usize,
78 pub prompt_tokens: usize,
79 pub total_tokens: usize,
80}
81
82impl std::fmt::Display for Usage {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(
85 f,
86 "Prompt tokens: {} Total tokens: {}",
87 self.prompt_tokens, self.total_tokens
88 )
89 }
90}
91
92#[derive(Debug, Deserialize)]
93pub struct ApiErrorResponse {
94 pub(crate) message: String,
95}
96
97#[derive(Debug, Deserialize)]
98#[serde(untagged)]
99pub(crate) enum ApiResponse<T> {
100 Ok(T),
101 Err(ApiErrorResponse),
102}
103
104#[cfg(test)]
105mod tests {
106 #[test]
107 fn test_client_initialization() {
108 let _client: crate::providers::mistral::Client =
109 crate::providers::mistral::Client::new("dummy-key").expect("Client::new() failed");
110 let _client_from_builder: crate::providers::mistral::Client =
111 crate::providers::mistral::Client::builder()
112 .api_key("dummy-key")
113 .build()
114 .expect("Client::builder() failed");
115 }
116}