Skip to main content

fx_mistral/chat/
mod.rs

1pub mod chat_request;
2
3use crate::chat::chat_request::ChatRequest;
4use crate::{MistralApiError, MistralClient, MistralError};
5use serde::de::DeserializeOwned;
6use serde::{Deserialize, Serialize};
7use tracing::{info, trace};
8
9//
10// Chat Response structs.
11//
12
13#[derive(Serialize, Deserialize, Debug)]
14pub struct ChatResponse {
15    id: String,
16    object: String,
17    created: u64,
18    model: String,
19    pub choices: Vec<Choice>,
20    usage: Usage,
21}
22
23#[derive(Serialize, Deserialize, Debug)]
24pub struct Choice {
25    index: u32,
26    pub message: MessageContent,
27    finish_reason: String,
28}
29
30#[derive(Serialize, Deserialize, Debug)]
31pub struct MessageContent {
32    role: String,
33    tool_calls: Option<serde_json::Value>,
34    pub content: String,
35}
36
37#[derive(Serialize, Deserialize, Debug)]
38pub struct Usage {
39    prompt_tokens: u32,
40    total_tokens: u32,
41    completion_tokens: u32,
42}
43
44
45
46
47pub struct ChatClient<'a> {
48    mistral_client: &'a MistralClient,
49    chat_path: String,
50    model: String,
51    temperature: f32,
52}
53
54impl<'a> ChatClient<'a> {
55    pub fn new(mistral_client: &'a MistralClient, model: &str, temperature: f32) -> Self {
56        ChatClient {
57            mistral_client,
58            chat_path: "chat/completions".to_string(),
59            model: model.to_string(),
60            temperature,
61        }
62    }
63
64    pub fn request_builder<S: Into<String>>(&self, system_prompt: S) -> chat_request::ChatRequestBuilder {
65        chat_request::ChatRequestBuilder::new(self.model.clone(), system_prompt.into(), self.temperature)
66            .temperature(self.temperature)
67    }
68
69    pub async fn chat_complete(&self, request: &ChatRequest) -> Result<ChatResponse, MistralError> {
70        info!("Chat request to {:?}", request.model);
71        trace!("Request: {}", serde_json::to_string_pretty(request).unwrap_or("Can't serialize request".to_string()));
72
73        let response = self
74            .mistral_client
75            .client
76            .post(&format!("{}/{}", self.mistral_client.base_url, self.chat_path))
77            .bearer_auth(&self.mistral_client.api_key)
78            .json(request)
79            .send()
80            .await
81            .map_err(MistralError::Network)?;
82
83        let status = response.status();
84        let text = response.text().await.map_err(MistralError::Network)?;
85        trace!("Response: {}", text);
86
87        if !status.is_success() {
88            // Try to parse API error JSON
89            let api_error: Result<MistralApiError, _> = serde_json::from_str(&text);
90            return match api_error {
91                Ok(err) => Err(MistralError::Api(err)),
92                Err(_) => Err(MistralError::Http(status)),
93            };
94        }
95
96        // Try to parse success response
97        serde_json::from_str(&text).map_err(MistralError::Parse)
98    }
99
100    pub async fn chat_complete_struct<T> (
101        &self,
102        request: &ChatRequest,
103    ) -> Result<T, MistralError>
104    where
105        T: DeserializeOwned,
106    {
107        let response = self.chat_complete(request).await?;
108        let content = response.choices[0].message.content.clone();
109        extract_struct_from_chat_response::<T>(&content)
110    }
111
112}
113
114
115pub fn extract_struct_from_chat_response<T>(content: &str) -> Result<T, MistralError>
116where
117    T: DeserializeOwned,
118{
119    let json_str = content
120        .trim()
121        .trim_start_matches("```json")
122        .trim_end_matches("```")
123        .trim();
124
125    serde_json::from_str(json_str).map_err(MistralError::Parse)
126}