1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
9use crate::{
10 OneOrMany,
11 completion::{self, CompletionError, CompletionRequest},
12 json_utils, message,
13 providers::mistral::client::ApiResponse,
14};
15
16pub const CODESTRAL: &str = "codestral-latest";
17pub const MISTRAL_LARGE: &str = "mistral-large-latest";
18pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
19pub const MISTRAL_SABA: &str = "mistral-saba-latest";
20pub const MINISTRAL_3B: &str = "ministral-3b-latest";
21pub const MINISTRAL_8B: &str = "ministral-8b-latest";
22
23pub const MISTRAL_SMALL: &str = "mistral-small-latest";
25pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
26pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
27pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
28
29#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
34#[serde(tag = "type", rename_all = "lowercase")]
35pub struct AssistantContent {
36 text: String,
37}
38
39#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
40#[serde(tag = "type", rename_all = "lowercase")]
41pub enum UserContent {
42 Text { text: String },
43}
44
45#[derive(Debug, Serialize, Deserialize, Clone)]
46pub struct Choice {
47 pub index: usize,
48 pub message: Message,
49 pub logprobs: Option<serde_json::Value>,
50 pub finish_reason: String,
51}
52
53#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
54#[serde(tag = "role", rename_all = "lowercase")]
55pub enum Message {
56 User {
57 content: String,
58 },
59 Assistant {
60 content: String,
61 #[serde(
62 default,
63 deserialize_with = "json_utils::null_or_vec",
64 skip_serializing_if = "Vec::is_empty"
65 )]
66 tool_calls: Vec<ToolCall>,
67 #[serde(default)]
68 prefix: bool,
69 },
70 System {
71 content: String,
72 },
73}
74
75impl Message {
76 pub fn user(content: String) -> Self {
77 Message::User { content }
78 }
79
80 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
81 Message::Assistant {
82 content,
83 tool_calls,
84 prefix,
85 }
86 }
87
88 pub fn system(content: String) -> Self {
89 Message::System { content }
90 }
91}
92
93impl TryFrom<message::Message> for Vec<Message> {
94 type Error = message::MessageError;
95
96 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
97 match message {
98 message::Message::User { content } => {
99 let (_, other_content): (Vec<_>, Vec<_>) = content
100 .into_iter()
101 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
102
103 let messages = other_content
104 .into_iter()
105 .filter_map(|content| match content {
106 message::UserContent::Text(message::Text { text }) => {
107 Some(Message::User { content: text })
108 }
109 _ => None,
110 })
111 .collect::<Vec<_>>();
112
113 Ok(messages)
114 }
115 message::Message::Assistant { content, .. } => {
116 let (text_content, tool_calls) = content.into_iter().fold(
117 (Vec::new(), Vec::new()),
118 |(mut texts, mut tools), content| {
119 match content {
120 message::AssistantContent::Text(text) => texts.push(text),
121 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
122 message::AssistantContent::Reasoning(_) => {
123 unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
124 }
125 }
126 (texts, tools)
127 },
128 );
129
130 Ok(vec![Message::Assistant {
131 content: text_content
132 .into_iter()
133 .next()
134 .map(|content| content.text)
135 .unwrap_or_default(),
136 tool_calls: tool_calls
137 .into_iter()
138 .map(|tool_call| tool_call.into())
139 .collect::<Vec<_>>(),
140 prefix: false,
141 }])
142 }
143 }
144 }
145}
146
147#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
148pub struct ToolCall {
149 pub id: String,
150 #[serde(default)]
151 pub r#type: ToolType,
152 pub function: Function,
153}
154
155impl From<message::ToolCall> for ToolCall {
156 fn from(tool_call: message::ToolCall) -> Self {
157 Self {
158 id: tool_call.id,
159 r#type: ToolType::default(),
160 function: Function {
161 name: tool_call.function.name,
162 arguments: tool_call.function.arguments,
163 },
164 }
165 }
166}
167
168#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
169pub struct Function {
170 pub name: String,
171 #[serde(with = "json_utils::stringified_json")]
172 pub arguments: serde_json::Value,
173}
174
175#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
176#[serde(rename_all = "lowercase")]
177pub enum ToolType {
178 #[default]
179 Function,
180}
181
182#[derive(Debug, Deserialize, Serialize, Clone)]
183pub struct ToolDefinition {
184 pub r#type: String,
185 pub function: completion::ToolDefinition,
186}
187
188impl From<completion::ToolDefinition> for ToolDefinition {
189 fn from(tool: completion::ToolDefinition) -> Self {
190 Self {
191 r#type: "function".into(),
192 function: tool,
193 }
194 }
195}
196
197#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
198pub struct ToolResultContent {
199 #[serde(default)]
200 r#type: ToolResultContentType,
201 text: String,
202}
203
204#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
205#[serde(rename_all = "lowercase")]
206pub enum ToolResultContentType {
207 #[default]
208 Text,
209}
210
211impl From<String> for ToolResultContent {
212 fn from(s: String) -> Self {
213 ToolResultContent {
214 r#type: ToolResultContentType::default(),
215 text: s,
216 }
217 }
218}
219
220impl From<String> for UserContent {
221 fn from(s: String) -> Self {
222 UserContent::Text { text: s }
223 }
224}
225
226impl FromStr for UserContent {
227 type Err = Infallible;
228
229 fn from_str(s: &str) -> Result<Self, Self::Err> {
230 Ok(UserContent::Text {
231 text: s.to_string(),
232 })
233 }
234}
235
236impl From<String> for AssistantContent {
237 fn from(s: String) -> Self {
238 AssistantContent { text: s }
239 }
240}
241
242impl FromStr for AssistantContent {
243 type Err = Infallible;
244
245 fn from_str(s: &str) -> Result<Self, Self::Err> {
246 Ok(AssistantContent {
247 text: s.to_string(),
248 })
249 }
250}
251
252#[derive(Clone)]
253pub struct CompletionModel {
254 pub(crate) client: Client,
255 pub model: String,
256}
257
258impl CompletionModel {
259 pub fn new(client: Client, model: &str) -> Self {
260 Self {
261 client,
262 model: model.to_string(),
263 }
264 }
265
266 pub(crate) fn create_completion_request(
267 &self,
268 completion_request: CompletionRequest,
269 ) -> Result<Value, CompletionError> {
270 let mut partial_history = vec![];
271 if let Some(docs) = completion_request.normalized_documents() {
272 partial_history.push(docs);
273 }
274
275 partial_history.extend(completion_request.chat_history);
276
277 let mut full_history: Vec<Message> = match &completion_request.preamble {
278 Some(preamble) => vec![Message::system(preamble.clone())],
279 None => vec![],
280 };
281
282 full_history.extend(
283 partial_history
284 .into_iter()
285 .map(message::Message::try_into)
286 .collect::<Result<Vec<Vec<Message>>, _>>()?
287 .into_iter()
288 .flatten()
289 .collect::<Vec<_>>(),
290 );
291
292 let request = if completion_request.tools.is_empty() {
293 json!({
294 "model": self.model,
295 "messages": full_history,
296
297 })
298 } else {
299 json!({
300 "model": self.model,
301 "messages": full_history,
302 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
303 "tool_choice": "auto",
304 })
305 };
306
307 let request = if let Some(temperature) = completion_request.temperature {
308 json_utils::merge(
309 request,
310 json!({
311 "temperature": temperature,
312 }),
313 )
314 } else {
315 request
316 };
317
318 let request = if let Some(params) = completion_request.additional_params {
319 json_utils::merge(request, params)
320 } else {
321 request
322 };
323
324 Ok(request)
325 }
326}
327
328#[derive(Debug, Deserialize, Clone, Serialize)]
329pub struct CompletionResponse {
330 pub id: String,
331 pub object: String,
332 pub created: u64,
333 pub model: String,
334 pub system_fingerprint: Option<String>,
335 pub choices: Vec<Choice>,
336 pub usage: Option<Usage>,
337}
338
339impl GetTokenUsage for CompletionResponse {
340 fn token_usage(&self) -> Option<crate::completion::Usage> {
341 let api_usage = self.usage.clone()?;
342
343 let mut usage = crate::completion::Usage::new();
344 usage.input_tokens = api_usage.prompt_tokens as u64;
345 usage.output_tokens = api_usage.completion_tokens as u64;
346 usage.total_tokens = api_usage.total_tokens as u64;
347
348 Some(usage)
349 }
350}
351
352impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
353 type Error = CompletionError;
354
355 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
356 let choice = response.choices.first().ok_or_else(|| {
357 CompletionError::ResponseError("Response contained no choices".to_owned())
358 })?;
359 let content = match &choice.message {
360 Message::Assistant {
361 content,
362 tool_calls,
363 ..
364 } => {
365 let mut content = if content.is_empty() {
366 vec![]
367 } else {
368 vec![completion::AssistantContent::text(content.clone())]
369 };
370
371 content.extend(
372 tool_calls
373 .iter()
374 .map(|call| {
375 completion::AssistantContent::tool_call(
376 &call.id,
377 &call.function.name,
378 call.function.arguments.clone(),
379 )
380 })
381 .collect::<Vec<_>>(),
382 );
383 Ok(content)
384 }
385 _ => Err(CompletionError::ResponseError(
386 "Response did not contain a valid message or tool call".into(),
387 )),
388 }?;
389
390 let choice = OneOrMany::many(content).map_err(|_| {
391 CompletionError::ResponseError(
392 "Response contained no message or tool call (empty)".to_owned(),
393 )
394 })?;
395
396 let usage = response
397 .usage
398 .as_ref()
399 .map(|usage| completion::Usage {
400 input_tokens: usage.prompt_tokens as u64,
401 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
402 total_tokens: usage.total_tokens as u64,
403 })
404 .unwrap_or_default();
405
406 Ok(completion::CompletionResponse {
407 choice,
408 usage,
409 raw_response: response,
410 })
411 }
412}
413
414impl completion::CompletionModel for CompletionModel {
415 type Response = CompletionResponse;
416 type StreamingResponse = CompletionResponse;
417
418 #[cfg_attr(feature = "worker", worker::send)]
419 async fn completion(
420 &self,
421 completion_request: CompletionRequest,
422 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
423 let request = self.create_completion_request(completion_request)?;
424
425 let response = self
426 .client
427 .post("v1/chat/completions")
428 .json(&request)
429 .send()
430 .await?;
431
432 if response.status().is_success() {
433 let text = response.text().await?;
434 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
435 ApiResponse::Ok(response) => {
436 tracing::debug!(target: "rig",
437 "Mistral completion token usage: {:?}",
438 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
439 );
440 response.try_into()
441 }
442 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
443 }
444 } else {
445 Err(CompletionError::ProviderError(response.text().await?))
446 }
447 }
448
449 #[cfg_attr(feature = "worker", worker::send)]
450 async fn stream(
451 &self,
452 request: CompletionRequest,
453 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
454 let resp = self.completion(request).await?;
455
456 let stream = Box::pin(stream! {
457 for c in resp.choice.clone() {
458 match c {
459 message::AssistantContent::Text(t) => {
460 yield Ok(RawStreamingChoice::Message(t.text.clone()))
461 }
462 message::AssistantContent::ToolCall(tc) => {
463 yield Ok(RawStreamingChoice::ToolCall {
464 id: tc.id.clone(),
465 name: tc.function.name.clone(),
466 arguments: tc.function.arguments.clone(),
467 call_id: None
468 })
469 }
470 message::AssistantContent::Reasoning(_) => {
471 unimplemented!("Reasoning is not supported on Mistral via Rig")
472 }
473 }
474 }
475
476 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
477 });
478
479 Ok(StreamingCompletionResponse::stream(stream))
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486
487 #[test]
488 fn test_response_deserialization() {
489 let json_data = r#"
491 {
492 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
493 "object": "chat.completion",
494 "model": "mistral-small-latest",
495 "usage": {
496 "prompt_tokens": 16,
497 "completion_tokens": 34,
498 "total_tokens": 50
499 },
500 "created": 1702256327,
501 "choices": [
502 {
503 "index": 0,
504 "message": {
505 "content": "string",
506 "tool_calls": [
507 {
508 "id": "null",
509 "type": "function",
510 "function": {
511 "name": "string",
512 "arguments": "{ }"
513 },
514 "index": 0
515 }
516 ],
517 "prefix": false,
518 "role": "assistant"
519 },
520 "finish_reason": "stop"
521 }
522 ]
523 }
524 "#;
525 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
526 assert_eq!(completion_response.model, MISTRAL_SMALL);
527
528 let CompletionResponse {
529 id,
530 object,
531 created,
532 choices,
533 usage,
534 ..
535 } = completion_response;
536
537 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
538
539 let Usage {
540 completion_tokens,
541 prompt_tokens,
542 total_tokens,
543 } = usage.unwrap();
544
545 assert_eq!(prompt_tokens, 16);
546 assert_eq!(completion_tokens, 34);
547 assert_eq!(total_tokens, 50);
548 assert_eq!(object, "chat.completion".to_string());
549 assert_eq!(created, 1702256327);
550 assert_eq!(choices.len(), 1);
551 }
552}