Skip to main content

agent_io/llm/mistral/
mod.rs

1//! Mistral Chat Model implementation
2
3mod builder;
4
5use async_trait::async_trait;
6
7use crate::llm::{
8    BaseChatModel, ChatCompletion, ChatStream, LlmError, Message, ToolChoice, ToolDefinition,
9};
10
11pub use builder::ChatMistralBuilder;
12
13/// Mistral Chat Model
14///
15/// # Example
16/// ```ignore
17/// use agent_io::llm::ChatMistral;
18///
19/// let llm = ChatMistral::new("mistral-large-latest")?;
20/// ```
21pub struct ChatMistral {
22    pub(super) inner: crate::llm::openai_compatible::ChatOpenAICompatible,
23}
24
25impl ChatMistral {
26    /// Create a new Mistral chat model
27    pub fn new(model: impl Into<String>) -> Result<Self, LlmError> {
28        Self::builder().model(model).build()
29    }
30
31    /// Create a builder for configuration
32    pub fn builder() -> ChatMistralBuilder {
33        ChatMistralBuilder::default()
34    }
35}
36
37#[async_trait]
38impl BaseChatModel for ChatMistral {
39    fn model(&self) -> &str {
40        self.inner.model()
41    }
42
43    fn provider(&self) -> &str {
44        "mistral"
45    }
46
47    fn context_window(&self) -> Option<u64> {
48        let model = self.model().to_lowercase();
49        if model.contains("large") || model.contains("codestral") {
50            Some(128_000)
51        } else {
52            Some(32_000)
53        }
54    }
55
56    async fn invoke(
57        &self,
58        messages: Vec<Message>,
59        tools: Option<Vec<ToolDefinition>>,
60        tool_choice: Option<ToolChoice>,
61    ) -> Result<ChatCompletion, LlmError> {
62        self.inner.invoke(messages, tools, tool_choice).await
63    }
64
65    async fn invoke_stream(
66        &self,
67        messages: Vec<Message>,
68        tools: Option<Vec<ToolDefinition>>,
69        tool_choice: Option<ToolChoice>,
70    ) -> Result<ChatStream, LlmError> {
71        self.inner.invoke_stream(messages, tools, tool_choice).await
72    }
73
74    fn supports_vision(&self) -> bool {
75        let model = self.model().to_lowercase();
76        model.contains("pixtral")
77    }
78}