1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5use tracing::{Instrument, info_span};
6
7use super::client::{Client, Usage};
8use crate::completion::GetTokenUsage;
9use crate::http_client::{self, HttpClientExt};
10use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
11use crate::{
12 OneOrMany,
13 completion::{self, CompletionError, CompletionRequest},
14 json_utils, message,
15 providers::mistral::client::ApiResponse,
16 telemetry::SpanCombinator,
17};
18
19pub const CODESTRAL: &str = "codestral-latest";
20pub const MISTRAL_LARGE: &str = "mistral-large-latest";
21pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
22pub const MISTRAL_SABA: &str = "mistral-saba-latest";
23pub const MINISTRAL_3B: &str = "ministral-3b-latest";
24pub const MINISTRAL_8B: &str = "ministral-8b-latest";
25
26pub const MISTRAL_SMALL: &str = "mistral-small-latest";
28pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
29pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
30pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
31
32#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
37#[serde(tag = "type", rename_all = "lowercase")]
38pub struct AssistantContent {
39 text: String,
40}
41
42#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
43#[serde(tag = "type", rename_all = "lowercase")]
44pub enum UserContent {
45 Text { text: String },
46}
47
48#[derive(Debug, Serialize, Deserialize, Clone)]
49pub struct Choice {
50 pub index: usize,
51 pub message: Message,
52 pub logprobs: Option<serde_json::Value>,
53 pub finish_reason: String,
54}
55
56#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
57#[serde(tag = "role", rename_all = "lowercase")]
58pub enum Message {
59 User {
60 content: String,
61 },
62 Assistant {
63 content: String,
64 #[serde(
65 default,
66 deserialize_with = "json_utils::null_or_vec",
67 skip_serializing_if = "Vec::is_empty"
68 )]
69 tool_calls: Vec<ToolCall>,
70 #[serde(default)]
71 prefix: bool,
72 },
73 System {
74 content: String,
75 },
76}
77
78impl Message {
79 pub fn user(content: String) -> Self {
80 Message::User { content }
81 }
82
83 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
84 Message::Assistant {
85 content,
86 tool_calls,
87 prefix,
88 }
89 }
90
91 pub fn system(content: String) -> Self {
92 Message::System { content }
93 }
94}
95
96impl TryFrom<message::Message> for Vec<Message> {
97 type Error = message::MessageError;
98
99 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
100 match message {
101 message::Message::User { content } => {
102 let (_, other_content): (Vec<_>, Vec<_>) = content
103 .into_iter()
104 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
105
106 let messages = other_content
107 .into_iter()
108 .filter_map(|content| match content {
109 message::UserContent::Text(message::Text { text }) => {
110 Some(Message::User { content: text })
111 }
112 _ => None,
113 })
114 .collect::<Vec<_>>();
115
116 Ok(messages)
117 }
118 message::Message::Assistant { content, .. } => {
119 let (text_content, tool_calls) = content.into_iter().fold(
120 (Vec::new(), Vec::new()),
121 |(mut texts, mut tools), content| {
122 match content {
123 message::AssistantContent::Text(text) => texts.push(text),
124 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
125 message::AssistantContent::Reasoning(_) => {
126 unimplemented!("Reasoning content is not currently supported on Mistral via Rig");
127 }
128 }
129 (texts, tools)
130 },
131 );
132
133 Ok(vec![Message::Assistant {
134 content: text_content
135 .into_iter()
136 .next()
137 .map(|content| content.text)
138 .unwrap_or_default(),
139 tool_calls: tool_calls
140 .into_iter()
141 .map(|tool_call| tool_call.into())
142 .collect::<Vec<_>>(),
143 prefix: false,
144 }])
145 }
146 }
147 }
148}
149
150#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
151pub struct ToolCall {
152 pub id: String,
153 #[serde(default)]
154 pub r#type: ToolType,
155 pub function: Function,
156}
157
158impl From<message::ToolCall> for ToolCall {
159 fn from(tool_call: message::ToolCall) -> Self {
160 Self {
161 id: tool_call.id,
162 r#type: ToolType::default(),
163 function: Function {
164 name: tool_call.function.name,
165 arguments: tool_call.function.arguments,
166 },
167 }
168 }
169}
170
171#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
172pub struct Function {
173 pub name: String,
174 #[serde(with = "json_utils::stringified_json")]
175 pub arguments: serde_json::Value,
176}
177
178#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
179#[serde(rename_all = "lowercase")]
180pub enum ToolType {
181 #[default]
182 Function,
183}
184
185#[derive(Debug, Deserialize, Serialize, Clone)]
186pub struct ToolDefinition {
187 pub r#type: String,
188 pub function: completion::ToolDefinition,
189}
190
191impl From<completion::ToolDefinition> for ToolDefinition {
192 fn from(tool: completion::ToolDefinition) -> Self {
193 Self {
194 r#type: "function".into(),
195 function: tool,
196 }
197 }
198}
199
200#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
201pub struct ToolResultContent {
202 #[serde(default)]
203 r#type: ToolResultContentType,
204 text: String,
205}
206
207#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
208#[serde(rename_all = "lowercase")]
209pub enum ToolResultContentType {
210 #[default]
211 Text,
212}
213
214impl From<String> for ToolResultContent {
215 fn from(s: String) -> Self {
216 ToolResultContent {
217 r#type: ToolResultContentType::default(),
218 text: s,
219 }
220 }
221}
222
223impl From<String> for UserContent {
224 fn from(s: String) -> Self {
225 UserContent::Text { text: s }
226 }
227}
228
229impl FromStr for UserContent {
230 type Err = Infallible;
231
232 fn from_str(s: &str) -> Result<Self, Self::Err> {
233 Ok(UserContent::Text {
234 text: s.to_string(),
235 })
236 }
237}
238
239impl From<String> for AssistantContent {
240 fn from(s: String) -> Self {
241 AssistantContent { text: s }
242 }
243}
244
245impl FromStr for AssistantContent {
246 type Err = Infallible;
247
248 fn from_str(s: &str) -> Result<Self, Self::Err> {
249 Ok(AssistantContent {
250 text: s.to_string(),
251 })
252 }
253}
254
255#[derive(Clone)]
256pub struct CompletionModel<T = reqwest::Client> {
257 pub(crate) client: Client<T>,
258 pub model: String,
259}
260
261#[derive(Debug, Default, Serialize, Deserialize)]
262pub enum ToolChoice {
263 #[default]
264 Auto,
265 None,
266 Any,
267}
268
269impl TryFrom<message::ToolChoice> for ToolChoice {
270 type Error = CompletionError;
271
272 fn try_from(value: message::ToolChoice) -> Result<Self, Self::Error> {
273 let res = match value {
274 message::ToolChoice::Auto => Self::Auto,
275 message::ToolChoice::None => Self::None,
276 message::ToolChoice::Required => Self::Any,
277 message::ToolChoice::Specific { .. } => {
278 return Err(CompletionError::ProviderError(
279 "Mistral doesn't support requiring specific tools to be called".to_string(),
280 ));
281 }
282 };
283
284 Ok(res)
285 }
286}
287
288impl<T> CompletionModel<T> {
289 pub fn new(client: Client<T>, model: &str) -> Self {
290 Self {
291 client,
292 model: model.to_string(),
293 }
294 }
295
296 pub(crate) fn create_completion_request(
297 &self,
298 completion_request: CompletionRequest,
299 ) -> Result<Value, CompletionError> {
300 let mut partial_history = vec![];
301 if let Some(docs) = completion_request.normalized_documents() {
302 partial_history.push(docs);
303 }
304
305 partial_history.extend(completion_request.chat_history);
306
307 let mut full_history: Vec<Message> = match &completion_request.preamble {
308 Some(preamble) => vec![Message::system(preamble.clone())],
309 None => vec![],
310 };
311
312 full_history.extend(
313 partial_history
314 .into_iter()
315 .map(message::Message::try_into)
316 .collect::<Result<Vec<Vec<Message>>, _>>()?
317 .into_iter()
318 .flatten()
319 .collect::<Vec<_>>(),
320 );
321
322 let tool_choice = completion_request
323 .tool_choice
324 .map(ToolChoice::try_from)
325 .transpose()?;
326
327 let request = if completion_request.tools.is_empty() {
328 json!({
329 "model": self.model,
330 "messages": full_history,
331
332 })
333 } else {
334 json!({
335 "model": self.model,
336 "messages": full_history,
337 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
338 "tool_choice": tool_choice,
339 })
340 };
341
342 let request = if let Some(temperature) = completion_request.temperature {
343 json_utils::merge(
344 request,
345 json!({
346 "temperature": temperature,
347 }),
348 )
349 } else {
350 request
351 };
352
353 let request = if let Some(params) = completion_request.additional_params {
354 json_utils::merge(request, params)
355 } else {
356 request
357 };
358
359 Ok(request)
360 }
361}
362
363#[derive(Debug, Deserialize, Clone, Serialize)]
364pub struct CompletionResponse {
365 pub id: String,
366 pub object: String,
367 pub created: u64,
368 pub model: String,
369 pub system_fingerprint: Option<String>,
370 pub choices: Vec<Choice>,
371 pub usage: Option<Usage>,
372}
373
374impl crate::telemetry::ProviderResponseExt for CompletionResponse {
375 type OutputMessage = Choice;
376 type Usage = Usage;
377
378 fn get_response_id(&self) -> Option<String> {
379 Some(self.id.clone())
380 }
381
382 fn get_response_model_name(&self) -> Option<String> {
383 Some(self.model.clone())
384 }
385
386 fn get_output_messages(&self) -> Vec<Self::OutputMessage> {
387 self.choices.clone()
388 }
389
390 fn get_text_response(&self) -> Option<String> {
391 let res = self
392 .choices
393 .iter()
394 .filter_map(|choice| match choice.message {
395 Message::Assistant { ref content, .. } => {
396 if content.is_empty() {
397 None
398 } else {
399 Some(content.to_string())
400 }
401 }
402 _ => None,
403 })
404 .collect::<Vec<String>>()
405 .join("\n");
406
407 if res.is_empty() { None } else { Some(res) }
408 }
409
410 fn get_usage(&self) -> Option<Self::Usage> {
411 self.usage.clone()
412 }
413}
414
415impl GetTokenUsage for CompletionResponse {
416 fn token_usage(&self) -> Option<crate::completion::Usage> {
417 let api_usage = self.usage.clone()?;
418
419 let mut usage = crate::completion::Usage::new();
420 usage.input_tokens = api_usage.prompt_tokens as u64;
421 usage.output_tokens = api_usage.completion_tokens as u64;
422 usage.total_tokens = api_usage.total_tokens as u64;
423
424 Some(usage)
425 }
426}
427
428impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
429 type Error = CompletionError;
430
431 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
432 let choice = response.choices.first().ok_or_else(|| {
433 CompletionError::ResponseError("Response contained no choices".to_owned())
434 })?;
435 let content = match &choice.message {
436 Message::Assistant {
437 content,
438 tool_calls,
439 ..
440 } => {
441 let mut content = if content.is_empty() {
442 vec![]
443 } else {
444 vec![completion::AssistantContent::text(content.clone())]
445 };
446
447 content.extend(
448 tool_calls
449 .iter()
450 .map(|call| {
451 completion::AssistantContent::tool_call(
452 &call.id,
453 &call.function.name,
454 call.function.arguments.clone(),
455 )
456 })
457 .collect::<Vec<_>>(),
458 );
459 Ok(content)
460 }
461 _ => Err(CompletionError::ResponseError(
462 "Response did not contain a valid message or tool call".into(),
463 )),
464 }?;
465
466 let choice = OneOrMany::many(content).map_err(|_| {
467 CompletionError::ResponseError(
468 "Response contained no message or tool call (empty)".to_owned(),
469 )
470 })?;
471
472 let usage = response
473 .usage
474 .as_ref()
475 .map(|usage| completion::Usage {
476 input_tokens: usage.prompt_tokens as u64,
477 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
478 total_tokens: usage.total_tokens as u64,
479 })
480 .unwrap_or_default();
481
482 Ok(completion::CompletionResponse {
483 choice,
484 usage,
485 raw_response: response,
486 })
487 }
488}
489
490impl<T> completion::CompletionModel for CompletionModel<T>
491where
492 T: HttpClientExt + Send + Clone + std::fmt::Debug + 'static,
493{
494 type Response = CompletionResponse;
495 type StreamingResponse = CompletionResponse;
496
497 #[cfg_attr(feature = "worker", worker::send)]
498 async fn completion(
499 &self,
500 completion_request: CompletionRequest,
501 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
502 let preamble = completion_request.preamble.clone();
503 let request = self.create_completion_request(completion_request)?;
504 let body = serde_json::to_vec(&request)?;
505
506 let span = if tracing::Span::current().is_disabled() {
507 info_span!(
508 target: "rig::completions",
509 "chat",
510 gen_ai.operation.name = "chat",
511 gen_ai.provider.name = "mistral",
512 gen_ai.request.model = self.model,
513 gen_ai.system_instructions = &preamble,
514 gen_ai.response.id = tracing::field::Empty,
515 gen_ai.response.model = tracing::field::Empty,
516 gen_ai.usage.output_tokens = tracing::field::Empty,
517 gen_ai.usage.input_tokens = tracing::field::Empty,
518 gen_ai.input.messages = serde_json::to_string(request.get("messages").unwrap()).unwrap(),
519 gen_ai.output.messages = tracing::field::Empty,
520 )
521 } else {
522 tracing::Span::current()
523 };
524
525 let request = self
526 .client
527 .post("v1/chat/completions")?
528 .header("Content-Type", "application/json")
529 .body(body)
530 .map_err(|e| CompletionError::HttpError(e.into()))?;
531
532 async move {
533 let response = self.client.send(request).await?;
534
535 if response.status().is_success() {
536 let text = http_client::text(response).await?;
537 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
538 ApiResponse::Ok(response) => {
539 let span = tracing::Span::current();
540 span.record_token_usage(&response);
541 span.record_model_output(&response.choices);
542 span.record_response_metadata(&response);
543 response.try_into()
544 }
545 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
546 }
547 } else {
548 let text = http_client::text(response).await?;
549 Err(CompletionError::ProviderError(text))
550 }
551 }
552 .instrument(span)
553 .await
554 }
555
556 #[cfg_attr(feature = "worker", worker::send)]
557 async fn stream(
558 &self,
559 request: CompletionRequest,
560 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
561 let resp = self.completion(request).await?;
562
563 let stream = stream! {
564 for c in resp.choice.clone() {
565 match c {
566 message::AssistantContent::Text(t) => {
567 yield Ok(RawStreamingChoice::Message(t.text.clone()))
568 }
569 message::AssistantContent::ToolCall(tc) => {
570 yield Ok(RawStreamingChoice::ToolCall {
571 id: tc.id.clone(),
572 name: tc.function.name.clone(),
573 arguments: tc.function.arguments.clone(),
574 call_id: None
575 })
576 }
577 message::AssistantContent::Reasoning(_) => {
578 unimplemented!("Reasoning is not supported on Mistral via Rig")
579 }
580 }
581 }
582
583 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
584 };
585
586 Ok(StreamingCompletionResponse::stream(Box::pin(stream)))
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_response_deserialization() {
596 let json_data = r#"
598 {
599 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
600 "object": "chat.completion",
601 "model": "mistral-small-latest",
602 "usage": {
603 "prompt_tokens": 16,
604 "completion_tokens": 34,
605 "total_tokens": 50
606 },
607 "created": 1702256327,
608 "choices": [
609 {
610 "index": 0,
611 "message": {
612 "content": "string",
613 "tool_calls": [
614 {
615 "id": "null",
616 "type": "function",
617 "function": {
618 "name": "string",
619 "arguments": "{ }"
620 },
621 "index": 0
622 }
623 ],
624 "prefix": false,
625 "role": "assistant"
626 },
627 "finish_reason": "stop"
628 }
629 ]
630 }
631 "#;
632 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
633 assert_eq!(completion_response.model, MISTRAL_SMALL);
634
635 let CompletionResponse {
636 id,
637 object,
638 created,
639 choices,
640 usage,
641 ..
642 } = completion_response;
643
644 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
645
646 let Usage {
647 completion_tokens,
648 prompt_tokens,
649 total_tokens,
650 } = usage.unwrap();
651
652 assert_eq!(prompt_tokens, 16);
653 assert_eq!(completion_tokens, 34);
654 assert_eq!(total_tokens, 50);
655 assert_eq!(object, "chat.completion".to_string());
656 assert_eq!(created, 1702256327);
657 assert_eq!(choices.len(), 1);
658 }
659}