1pub mod chat_request;
2
3use crate::chat::chat_request::ChatRequest;
4use crate::{MistralApiError, MistralClient, MistralError};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use tracing::{info, trace};
8
9#[derive(Serialize, Deserialize, Debug)]
14pub struct ChatResponse {
15 id: String,
16 object: String,
17 created: u64,
18 model: String,
19 pub choices: Vec<Choice>,
20 usage: Usage,
21}
22
23#[derive(Serialize, Deserialize, Debug)]
24pub struct Choice {
25 index: u32,
26 pub message: MessageContent,
27 finish_reason: String,
28}
29
30#[derive(Serialize, Deserialize, Debug)]
31pub struct MessageContent {
32 role: String,
33 tool_calls: Option<serde_json::Value>,
34 pub content: String,
35}
36
37#[derive(Serialize, Deserialize, Debug)]
38pub struct Usage {
39 prompt_tokens: u32,
40 total_tokens: u32,
41 completion_tokens: u32,
42}
43
44
45
46
47pub struct ChatClient<'a> {
48 mistral_client: &'a MistralClient,
49 chat_path: String,
50 model: String,
51 temperature: f32,
52}
53
54impl<'a> ChatClient<'a> {
55 pub fn new(mistral_client: &'a MistralClient, model: &str, temperature: f32) -> Self {
56 ChatClient {
57 mistral_client,
58 chat_path: "chat/completions".to_string(),
59 model: model.to_string(),
60 temperature,
61 }
62 }
63
64 pub fn request_builder<S: Into<String>>(&self, system_prompt: S) -> chat_request::ChatRequestBuilder {
65 chat_request::ChatRequestBuilder::new(self.model.clone(), system_prompt.into(), self.temperature)
66 .temperature(self.temperature)
67 }
68
69 pub async fn chat_complete(&self, request: &ChatRequest) -> Result<ChatResponse, MistralError> {
70 info!("Chat request to {:?}", request.model);
71 trace!("Request: {}", serde_json::to_string_pretty(request).unwrap_or("Can't serialize request".to_string()));
72
73 let response = self
74 .mistral_client
75 .client
76 .post(&format!("{}/{}", self.mistral_client.base_url, self.chat_path))
77 .bearer_auth(&self.mistral_client.api_key)
78 .json(request)
79 .send()
80 .await
81 .map_err(MistralError::Network)?;
82
83 let status = response.status();
84 let text = response.text().await.map_err(MistralError::Network)?;
85 trace!("Response: {}", text);
86
87 if !status.is_success() {
88 let api_error: Result<MistralApiError, _> = serde_json::from_str(&text);
90 return match api_error {
91 Ok(err) => Err(MistralError::Api(err)),
92 Err(_) => Err(MistralError::Http(status)),
93 };
94 }
95
96 serde_json::from_str(&text).map_err(MistralError::Parse)
98 }
99
100 pub async fn chat_complete_struct<T> (
101 &self,
102 request: &ChatRequest,
103 ) -> Result<T, MistralError>
104 where
105 T: DeserializeOwned,
106 {
107 let response = self.chat_complete(request).await?;
108 let content = response.choices[0].message.content.clone();
109 extract_struct_from_chat_response::<T>(&content)
110 }
111
112}
113
114
115pub fn extract_struct_from_chat_response<T>(content: &str) -> Result<T, MistralError>
116where
117 T: DeserializeOwned,
118{
119 let json_str = content
120 .trim()
121 .trim_start_matches("```json")
122 .trim_end_matches("```")
123 .trim();
124
125 serde_json::from_str(json_str).map_err(MistralError::Parse)
126}