rig/providers/mistral/
client.rs1use crate::{
2 client::{
3 self, BearerAuth, Capabilities, Capable, DebugExt, Nothing, Provider, ProviderBuilder,
4 ProviderClient,
5 },
6 http_client,
7};
8use serde::{Deserialize, Serialize};
9use std::fmt::Debug;
10
11const MISTRAL_API_BASE_URL: &str = "https://api.mistral.ai";
12
13#[derive(Debug, Default, Clone, Copy)]
14pub struct MistralExt;
15#[derive(Debug, Default, Clone, Copy)]
16pub struct MistralBuilder;
17
18type MistralApiKey = BearerAuth;
19
20pub type Client<H = reqwest::Client> = client::Client<MistralExt, H>;
21pub type ClientBuilder<H = reqwest::Client> = client::ClientBuilder<MistralBuilder, String, H>;
22
23impl Provider for MistralExt {
24 type Builder = MistralBuilder;
25
26 const VERIFY_PATH: &'static str = "/models";
27
28 fn build<H>(
29 _: &client::ClientBuilder<Self::Builder, MistralApiKey, H>,
30 ) -> http_client::Result<Self> {
31 Ok(Self)
32 }
33}
34
35impl<H> Capabilities<H> for MistralExt {
36 type Completion = Capable<super::CompletionModel<H>>;
37 type Embeddings = Capable<super::EmbeddingModel<H>>;
38
39 type Transcription = Nothing;
40 #[cfg(feature = "image")]
41 type ImageGeneration = Nothing;
42
43 #[cfg(feature = "audio")]
44 type AudioGeneration = Nothing;
45}
46
47impl DebugExt for MistralExt {}
48
49impl ProviderBuilder for MistralBuilder {
50 type Output = MistralExt;
51 type ApiKey = MistralApiKey;
52
53 const BASE_URL: &'static str = MISTRAL_API_BASE_URL;
54}
55
56impl ProviderClient for Client {
57 type Input = String;
58
59 fn from_env() -> Self
62 where
63 Self: Sized,
64 {
65 let api_key = std::env::var("MISTRAL_API_KEY").expect("MISTRAL_API_KEY not set");
66 Self::new(&api_key).unwrap()
67 }
68
69 fn from_val(input: Self::Input) -> Self {
70 Self::new(&input).unwrap()
71 }
72}
73
74#[derive(Clone, Debug, Deserialize, Serialize)]
75pub struct Usage {
76 pub completion_tokens: usize,
77 pub prompt_tokens: usize,
78 pub total_tokens: usize,
79}
80
81impl std::fmt::Display for Usage {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(
84 f,
85 "Prompt tokens: {} Total tokens: {}",
86 self.prompt_tokens, self.total_tokens
87 )
88 }
89}
90
91#[derive(Debug, Deserialize)]
92pub struct ApiErrorResponse {
93 pub(crate) message: String,
94}
95
96#[derive(Debug, Deserialize)]
97#[serde(untagged)]
98pub(crate) enum ApiResponse<T> {
99 Ok(T),
100 Err(ApiErrorResponse),
101}