mistralai_client/v1/
chat.rs

1use serde::{Deserialize, Serialize};
2
3use crate::v1::{common, constants, tool};
4
5// -----------------------------------------------------------------------------
6// Definitions
7
8#[derive(Clone, Debug, Deserialize, Serialize)]
9pub struct ChatMessage {
10    pub role: ChatMessageRole,
11    pub content: String,
12    #[serde(skip_serializing_if = "Option::is_none")]
13    pub tool_calls: Option<Vec<tool::ToolCall>>,
14}
15impl ChatMessage {
16    pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
17        Self {
18            role: ChatMessageRole::Assistant,
19            content: content.to_string(),
20            tool_calls,
21        }
22    }
23
24    pub fn new_user_message(content: &str) -> Self {
25        Self {
26            role: ChatMessageRole::User,
27            content: content.to_string(),
28            tool_calls: None,
29        }
30    }
31}
32
33/// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information.
34#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
35pub enum ChatMessageRole {
36    #[serde(rename = "system")]
37    System,
38    #[serde(rename = "assistant")]
39    Assistant,
40    #[serde(rename = "user")]
41    User,
42    #[serde(rename = "tool")]
43    Tool,
44}
45
46/// The format that the model must output.
47///
48/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
49#[derive(Clone, Debug, Serialize, Deserialize)]
50pub struct ResponseFormat {
51    #[serde(rename = "type")]
52    pub type_: String,
53}
54impl ResponseFormat {
55    pub fn json_object() -> Self {
56        Self {
57            type_: "json_object".to_string(),
58        }
59    }
60}
61
62// -----------------------------------------------------------------------------
63// Request
64
65/// The parameters for the chat request.
66///
67/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
68#[derive(Clone, Debug)]
69pub struct ChatParams {
70    /// The maximum number of tokens to generate in the completion.
71    ///
72    /// Defaults to `None`.
73    pub max_tokens: Option<u32>,
74    /// The seed to use for random sampling. If set, different calls will generate deterministic results.
75    ///
76    /// Defaults to `None`.
77    pub random_seed: Option<u32>,
78    /// The format that the model must output.
79    ///
80    /// Defaults to `None`.
81    pub response_format: Option<ResponseFormat>,
82    /// Whether to inject a safety prompt before all conversations.
83    ///
84    /// Defaults to `false`.
85    pub safe_prompt: bool,
86    /// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
87    ///
88    /// Defaults to `0.7`.
89    pub temperature: f32,
90    /// Specifies if/how functions are called.
91    ///
92    /// Defaults to `None`.
93    pub tool_choice: Option<tool::ToolChoice>,
94    /// A list of available tools for the model.
95    ///
96    /// Defaults to `None`.
97    pub tools: Option<Vec<tool::Tool>>,
98    /// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
99    ///
100    /// Defaults to `1.0`.
101    pub top_p: f32,
102}
103impl Default for ChatParams {
104    fn default() -> Self {
105        Self {
106            max_tokens: None,
107            random_seed: None,
108            safe_prompt: false,
109            response_format: None,
110            temperature: 0.7,
111            tool_choice: None,
112            tools: None,
113            top_p: 1.0,
114        }
115    }
116}
117impl ChatParams {
118    pub fn json_default() -> Self {
119        Self {
120            max_tokens: None,
121            random_seed: None,
122            safe_prompt: false,
123            response_format: None,
124            temperature: 0.7,
125            tool_choice: None,
126            tools: None,
127            top_p: 1.0,
128        }
129    }
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133pub struct ChatRequest {
134    pub messages: Vec<ChatMessage>,
135    pub model: constants::Model,
136
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub max_tokens: Option<u32>,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    pub random_seed: Option<u32>,
141    #[serde(skip_serializing_if = "Option::is_none")]
142    pub response_format: Option<ResponseFormat>,
143    pub safe_prompt: bool,
144    pub stream: bool,
145    pub temperature: f32,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    pub tool_choice: Option<tool::ToolChoice>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub tools: Option<Vec<tool::Tool>>,
150    pub top_p: f32,
151}
152impl ChatRequest {
153    pub fn new(
154        model: constants::Model,
155        messages: Vec<ChatMessage>,
156        stream: bool,
157        options: Option<ChatParams>,
158    ) -> Self {
159        let ChatParams {
160            max_tokens,
161            random_seed,
162            safe_prompt,
163            temperature,
164            tool_choice,
165            tools,
166            top_p,
167            response_format,
168        } = options.unwrap_or_default();
169
170        Self {
171            messages,
172            model,
173
174            max_tokens,
175            random_seed,
176            safe_prompt,
177            stream,
178            temperature,
179            tool_choice,
180            tools,
181            top_p,
182            response_format,
183        }
184    }
185}
186
187// -----------------------------------------------------------------------------
188// Response
189
190#[derive(Clone, Debug, Deserialize, Serialize)]
191pub struct ChatResponse {
192    pub id: String,
193    pub object: String,
194    /// Unix timestamp (in seconds).
195    pub created: u32,
196    pub model: constants::Model,
197    pub choices: Vec<ChatResponseChoice>,
198    pub usage: common::ResponseUsage,
199}
200
201#[derive(Clone, Debug, Deserialize, Serialize)]
202pub struct ChatResponseChoice {
203    pub index: u32,
204    pub message: ChatMessage,
205    pub finish_reason: ChatResponseChoiceFinishReason,
206    // TODO Check this prop (seen in API responses but undocumented).
207    // pub logprobs: ???
208}
209
210#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
211pub enum ChatResponseChoiceFinishReason {
212    #[serde(rename = "stop")]
213    Stop,
214    #[serde(rename = "tool_calls")]
215    ToolCalls,
216}