1use async_trait::async_trait;
4
5use crate::llm::openai_compatible::ChatOpenAICompatible;
6use crate::llm::{
7 BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
8};
9
10const MISTRAL_URL: &str = "https://api.mistral.ai/v1";
11
12pub struct ChatMistral {
21 inner: ChatOpenAICompatible,
22}
23
24impl ChatMistral {
25 pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
27 Self::builder().model(model).build()
28 }
29
30 pub fn builder() -> ChatMistralBuilder {
32 ChatMistralBuilder::default()
33 }
34}
35
36#[derive(Default)]
37pub struct ChatMistralBuilder {
38 model: Option<String>,
39 api_key: Option<String>,
40 base_url: Option<String>,
41 temperature: Option<f32>,
42 max_tokens: Option<u64>,
43}
44
45impl ChatMistralBuilder {
46 pub fn model(mut self, model: impl Into<String>) -> Self {
47 self.model = Some(model.into());
48 self
49 }
50
51 pub fn api_key(mut self, key: impl Into<String>) -> Self {
52 self.api_key = Some(key.into());
53 self
54 }
55
56 pub fn base_url(mut self, url: impl Into<String>) -> Self {
57 self.base_url = Some(url.into());
58 self
59 }
60
61 pub fn temperature(mut self, temp: f32) -> Self {
62 self.temperature = Some(temp);
63 self
64 }
65
66 pub fn max_tokens(mut self, tokens: u64) -> Self {
67 self.max_tokens = Some(tokens);
68 self
69 }
70
71 pub fn build(self) -> Result<ChatMistral, LlmError> {
72 let model = self
73 .model
74 .ok_or_else(|| LlmError::Config("model is required".into()))?;
75
76 let api_key = self
77 .api_key
78 .or_else(|| std::env::var("MISTRAL_API_KEY").ok())
79 .ok_or_else(|| LlmError::Config("MISTRAL_API_KEY not set".into()))?;
80
81 let base_url = self
82 .base_url
83 .or_else(|| std::env::var("MISTRAL_BASE_URL").ok())
84 .unwrap_or_else(|| MISTRAL_URL.to_string());
85
86 let inner = ChatOpenAICompatible::builder()
87 .model(&model)
88 .base_url(&base_url)
89 .provider("mistral")
90 .api_key(Some(api_key))
91 .temperature(self.temperature.unwrap_or(0.2))
92 .max_completion_tokens(self.max_tokens)
93 .build()?;
94
95 Ok(ChatMistral { inner })
96 }
97}
98
99#[async_trait]
100impl BaseChatModel for ChatMistral {
101 fn model(&self) -> &str {
102 self.inner.model()
103 }
104
105 fn provider(&self) -> &str {
106 "mistral"
107 }
108
109 fn context_window(&self) -> Option<u64> {
110 let model = self.model().to_lowercase();
111 if model.contains("large") || model.contains("codestral") {
112 Some(128_000)
113 } else {
114 Some(32_000)
115 }
116 }
117
118 async fn invoke(
119 &self,
120 messages: Vec<Message>,
121 tools: Option<Vec<ToolDefinition>>,
122 tool_choice: Option<ToolChoice>,
123 ) -> Result<ChatCompletion, LlmError> {
124 self.inner.invoke(messages, tools, tool_choice).await
125 }
126
127 async fn invoke_stream(
128 &self,
129 messages: Vec<Message>,
130 tools: Option<Vec<ToolDefinition>>,
131 tool_choice: Option<ToolChoice>,
132 ) -> Result<ChatStream, LlmError> {
133 self.inner.invoke_stream(messages, tools, tool_choice).await
134 }
135
136 fn supports_vision(&self) -> bool {
137 let model = self.model().to_lowercase();
138 model.contains("pixtral")
139 }
140}