use std::collections::HashMap;
use std::sync::Arc;
use llmsdk_provider::ProviderError;
use llmsdk_provider_utils::http::HttpClient;
use crate::chat::MistralChatModel;
use crate::embedding::MistralEmbeddingModel;
use crate::{API_KEY_ENV_VAR, DEFAULT_BASE_URL};
#[derive(Debug, Clone)]
pub struct Mistral {
inner: Arc<Inner>,
}
pub type GenerateIdFn = dyn Fn() -> String + Send + Sync;
pub(crate) struct Inner {
pub(crate) base_url: String,
pub(crate) headers: HashMap<String, Option<String>>,
pub(crate) http: HttpClient,
pub(crate) generate_id: Option<Arc<GenerateIdFn>>,
}
impl std::fmt::Debug for Inner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Inner")
.field("base_url", &self.base_url)
.field("headers", &self.headers)
.field("http", &self.http)
.field("generate_id", &self.generate_id.is_some())
.finish()
}
}
impl Mistral {
#[must_use]
pub fn builder() -> MistralBuilder {
MistralBuilder::default()
}
pub fn from_env() -> Result<Self, ProviderError> {
Self::builder().build()
}
#[must_use]
pub fn chat(&self, model_id: impl Into<String>) -> MistralChatModel {
MistralChatModel::new(Arc::clone(&self.inner), model_id.into())
}
#[must_use]
pub fn language_model(&self, model_id: impl Into<String>) -> MistralChatModel {
self.chat(model_id)
}
#[must_use]
pub fn embedding(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
MistralEmbeddingModel::new(Arc::clone(&self.inner), model_id.into())
}
#[must_use]
pub fn embedding_model(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
self.embedding(model_id)
}
#[must_use]
pub fn text_embedding(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
self.embedding(model_id)
}
#[must_use]
pub fn text_embedding_model(&self, model_id: impl Into<String>) -> MistralEmbeddingModel {
self.embedding(model_id)
}
}
#[derive(Default, Clone)]
pub struct MistralBuilder {
api_key: Option<String>,
base_url: Option<String>,
extra_headers: HashMap<String, Option<String>>,
http: Option<HttpClient>,
generate_id: Option<Arc<GenerateIdFn>>,
}
impl std::fmt::Debug for MistralBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MistralBuilder")
.field("api_key", &self.api_key.as_ref().map(|_| "***"))
.field("base_url", &self.base_url)
.field("extra_headers", &self.extra_headers)
.field("http", &self.http.is_some())
.field("generate_id", &self.generate_id.is_some())
.finish()
}
}
impl MistralBuilder {
#[must_use]
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
#[must_use]
pub fn header(mut self, name: impl Into<String>, value: Option<String>) -> Self {
self.extra_headers.insert(name.into(), value);
self
}
#[must_use]
pub fn http_client(mut self, client: HttpClient) -> Self {
self.http = Some(client);
self
}
#[must_use]
pub fn generate_id<F>(mut self, f: F) -> Self
where
F: Fn() -> String + Send + Sync + 'static,
{
self.generate_id = Some(Arc::new(f));
self
}
pub fn build(self) -> Result<Mistral, ProviderError> {
let api_key = llmsdk_provider_utils::api_key::load_api_key(
&llmsdk_provider_utils::api_key::LoadApiKey {
api_key: self.api_key.as_deref(),
env_var: API_KEY_ENV_VAR,
description: "Mistral",
parameter_name: Some("api_key"),
},
)?;
let base_url = self.base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_owned());
let mut headers = self.extra_headers;
headers.insert("authorization".into(), Some(format!("Bearer {api_key}")));
let http = match self.http {
Some(client) => client,
None => HttpClient::new()?,
};
Ok(Mistral {
inner: Arc::new(Inner {
base_url,
headers,
http,
generate_id: self.generate_id,
}),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_with_explicit_key_succeeds() {
let m = Mistral::builder().api_key("test-key").build().expect("ok");
assert_eq!(m.inner.base_url, DEFAULT_BASE_URL);
assert!(
m.inner
.headers
.get("authorization")
.unwrap()
.as_ref()
.unwrap()
.starts_with("Bearer ")
);
}
#[test]
fn builder_custom_base_url() {
let m = Mistral::builder()
.api_key("k")
.base_url("https://proxy.example.com/v1")
.build()
.expect("ok");
assert_eq!(m.inner.base_url, "https://proxy.example.com/v1");
}
#[test]
fn builder_generate_id_is_stored() {
let m = Mistral::builder()
.api_key("k")
.generate_id(|| "custom-id".to_owned())
.build()
.expect("ok");
let gen_fn = m.inner.generate_id.as_ref().expect("generate_id stored");
assert_eq!(gen_fn(), "custom-id");
}
#[test]
fn builder_custom_header() {
let m = Mistral::builder()
.api_key("k")
.header("x-feature", Some("y".into()))
.build()
.expect("ok");
assert_eq!(
m.inner.headers.get("x-feature").unwrap().as_deref(),
Some("y")
);
}
}