Skip to main content

agent_io/llm/
mistral.rs

1//! Mistral Chat Model implementation
2
3use async_trait::async_trait;
4
5use crate::llm::openai_compatible::ChatOpenAICompatible;
6use crate::llm::{
7    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
8};
9
10const MISTRAL_URL: &str = "https://api.mistral.ai/v1";
11
12/// Mistral Chat Model
13///
14/// # Example
15/// ```ignore
16/// use agent_io::llm::ChatMistral;
17///
18/// let llm = ChatMistral::new("mistral-large-latest")?;
19/// ```
20pub struct ChatMistral {
21    inner: ChatOpenAICompatible,
22}
23
24impl ChatMistral {
25    /// Create a new Mistral chat model
26    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
27        Self::builder().model(model).build()
28    }
29
30    /// Create a builder for configuration
31    pub fn builder() -> ChatMistralBuilder {
32        ChatMistralBuilder::default()
33    }
34}
35
36#[derive(Default)]
37pub struct ChatMistralBuilder {
38    model: Option<String>,
39    api_key: Option<String>,
40    base_url: Option<String>,
41    temperature: Option<f32>,
42    max_tokens: Option<u64>,
43}
44
45impl ChatMistralBuilder {
46    pub fn model(mut self, model: impl Into<String>) -> Self {
47        self.model = Some(model.into());
48        self
49    }
50
51    pub fn api_key(mut self, key: impl Into<String>) -> Self {
52        self.api_key = Some(key.into());
53        self
54    }
55
56    pub fn base_url(mut self, url: impl Into<String>) -> Self {
57        self.base_url = Some(url.into());
58        self
59    }
60
61    pub fn temperature(mut self, temp: f32) -> Self {
62        self.temperature = Some(temp);
63        self
64    }
65
66    pub fn max_tokens(mut self, tokens: u64) -> Self {
67        self.max_tokens = Some(tokens);
68        self
69    }
70
71    pub fn build(self) -> Result<ChatMistral, LlmError> {
72        let model = self
73            .model
74            .ok_or_else(|| LlmError::Config("model is required".into()))?;
75
76        let api_key = self
77            .api_key
78            .or_else(|| std::env::var("MISTRAL_API_KEY").ok())
79            .ok_or_else(|| LlmError::Config("MISTRAL_API_KEY not set".into()))?;
80
81        let base_url = self
82            .base_url
83            .or_else(|| std::env::var("MISTRAL_BASE_URL").ok())
84            .unwrap_or_else(|| MISTRAL_URL.to_string());
85
86        let inner = ChatOpenAICompatible::builder()
87            .model(&model)
88            .base_url(&base_url)
89            .provider("mistral")
90            .api_key(Some(api_key))
91            .temperature(self.temperature.unwrap_or(0.2))
92            .max_completion_tokens(self.max_tokens)
93            .build()?;
94
95        Ok(ChatMistral { inner })
96    }
97}
98
99#[async_trait]
100impl BaseChatModel for ChatMistral {
101    fn model(&self) -> &str {
102        self.inner.model()
103    }
104
105    fn provider(&self) -> &str {
106        "mistral"
107    }
108
109    fn context_window(&self) -> Option<u64> {
110        let model = self.model().to_lowercase();
111        if model.contains("large") || model.contains("codestral") {
112            Some(128_000)
113        } else {
114            Some(32_000)
115        }
116    }
117
118    async fn invoke(
119        &self,
120        messages: Vec<Message>,
121        tools: Option<Vec<ToolDefinition>>,
122        tool_choice: Option<ToolChoice>,
123    ) -> Result<ChatCompletion, LlmError> {
124        self.inner.invoke(messages, tools, tool_choice).await
125    }
126
127    async fn invoke_stream(
128        &self,
129        messages: Vec<Message>,
130        tools: Option<Vec<ToolDefinition>>,
131        tool_choice: Option<ToolChoice>,
132    ) -> Result<ChatStream, LlmError> {
133        self.inner.invoke_stream(messages, tools, tool_choice).await
134    }
135
136    fn supports_vision(&self) -> bool {
137        let model = self.model().to_lowercase();
138        model.contains("pixtral")
139    }
140}