1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
8use crate::{
9 OneOrMany,
10 completion::{self, CompletionError, CompletionRequest},
11 json_utils, message,
12 providers::mistral::client::ApiResponse,
13};
14
15pub const CODESTRAL: &str = "codestral-latest";
16pub const MISTRAL_LARGE: &str = "mistral-large-latest";
17pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
18pub const MISTRAL_SABA: &str = "mistral-saba-latest";
19pub const MINISTRAL_3B: &str = "ministral-3b-latest";
20pub const MINISTRAL_8B: &str = "ministral-8b-latest";
21
22pub const MISTRAL_SMALL: &str = "mistral-small-latest";
24pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
25pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
26pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
27
28#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
33#[serde(tag = "type", rename_all = "lowercase")]
34pub struct AssistantContent {
35 text: String,
36}
37
38#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41 Text { text: String },
42}
43
44#[derive(Debug, Serialize, Deserialize, Clone)]
45pub struct Choice {
46 pub index: usize,
47 pub message: Message,
48 pub logprobs: Option<serde_json::Value>,
49 pub finish_reason: String,
50}
51
52#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
53#[serde(tag = "role", rename_all = "lowercase")]
54pub enum Message {
55 User {
56 content: String,
57 },
58 Assistant {
59 content: String,
60 #[serde(
61 default,
62 deserialize_with = "json_utils::null_or_vec",
63 skip_serializing_if = "Vec::is_empty"
64 )]
65 tool_calls: Vec<ToolCall>,
66 #[serde(default)]
67 prefix: bool,
68 },
69 System {
70 content: String,
71 },
72}
73
74impl Message {
75 pub fn user(content: String) -> Self {
76 Message::User { content }
77 }
78
79 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
80 Message::Assistant {
81 content,
82 tool_calls,
83 prefix,
84 }
85 }
86
87 pub fn system(content: String) -> Self {
88 Message::System { content }
89 }
90}
91
92impl TryFrom<message::Message> for Vec<Message> {
93 type Error = message::MessageError;
94
95 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
96 match message {
97 message::Message::User { content } => {
98 let (_, other_content): (Vec<_>, Vec<_>) = content
99 .into_iter()
100 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
101
102 let messages = other_content
103 .into_iter()
104 .filter_map(|content| match content {
105 message::UserContent::Text(message::Text { text }) => {
106 Some(Message::User { content: text })
107 }
108 _ => None,
109 })
110 .collect::<Vec<_>>();
111
112 Ok(messages)
113 }
114 message::Message::Assistant { content, .. } => {
115 let (text_content, tool_calls) = content.into_iter().fold(
116 (Vec::new(), Vec::new()),
117 |(mut texts, mut tools), content| {
118 match content {
119 message::AssistantContent::Text(text) => texts.push(text),
120 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
121 }
122 (texts, tools)
123 },
124 );
125
126 Ok(vec![Message::Assistant {
127 content: text_content
128 .into_iter()
129 .next()
130 .map(|content| content.text)
131 .unwrap_or_default(),
132 tool_calls: tool_calls
133 .into_iter()
134 .map(|tool_call| tool_call.into())
135 .collect::<Vec<_>>(),
136 prefix: false,
137 }])
138 }
139 }
140 }
141}
142
143#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
144pub struct ToolCall {
145 pub id: String,
146 #[serde(default)]
147 pub r#type: ToolType,
148 pub function: Function,
149}
150
151impl From<message::ToolCall> for ToolCall {
152 fn from(tool_call: message::ToolCall) -> Self {
153 Self {
154 id: tool_call.id,
155 r#type: ToolType::default(),
156 function: Function {
157 name: tool_call.function.name,
158 arguments: tool_call.function.arguments,
159 },
160 }
161 }
162}
163
164#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
165pub struct Function {
166 pub name: String,
167 #[serde(with = "json_utils::stringified_json")]
168 pub arguments: serde_json::Value,
169}
170
171#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
172#[serde(rename_all = "lowercase")]
173pub enum ToolType {
174 #[default]
175 Function,
176}
177
178#[derive(Debug, Deserialize, Serialize, Clone)]
179pub struct ToolDefinition {
180 pub r#type: String,
181 pub function: completion::ToolDefinition,
182}
183
184impl From<completion::ToolDefinition> for ToolDefinition {
185 fn from(tool: completion::ToolDefinition) -> Self {
186 Self {
187 r#type: "function".into(),
188 function: tool,
189 }
190 }
191}
192
193#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
194pub struct ToolResultContent {
195 #[serde(default)]
196 r#type: ToolResultContentType,
197 text: String,
198}
199
200#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
201#[serde(rename_all = "lowercase")]
202pub enum ToolResultContentType {
203 #[default]
204 Text,
205}
206
207impl From<String> for ToolResultContent {
208 fn from(s: String) -> Self {
209 ToolResultContent {
210 r#type: ToolResultContentType::default(),
211 text: s,
212 }
213 }
214}
215
216impl From<String> for UserContent {
217 fn from(s: String) -> Self {
218 UserContent::Text { text: s }
219 }
220}
221
222impl FromStr for UserContent {
223 type Err = Infallible;
224
225 fn from_str(s: &str) -> Result<Self, Self::Err> {
226 Ok(UserContent::Text {
227 text: s.to_string(),
228 })
229 }
230}
231
232impl From<String> for AssistantContent {
233 fn from(s: String) -> Self {
234 AssistantContent { text: s }
235 }
236}
237
238impl FromStr for AssistantContent {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(AssistantContent {
243 text: s.to_string(),
244 })
245 }
246}
247
248#[derive(Clone)]
249pub struct CompletionModel {
250 pub(crate) client: Client,
251 pub model: String,
252}
253
254impl CompletionModel {
255 pub fn new(client: Client, model: &str) -> Self {
256 Self {
257 client,
258 model: model.to_string(),
259 }
260 }
261
262 pub(crate) fn create_completion_request(
263 &self,
264 completion_request: CompletionRequest,
265 ) -> Result<Value, CompletionError> {
266 let mut partial_history = vec![];
267 if let Some(docs) = completion_request.normalized_documents() {
268 partial_history.push(docs);
269 }
270
271 partial_history.extend(completion_request.chat_history);
272
273 let mut full_history: Vec<Message> = match &completion_request.preamble {
274 Some(preamble) => vec![Message::system(preamble.clone())],
275 None => vec![],
276 };
277
278 full_history.extend(
279 partial_history
280 .into_iter()
281 .map(message::Message::try_into)
282 .collect::<Result<Vec<Vec<Message>>, _>>()?
283 .into_iter()
284 .flatten()
285 .collect::<Vec<_>>(),
286 );
287
288 let request = if completion_request.tools.is_empty() {
289 json!({
290 "model": self.model,
291 "messages": full_history,
292
293 })
294 } else {
295 json!({
296 "model": self.model,
297 "messages": full_history,
298 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
299 "tool_choice": "auto",
300 })
301 };
302
303 let request = if let Some(temperature) = completion_request.temperature {
304 json_utils::merge(
305 request,
306 json!({
307 "temperature": temperature,
308 }),
309 )
310 } else {
311 request
312 };
313
314 let request = if let Some(params) = completion_request.additional_params {
315 json_utils::merge(request, params)
316 } else {
317 request
318 };
319
320 Ok(request)
321 }
322}
323
324#[derive(Debug, Deserialize, Clone)]
325pub struct CompletionResponse {
326 pub id: String,
327 pub object: String,
328 pub created: u64,
329 pub model: String,
330 pub system_fingerprint: Option<String>,
331 pub choices: Vec<Choice>,
332 pub usage: Option<Usage>,
333}
334
335impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
336 type Error = CompletionError;
337
338 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
339 let choice = response.choices.first().ok_or_else(|| {
340 CompletionError::ResponseError("Response contained no choices".to_owned())
341 })?;
342 let content = match &choice.message {
343 Message::Assistant {
344 content,
345 tool_calls,
346 ..
347 } => {
348 let mut content = if content.is_empty() {
349 vec![]
350 } else {
351 vec![completion::AssistantContent::text(content.clone())]
352 };
353
354 content.extend(
355 tool_calls
356 .iter()
357 .map(|call| {
358 completion::AssistantContent::tool_call(
359 &call.id,
360 &call.function.name,
361 call.function.arguments.clone(),
362 )
363 })
364 .collect::<Vec<_>>(),
365 );
366 Ok(content)
367 }
368 _ => Err(CompletionError::ResponseError(
369 "Response did not contain a valid message or tool call".into(),
370 )),
371 }?;
372
373 let choice = OneOrMany::many(content).map_err(|_| {
374 CompletionError::ResponseError(
375 "Response contained no message or tool call (empty)".to_owned(),
376 )
377 })?;
378
379 let usage = response
380 .usage
381 .as_ref()
382 .map(|usage| completion::Usage {
383 input_tokens: usage.prompt_tokens as u64,
384 output_tokens: (usage.total_tokens - usage.prompt_tokens) as u64,
385 total_tokens: usage.total_tokens as u64,
386 })
387 .unwrap_or_default();
388
389 Ok(completion::CompletionResponse {
390 choice,
391 usage,
392 raw_response: response,
393 })
394 }
395}
396
397impl completion::CompletionModel for CompletionModel {
398 type Response = CompletionResponse;
399 type StreamingResponse = CompletionResponse;
400
401 #[cfg_attr(feature = "worker", worker::send)]
402 async fn completion(
403 &self,
404 completion_request: CompletionRequest,
405 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
406 let request = self.create_completion_request(completion_request)?;
407
408 let response = self
409 .client
410 .post("v1/chat/completions")
411 .json(&request)
412 .send()
413 .await?;
414
415 if response.status().is_success() {
416 let text = response.text().await?;
417 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
418 ApiResponse::Ok(response) => {
419 tracing::debug!(target: "rig",
420 "Mistral completion token usage: {:?}",
421 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
422 );
423 response.try_into()
424 }
425 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
426 }
427 } else {
428 Err(CompletionError::ProviderError(response.text().await?))
429 }
430 }
431
432 #[cfg_attr(feature = "worker", worker::send)]
433 async fn stream(
434 &self,
435 request: CompletionRequest,
436 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
437 let resp = self.completion(request).await?;
438
439 let stream = Box::pin(stream! {
440 for c in resp.choice.clone() {
441 match c {
442 message::AssistantContent::Text(t) => {
443 yield Ok(RawStreamingChoice::Message(t.text.clone()))
444 }
445 message::AssistantContent::ToolCall(tc) => {
446 yield Ok(RawStreamingChoice::ToolCall {
447 id: tc.id.clone(),
448 name: tc.function.name.clone(),
449 arguments: tc.function.arguments.clone(),
450 call_id: None
451 })
452 }
453 }
454 }
455
456 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
457 });
458
459 Ok(StreamingCompletionResponse::stream(stream))
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_response_deserialization() {
469 let json_data = r#"
471 {
472 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
473 "object": "chat.completion",
474 "model": "mistral-small-latest",
475 "usage": {
476 "prompt_tokens": 16,
477 "completion_tokens": 34,
478 "total_tokens": 50
479 },
480 "created": 1702256327,
481 "choices": [
482 {
483 "index": 0,
484 "message": {
485 "content": "string",
486 "tool_calls": [
487 {
488 "id": "null",
489 "type": "function",
490 "function": {
491 "name": "string",
492 "arguments": "{ }"
493 },
494 "index": 0
495 }
496 ],
497 "prefix": false,
498 "role": "assistant"
499 },
500 "finish_reason": "stop"
501 }
502 ]
503 }
504 "#;
505 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
506 assert_eq!(completion_response.model, MISTRAL_SMALL);
507
508 let CompletionResponse {
509 id,
510 object,
511 created,
512 choices,
513 usage,
514 ..
515 } = completion_response;
516
517 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
518
519 let Usage {
520 completion_tokens,
521 prompt_tokens,
522 total_tokens,
523 } = usage.unwrap();
524
525 assert_eq!(prompt_tokens, 16);
526 assert_eq!(completion_tokens, 34);
527 assert_eq!(total_tokens, 50);
528 assert_eq!(object, "chat.completion".to_string());
529 assert_eq!(created, 1702256327);
530 assert_eq!(choices.len(), 1);
531 }
532}