1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
8use crate::{
9 OneOrMany,
10 completion::{self, CompletionError, CompletionRequest},
11 json_utils, message,
12 providers::mistral::client::ApiResponse,
13};
14
15pub const CODESTRAL: &str = "codestral-latest";
16pub const MISTRAL_LARGE: &str = "mistral-large-latest";
17pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
18pub const MISTRAL_SABA: &str = "mistral-saba-latest";
19pub const MINISTRAL_3B: &str = "ministral-3b-latest";
20pub const MINISTRAL_8B: &str = "ministral-8b-latest";
21
22pub const MISTRAL_SMALL: &str = "mistral-small-latest";
24pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
25pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
26pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
27
28#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
33#[serde(tag = "type", rename_all = "lowercase")]
34pub struct AssistantContent {
35 text: String,
36}
37
38#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41 Text { text: String },
42}
43
44#[derive(Debug, Serialize, Deserialize, Clone)]
45pub struct Choice {
46 pub index: usize,
47 pub message: Message,
48 pub logprobs: Option<serde_json::Value>,
49 pub finish_reason: String,
50}
51
52#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
53#[serde(tag = "role", rename_all = "lowercase")]
54pub enum Message {
55 User {
56 content: String,
57 },
58 Assistant {
59 content: String,
60 #[serde(
61 default,
62 deserialize_with = "json_utils::null_or_vec",
63 skip_serializing_if = "Vec::is_empty"
64 )]
65 tool_calls: Vec<ToolCall>,
66 #[serde(default)]
67 prefix: bool,
68 },
69 System {
70 content: String,
71 },
72}
73
74impl Message {
75 pub fn user(content: String) -> Self {
76 Message::User { content }
77 }
78
79 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
80 Message::Assistant {
81 content,
82 tool_calls,
83 prefix,
84 }
85 }
86
87 pub fn system(content: String) -> Self {
88 Message::System { content }
89 }
90}
91
92impl TryFrom<message::Message> for Vec<Message> {
93 type Error = message::MessageError;
94
95 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
96 match message {
97 message::Message::User { content } => {
98 let (_, other_content): (Vec<_>, Vec<_>) = content
99 .into_iter()
100 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
101
102 let messages = other_content
103 .into_iter()
104 .filter_map(|content| match content {
105 message::UserContent::Text(message::Text { text }) => {
106 Some(Message::User { content: text })
107 }
108 _ => None,
109 })
110 .collect::<Vec<_>>();
111
112 Ok(messages)
113 }
114 message::Message::Assistant { content, .. } => {
115 let (text_content, tool_calls) = content.into_iter().fold(
116 (Vec::new(), Vec::new()),
117 |(mut texts, mut tools), content| {
118 match content {
119 message::AssistantContent::Text(text) => texts.push(text),
120 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
121 message::AssistantContent::Reasoning(_) => {
122 unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
123 }
124 }
125 (texts, tools)
126 },
127 );
128
129 Ok(vec![Message::Assistant {
130 content: text_content
131 .into_iter()
132 .next()
133 .map(|content| content.text)
134 .unwrap_or_default(),
135 tool_calls: tool_calls
136 .into_iter()
137 .map(|tool_call| tool_call.into())
138 .collect::<Vec<_>>(),
139 prefix: false,
140 }])
141 }
142 }
143 }
144}
145
146#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
147pub struct ToolCall {
148 pub id: String,
149 #[serde(default)]
150 pub r#type: ToolType,
151 pub function: Function,
152}
153
154impl From<message::ToolCall> for ToolCall {
155 fn from(tool_call: message::ToolCall) -> Self {
156 Self {
157 id: tool_call.id,
158 r#type: ToolType::default(),
159 function: Function {
160 name: tool_call.function.name,
161 arguments: tool_call.function.arguments,
162 },
163 }
164 }
165}
166
167#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
168pub struct Function {
169 pub name: String,
170 #[serde(with = "json_utils::stringified_json")]
171 pub arguments: serde_json::Value,
172}
173
174#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
175#[serde(rename_all = "lowercase")]
176pub enum ToolType {
177 #[default]
178 Function,
179}
180
181#[derive(Debug, Deserialize, Serialize, Clone)]
182pub struct ToolDefinition {
183 pub r#type: String,
184 pub function: completion::ToolDefinition,
185}
186
187impl From<completion::ToolDefinition> for ToolDefinition {
188 fn from(tool: completion::ToolDefinition) -> Self {
189 Self {
190 r#type: "function".into(),
191 function: tool,
192 }
193 }
194}
195
196#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
197pub struct ToolResultContent {
198 #[serde(default)]
199 r#type: ToolResultContentType,
200 text: String,
201}
202
203#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
204#[serde(rename_all = "lowercase")]
205pub enum ToolResultContentType {
206 #[default]
207 Text,
208}
209
210impl From<String> for ToolResultContent {
211 fn from(s: String) -> Self {
212 ToolResultContent {
213 r#type: ToolResultContentType::default(),
214 text: s,
215 }
216 }
217}
218
219impl From<String> for UserContent {
220 fn from(s: String) -> Self {
221 UserContent::Text { text: s }
222 }
223}
224
225impl FromStr for UserContent {
226 type Err = Infallible;
227
228 fn from_str(s: &str) -> Result<Self, Self::Err> {
229 Ok(UserContent::Text {
230 text: s.to_string(),
231 })
232 }
233}
234
235impl From<String> for AssistantContent {
236 fn from(s: String) -> Self {
237 AssistantContent { text: s }
238 }
239}
240
241impl FromStr for AssistantContent {
242 type Err = Infallible;
243
244 fn from_str(s: &str) -> Result<Self, Self::Err> {
245 Ok(AssistantContent {
246 text: s.to_string(),
247 })
248 }
249}
250
251#[derive(Clone)]
252pub struct CompletionModel {
253 pub(crate) client: Client,
254 pub model: String,
255}
256
257impl CompletionModel {
258 pub fn new(client: Client, model: &str) -> Self {
259 Self {
260 client,
261 model: model.to_string(),
262 }
263 }
264
265 pub(crate) fn create_completion_request(
266 &self,
267 completion_request: CompletionRequest,
268 ) -> Result<Value, CompletionError> {
269 let mut partial_history = vec![];
270 if let Some(docs) = completion_request.normalized_documents() {
271 partial_history.push(docs);
272 }
273
274 partial_history.extend(completion_request.chat_history);
275
276 let mut full_history: Vec<Message> = match &completion_request.preamble {
277 Some(preamble) => vec![Message::system(preamble.clone())],
278 None => vec![],
279 };
280
281 full_history.extend(
282 partial_history
283 .into_iter()
284 .map(message::Message::try_into)
285 .collect::<Result<Vec<Vec<Message>>, _>>()?
286 .into_iter()
287 .flatten()
288 .collect::<Vec<_>>(),
289 );
290
291 let request = if completion_request.tools.is_empty() {
292 json!({
293 "model": self.model,
294 "messages": full_history,
295
296 })
297 } else {
298 json!({
299 "model": self.model,
300 "messages": full_history,
301 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
302 "tool_choice": "auto",
303 })
304 };
305
306 let request = if let Some(temperature) = completion_request.temperature {
307 json_utils::merge(
308 request,
309 json!({
310 "temperature": temperature,
311 }),
312 )
313 } else {
314 request
315 };
316
317 let request = if let Some(params) = completion_request.additional_params {
318 json_utils::merge(request, params)
319 } else {
320 request
321 };
322
323 Ok(request)
324 }
325}
326
327#[derive(Debug, Deserialize, Clone, Serialize)]
328pub struct CompletionResponse {
329 pub id: String,
330 pub object: String,
331 pub created: u64,
332 pub model: String,
333 pub system_fingerprint: Option<String>,
334 pub choices: Vec<Choice>,
335 pub usage: Option<Usage>,
336}
337
338impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
339 type Error = CompletionError;
340
341 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
342 let choice = response.choices.first().ok_or_else(|| {
343 CompletionError::ResponseError("Response contained no choices".to_owned())
344 })?;
345 let content = match &choice.message {
346 Message::Assistant {
347 content,
348 tool_calls,
349 ..
350 } => {
351 let mut content = if content.is_empty() {
352 vec![]
353 } else {
354 vec![completion::AssistantContent::text(content.clone())]
355 };
356
357 content.extend(
358 tool_calls
359 .iter()
360 .map(|call| {
361 completion::AssistantContent::tool_call(
362 &call.id,
363 &call.function.name,
364 call.function.arguments.clone(),
365 )
366 })
367 .collect::<Vec<_>>(),
368 );
369 Ok(content)
370 }
371 _ => Err(CompletionError::ResponseError(
372 "Response did not contain a valid message or tool call".into(),
373 )),
374 }?;
375
376 let choice = OneOrMany::many(content).map_err(|_| {
377 CompletionError::ResponseError(
378 "Response contained no message or tool call (empty)".to_owned(),
379 )
380 })?;
381
382 let usage = response
383 .usage
384 .as_ref()
385 .map(|usage| completion::Usage {
386 input_tokens: usage.prompt_tokens as u64,
387 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
388 total_tokens: usage.total_tokens as u64,
389 })
390 .unwrap_or_default();
391
392 Ok(completion::CompletionResponse {
393 choice,
394 usage,
395 raw_response: response,
396 })
397 }
398}
399
400impl completion::CompletionModel for CompletionModel {
401 type Response = CompletionResponse;
402 type StreamingResponse = CompletionResponse;
403
404 #[cfg_attr(feature = "worker", worker::send)]
405 async fn completion(
406 &self,
407 completion_request: CompletionRequest,
408 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
409 let request = self.create_completion_request(completion_request)?;
410
411 let response = self
412 .client
413 .post("v1/chat/completions")
414 .json(&request)
415 .send()
416 .await?;
417
418 if response.status().is_success() {
419 let text = response.text().await?;
420 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
421 ApiResponse::Ok(response) => {
422 tracing::debug!(target: "rig",
423 "Mistral completion token usage: {:?}",
424 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
425 );
426 response.try_into()
427 }
428 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
429 }
430 } else {
431 Err(CompletionError::ProviderError(response.text().await?))
432 }
433 }
434
435 #[cfg_attr(feature = "worker", worker::send)]
436 async fn stream(
437 &self,
438 request: CompletionRequest,
439 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
440 let resp = self.completion(request).await?;
441
442 let stream = Box::pin(stream! {
443 for c in resp.choice.clone() {
444 match c {
445 message::AssistantContent::Text(t) => {
446 yield Ok(RawStreamingChoice::Message(t.text.clone()))
447 }
448 message::AssistantContent::ToolCall(tc) => {
449 yield Ok(RawStreamingChoice::ToolCall {
450 id: tc.id.clone(),
451 name: tc.function.name.clone(),
452 arguments: tc.function.arguments.clone(),
453 call_id: None
454 })
455 }
456 message::AssistantContent::Reasoning(_) => {
457 unimplemented!("Reasoning is not supported on Mistral via Rig")
458 }
459 }
460 }
461
462 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
463 });
464
465 Ok(StreamingCompletionResponse::stream(stream))
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472
473 #[test]
474 fn test_response_deserialization() {
475 let json_data = r#"
477 {
478 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
479 "object": "chat.completion",
480 "model": "mistral-small-latest",
481 "usage": {
482 "prompt_tokens": 16,
483 "completion_tokens": 34,
484 "total_tokens": 50
485 },
486 "created": 1702256327,
487 "choices": [
488 {
489 "index": 0,
490 "message": {
491 "content": "string",
492 "tool_calls": [
493 {
494 "id": "null",
495 "type": "function",
496 "function": {
497 "name": "string",
498 "arguments": "{ }"
499 },
500 "index": 0
501 }
502 ],
503 "prefix": false,
504 "role": "assistant"
505 },
506 "finish_reason": "stop"
507 }
508 ]
509 }
510 "#;
511 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
512 assert_eq!(completion_response.model, MISTRAL_SMALL);
513
514 let CompletionResponse {
515 id,
516 object,
517 created,
518 choices,
519 usage,
520 ..
521 } = completion_response;
522
523 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
524
525 let Usage {
526 completion_tokens,
527 prompt_tokens,
528 total_tokens,
529 } = usage.unwrap();
530
531 assert_eq!(prompt_tokens, 16);
532 assert_eq!(completion_tokens, 34);
533 assert_eq!(total_tokens, 50);
534 assert_eq!(object, "chat.completion".to_string());
535 assert_eq!(created, 1702256327);
536 assert_eq!(choices.len(), 1);
537 }
538}