openai_api_rs/v1/
chat_completion.rs

1use super::{common, types};
2use crate::impl_builder_methods;
3
4use serde::de::{self, MapAccess, SeqAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::fmt;
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
11pub enum ToolChoiceType {
12    None,
13    Auto,
14    Required,
15    ToolChoice { tool: Tool },
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
19#[serde(rename_all = "lowercase")]
20pub enum ReasoningEffort {
21    Low,
22    Medium,
23    High,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27#[serde(untagged)]
28pub enum ReasoningMode {
29    Effort { effort: ReasoningEffort },
30    MaxTokens { max_tokens: i64 },
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Reasoning {
35    #[serde(flatten)]
36    pub mode: Option<ReasoningMode>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub exclude: Option<bool>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub enabled: Option<bool>,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub struct ChatCompletionRequest {
45    pub model: String,
46    pub messages: Vec<ChatCompletionMessage>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub temperature: Option<f64>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub top_p: Option<f64>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub n: Option<i64>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub response_format: Option<Value>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub stream: Option<bool>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub stop: Option<Vec<String>>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub max_tokens: Option<i64>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub presence_penalty: Option<f64>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub frequency_penalty: Option<f64>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub logit_bias: Option<HashMap<String, i32>>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub user: Option<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub seed: Option<i64>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub tools: Option<Vec<Tool>>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub parallel_tool_calls: Option<bool>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    #[serde(serialize_with = "serialize_tool_choice")]
77    pub tool_choice: Option<ToolChoiceType>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub reasoning: Option<Reasoning>,
80    /// Optional list of transforms to apply to the chat completion request.
81    ///
82    /// Transforms allow modifying the request before it's sent to the API,
83    /// enabling features like prompt rewriting, content filtering, or other
84    /// preprocessing steps. When None, no transforms are applied.
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub transforms: Option<Vec<String>>,
87}
88
89impl ChatCompletionRequest {
90    pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
91        Self {
92            model,
93            messages,
94            temperature: None,
95            top_p: None,
96            stream: None,
97            n: None,
98            response_format: None,
99            stop: None,
100            max_tokens: None,
101            presence_penalty: None,
102            frequency_penalty: None,
103            logit_bias: None,
104            user: None,
105            seed: None,
106            tools: None,
107            parallel_tool_calls: None,
108            tool_choice: None,
109            reasoning: None,
110            transforms: None,
111        }
112    }
113}
114
115impl_builder_methods!(
116    ChatCompletionRequest,
117    temperature: f64,
118    top_p: f64,
119    n: i64,
120    response_format: Value,
121    stream: bool,
122    stop: Vec<String>,
123    max_tokens: i64,
124    presence_penalty: f64,
125    frequency_penalty: f64,
126    logit_bias: HashMap<String, i32>,
127    user: String,
128    seed: i64,
129    tools: Vec<Tool>,
130    parallel_tool_calls: bool,
131    tool_choice: ToolChoiceType,
132    reasoning: Reasoning,
133    transforms: Vec<String>
134);
135
136#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
137#[allow(non_camel_case_types)]
138pub enum MessageRole {
139    user,
140    system,
141    assistant,
142    function,
143    tool,
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub enum Content {
148    Text(String),
149    ImageUrl(Vec<ImageUrl>),
150}
151
152impl serde::Serialize for Content {
153    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
154    where
155        S: serde::Serializer,
156    {
157        match *self {
158            Content::Text(ref text) => {
159                if text.is_empty() {
160                    serializer.serialize_none()
161                } else {
162                    serializer.serialize_str(text)
163                }
164            }
165            Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
166        }
167    }
168}
169
170impl<'de> Deserialize<'de> for Content {
171    fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
172    where
173        D: Deserializer<'de>,
174    {
175        struct ContentVisitor;
176
177        impl<'de> Visitor<'de> for ContentVisitor {
178            type Value = Content;
179
180            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
181                formatter.write_str("a valid content type")
182            }
183
184            fn visit_str<E>(self, value: &str) -> Result<Content, E>
185            where
186                E: de::Error,
187            {
188                Ok(Content::Text(value.to_string()))
189            }
190
191            fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
192            where
193                A: SeqAccess<'de>,
194            {
195                let image_urls: Vec<ImageUrl> =
196                    Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
197                Ok(Content::ImageUrl(image_urls))
198            }
199
200            fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
201            where
202                M: MapAccess<'de>,
203            {
204                let image_urls: Vec<ImageUrl> =
205                    Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
206                Ok(Content::ImageUrl(image_urls))
207            }
208
209            fn visit_none<E>(self) -> Result<Self::Value, E>
210            where
211                E: de::Error,
212            {
213                Ok(Content::Text(String::new()))
214            }
215
216            fn visit_unit<E>(self) -> Result<Self::Value, E>
217            where
218                E: de::Error,
219            {
220                Ok(Content::Text(String::new()))
221            }
222        }
223
224        deserializer.deserialize_any(ContentVisitor)
225    }
226}
227
228#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
229#[allow(non_camel_case_types)]
230pub enum ContentType {
231    text,
232    image_url,
233}
234
235#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
236#[allow(non_camel_case_types)]
237pub struct ImageUrlType {
238    pub url: String,
239}
240
241#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
242#[allow(non_camel_case_types)]
243pub struct ImageUrl {
244    pub r#type: ContentType,
245    #[serde(skip_serializing_if = "Option::is_none")]
246    pub text: Option<String>,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub image_url: Option<ImageUrlType>,
249}
250
251#[derive(Debug, Deserialize, Serialize, Clone)]
252pub struct ChatCompletionMessage {
253    pub role: MessageRole,
254    pub content: Content,
255    #[serde(skip_serializing_if = "Option::is_none")]
256    pub name: Option<String>,
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub tool_calls: Option<Vec<ToolCall>>,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub tool_call_id: Option<String>,
261}
262
263#[derive(Debug, Deserialize, Serialize, Clone)]
264pub struct ChatCompletionMessageForResponse {
265    pub role: MessageRole,
266    #[serde(skip_serializing_if = "Option::is_none")]
267    pub content: Option<String>,
268    #[serde(skip_serializing_if = "Option::is_none")]
269    pub reasoning_content: Option<String>,
270    #[serde(skip_serializing_if = "Option::is_none")]
271    pub name: Option<String>,
272    #[serde(skip_serializing_if = "Option::is_none")]
273    pub tool_calls: Option<Vec<ToolCall>>,
274}
275
276#[derive(Debug, Deserialize, Serialize)]
277pub struct ChatCompletionChoice {
278    pub index: i64,
279    pub message: ChatCompletionMessageForResponse,
280    pub finish_reason: Option<FinishReason>,
281    pub finish_details: Option<FinishDetails>,
282}
283
284#[derive(Debug, Deserialize, Serialize)]
285pub struct ChatCompletionResponse {
286    pub id: Option<String>,
287    pub object: String,
288    pub created: i64,
289    pub model: String,
290    pub choices: Vec<ChatCompletionChoice>,
291    pub usage: common::Usage,
292    pub system_fingerprint: Option<String>,
293}
294
295#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
296#[allow(non_camel_case_types)]
297pub enum FinishReason {
298    stop,
299    length,
300    content_filter,
301    tool_calls,
302    null,
303}
304
305#[derive(Debug, Deserialize, Serialize)]
306#[allow(non_camel_case_types)]
307pub struct FinishDetails {
308    pub r#type: FinishReason,
309    pub stop: String,
310}
311
312#[derive(Debug, Deserialize, Serialize, Clone)]
313pub struct ToolCall {
314    pub id: String,
315    pub r#type: String,
316    pub function: ToolCallFunction,
317}
318
319#[derive(Debug, Deserialize, Serialize, Clone)]
320pub struct ToolCallFunction {
321    #[serde(skip_serializing_if = "Option::is_none")]
322    pub name: Option<String>,
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub arguments: Option<String>,
325}
326
327fn serialize_tool_choice<S>(
328    value: &Option<ToolChoiceType>,
329    serializer: S,
330) -> Result<S::Ok, S::Error>
331where
332    S: Serializer,
333{
334    match value {
335        Some(ToolChoiceType::None) => serializer.serialize_str("none"),
336        Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
337        Some(ToolChoiceType::Required) => serializer.serialize_str("required"),
338        Some(ToolChoiceType::ToolChoice { tool }) => {
339            let mut map = serializer.serialize_map(Some(2))?;
340            map.serialize_entry("type", &tool.r#type)?;
341            map.serialize_entry("function", &tool.function)?;
342            map.end()
343        }
344        None => serializer.serialize_none(),
345    }
346}
347
348#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
349pub struct Tool {
350    pub r#type: ToolType,
351    pub function: types::Function,
352}
353
354#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
355#[serde(rename_all = "snake_case")]
356pub enum ToolType {
357    Function,
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use serde_json::json;
364
365    #[test]
366    fn test_reasoning_effort_serialization() {
367        let reasoning = Reasoning {
368            mode: Some(ReasoningMode::Effort {
369                effort: ReasoningEffort::High,
370            }),
371            exclude: Some(false),
372            enabled: None,
373        };
374
375        let serialized = serde_json::to_value(&reasoning).unwrap();
376        let expected = json!({
377            "effort": "high",
378            "exclude": false
379        });
380
381        assert_eq!(serialized, expected);
382    }
383
384    #[test]
385    fn test_reasoning_max_tokens_serialization() {
386        let reasoning = Reasoning {
387            mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }),
388            exclude: None,
389            enabled: Some(true),
390        };
391
392        let serialized = serde_json::to_value(&reasoning).unwrap();
393        let expected = json!({
394            "max_tokens": 2000,
395            "enabled": true
396        });
397
398        assert_eq!(serialized, expected);
399    }
400
401    #[test]
402    fn test_reasoning_deserialization() {
403        let json_str = r#"{"effort": "medium", "exclude": true}"#;
404        let reasoning: Reasoning = serde_json::from_str(json_str).unwrap();
405
406        match reasoning.mode {
407            Some(ReasoningMode::Effort { effort }) => {
408                assert_eq!(effort, ReasoningEffort::Medium);
409            }
410            _ => panic!("Expected effort mode"),
411        }
412        assert_eq!(reasoning.exclude, Some(true));
413    }
414
415    #[test]
416    fn test_chat_completion_request_with_reasoning() {
417        let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
418
419        req.reasoning = Some(Reasoning {
420            mode: Some(ReasoningMode::Effort {
421                effort: ReasoningEffort::Low,
422            }),
423            exclude: None,
424            enabled: None,
425        });
426
427        let serialized = serde_json::to_value(&req).unwrap();
428        assert_eq!(serialized["reasoning"]["effort"], "low");
429    }
430
431    #[test]
432    fn test_transforms_none_serialization() {
433        let req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
434        let serialised = serde_json::to_value(&req).unwrap();
435        // Verify that the transforms field is completely omitted from JSON output
436        assert!(!serialised.as_object().unwrap().contains_key("transforms"));
437    }
438
439    #[test]
440    fn test_transforms_some_serialization() {
441        let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
442        req.transforms = Some(vec!["transform1".to_string(), "transform2".to_string()]);
443        let serialised = serde_json::to_value(&req).unwrap();
444        // Verify that the transforms field is included as a proper JSON array
445        assert_eq!(
446            serialised["transforms"],
447            serde_json::json!(["transform1", "transform2"])
448        );
449    }
450
451    #[test]
452    fn test_transforms_some_deserialization() {
453        let json_str =
454            r#"{"model": "gpt-4", "messages": [], "transforms": ["transform1", "transform2"]}"#;
455        let req: ChatCompletionRequest = serde_json::from_str(json_str).unwrap();
456        // Verify that the transforms field is properly populated with Some(vec)
457        assert_eq!(
458            req.transforms,
459            Some(vec!["transform1".to_string(), "transform2".to_string()])
460        );
461    }
462
463    #[test]
464    fn test_transforms_none_deserialization() {
465        let json_str = r#"{"model": "gpt-4", "messages": []}"#;
466        let req: ChatCompletionRequest = serde_json::from_str(json_str).unwrap();
467        // Verify that the transforms field is properly set to None when absent
468        assert_eq!(req.transforms, None);
469    }
470
471    #[test]
472    fn test_transforms_builder_method() {
473        let transforms = vec!["transform1".to_string(), "transform2".to_string()];
474        let req =
475            ChatCompletionRequest::new("gpt-4".to_string(), vec![]).transforms(transforms.clone());
476        // Verify that the transforms field is properly set through the builder method
477        assert_eq!(req.transforms, Some(transforms));
478    }
479}