1use async_stream::stream;
2use serde::{Deserialize, Serialize};
3use serde_json::{Value, json};
4use std::{convert::Infallible, str::FromStr};
5
6use super::client::{Client, Usage};
7use crate::streaming::{RawStreamingChoice, StreamingCompletionResponse};
8use crate::{
9 OneOrMany,
10 completion::{self, CompletionError, CompletionRequest},
11 json_utils, message,
12 providers::mistral::client::ApiResponse,
13};
14
15pub const CODESTRAL: &str = "codestral-latest";
16pub const MISTRAL_LARGE: &str = "mistral-large-latest";
17pub const PIXTRAL_LARGE: &str = "pixtral-large-latest";
18pub const MISTRAL_SABA: &str = "mistral-saba-latest";
19pub const MINISTRAL_3B: &str = "ministral-3b-latest";
20pub const MINISTRAL_8B: &str = "ministral-8b-latest";
21
22pub const MISTRAL_SMALL: &str = "mistral-small-latest";
24pub const PIXTRAL_SMALL: &str = "pixtral-12b-2409";
25pub const MISTRAL_NEMO: &str = "open-mistral-nemo";
26pub const CODESTRAL_MAMBA: &str = "open-codestral-mamba";
27
28#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
33#[serde(tag = "type", rename_all = "lowercase")]
34pub struct AssistantContent {
35 text: String,
36}
37
38#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
39#[serde(tag = "type", rename_all = "lowercase")]
40pub enum UserContent {
41 Text { text: String },
42}
43
44#[derive(Debug, Serialize, Deserialize, Clone)]
45pub struct Choice {
46 pub index: usize,
47 pub message: Message,
48 pub logprobs: Option<serde_json::Value>,
49 pub finish_reason: String,
50}
51
52#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
53#[serde(tag = "role", rename_all = "lowercase")]
54pub enum Message {
55 User {
56 content: String,
57 },
58 Assistant {
59 content: String,
60 #[serde(
61 default,
62 deserialize_with = "json_utils::null_or_vec",
63 skip_serializing_if = "Vec::is_empty"
64 )]
65 tool_calls: Vec<ToolCall>,
66 #[serde(default)]
67 prefix: bool,
68 },
69 System {
70 content: String,
71 },
72}
73
74impl Message {
75 pub fn user(content: String) -> Self {
76 Message::User { content }
77 }
78
79 pub fn assistant(content: String, tool_calls: Vec<ToolCall>, prefix: bool) -> Self {
80 Message::Assistant {
81 content,
82 tool_calls,
83 prefix,
84 }
85 }
86
87 pub fn system(content: String) -> Self {
88 Message::System { content }
89 }
90}
91
92impl TryFrom<message::Message> for Vec<Message> {
93 type Error = message::MessageError;
94
95 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
96 match message {
97 message::Message::User { content } => {
98 let (_, other_content): (Vec<_>, Vec<_>) = content
99 .into_iter()
100 .partition(|content| matches!(content, message::UserContent::ToolResult(_)));
101
102 let messages = other_content
103 .into_iter()
104 .filter_map(|content| match content {
105 message::UserContent::Text(message::Text { text }) => {
106 Some(Message::User { content: text })
107 }
108 _ => None,
109 })
110 .collect::<Vec<_>>();
111
112 Ok(messages)
113 }
114 message::Message::Assistant { content, .. } => {
115 let (text_content, tool_calls) = content.into_iter().fold(
116 (Vec::new(), Vec::new()),
117 |(mut texts, mut tools), content| {
118 match content {
119 message::AssistantContent::Text(text) => texts.push(text),
120 message::AssistantContent::ToolCall(tool_call) => tools.push(tool_call),
121 }
122 (texts, tools)
123 },
124 );
125
126 Ok(vec![Message::Assistant {
127 content: text_content
128 .into_iter()
129 .next()
130 .map(|content| content.text)
131 .unwrap_or_default(),
132 tool_calls: tool_calls
133 .into_iter()
134 .map(|tool_call| tool_call.into())
135 .collect::<Vec<_>>(),
136 prefix: false,
137 }])
138 }
139 }
140 }
141}
142
143#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
144pub struct ToolCall {
145 pub id: String,
146 #[serde(default)]
147 pub r#type: ToolType,
148 pub function: Function,
149}
150
151impl From<message::ToolCall> for ToolCall {
152 fn from(tool_call: message::ToolCall) -> Self {
153 Self {
154 id: tool_call.id,
155 r#type: ToolType::default(),
156 function: Function {
157 name: tool_call.function.name,
158 arguments: tool_call.function.arguments,
159 },
160 }
161 }
162}
163
164#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
165pub struct Function {
166 pub name: String,
167 #[serde(with = "json_utils::stringified_json")]
168 pub arguments: serde_json::Value,
169}
170
171#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
172#[serde(rename_all = "lowercase")]
173pub enum ToolType {
174 #[default]
175 Function,
176}
177
178#[derive(Debug, Deserialize, Serialize, Clone)]
179pub struct ToolDefinition {
180 pub r#type: String,
181 pub function: completion::ToolDefinition,
182}
183
184impl From<completion::ToolDefinition> for ToolDefinition {
185 fn from(tool: completion::ToolDefinition) -> Self {
186 Self {
187 r#type: "function".into(),
188 function: tool,
189 }
190 }
191}
192
193#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
194pub struct ToolResultContent {
195 #[serde(default)]
196 r#type: ToolResultContentType,
197 text: String,
198}
199
200#[derive(Default, Debug, Serialize, Deserialize, PartialEq, Clone)]
201#[serde(rename_all = "lowercase")]
202pub enum ToolResultContentType {
203 #[default]
204 Text,
205}
206
207impl From<String> for ToolResultContent {
208 fn from(s: String) -> Self {
209 ToolResultContent {
210 r#type: ToolResultContentType::default(),
211 text: s,
212 }
213 }
214}
215
216impl From<String> for UserContent {
217 fn from(s: String) -> Self {
218 UserContent::Text { text: s }
219 }
220}
221
222impl FromStr for UserContent {
223 type Err = Infallible;
224
225 fn from_str(s: &str) -> Result<Self, Self::Err> {
226 Ok(UserContent::Text {
227 text: s.to_string(),
228 })
229 }
230}
231
232impl From<String> for AssistantContent {
233 fn from(s: String) -> Self {
234 AssistantContent { text: s }
235 }
236}
237
238impl FromStr for AssistantContent {
239 type Err = Infallible;
240
241 fn from_str(s: &str) -> Result<Self, Self::Err> {
242 Ok(AssistantContent {
243 text: s.to_string(),
244 })
245 }
246}
247
248#[derive(Clone)]
249pub struct CompletionModel {
250 pub(crate) client: Client,
251 pub model: String,
252}
253
254impl CompletionModel {
255 pub fn new(client: Client, model: &str) -> Self {
256 Self {
257 client,
258 model: model.to_string(),
259 }
260 }
261
262 pub(crate) fn create_completion_request(
263 &self,
264 completion_request: CompletionRequest,
265 ) -> Result<Value, CompletionError> {
266 let mut partial_history = vec![];
267 if let Some(docs) = completion_request.normalized_documents() {
268 partial_history.push(docs);
269 }
270
271 partial_history.extend(completion_request.chat_history);
272
273 let mut full_history: Vec<Message> = match &completion_request.preamble {
274 Some(preamble) => vec![Message::system(preamble.clone())],
275 None => vec![],
276 };
277
278 full_history.extend(
279 partial_history
280 .into_iter()
281 .map(message::Message::try_into)
282 .collect::<Result<Vec<Vec<Message>>, _>>()?
283 .into_iter()
284 .flatten()
285 .collect::<Vec<_>>(),
286 );
287
288 let request = if completion_request.tools.is_empty() {
289 json!({
290 "model": self.model,
291 "messages": full_history,
292
293 })
294 } else {
295 json!({
296 "model": self.model,
297 "messages": full_history,
298 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
299 "tool_choice": "auto",
300 })
301 };
302
303 let request = if let Some(temperature) = completion_request.temperature {
304 json_utils::merge(
305 request,
306 json!({
307 "temperature": temperature,
308 }),
309 )
310 } else {
311 request
312 };
313
314 let request = if let Some(params) = completion_request.additional_params {
315 json_utils::merge(request, params)
316 } else {
317 request
318 };
319
320 Ok(request)
321 }
322}
323
324#[derive(Debug, Deserialize, Clone)]
325pub struct CompletionResponse {
326 pub id: String,
327 pub object: String,
328 pub created: u64,
329 pub model: String,
330 pub system_fingerprint: Option<String>,
331 pub choices: Vec<Choice>,
332 pub usage: Option<Usage>,
333}
334
335impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
336 type Error = CompletionError;
337
338 fn try_from(response: CompletionResponse) -> Result<Self, Self::Error> {
339 let choice = response.choices.first().ok_or_else(|| {
340 CompletionError::ResponseError("Response contained no choices".to_owned())
341 })?;
342 let content = match &choice.message {
343 Message::Assistant {
344 content,
345 tool_calls,
346 ..
347 } => {
348 let mut content = if content.is_empty() {
349 vec![]
350 } else {
351 vec![completion::AssistantContent::text(content.clone())]
352 };
353
354 content.extend(
355 tool_calls
356 .iter()
357 .map(|call| {
358 completion::AssistantContent::tool_call(
359 &call.id,
360 &call.function.name,
361 call.function.arguments.clone(),
362 )
363 })
364 .collect::<Vec<_>>(),
365 );
366 Ok(content)
367 }
368 _ => Err(CompletionError::ResponseError(
369 "Response did not contain a valid message or tool call".into(),
370 )),
371 }?;
372
373 let choice = OneOrMany::many(content).map_err(|_| {
374 CompletionError::ResponseError(
375 "Response contained no message or tool call (empty)".to_owned(),
376 )
377 })?;
378
379 Ok(completion::CompletionResponse {
380 choice,
381 raw_response: response,
382 })
383 }
384}
385
386impl completion::CompletionModel for CompletionModel {
387 type Response = CompletionResponse;
388 type StreamingResponse = CompletionResponse;
389
390 #[cfg_attr(feature = "worker", worker::send)]
391 async fn completion(
392 &self,
393 completion_request: CompletionRequest,
394 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
395 let request = self.create_completion_request(completion_request)?;
396
397 let response = self
398 .client
399 .post("v1/chat/completions")
400 .json(&request)
401 .send()
402 .await?;
403
404 if response.status().is_success() {
405 let text = response.text().await?;
406 match serde_json::from_str::<ApiResponse<CompletionResponse>>(&text)? {
407 ApiResponse::Ok(response) => {
408 tracing::debug!(target: "rig",
409 "Mistral completion token usage: {:?}",
410 response.usage.clone().map(|usage| format!("{usage}")).unwrap_or("N/A".to_string())
411 );
412 response.try_into()
413 }
414 ApiResponse::Err(err) => Err(CompletionError::ProviderError(err.message)),
415 }
416 } else {
417 Err(CompletionError::ProviderError(response.text().await?))
418 }
419 }
420
421 #[cfg_attr(feature = "worker", worker::send)]
422 async fn stream(
423 &self,
424 request: CompletionRequest,
425 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
426 let resp = self.completion(request).await?;
427
428 let stream = Box::pin(stream! {
429 for c in resp.choice.clone() {
430 match c {
431 message::AssistantContent::Text(t) => {
432 yield Ok(RawStreamingChoice::Message(t.text.clone()))
433 }
434 message::AssistantContent::ToolCall(tc) => {
435 yield Ok(RawStreamingChoice::ToolCall {
436 id: tc.id.clone(),
437 name: tc.function.name.clone(),
438 arguments: tc.function.arguments.clone(),
439 call_id: None
440 })
441 }
442 }
443 }
444
445 yield Ok(RawStreamingChoice::FinalResponse(resp.raw_response.clone()));
446 });
447
448 Ok(StreamingCompletionResponse::stream(stream))
449 }
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 #[test]
457 fn test_response_deserialization() {
458 let json_data = r#"
460 {
461 "id": "cmpl-e5cc70bb28c444948073e77776eb30ef",
462 "object": "chat.completion",
463 "model": "mistral-small-latest",
464 "usage": {
465 "prompt_tokens": 16,
466 "completion_tokens": 34,
467 "total_tokens": 50
468 },
469 "created": 1702256327,
470 "choices": [
471 {
472 "index": 0,
473 "message": {
474 "content": "string",
475 "tool_calls": [
476 {
477 "id": "null",
478 "type": "function",
479 "function": {
480 "name": "string",
481 "arguments": "{ }"
482 },
483 "index": 0
484 }
485 ],
486 "prefix": false,
487 "role": "assistant"
488 },
489 "finish_reason": "stop"
490 }
491 ]
492 }
493 "#;
494 let completion_response = serde_json::from_str::<CompletionResponse>(json_data).unwrap();
495 assert_eq!(completion_response.model, MISTRAL_SMALL);
496
497 let CompletionResponse {
498 id,
499 object,
500 created,
501 choices,
502 usage,
503 ..
504 } = completion_response;
505
506 assert_eq!(id, "cmpl-e5cc70bb28c444948073e77776eb30ef");
507
508 let Usage {
509 completion_tokens,
510 prompt_tokens,
511 total_tokens,
512 } = usage.unwrap();
513
514 assert_eq!(prompt_tokens, 16);
515 assert_eq!(completion_tokens, 34);
516 assert_eq!(total_tokens, 50);
517 assert_eq!(object, "chat.completion".to_string());
518 assert_eq!(created, 1702256327);
519 assert_eq!(choices.len(), 1);
520 }
521}