1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use std::{convert::Infallible, str::FromStr};
4use tracing::{Instrument, Level, enabled, info_span};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::http_client::{self, HttpClientExt};
9use crate::streaming::{RawStreamingChoice, RawStreamingToolCall, StreamingCompletionResponse};
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 json_utils, message,
14 providers::mistral::client::ApiResponse,
15 telemetry::SpanCombinator,
16};
17
18pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47 text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53 Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58 pub index: usize,
59 pub message: Message,
60 pub logprobs: Option<serde_json::Value>,
61 pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67 User {
68 content: String,
69 },
70 Assistant {
71 content: String,
72 #[serde(
73 default,
74 deserialize_with = "json_utils::null_or_vec",
75 skip_serializing_if = "Vec::is_empty"
76 )]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 prefix: bool,
80 },
81 System {
82 content: String,
83 },
84 Tool {
85 name: String,
87 content: String,
89 tool_call_id: String,
91 },
92}
93
94impl Message {
95 pub fn user(content: String) -> Self {
96 Message::User { content }
97 }
98
99 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
100 Message::Assistant {
101 content,
102 tool_calls,
103 prefix,
104 }
105 }
106
107 pub fn system(content: String) -> Self {
108 Message::System { content }
109 }
110}
111
112impl TryFrom<message::Message> for Vec<Message> {
113 type Error = message::MessageError;
114
115 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
116 match message {
117 message::Message::User { content } => {
118 let mut tool_result_messages = Vec::new();
119 let mut other_messages = Vec::new();
120
121 for content_item in content {
122 match content_item {
123 message::UserContent::ToolResult(message::ToolResult {
124 id,
125 call_id,
126 content: tool_content,
127 }) => {
128 let call_id_key = call_id.unwrap_or_else(|| id.clone());
129 let content_text = tool_content
130 .into_iter()
131 .find_map(|content_item| match content_item {
132 message::ToolResultContent::Text(text) => Some(text.text),
133 message::ToolResultContent::Image(_) => None,
134 })
135 .unwrap_or_default();
136 tool_result_messages.push(Message::Tool {
137 name: id,
138 content: content_text,
139 tool_call_id: call_id_key,
140 });
141 }
142 message::UserContent::Text(message::Text { text }) => {
143 other_messages.push(Message::User { content: text });
144 }
145 _ => {}
146 }
147 }
148
149 tool_result_messages.append(&mut other_messages);
150 Ok(tool_result_messages)
151 }
152 message::Message::Assistant { content, .. } => {
153 let (text_content, tool_calls) = content.into_iter().fold(
154 (Vec::new(), Vec::new()),
155 |(mut texts, mut tools), content| {
156 match content {
157 message::AssistantContent::Text(text) => texts.push(text),
158 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
159 message::AssistantContent::Reasoning(_) => {
160 panic!("Reasoning content is not currently supported on Mistral via Rig");
161 }
162 message::AssistantContent::Image(_) => {
163 panic!("Image content is not currently supported on Mistral via Rig");
164 }
165 }
166 (texts, tools)
167 },
168 );
169
170 Ok(vec![Message::Assistant {
171 content: text_content
172 .into_iter()
173 .next()
174 .map(|content| content.text)
175 .unwrap_or_default(),
176 tool_calls: tool_calls
177 .into_iter()
178 .map(|tool_call| tool_call.into())
179 .collect::<Vec<_>>(),
180 prefix: false,
181 }])
182 }
183 }
184 }
185}
186
187#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
188pub struct ToolCall {
189 pub id: String,
190 #[serde(default)]
191 pub r#type: ToolType,
192 pub function: Function,
193}
194
195impl From<message::ToolCall> for ToolCall {
196 fn from(tool_call: message::ToolCall) -> Self {
197 Self {
198 id: tool_call.id,
199 r#type: ToolType::default(),
200 function: Function {
201 name: tool_call.function.name,
202 arguments: tool_call.function.arguments,
203 },
204 }
205 }
206}
207
208#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
209pub struct Function {
210 pub name: String,
211 #[serde(with = "json_utils::stringified_json")]
212 pub arguments: serde_json::Value,
213}
214
215#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
216#[serde(rename_all = "lowercase")]
217pub enum ToolType {
218 #[default]
219 Function,
220}
221
222#[derive(Debug, Deserialize, Serialize, Clone)]
223pub struct ToolDefinition {
224 pub r#type: String,
225 pub function: completion::ToolDefinition,
226}
227
228impl From<completion::ToolDefinition> for ToolDefinition {
229 fn from(tool: completion::ToolDefinition) -> Self {
230 Self {
231 r#type: "function".into(),
232 function: tool,
233 }
234 }
235}
236
237#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
238pub struct ToolResultContent {
239 #[serde(default)]
240 r#type: ToolResultContentType,
241 text: String,
242}
243
244#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
245#[serde(rename_all = "lowercase")]
246pub enum ToolResultContentType {
247 #[default]
248 Text,
249}
250
251impl From<String> for ToolResultContent {
252 fn from(s: String) -> Self {
253 ToolResultContent {
254 r#type: ToolResultContentType::default(),
255 text: s,
256 }
257 }
258}
259
260impl From<String> for UserContent {
261 fn from(s: String) -> Self {
262 UserContent::Text { text: s }
263 }
264}
265
266impl FromStr for UserContent {
267 type Err = Infallible;
268
269 fn from_str(s: &str) -> Result<Self, Self::Err> {
270 Ok(UserContent::Text {
271 text: s.to_string(),
272 })
273 }
274}
275
276impl From<String> for AssistantContent {
277 fn from(s: String) -> Self {
278 AssistantContent { text: s }
279 }
280}
281
282impl FromStr for AssistantContent {
283 type Err = Infallible;
284
285 fn from_str(s: &str) -> Result<Self, Self::Err> {
286 Ok(AssistantContent {
287 text: s.to_string(),
288 })
289 }
290}
291
292#[derive(Clone)]
293pub struct CompletionModel<T = reqwest::Client> {
294 pub(crate) client: Client<T>,
295 pub model: String,
296}
297
298#[derive(Debug, Default, Serialize, Deserialize)]
299pub enum ToolChoice {
300 #[default]
301 Auto,
302 None,
303 Any,
304}
305
306impl TryFrom<message::ToolChoice> for ToolChoice {
307 type Error = CompletionError;
308
309 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
310 let res = match value {
311 message::ToolChoice::Auto => Self::Auto,
312 message::ToolChoice::None => Self::None,
313 message::ToolChoice::Required => Self::Any,
314 message::ToolChoice::Specific { .. } => {
315 return Err(CompletionError::ProviderError(
316 "Mistral doesn't support requiring specific tools to be called".to_string(),
317 ));
318 }
319 };
320
321 Ok(res)
322 }
323}
324
325#[derive(Debug, Serialize, Deserialize)]
326pub(super) struct MistralCompletionRequest {
327 model: String,
328 pub messages: Vec<Message>,
329 #[serde(skip_serializing_if = "Option::is_none")]
330 temperature: Option<f64>,
331 #[serde(skip_serializing_if = "Vec::is_empty")]
332 tools: Vec<ToolDefinition>,
333 #[serde(skip_serializing_if = "Option::is_none")]
334 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
335 #[serde(flatten, skip_serializing_if = "Option::is_none")]
336 pub additional_params: Option<serde_json::Value>,
337}
338
339impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
340 type Error = CompletionError;
341
342 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
343 let mut full_history: Vec<Message> = match &req.preamble {
344 Some(preamble) => vec![Message::system(preamble.clone())],
345 None => vec![],
346 };
347 if let Some(docs) = req.normalized_documents() {
348 let docs: Vec<Message> = docs.try_into()?;
349 full_history.extend(docs);
350 }
351
352 let chat_history: Vec<Message> = req
353 .chat_history
354 .clone()
355 .into_iter()
356 .map(|message| message.try_into())
357 .collect::<Result<Vec<Vec<Message>>, _>>()?
358 .into_iter()
359 .flatten()
360 .collect();
361
362 full_history.extend(chat_history);
363
364 let tool_choice = req
365 .tool_choice
366 .clone()
367 .map(crate::providers::openai::completion::ToolChoice::try_from)
368 .transpose()?;
369
370 Ok(Self {
371 model: model.to_string(),
372 messages: full_history,
373 temperature: req.temperature,
374 tools: req
375 .tools
376 .clone()
377 .into_iter()
378 .map(ToolDefinition::from)
379 .collect::<Vec<_>>(),
380 tool_choice,
381 additional_params: req.additional_params,
382 })
383 }
384}
385
386impl<T> CompletionModel<T> {
387 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
388 Self {
389 client,
390 model: model.into(),
391 }
392 }
393
394 pub fn with_model(client: Client<T>, model: &str) -> Self {
395 Self {
396 client,
397 model: model.into(),
398 }
399 }
400}
401
402#[derive(Debug, Deserialize, Clone, Serialize)]
403pub struct CompletionResponse {
404 pub id: String,
405 pub object: String,
406 pub created: u64,
407 pub model: String,
408 pub system_fingerprint: Option<String>,
409 pub choices: Vec<Choice>,
410 pub usage: Option<Usage>,
411}
412
413impl crate::telemetry::ProviderResponseExt for CompletionResponse {
414 type OutputMessage = Choice;
415 type Usage = Usage;
416
417 fn get_response_id(&self) -> Option<String> {
418 Some(self.id.clone())
419 }
420
421 fn get_response_model_name(&self) -> Option<String> {
422 Some(self.model.clone())
423 }
424
425 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
426 self.choices.clone()
427 }
428
429 fn get_text_response(&self) -> Option<String> {
430 let res = self
431 .choices
432 .iter()
433 .filter_map(|choice| match choice.message {
434 Message::Assistant { ref content, .. } => {
435 if content.is_empty() {
436 None
437 } else {
438 Some(content.to_string())
439 }
440 }
441 _ => None,
442 })
443 .collect::<Vec<String>>()
444 .join("\n");
445
446 if res.is_empty() { None } else { Some(res) }
447 }
448
449 fn get_usage(&self) -> Option<Self::Usage> {
450 self.usage.clone()
451 }
452}
453
454impl GetTokenUsage for CompletionResponse {
455 fn token_usage(&self) -> Option<crate::completion::Usage> {
456 let api_usage = self.usage.clone()?;
457
458 let mut usage = crate::completion::Usage::new();
459 usage.input_tokens = api_usage.prompt_tokens as u64;
460 usage.output_tokens = api_usage.completion_tokens as u64;
461 usage.total_tokens = api_usage.total_tokens as u64;
462
463 Some(usage)
464 }
465}
466
467impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
468 type Error = CompletionError;
469
470 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
471 let choice = response.choices.first().ok_or_else(|| {
472 CompletionError::ResponseError("Response contained no choices".to_owned())
473 })?;
474 let content = match &choice.message {
475 Message::Assistant {
476 content,
477 tool_calls,
478 ..
479 } => {
480 let mut content = if content.is_empty() {
481 vec![]
482 } else {
483 vec![completion::AssistantContent::text(content.clone())]
484 };
485
486 content.extend(
487 tool_calls
488 .iter()
489 .map(|call| {
490 completion::AssistantContent::tool_call(
491 &call.id,
492 &call.function.name,
493 call.function.arguments.clone(),
494 )
495 })
496 .collect::<Vec<_>>(),
497 );
498 Ok(content)
499 }
500 _ => Err(CompletionError::ResponseError(
501 "Response did not contain a valid message or tool call".into(),
502 )),
503 }?;
504
505 let choice = OneOrMany::many(content).map_err(|_| {
506 CompletionError::ResponseError(
507 "Response contained no message or tool call (empty)".to_owned(),
508 )
509 })?;
510
511 let usage = response
512 .usage
513 .as_ref()
514 .map(|usage| completion::Usage {
515 input_tokens: usage.prompt_tokens as u64,
516 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
517 total_tokens: usage.total_tokens as u64,
518 })
519 .unwrap_or_default();
520
521 Ok(completion::CompletionResponse {
522 choice,
523 usage,
524 raw_response: response,
525 })
526 }
527}
528
529impl<T> completion::CompletionModel for CompletionModel<T>
530where
531 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
532{
533 type Response = CompletionResponse;
534 type StreamingResponse = CompletionResponse;
535
536 type Client = Client<T>;
537
538 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
539 Self::new(client.clone(), model.into())
540 }
541
542 async fn completion(
543 &self,
544 completion_request: CompletionRequest,
545 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
546 let preamble = completion_request.preamble.clone();
547 let request =
548 MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
549
550 if enabled!(Level::TRACE) {
551 tracing::trace!(
552 target: "rig::completions",
553 "Mistral completion request: {}",
554 serde_json::to_string_pretty(&request)?
555 );
556 }
557
558 let span = if tracing::Span::current().is_disabled() {
559 info_span!(
560 target: "rig::completions",
561 "chat",
562 gen_ai.operation.name = "chat",
563 gen_ai.provider.name = "mistral",
564 gen_ai.request.model = self.model,
565 gen_ai.system_instructions = &preamble,
566 gen_ai.response.id = tracing::field::Empty,
567 gen_ai.response.model = tracing::field::Empty,
568 gen_ai.usage.output_tokens = tracing::field::Empty,
569 gen_ai.usage.input_tokens = tracing::field::Empty,
570 )
571 } else {
572 tracing::Span::current()
573 };
574
575 let body = serde_json::to_vec(&request)?;
576
577 let request = self
578 .client
579 .post("v1/chat/completions")?
580 .body(body)
581 .map_err(|e| CompletionError::HttpError(e.into()))?;
582
583 async move {
584 let response = self.client.send(request).await?;
585
586 if response.status().is_success() {
587 let text = http_client::text(response).await?;
588 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
589 ApiResponse::Ok(response) => {
590 let span = tracing::Span::current();
591 span.record_token_usage(&response);
592 span.record_response_metadata(&response);
593 response.try_into()
594 }
595 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
596 }
597 } else {
598 let text = http_client::text(response).await?;
599 Err(CompletionError::ProviderError(text))
600 }
601 }
602 .instrument(span)
603 .await
604 }
605
606 async fn stream(
607 &self,
608 request: CompletionRequest,
609 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
610 let resp = self.completion(request).await?;
611
612 let stream = stream! {
613 for c in resp.choice.clone() {
614 match c {
615 message::AssistantContent::Text(t) => {
616 yield Ok(RawStreamingChoice::Message(t.text.clone()))
617 }
618 message::AssistantContent::ToolCall(tc) => {
619 yield Ok(RawStreamingChoice::ToolCall(
620 RawStreamingToolCall::new(
621 tc.id.clone(),
622 tc.function.name.clone(),
623 tc.function.arguments.clone(),
624 )
625 ))
626 }
627 message::AssistantContent::Reasoning(_) => {
628 panic!("Reasoning is not supported on Mistral via Rig")
629 }
630 message::AssistantContent::Image(_) => {
631 panic!("Image content is not supported on Mistral via Rig")
632 }
633 }
634 }
635
636 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
637 };
638
639 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
640 }
641}
642
643#[cfg(test)]
644mod tests {
645 use super::*;
646
647 #[test]
648 fn test_response_deserialization() {
649 let json_data = r#"
651 {
652 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
653 "object": "chat.completion",
654 "model": "mistral-small-latest",
655 "usage": {
656 "prompt_tokens": 16,
657 "completion_tokens": 34,
658 "total_tokens": 50
659 },
660 "created": 1702256327,
661 "choices": [
662 {
663 "index": 0,
664 "message": {
665 "content": "string",
666 "tool_calls": [
667 {
668 "id": "null",
669 "type": "function",
670 "function": {
671 "name": "string",
672 "arguments": "{ }"
673 },
674 "index": 0
675 }
676 ],
677 "prefix": false,
678 "role": "assistant"
679 },
680 "finish_reason": "stop"
681 }
682 ]
683 }
684 "#;
685 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
686 assert_eq!(completion_response.model, MISTRAL_SMALL);
687
688 let CompletionResponse {
689 id,
690 object,
691 created,
692 choices,
693 usage,
694 ..
695 } = completion_response;
696
697 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
698
699 let Usage {
700 completion_tokens,
701 prompt_tokens,
702 total_tokens,
703 } = usage.unwrap();
704
705 assert_eq!(prompt_tokens, 16);
706 assert_eq!(completion_tokens, 34);
707 assert_eq!(total_tokens, 50);
708 assert_eq!(object, "chat.completion".to_string());
709 assert_eq!(created, 1702256327);
710 assert_eq!(choices.len(), 1);
711 }
712}