1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use std::{convert::Infallible, str::FromStr};
4use tracing::{Instrument, info_span};
5
6use super::client::{Client, Usage};
7use crate::completion::GetTokenUsage;
8use crate::http_client::{self, HttpClientExt};
9use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
10use crate::{
11 OneOrMany,
12 completion::{self, CompletionError, CompletionRequest},
13 json_utils, message,
14 providers::mistral::client::ApiResponse,
15 telemetry::SpanCombinator,
16};
17
18pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
22pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
24pub const MISTRAL_SABA: &str = "mistral-saba-latest";
26pub const MINISTRAL_3B: &str = "ministral-3b-latest";
28pub const MINISTRAL_8B: &str = "ministral-8b-latest";
30
31pub const MISTRAL_SMALL: &str = "mistral-small-latest";
33pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
35pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
37pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
39
40#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
45#[serde(tag = "type", rename_all = "lowercase")]
46pub struct AssistantContent {
47 text: String,
48}
49
50#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
51#[serde(tag = "type", rename_all = "lowercase")]
52pub enum UserContent {
53 Text { text: String },
54}
55
56#[derive(Debug, Serialize, Deserialize, Clone)]
57pub struct Choice {
58 pub index: usize,
59 pub message: Message,
60 pub logprobs: Option<serde_json::Value>,
61 pub finish_reason: String,
62}
63
64#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
65#[serde(tag = "role", rename_all = "lowercase")]
66pub enum Message {
67 User {
68 content: String,
69 },
70 Assistant {
71 content: String,
72 #[serde(
73 default,
74 deserialize_with = "json_utils::null_or_vec",
75 skip_serializing_if = "Vec::is_empty"
76 )]
77 tool_calls: Vec<ToolCall>,
78 #[serde(default)]
79 prefix: bool,
80 },
81 System {
82 content: String,
83 },
84}
85
86impl Message {
87 pub fn user(content: String) -> Self {
88 Message::User { content }
89 }
90
91 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
92 Message::Assistant {
93 content,
94 tool_calls,
95 prefix,
96 }
97 }
98
99 pub fn system(content: String) -> Self {
100 Message::System { content }
101 }
102}
103
104impl TryFrom<message::Message> for Vec<Message> {
105 type Error = message::MessageError;
106
107 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
108 match message {
109 message::Message::User { content } => {
110 let (_, other_content): (Vec<_>, Vec<_>) = content
111 .into_iter()
112 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
113
114 let messages = other_content
115 .into_iter()
116 .filter_map(|content| match content {
117 message::UserContent::Text(message::Text { text }) => {
118 Some(Message::User { content: text })
119 }
120 _ => None,
121 })
122 .collect::<Vec<_>>();
123
124 Ok(messages)
125 }
126 message::Message::Assistant { content, .. } => {
127 let (text_content, tool_calls) = content.into_iter().fold(
128 (Vec::new(), Vec::new()),
129 |(mut texts, mut tools), content| {
130 match content {
131 message::AssistantContent::Text(text) => texts.push(text),
132 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
133 message::AssistantContent::Reasoning(_) => {
134 unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
135 }
136 message::AssistantContent::Image(_) => {
137 unimplemented!("Image content is not currently supported on Mistral via Rig");
138 }
139 }
140 (texts, tools)
141 },
142 );
143
144 Ok(vec![Message::Assistant {
145 content: text_content
146 .into_iter()
147 .next()
148 .map(|content| content.text)
149 .unwrap_or_default(),
150 tool_calls: tool_calls
151 .into_iter()
152 .map(|tool_call| tool_call.into())
153 .collect::<Vec<_>>(),
154 prefix: false,
155 }])
156 }
157 }
158 }
159}
160
161#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
162pub struct ToolCall {
163 pub id: String,
164 #[serde(default)]
165 pub r#type: ToolType,
166 pub function: Function,
167}
168
169impl From<message::ToolCall> for ToolCall {
170 fn from(tool_call: message::ToolCall) -> Self {
171 Self {
172 id: tool_call.id,
173 r#type: ToolType::default(),
174 function: Function {
175 name: tool_call.function.name,
176 arguments: tool_call.function.arguments,
177 },
178 }
179 }
180}
181
182#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
183pub struct Function {
184 pub name: String,
185 #[serde(with = "json_utils::stringified_json")]
186 pub arguments: serde_json::Value,
187}
188
189#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
190#[serde(rename_all = "lowercase")]
191pub enum ToolType {
192 #[default]
193 Function,
194}
195
196#[derive(Debug, Deserialize, Serialize, Clone)]
197pub struct ToolDefinition {
198 pub r#type: String,
199 pub function: completion::ToolDefinition,
200}
201
202impl From<completion::ToolDefinition> for ToolDefinition {
203 fn from(tool: completion::ToolDefinition) -> Self {
204 Self {
205 r#type: "function".into(),
206 function: tool,
207 }
208 }
209}
210
211#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
212pub struct ToolResultContent {
213 #[serde(default)]
214 r#type: ToolResultContentType,
215 text: String,
216}
217
218#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
219#[serde(rename_all = "lowercase")]
220pub enum ToolResultContentType {
221 #[default]
222 Text,
223}
224
225impl From<String> for ToolResultContent {
226 fn from(s: String) -> Self {
227 ToolResultContent {
228 r#type: ToolResultContentType::default(),
229 text: s,
230 }
231 }
232}
233
234impl From<String> for UserContent {
235 fn from(s: String) -> Self {
236 UserContent::Text { text: s }
237 }
238}
239
240impl FromStr for UserContent {
241 type Err = Infallible;
242
243 fn from_str(s: &str) -> Result<Self, Self::Err> {
244 Ok(UserContent::Text {
245 text: s.to_string(),
246 })
247 }
248}
249
250impl From<String> for AssistantContent {
251 fn from(s: String) -> Self {
252 AssistantContent { text: s }
253 }
254}
255
256impl FromStr for AssistantContent {
257 type Err = Infallible;
258
259 fn from_str(s: &str) -> Result<Self, Self::Err> {
260 Ok(AssistantContent {
261 text: s.to_string(),
262 })
263 }
264}
265
266#[derive(Clone)]
267pub struct CompletionModel<T = reqwest::Client> {
268 pub(crate) client: Client<T>,
269 pub model: String,
270}
271
272#[derive(Debug, Default, Serialize, Deserialize)]
273pub enum ToolChoice {
274 #[default]
275 Auto,
276 None,
277 Any,
278}
279
280impl TryFrom<message::ToolChoice> for ToolChoice {
281 type Error = CompletionError;
282
283 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
284 let res = match value {
285 message::ToolChoice::Auto => Self::Auto,
286 message::ToolChoice::None => Self::None,
287 message::ToolChoice::Required => Self::Any,
288 message::ToolChoice::Specific { .. } => {
289 return Err(CompletionError::ProviderError(
290 "Mistral doesn't support requiring specific tools to be called".to_string(),
291 ));
292 }
293 };
294
295 Ok(res)
296 }
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub(super) struct MistralCompletionRequest {
301 model: String,
302 pub messages: Vec<Message>,
303 #[serde(flatten, skip_serializing_if = "Option::is_none")]
304 temperature: Option<f64>,
305 #[serde(skip_serializing_if = "Vec::is_empty")]
306 tools: Vec<ToolDefinition>,
307 #[serde(flatten, skip_serializing_if = "Option::is_none")]
308 tool_choice: Option<crate::providers::openai::completion::ToolChoice>,
309 #[serde(flatten, skip_serializing_if = "Option::is_none")]
310 pub additional_params: Option<serde_json::Value>,
311}
312
313impl TryFrom<(&str, CompletionRequest)> for MistralCompletionRequest {
314 type Error = CompletionError;
315
316 fn try_from((model, req): (&str, CompletionRequest)) -> Result<Self, Self::Error> {
317 let mut full_history: Vec<Message> = match &req.preamble {
318 Some(preamble) => vec![Message::system(preamble.clone())],
319 None => vec![],
320 };
321 if let Some(docs) = req.normalized_documents() {
322 let docs: Vec<Message> = docs.try_into()?;
323 full_history.extend(docs);
324 }
325
326 let chat_history: Vec<Message> = req
327 .chat_history
328 .clone()
329 .into_iter()
330 .map(|message| message.try_into())
331 .collect::<Result<Vec<Vec<Message>>, _>>()?
332 .into_iter()
333 .flatten()
334 .collect();
335
336 full_history.extend(chat_history);
337
338 let tool_choice = req
339 .tool_choice
340 .clone()
341 .map(crate::providers::openai::completion::ToolChoice::try_from)
342 .transpose()?;
343
344 Ok(Self {
345 model: model.to_string(),
346 messages: full_history,
347 temperature: req.temperature,
348 tools: req
349 .tools
350 .clone()
351 .into_iter()
352 .map(ToolDefinition::from)
353 .collect::<Vec<_>>(),
354 tool_choice,
355 additional_params: req.additional_params,
356 })
357 }
358}
359
360impl<T> CompletionModel<T> {
361 pub fn new(client: Client<T>, model: impl Into<String>) -> Self {
362 Self {
363 client,
364 model: model.into(),
365 }
366 }
367
368 pub fn with_model(client: Client<T>, model: &str) -> Self {
369 Self {
370 client,
371 model: model.into(),
372 }
373 }
374}
375
376#[derive(Debug, Deserialize, Clone, Serialize)]
377pub struct CompletionResponse {
378 pub id: String,
379 pub object: String,
380 pub created: u64,
381 pub model: String,
382 pub system_fingerprint: Option<String>,
383 pub choices: Vec<Choice>,
384 pub usage: Option<Usage>,
385}
386
387impl crate::telemetry::ProviderResponseExt for CompletionResponse {
388 type OutputMessage = Choice;
389 type Usage = Usage;
390
391 fn get_response_id(&self) -> Option<String> {
392 Some(self.id.clone())
393 }
394
395 fn get_response_model_name(&self) -> Option<String> {
396 Some(self.model.clone())
397 }
398
399 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
400 self.choices.clone()
401 }
402
403 fn get_text_response(&self) -> Option<String> {
404 let res = self
405 .choices
406 .iter()
407 .filter_map(|choice| match choice.message {
408 Message::Assistant { ref content, .. } => {
409 if content.is_empty() {
410 None
411 } else {
412 Some(content.to_string())
413 }
414 }
415 _ => None,
416 })
417 .collect::<Vec<String>>()
418 .join("\n");
419
420 if res.is_empty() { None } else { Some(res) }
421 }
422
423 fn get_usage(&self) -> Option<Self::Usage> {
424 self.usage.clone()
425 }
426}
427
428impl GetTokenUsage for CompletionResponse {
429 fn token_usage(&self) -> Option<crate::completion::Usage> {
430 let api_usage = self.usage.clone()?;
431
432 let mut usage = crate::completion::Usage::new();
433 usage.input_tokens = api_usage.prompt_tokens as u64;
434 usage.output_tokens = api_usage.completion_tokens as u64;
435 usage.total_tokens = api_usage.total_tokens as u64;
436
437 Some(usage)
438 }
439}
440
441impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
442 type Error = CompletionError;
443
444 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
445 let choice = response.choices.first().ok_or_else(|| {
446 CompletionError::ResponseError("Response contained no choices".to_owned())
447 })?;
448 let content = match &choice.message {
449 Message::Assistant {
450 content,
451 tool_calls,
452 ..
453 } => {
454 let mut content = if content.is_empty() {
455 vec![]
456 } else {
457 vec![completion::AssistantContent::text(content.clone())]
458 };
459
460 content.extend(
461 tool_calls
462 .iter()
463 .map(|call| {
464 completion::AssistantContent::tool_call(
465 &call.id,
466 &call.function.name,
467 call.function.arguments.clone(),
468 )
469 })
470 .collect::<Vec<_>>(),
471 );
472 Ok(content)
473 }
474 _ => Err(CompletionError::ResponseError(
475 "Response did not contain a valid message or tool call".into(),
476 )),
477 }?;
478
479 let choice = OneOrMany::many(content).map_err(|_| {
480 CompletionError::ResponseError(
481 "Response contained no message or tool call (empty)".to_owned(),
482 )
483 })?;
484
485 let usage = response
486 .usage
487 .as_ref()
488 .map(|usage| completion::Usage {
489 input_tokens: usage.prompt_tokens as u64,
490 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
491 total_tokens: usage.total_tokens as u64,
492 })
493 .unwrap_or_default();
494
495 Ok(completion::CompletionResponse {
496 choice,
497 usage,
498 raw_response: response,
499 })
500 }
501}
502
503impl<T> completion::CompletionModel for CompletionModel<T>
504where
505 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
506{
507 type Response = CompletionResponse;
508 type StreamingResponse = CompletionResponse;
509
510 type Client = Client<T>;
511
512 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
513 Self::new(client.clone(), model.into())
514 }
515
516 #[cfg_attr(feature = "worker", worker::send)]
517 async fn completion(
518 &self,
519 completion_request: CompletionRequest,
520 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
521 let preamble = completion_request.preamble.clone();
522 let request =
523 MistralCompletionRequest::try_from((self.model.as_ref(), completion_request))?;
524
525 let span = if tracing::Span::current().is_disabled() {
526 info_span!(
527 target: "rig::completions",
528 "chat",
529 gen_ai.operation.name = "chat",
530 gen_ai.provider.name = "mistral",
531 gen_ai.request.model = self.model,
532 gen_ai.system_instructions = &preamble,
533 gen_ai.response.id = tracing::field::Empty,
534 gen_ai.response.model = tracing::field::Empty,
535 gen_ai.usage.output_tokens = tracing::field::Empty,
536 gen_ai.usage.input_tokens = tracing::field::Empty,
537 gen_ai.input.messages = serde_json::to_string(&request.messages)?,
538 gen_ai.output.messages = tracing::field::Empty,
539 )
540 } else {
541 tracing::Span::current()
542 };
543
544 let body = serde_json::to_vec(&request)?;
545
546 let request = self
547 .client
548 .post("v1/chat/completions")?
549 .body(body)
550 .map_err(|e| CompletionError::HttpError(e.into()))?;
551
552 async move {
553 let response = self.client.send(request).await?;
554
555 if response.status().is_success() {
556 let text = http_client::text(response).await?;
557 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
558 ApiResponse::Ok(response) => {
559 let span = tracing::Span::current();
560 span.record_token_usage(&response);
561 span.record_model_output(&response.choices);
562 span.record_response_metadata(&response);
563 response.try_into()
564 }
565 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
566 }
567 } else {
568 let text = http_client::text(response).await?;
569 Err(CompletionError::ProviderError(text))
570 }
571 }
572 .instrument(span)
573 .await
574 }
575
576 #[cfg_attr(feature = "worker", worker::send)]
577 async fn stream(
578 &self,
579 request: CompletionRequest,
580 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
581 let resp = self.completion(request).await?;
582
583 let stream = stream! {
584 for c in resp.choice.clone() {
585 match c {
586 message::AssistantContent::Text(t) => {
587 yield Ok(RawStreamingChoice::Message(t.text.clone()))
588 }
589 message::AssistantContent::ToolCall(tc) => {
590 yield Ok(RawStreamingChoice::ToolCall {
591 id: tc.id.clone(),
592 name: tc.function.name.clone(),
593 arguments: tc.function.arguments.clone(),
594 call_id: None
595 })
596 }
597 message::AssistantContent::Reasoning(_) => {
598 unimplemented!("Reasoning is not supported on Mistral via Rig")
599 }
600 message::AssistantContent::Image(_) => {
601 unimplemented!("Image content is not supported on Mistral via Rig")
602 }
603 }
604 }
605
606 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
607 };
608
609 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_response_deserialization() {
619 let json_data = r#"
621 {
622 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
623 "object": "chat.completion",
624 "model": "mistral-small-latest",
625 "usage": {
626 "prompt_tokens": 16,
627 "completion_tokens": 34,
628 "total_tokens": 50
629 },
630 "created": 1702256327,
631 "choices": [
632 {
633 "index": 0,
634 "message": {
635 "content": "string",
636 "tool_calls": [
637 {
638 "id": "null",
639 "type": "function",
640 "function": {
641 "name": "string",
642 "arguments": "{ }"
643 },
644 "index": 0
645 }
646 ],
647 "prefix": false,
648 "role": "assistant"
649 },
650 "finish_reason": "stop"
651 }
652 ]
653 }
654 "#;
655 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
656 assert_eq!(completion_response.model, MISTRAL_SMALL);
657
658 let CompletionResponse {
659 id,
660 object,
661 created,
662 choices,
663 usage,
664 ..
665 } = completion_response;
666
667 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
668
669 let Usage {
670 completion_tokens,
671 prompt_tokens,
672 total_tokens,
673 } = usage.unwrap();
674
675 assert_eq!(prompt_tokens, 16);
676 assert_eq!(completion_tokens, 34);
677 assert_eq!(total_tokens, 50);
678 assert_eq!(object, "chat.completion".to_string());
679 assert_eq!(created, 1702256327);
680 assert_eq!(choices.len(), 1);
681 }
682}