1use super::{common, types};
2use crate::impl_builder_methods;
3
4use serde::de::{self, MapAccess, SeqAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::fmt;
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
11pub enum ToolChoiceType {
12 None,
13 Auto,
14 Required,
15 ToolChoice { tool: Tool },
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
19#[serde(rename_all = "lowercase")]
20pub enum ReasoningEffort {
21 Low,
22 Medium,
23 High,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27#[serde(untagged)]
28pub enum ReasoningMode {
29 Effort { effort: ReasoningEffort },
30 MaxTokens { max_tokens: i64 },
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Reasoning {
35 #[serde(flatten)]
36 pub mode: Option<ReasoningMode>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub exclude: Option<bool>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub enabled: Option<bool>,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub struct ChatCompletionRequest {
45 pub model: String,
46 pub messages: Vec<ChatCompletionMessage>,
47 #[serde(skip_serializing_if = "Option::is_none")]
48 pub temperature: Option<f64>,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 pub top_p: Option<f64>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 pub n: Option<i64>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub response_format: Option<Value>,
55 #[serde(skip_serializing_if = "Option::is_none")]
56 pub stream: Option<bool>,
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub stop: Option<Vec<String>>,
59 #[serde(skip_serializing_if = "Option::is_none")]
60 pub max_tokens: Option<i64>,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub presence_penalty: Option<f64>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub frequency_penalty: Option<f64>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub logit_bias: Option<HashMap<String, i32>>,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub user: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub seed: Option<i64>,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 pub tools: Option<Vec<Tool>>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 pub parallel_tool_calls: Option<bool>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 #[serde(serialize_with = "serialize_tool_choice")]
77 pub tool_choice: Option<ToolChoiceType>,
78 #[serde(skip_serializing_if = "Option::is_none")]
79 pub reasoning: Option<Reasoning>,
80 #[serde(skip_serializing_if = "Option::is_none")]
86 pub transforms: Option<Vec<String>>,
87}
88
89impl ChatCompletionRequest {
90 pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
91 Self {
92 model,
93 messages,
94 temperature: None,
95 top_p: None,
96 stream: None,
97 n: None,
98 response_format: None,
99 stop: None,
100 max_tokens: None,
101 presence_penalty: None,
102 frequency_penalty: None,
103 logit_bias: None,
104 user: None,
105 seed: None,
106 tools: None,
107 parallel_tool_calls: None,
108 tool_choice: None,
109 reasoning: None,
110 transforms: None,
111 }
112 }
113}
114
115impl_builder_methods!(
116 ChatCompletionRequest,
117 temperature: f64,
118 top_p: f64,
119 n: i64,
120 response_format: Value,
121 stream: bool,
122 stop: Vec<String>,
123 max_tokens: i64,
124 presence_penalty: f64,
125 frequency_penalty: f64,
126 logit_bias: HashMap<String, i32>,
127 user: String,
128 seed: i64,
129 tools: Vec<Tool>,
130 parallel_tool_calls: bool,
131 tool_choice: ToolChoiceType,
132 reasoning: Reasoning,
133 transforms: Vec<String>
134);
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
137#[allow(non_camel_case_types)]
138pub enum MessageRole {
139 user,
140 system,
141 assistant,
142 function,
143 tool,
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub enum Content {
148 Text(String),
149 ImageUrl(Vec<ImageUrl>),
150}
151
152impl serde::Serialize for Content {
153 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
154 where
155 S: serde::Serializer,
156 {
157 match *self {
158 Content::Text(ref text) => {
159 if text.is_empty() {
160 serializer.serialize_none()
161 } else {
162 serializer.serialize_str(text)
163 }
164 }
165 Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
166 }
167 }
168}
169
170impl<'de> Deserialize<'de> for Content {
171 fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
172 where
173 D: Deserializer<'de>,
174 {
175 struct ContentVisitor;
176
177 impl<'de> Visitor<'de> for ContentVisitor {
178 type Value = Content;
179
180 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
181 formatter.write_str("a valid content type")
182 }
183
184 fn visit_str<E>(self, value: &str) -> Result<Content, E>
185 where
186 E: de::Error,
187 {
188 Ok(Content::Text(value.to_string()))
189 }
190
191 fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
192 where
193 A: SeqAccess<'de>,
194 {
195 let image_urls: Vec<ImageUrl> =
196 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
197 Ok(Content::ImageUrl(image_urls))
198 }
199
200 fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
201 where
202 M: MapAccess<'de>,
203 {
204 let image_urls: Vec<ImageUrl> =
205 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
206 Ok(Content::ImageUrl(image_urls))
207 }
208
209 fn visit_none<E>(self) -> Result<Self::Value, E>
210 where
211 E: de::Error,
212 {
213 Ok(Content::Text(String::new()))
214 }
215
216 fn visit_unit<E>(self) -> Result<Self::Value, E>
217 where
218 E: de::Error,
219 {
220 Ok(Content::Text(String::new()))
221 }
222 }
223
224 deserializer.deserialize_any(ContentVisitor)
225 }
226}
227
228#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
229#[allow(non_camel_case_types)]
230pub enum ContentType {
231 text,
232 image_url,
233}
234
235#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
236#[allow(non_camel_case_types)]
237pub struct ImageUrlType {
238 pub url: String,
239}
240
241#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
242#[allow(non_camel_case_types)]
243pub struct ImageUrl {
244 pub r#type: ContentType,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 pub text: Option<String>,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub image_url: Option<ImageUrlType>,
249}
250
251#[derive(Debug, Deserialize, Serialize, Clone)]
252pub struct ChatCompletionMessage {
253 pub role: MessageRole,
254 pub content: Content,
255 #[serde(skip_serializing_if = "Option::is_none")]
256 pub name: Option<String>,
257 #[serde(skip_serializing_if = "Option::is_none")]
258 pub tool_calls: Option<Vec<ToolCall>>,
259 #[serde(skip_serializing_if = "Option::is_none")]
260 pub tool_call_id: Option<String>,
261}
262
263#[derive(Debug, Deserialize, Serialize, Clone)]
264pub struct ChatCompletionMessageForResponse {
265 pub role: MessageRole,
266 #[serde(skip_serializing_if = "Option::is_none")]
267 pub content: Option<String>,
268 #[serde(skip_serializing_if = "Option::is_none")]
269 pub reasoning_content: Option<String>,
270 #[serde(skip_serializing_if = "Option::is_none")]
271 pub name: Option<String>,
272 #[serde(skip_serializing_if = "Option::is_none")]
273 pub tool_calls: Option<Vec<ToolCall>>,
274}
275
276#[derive(Debug, Deserialize, Serialize)]
277pub struct ChatCompletionChoice {
278 pub index: i64,
279 pub message: ChatCompletionMessageForResponse,
280 pub finish_reason: Option<FinishReason>,
281 pub finish_details: Option<FinishDetails>,
282}
283
284#[derive(Debug, Deserialize, Serialize)]
285pub struct ChatCompletionResponse {
286 pub id: Option<String>,
287 pub object: String,
288 pub created: i64,
289 pub model: String,
290 pub choices: Vec<ChatCompletionChoice>,
291 pub usage: common::Usage,
292 pub system_fingerprint: Option<String>,
293}
294
295#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
296#[allow(non_camel_case_types)]
297pub enum FinishReason {
298 stop,
299 length,
300 content_filter,
301 tool_calls,
302 null,
303}
304
305#[derive(Debug, Deserialize, Serialize)]
306#[allow(non_camel_case_types)]
307pub struct FinishDetails {
308 pub r#type: FinishReason,
309 pub stop: String,
310}
311
312#[derive(Debug, Deserialize, Serialize, Clone)]
313pub struct ToolCall {
314 pub id: String,
315 pub r#type: String,
316 pub function: ToolCallFunction,
317}
318
319#[derive(Debug, Deserialize, Serialize, Clone)]
320pub struct ToolCallFunction {
321 #[serde(skip_serializing_if = "Option::is_none")]
322 pub name: Option<String>,
323 #[serde(skip_serializing_if = "Option::is_none")]
324 pub arguments: Option<String>,
325}
326
327fn serialize_tool_choice<S>(
328 value: &Option<ToolChoiceType>,
329 serializer: S,
330) -> Result<S::Ok, S::Error>
331where
332 S: Serializer,
333{
334 match value {
335 Some(ToolChoiceType::None) => serializer.serialize_str("none"),
336 Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
337 Some(ToolChoiceType::Required) => serializer.serialize_str("required"),
338 Some(ToolChoiceType::ToolChoice { tool }) => {
339 let mut map = serializer.serialize_map(Some(2))?;
340 map.serialize_entry("type", &tool.r#type)?;
341 map.serialize_entry("function", &tool.function)?;
342 map.end()
343 }
344 None => serializer.serialize_none(),
345 }
346}
347
348#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
349pub struct Tool {
350 pub r#type: ToolType,
351 pub function: types::Function,
352}
353
354#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
355#[serde(rename_all = "snake_case")]
356pub enum ToolType {
357 Function,
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use serde_json::json;
364
365 #[test]
366 fn test_reasoning_effort_serialization() {
367 let reasoning = Reasoning {
368 mode: Some(ReasoningMode::Effort {
369 effort: ReasoningEffort::High,
370 }),
371 exclude: Some(false),
372 enabled: None,
373 };
374
375 let serialized = serde_json::to_value(&reasoning).unwrap();
376 let expected = json!({
377 "effort": "high",
378 "exclude": false
379 });
380
381 assert_eq!(serialized, expected);
382 }
383
384 #[test]
385 fn test_reasoning_max_tokens_serialization() {
386 let reasoning = Reasoning {
387 mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }),
388 exclude: None,
389 enabled: Some(true),
390 };
391
392 let serialized = serde_json::to_value(&reasoning).unwrap();
393 let expected = json!({
394 "max_tokens": 2000,
395 "enabled": true
396 });
397
398 assert_eq!(serialized, expected);
399 }
400
401 #[test]
402 fn test_reasoning_deserialization() {
403 let json_str = r#"{"effort": "medium", "exclude": true}"#;
404 let reasoning: Reasoning = serde_json::from_str(json_str).unwrap();
405
406 match reasoning.mode {
407 Some(ReasoningMode::Effort { effort }) => {
408 assert_eq!(effort, ReasoningEffort::Medium);
409 }
410 _ => panic!("Expected effort mode"),
411 }
412 assert_eq!(reasoning.exclude, Some(true));
413 }
414
415 #[test]
416 fn test_chat_completion_request_with_reasoning() {
417 let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
418
419 req.reasoning = Some(Reasoning {
420 mode: Some(ReasoningMode::Effort {
421 effort: ReasoningEffort::Low,
422 }),
423 exclude: None,
424 enabled: None,
425 });
426
427 let serialized = serde_json::to_value(&req).unwrap();
428 assert_eq!(serialized["reasoning"]["effort"], "low");
429 }
430
431 #[test]
432 fn test_transforms_none_serialization() {
433 let req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
434 let serialised = serde_json::to_value(&req).unwrap();
435 assert!(!serialised.as_object().unwrap().contains_key("transforms"));
437 }
438
439 #[test]
440 fn test_transforms_some_serialization() {
441 let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
442 req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]);
443 let serialised = serde_json::to_value(&req).unwrap();
444 assert_eq!(
446 serialised["transforms"],
447 serde_json::json!(["transform1", "transform2"])
448 );
449 }
450
451 #[test]
452 fn test_transforms_some_deserialization() {
453 let json_str =
454 r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#;
455 let req: ChatCompletionRequest = serde_json::from_str(json_str).unwrap();
456 assert_eq!(
458 req.transforms,
459 Some(vec!["transform1".to_string(), "transform2".to_string()])
460 );
461 }
462
463 #[test]
464 fn test_transforms_none_deserialization() {
465 let json_str = r#"{"model": "gpt-4", "messages": []}"#;
466 let req: ChatCompletionRequest = serde_json::from_str(json_str).unwrap();
467 assert_eq!(req.transforms, None);
469 }
470
471 #[test]
472 fn test_transforms_builder_method() {
473 let transforms = vec!["transform1".to_string(), "transform2".to_string()];
474 let req =
475 ChatCompletionRequest::new("gpt-4".to_string(), vec![]).transforms(transforms.clone());
476 assert_eq!(req.transforms, Some(transforms));
478 }
479}