Skip to main content

fx_mistral/chat/
chat_request.rs

1use serde::{Deserialize, Serialize};
2
3//
4// Chat Request structs
5//
6
7#[derive(Serialize, Deserialize, Debug)]
8pub struct ChatRequest {
9    pub model: String,
10    pub messages: Vec<Message>,
11    #[serde(skip_serializing_if = "Option::is_none")]
12    pub response_format: Option<ResponseFormat>,
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub max_tokens: Option<u32>,
15    #[serde(skip_serializing_if = "Option::is_none")]
16    pub temperature: Option<f32>,
17}
18
19#[derive(Serialize, Deserialize, Debug)]
20pub struct ResponseFormat {
21    #[serde(rename = "type")]
22    pub format_type: String, // should be "json_schema"
23    pub json_schema: JsonSchemaFormat,
24}
25
26#[derive(Serialize, Deserialize, Debug)]
27pub struct JsonSchemaFormat {
28    pub schema: serde_json::Value, // Accept raw JSON schema as a serde_json::Value
29    pub name: String,
30    pub strict: bool,
31}
32
33#[derive(Serialize, Deserialize, Debug)]
34#[serde(untagged)]
35pub enum Message {
36    Simple {
37        role: String,
38        content: String, // For system messages
39    },
40    WithContentArray {
41        role: String,
42        content: Vec<Content>, // For user messages with docs/texts
43    },
44}
45
46#[derive(Serialize, Deserialize, Debug)]
47#[serde(tag = "type")]
48pub enum Content {
49    #[serde(rename = "text")]
50    Text { text: String },
51    #[serde(rename = "document_url")]
52    DocumentUrl { document_url: String },
53}
54
55pub struct Messages {
56    pub messages: Vec<Message>,
57}
58
59//
60// Chat Request Builder
61//
62
63pub struct ChatRequestBuilder {
64    model: String,
65    messages: Vec<Message>,
66    response_format: Option<ResponseFormat>,
67    max_tokens: Option<u32>,
68    temperature: Option<f32>,
69}
70
71impl ChatRequestBuilder {
72    pub fn new<S: Into<String>>(model: S, system_message: S, temparatur: f32) -> Self {
73        ChatRequestBuilder {
74            model: model.into(),
75            messages: vec![Message::Simple {
76                role: "system".to_string(),
77                content: system_message.into(),
78            }],
79            response_format: None,
80            max_tokens: None,
81            temperature: Some(temparatur),
82        }
83    }
84
85    pub fn add_user_message<S: Into<String>>(mut self, text: S) -> Self {
86        self.messages.push(Message::WithContentArray {
87            role: "user".to_string(),
88            content: vec![Content::Text { text: text.into() }],
89        });
90        self
91    }
92
93    pub fn add_document_message<S: Into<String>>(mut self, text: S, document_url: S) -> Self {
94        self.messages.push(Message::WithContentArray {
95            role: "user".to_string(),
96            content: vec![
97                Content::Text { text: text.into() },
98                Content::DocumentUrl {
99                    document_url: document_url.into(),
100                },
101            ],
102        });
103        self
104    }
105
106    pub fn response_format_from_json<S: Into<String>>(mut self, schema_json: S, name: S, strict: bool) -> Self {
107        let schema_value: serde_json::Value = serde_json::from_str(&schema_json.into()).expect("Invalid JSON schema");
108        self.response_format = Some(ResponseFormat {
109            format_type: "json_schema".to_string(),
110            json_schema: JsonSchemaFormat {
111                schema: schema_value,
112                name: name.into(),
113                strict,
114            },
115        });
116        self
117    }
118
119    pub fn max_tokens(mut self, value: u32) -> Self {
120        self.max_tokens = Some(value);
121        self
122    }
123
124    pub fn temperature(mut self, value: f32) -> Self {
125        self.temperature = Some(value);
126        self
127    }
128
129    pub fn build(self) -> ChatRequest {
130        ChatRequest {
131            model: self.model,
132            messages: self.messages,
133            response_format: self.response_format,
134            max_tokens: self.max_tokens,
135            temperature: self.temperature,
136        }
137    }
138}