Skip to main content

synaptic_mistral/
lib.rs

1use std::sync::Arc;
2pub use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
3use synaptic_models::ProviderBackend;
4use synaptic_openai::{OpenAiChatModel, OpenAiConfig};
5pub use synaptic_openai::{OpenAiEmbeddings, OpenAiEmbeddingsConfig};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum MistralModel {
9    MistralLargeLatest,
10    MistralSmallLatest,
11    OpenMistralNemo,
12    CodestralLatest,
13    Custom(String),
14}
15impl MistralModel {
16    pub fn as_str(&self) -> &str {
17        match self {
18            MistralModel::MistralLargeLatest => "mistral-large-latest",
19            MistralModel::MistralSmallLatest => "mistral-small-latest",
20            MistralModel::OpenMistralNemo => "open-mistral-nemo",
21            MistralModel::CodestralLatest => "codestral-latest",
22            MistralModel::Custom(s) => s.as_str(),
23        }
24    }
25}
26impl std::fmt::Display for MistralModel {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "{}", self.as_str())
29    }
30}
31
32#[derive(Debug, Clone)]
33pub struct MistralConfig {
34    pub api_key: String,
35    pub model: String,
36    pub max_tokens: Option<u32>,
37    pub temperature: Option<f64>,
38    pub top_p: Option<f64>,
39    pub stop: Option<Vec<String>>,
40    pub seed: Option<u64>,
41}
42impl MistralConfig {
43    pub fn new(api_key: impl Into<String>, model: MistralModel) -> Self {
44        Self {
45            api_key: api_key.into(),
46            model: model.to_string(),
47            max_tokens: None,
48            temperature: None,
49            top_p: None,
50            stop: None,
51            seed: None,
52        }
53    }
54    pub fn new_custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55        Self {
56            api_key: api_key.into(),
57            model: model.into(),
58            max_tokens: None,
59            temperature: None,
60            top_p: None,
61            stop: None,
62            seed: None,
63        }
64    }
65    pub fn with_max_tokens(mut self, v: u32) -> Self {
66        self.max_tokens = Some(v);
67        self
68    }
69    pub fn with_temperature(mut self, v: f64) -> Self {
70        self.temperature = Some(v);
71        self
72    }
73    pub fn with_top_p(mut self, v: f64) -> Self {
74        self.top_p = Some(v);
75        self
76    }
77    pub fn with_stop(mut self, v: Vec<String>) -> Self {
78        self.stop = Some(v);
79        self
80    }
81    pub fn with_seed(mut self, v: u64) -> Self {
82        self.seed = Some(v);
83        self
84    }
85}
86impl From<MistralConfig> for OpenAiConfig {
87    fn from(c: MistralConfig) -> Self {
88        let mut cfg =
89            OpenAiConfig::new(c.api_key, c.model).with_base_url("https://api.mistral.ai/v1");
90        if let Some(v) = c.max_tokens {
91            cfg = cfg.with_max_tokens(v);
92        }
93        if let Some(v) = c.temperature {
94            cfg = cfg.with_temperature(v);
95        }
96        if let Some(v) = c.top_p {
97            cfg = cfg.with_top_p(v);
98        }
99        if let Some(v) = c.stop {
100            cfg = cfg.with_stop(v);
101        }
102        if let Some(v) = c.seed {
103            cfg = cfg.with_seed(v);
104        }
105        cfg
106    }
107}
108
109pub struct MistralChatModel {
110    inner: OpenAiChatModel,
111}
112
113impl MistralChatModel {
114    pub fn new(config: MistralConfig, backend: Arc<dyn ProviderBackend>) -> Self {
115        Self {
116            inner: OpenAiChatModel::new(config.into(), backend),
117        }
118    }
119}
120
121#[async_trait::async_trait]
122impl ChatModel for MistralChatModel {
123    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
124        self.inner.chat(request).await
125    }
126    fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
127        self.inner.stream_chat(request)
128    }
129}
130
131pub fn mistral_embeddings(
132    api_key: impl Into<String>,
133    model: impl Into<String>,
134    backend: Arc<dyn ProviderBackend>,
135) -> OpenAiEmbeddings {
136    let config = OpenAiEmbeddingsConfig::new(api_key)
137        .with_model(model)
138        .with_base_url("https://api.mistral.ai/v1");
139    OpenAiEmbeddings::new(config, backend)
140}