1use std::sync::Arc;
2pub use synaptic_core::{ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError};
3use synaptic_models::ProviderBackend;
4use synaptic_openai::{OpenAiChatModel, OpenAiConfig};
5pub use synaptic_openai::{OpenAiEmbeddings, OpenAiEmbeddingsConfig};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub enum MistralModel {
9 MistralLargeLatest,
10 MistralSmallLatest,
11 OpenMistralNemo,
12 CodestralLatest,
13 Custom(String),
14}
15impl MistralModel {
16 pub fn as_str(&self) -> &str {
17 match self {
18 MistralModel::MistralLargeLatest => "mistral-large-latest",
19 MistralModel::MistralSmallLatest => "mistral-small-latest",
20 MistralModel::OpenMistralNemo => "open-mistral-nemo",
21 MistralModel::CodestralLatest => "codestral-latest",
22 MistralModel::Custom(s) => s.as_str(),
23 }
24 }
25}
26impl std::fmt::Display for MistralModel {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(f, "{}", self.as_str())
29 }
30}
31
32#[derive(Debug, Clone)]
33pub struct MistralConfig {
34 pub api_key: String,
35 pub model: String,
36 pub max_tokens: Option<u32>,
37 pub temperature: Option<f64>,
38 pub top_p: Option<f64>,
39 pub stop: Option<Vec<String>>,
40 pub seed: Option<u64>,
41}
42impl MistralConfig {
43 pub fn new(api_key: impl Into<String>, model: MistralModel) -> Self {
44 Self {
45 api_key: api_key.into(),
46 model: model.to_string(),
47 max_tokens: None,
48 temperature: None,
49 top_p: None,
50 stop: None,
51 seed: None,
52 }
53 }
54 pub fn new_custom(api_key: impl Into<String>, model: impl Into<String>) -> Self {
55 Self {
56 api_key: api_key.into(),
57 model: model.into(),
58 max_tokens: None,
59 temperature: None,
60 top_p: None,
61 stop: None,
62 seed: None,
63 }
64 }
65 pub fn with_max_tokens(mut self, v: u32) -> Self {
66 self.max_tokens = Some(v);
67 self
68 }
69 pub fn with_temperature(mut self, v: f64) -> Self {
70 self.temperature = Some(v);
71 self
72 }
73 pub fn with_top_p(mut self, v: f64) -> Self {
74 self.top_p = Some(v);
75 self
76 }
77 pub fn with_stop(mut self, v: Vec<String>) -> Self {
78 self.stop = Some(v);
79 self
80 }
81 pub fn with_seed(mut self, v: u64) -> Self {
82 self.seed = Some(v);
83 self
84 }
85}
86impl From<MistralConfig> for OpenAiConfig {
87 fn from(c: MistralConfig) -> Self {
88 let mut cfg =
89 OpenAiConfig::new(c.api_key, c.model).with_base_url("https://api.mistral.ai/v1");
90 if let Some(v) = c.max_tokens {
91 cfg = cfg.with_max_tokens(v);
92 }
93 if let Some(v) = c.temperature {
94 cfg = cfg.with_temperature(v);
95 }
96 if let Some(v) = c.top_p {
97 cfg = cfg.with_top_p(v);
98 }
99 if let Some(v) = c.stop {
100 cfg = cfg.with_stop(v);
101 }
102 if let Some(v) = c.seed {
103 cfg = cfg.with_seed(v);
104 }
105 cfg
106 }
107}
108
109pub struct MistralChatModel {
110 inner: OpenAiChatModel,
111}
112
113impl MistralChatModel {
114 pub fn new(config: MistralConfig, backend: Arc<dyn ProviderBackend>) -> Self {
115 Self {
116 inner: OpenAiChatModel::new(config.into(), backend),
117 }
118 }
119}
120
121#[async_trait::async_trait]
122impl ChatModel for MistralChatModel {
123 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
124 self.inner.chat(request).await
125 }
126 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
127 self.inner.stream_chat(request)
128 }
129}
130
131pub fn mistral_embeddings(
132 api_key: impl Into<String>,
133 model: impl Into<String>,
134 backend: Arc<dyn ProviderBackend>,
135) -> OpenAiEmbeddings {
136 let config = OpenAiEmbeddingsConfig::new(api_key)
137 .with_model(model)
138 .with_base_url("https://api.mistral.ai/v1");
139 OpenAiEmbeddings::new(config, backend)
140}