1use super::{common, types};
2use crate::impl_builder_methods;
3
4use serde::de::{self, MapAccess, SeqAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::fmt;
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
11pub enum ToolChoiceType {
12 None,
13 Auto,
14 Required,
15 ToolChoice { tool: Tool },
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
19#[serde(rename_all = "lowercase")]
20pub enum ReasoningEffort {
21 Low,
22 Medium,
23 High,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27#[serde(untagged)]
28pub enum ReasoningMode {
29 Effort { effort: ReasoningEffort },
30 MaxTokens { max_tokens: i64 },
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Reasoning {
35 #[serde(flatten)]
36 pub mode: Option<ReasoningMode>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub exclude: Option<bool>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub enabled: Option<bool>,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub struct ChatCompletionRequest {
45 pub model: String,
46 pub messages: Vec<ChatCompletionMessage>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub temperature: Option<f64>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub top_p: Option<f64>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub n: Option<i64>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub response_format: Option<Value>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub stream: Option<bool>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub stop: Option<Vec<String>>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub max_tokens: Option<i64>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub presence_penalty: Option<f64>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub frequency_penalty: Option<f64>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub logit_bias: Option<HashMap<String, i32>>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub user: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub seed: Option<i64>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub tools: Option<Vec<Tool>>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub parallel_tool_calls: Option<bool>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 #[serde(serialize_with = "serialize_tool_choice")]
77 pub tool_choice: Option<ToolChoiceType>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub reasoning: Option<Reasoning>,
80}
81
82impl ChatCompletionRequest {
83 pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
84 Self {
85 model,
86 messages,
87 temperature: None,
88 top_p: None,
89 stream: None,
90 n: None,
91 response_format: None,
92 stop: None,
93 max_tokens: None,
94 presence_penalty: None,
95 frequency_penalty: None,
96 logit_bias: None,
97 user: None,
98 seed: None,
99 tools: None,
100 parallel_tool_calls: None,
101 tool_choice: None,
102 reasoning: None,
103 }
104 }
105}
106
107impl_builder_methods!(
108 ChatCompletionRequest,
109 temperature: f64,
110 top_p: f64,
111 n: i64,
112 response_format: Value,
113 stream: bool,
114 stop: Vec<String>,
115 max_tokens: i64,
116 presence_penalty: f64,
117 frequency_penalty: f64,
118 logit_bias: HashMap<String, i32>,
119 user: String,
120 seed: i64,
121 tools: Vec<Tool>,
122 parallel_tool_calls: bool,
123 tool_choice: ToolChoiceType,
124 reasoning: Reasoning
125);
126
127#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
128#[allow(non_camel_case_types)]
129pub enum MessageRole {
130 user,
131 system,
132 assistant,
133 function,
134 tool,
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum Content {
139 Text(String),
140 ImageUrl(Vec<ImageUrl>),
141}
142
143impl serde::Serialize for Content {
144 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145 where
146 S: serde::Serializer,
147 {
148 match *self {
149 Content::Text(ref text) => {
150 if text.is_empty() {
151 serializer.serialize_none()
152 } else {
153 serializer.serialize_str(text)
154 }
155 }
156 Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
157 }
158 }
159}
160
161impl<'de> Deserialize<'de> for Content {
162 fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
163 where
164 D: Deserializer<'de>,
165 {
166 struct ContentVisitor;
167
168 impl<'de> Visitor<'de> for ContentVisitor {
169 type Value = Content;
170
171 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
172 formatter.write_str("a valid content type")
173 }
174
175 fn visit_str<E>(self, value: &str) -> Result<Content, E>
176 where
177 E: de::Error,
178 {
179 Ok(Content::Text(value.to_string()))
180 }
181
182 fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
183 where
184 A: SeqAccess<'de>,
185 {
186 let image_urls: Vec<ImageUrl> =
187 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
188 Ok(Content::ImageUrl(image_urls))
189 }
190
191 fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
192 where
193 M: MapAccess<'de>,
194 {
195 let image_urls: Vec<ImageUrl> =
196 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
197 Ok(Content::ImageUrl(image_urls))
198 }
199
200 fn visit_none<E>(self) -> Result<Self::Value, E>
201 where
202 E: de::Error,
203 {
204 Ok(Content::Text(String::new()))
205 }
206
207 fn visit_unit<E>(self) -> Result<Self::Value, E>
208 where
209 E: de::Error,
210 {
211 Ok(Content::Text(String::new()))
212 }
213 }
214
215 deserializer.deserialize_any(ContentVisitor)
216 }
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
220#[allow(non_camel_case_types)]
221pub enum ContentType {
222 text,
223 image_url,
224}
225
226#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
227#[allow(non_camel_case_types)]
228pub struct ImageUrlType {
229 pub url: String,
230}
231
232#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
233#[allow(non_camel_case_types)]
234pub struct ImageUrl {
235 pub r#type: ContentType,
236 #[serde(skip_serializing_if = "Option::is_none")]
237 pub text: Option<String>,
238 #[serde(skip_serializing_if = "Option::is_none")]
239 pub image_url: Option<ImageUrlType>,
240}
241
242#[derive(Debug, Deserialize, Serialize, Clone)]
243pub struct ChatCompletionMessage {
244 pub role: MessageRole,
245 pub content: Content,
246 #[serde(skip_serializing_if = "Option::is_none")]
247 pub name: Option<String>,
248 #[serde(skip_serializing_if = "Option::is_none")]
249 pub tool_calls: Option<Vec<ToolCall>>,
250 #[serde(skip_serializing_if = "Option::is_none")]
251 pub tool_call_id: Option<String>,
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone)]
255pub struct ChatCompletionMessageForResponse {
256 pub role: MessageRole,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 pub content: Option<String>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub reasoning_content: Option<String>,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 pub name: Option<String>,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 pub tool_calls: Option<Vec<ToolCall>>,
265}
266
267#[derive(Debug, Deserialize, Serialize)]
268pub struct ChatCompletionChoice {
269 pub index: i64,
270 pub message: ChatCompletionMessageForResponse,
271 pub finish_reason: Option<FinishReason>,
272 pub finish_details: Option<FinishDetails>,
273}
274
275#[derive(Debug, Deserialize, Serialize)]
276pub struct ChatCompletionResponse {
277 pub id: Option<String>,
278 pub object: String,
279 pub created: i64,
280 pub model: String,
281 pub choices: Vec<ChatCompletionChoice>,
282 pub usage: common::Usage,
283 pub system_fingerprint: Option<String>,
284}
285
286#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
287#[allow(non_camel_case_types)]
288pub enum FinishReason {
289 stop,
290 length,
291 content_filter,
292 tool_calls,
293 null,
294}
295
296#[derive(Debug, Deserialize, Serialize)]
297#[allow(non_camel_case_types)]
298pub struct FinishDetails {
299 pub r#type: FinishReason,
300 pub stop: String,
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone)]
304pub struct ToolCall {
305 pub id: String,
306 pub r#type: String,
307 pub function: ToolCallFunction,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone)]
311pub struct ToolCallFunction {
312 #[serde(skip_serializing_if = "Option::is_none")]
313 pub name: Option<String>,
314 #[serde(skip_serializing_if = "Option::is_none")]
315 pub arguments: Option<String>,
316}
317
318fn serialize_tool_choice<S>(
319 value: &Option<ToolChoiceType>,
320 serializer: S,
321) -> Result<S::Ok, S::Error>
322where
323 S: Serializer,
324{
325 match value {
326 Some(ToolChoiceType::None) => serializer.serialize_str("none"),
327 Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
328 Some(ToolChoiceType::Required) => serializer.serialize_str("required"),
329 Some(ToolChoiceType::ToolChoice { tool }) => {
330 let mut map = serializer.serialize_map(Some(2))?;
331 map.serialize_entry("type", &tool.r#type)?;
332 map.serialize_entry("function", &tool.function)?;
333 map.end()
334 }
335 None => serializer.serialize_none(),
336 }
337}
338
339#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
340pub struct Tool {
341 pub r#type: ToolType,
342 pub function: types::Function,
343}
344
345#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
346#[serde(rename_all = "snake_case")]
347pub enum ToolType {
348 Function,
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use serde_json::json;
355
356 #[test]
357 fn test_reasoning_effort_serialization() {
358 let reasoning = Reasoning {
359 mode: Some(ReasoningMode::Effort {
360 effort: ReasoningEffort::High,
361 }),
362 exclude: Some(false),
363 enabled: None,
364 };
365
366 let serialized = serde_json::to_value(&reasoning).unwrap();
367 let expected = json!({
368 "effort": "high",
369 "exclude": false
370 });
371
372 assert_eq!(serialized, expected);
373 }
374
375 #[test]
376 fn test_reasoning_max_tokens_serialization() {
377 let reasoning = Reasoning {
378 mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }),
379 exclude: None,
380 enabled: Some(true),
381 };
382
383 let serialized = serde_json::to_value(&reasoning).unwrap();
384 let expected = json!({
385 "max_tokens": 2000,
386 "enabled": true
387 });
388
389 assert_eq!(serialized, expected);
390 }
391
392 #[test]
393 fn test_reasoning_deserialization() {
394 let json_str = r#"{"effort": "medium", "exclude": true}"#;
395 let reasoning: Reasoning = serde_json::from_str(json_str).unwrap();
396
397 match reasoning.mode {
398 Some(ReasoningMode::Effort { effort }) => {
399 assert_eq!(effort, ReasoningEffort::Medium);
400 }
401 _ => panic!("Expected effort mode"),
402 }
403 assert_eq!(reasoning.exclude, Some(true));
404 }
405
406 #[test]
407 fn test_chat_completion_request_with_reasoning() {
408 let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
409
410 req.reasoning = Some(Reasoning {
411 mode: Some(ReasoningMode::Effort {
412 effort: ReasoningEffort::Low,
413 }),
414 exclude: None,
415 enabled: None,
416 });
417
418 let serialized = serde_json::to_value(&req).unwrap();
419 assert_eq!(serialized["reasoning"]["effort"], "low");
420 }
421}