openai_api_rs/v1/
chat_completion.rs

1use super::{common, types};
2use crate::impl_builder_methods;
3
4use serde::de::{self, MapAccess, SeqAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::fmt;
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
11pub enum ToolChoiceType {
12    None,
13    Auto,
14    Required,
15    ToolChoice { tool: Tool },
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
19#[serde(rename_all = "lowercase")]
20pub enum ReasoningEffort {
21    Low,
22    Medium,
23    High,
24}
25
26#[derive(Debug, Serialize, Deserialize, Clone)]
27#[serde(untagged)]
28pub enum ReasoningMode {
29    Effort { effort: ReasoningEffort },
30    MaxTokens { max_tokens: i64 },
31}
32
33#[derive(Debug, Serialize, Deserialize, Clone)]
34pub struct Reasoning {
35    #[serde(flatten)]
36    pub mode: Option<ReasoningMode>,
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub exclude: Option<bool>,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub enabled: Option<bool>,
41}
42
43#[derive(Debug, Serialize, Deserialize, Clone)]
44pub struct ChatCompletionRequest {
45    pub model: String,
46    pub messages: Vec<ChatCompletionMessage>,
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub temperature: Option<f64>,
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub top_p: Option<f64>,
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub n: Option<i64>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub response_format: Option<Value>,
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub stream: Option<bool>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub stop: Option<Vec<String>>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub max_tokens: Option<i64>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub presence_penalty: Option<f64>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub frequency_penalty: Option<f64>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub logit_bias: Option<HashMap<String, i32>>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub user: Option<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub seed: Option<i64>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub tools: Option<Vec<Tool>>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub parallel_tool_calls: Option<bool>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    #[serde(serialize_with = "serialize_tool_choice")]
77    pub tool_choice: Option<ToolChoiceType>,
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub reasoning: Option<Reasoning>,
80}
81
82impl ChatCompletionRequest {
83    pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
84        Self {
85            model,
86            messages,
87            temperature: None,
88            top_p: None,
89            stream: None,
90            n: None,
91            response_format: None,
92            stop: None,
93            max_tokens: None,
94            presence_penalty: None,
95            frequency_penalty: None,
96            logit_bias: None,
97            user: None,
98            seed: None,
99            tools: None,
100            parallel_tool_calls: None,
101            tool_choice: None,
102            reasoning: None,
103        }
104    }
105}
106
107impl_builder_methods!(
108    ChatCompletionRequest,
109    temperature: f64,
110    top_p: f64,
111    n: i64,
112    response_format: Value,
113    stream: bool,
114    stop: Vec<String>,
115    max_tokens: i64,
116    presence_penalty: f64,
117    frequency_penalty: f64,
118    logit_bias: HashMap<String, i32>,
119    user: String,
120    seed: i64,
121    tools: Vec<Tool>,
122    parallel_tool_calls: bool,
123    tool_choice: ToolChoiceType,
124    reasoning: Reasoning
125);
126
127#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
128#[allow(non_camel_case_types)]
129pub enum MessageRole {
130    user,
131    system,
132    assistant,
133    function,
134    tool,
135}
136
137#[derive(Debug, Clone, PartialEq, Eq)]
138pub enum Content {
139    Text(String),
140    ImageUrl(Vec<ImageUrl>),
141}
142
143impl serde::Serialize for Content {
144    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
145    where
146        S: serde::Serializer,
147    {
148        match *self {
149            Content::Text(ref text) => {
150                if text.is_empty() {
151                    serializer.serialize_none()
152                } else {
153                    serializer.serialize_str(text)
154                }
155            }
156            Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
157        }
158    }
159}
160
161impl<'de> Deserialize<'de> for Content {
162    fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
163    where
164        D: Deserializer<'de>,
165    {
166        struct ContentVisitor;
167
168        impl<'de> Visitor<'de> for ContentVisitor {
169            type Value = Content;
170
171            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
172                formatter.write_str("a valid content type")
173            }
174
175            fn visit_str<E>(self, value: &str) -> Result<Content, E>
176            where
177                E: de::Error,
178            {
179                Ok(Content::Text(value.to_string()))
180            }
181
182            fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
183            where
184                A: SeqAccess<'de>,
185            {
186                let image_urls: Vec<ImageUrl> =
187                    Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
188                Ok(Content::ImageUrl(image_urls))
189            }
190
191            fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
192            where
193                M: MapAccess<'de>,
194            {
195                let image_urls: Vec<ImageUrl> =
196                    Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
197                Ok(Content::ImageUrl(image_urls))
198            }
199
200            fn visit_none<E>(self) -> Result<Self::Value, E>
201            where
202                E: de::Error,
203            {
204                Ok(Content::Text(String::new()))
205            }
206
207            fn visit_unit<E>(self) -> Result<Self::Value, E>
208            where
209                E: de::Error,
210            {
211                Ok(Content::Text(String::new()))
212            }
213        }
214
215        deserializer.deserialize_any(ContentVisitor)
216    }
217}
218
219#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
220#[allow(non_camel_case_types)]
221pub enum ContentType {
222    text,
223    image_url,
224}
225
226#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
227#[allow(non_camel_case_types)]
228pub struct ImageUrlType {
229    pub url: String,
230}
231
232#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
233#[allow(non_camel_case_types)]
234pub struct ImageUrl {
235    pub r#type: ContentType,
236    #[serde(skip_serializing_if = "Option::is_none")]
237    pub text: Option<String>,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub image_url: Option<ImageUrlType>,
240}
241
242#[derive(Debug, Deserialize, Serialize, Clone)]
243pub struct ChatCompletionMessage {
244    pub role: MessageRole,
245    pub content: Content,
246    #[serde(skip_serializing_if = "Option::is_none")]
247    pub name: Option<String>,
248    #[serde(skip_serializing_if = "Option::is_none")]
249    pub tool_calls: Option<Vec<ToolCall>>,
250    #[serde(skip_serializing_if = "Option::is_none")]
251    pub tool_call_id: Option<String>,
252}
253
254#[derive(Debug, Deserialize, Serialize, Clone)]
255pub struct ChatCompletionMessageForResponse {
256    pub role: MessageRole,
257    #[serde(skip_serializing_if = "Option::is_none")]
258    pub content: Option<String>,
259    #[serde(skip_serializing_if = "Option::is_none")]
260    pub reasoning_content: Option<String>,
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub name: Option<String>,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub tool_calls: Option<Vec<ToolCall>>,
265}
266
267#[derive(Debug, Deserialize, Serialize)]
268pub struct ChatCompletionChoice {
269    pub index: i64,
270    pub message: ChatCompletionMessageForResponse,
271    pub finish_reason: Option<FinishReason>,
272    pub finish_details: Option<FinishDetails>,
273}
274
275#[derive(Debug, Deserialize, Serialize)]
276pub struct ChatCompletionResponse {
277    pub id: Option<String>,
278    pub object: String,
279    pub created: i64,
280    pub model: String,
281    pub choices: Vec<ChatCompletionChoice>,
282    pub usage: common::Usage,
283    pub system_fingerprint: Option<String>,
284}
285
286#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
287#[allow(non_camel_case_types)]
288pub enum FinishReason {
289    stop,
290    length,
291    content_filter,
292    tool_calls,
293    null,
294}
295
296#[derive(Debug, Deserialize, Serialize)]
297#[allow(non_camel_case_types)]
298pub struct FinishDetails {
299    pub r#type: FinishReason,
300    pub stop: String,
301}
302
303#[derive(Debug, Deserialize, Serialize, Clone)]
304pub struct ToolCall {
305    pub id: String,
306    pub r#type: String,
307    pub function: ToolCallFunction,
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone)]
311pub struct ToolCallFunction {
312    #[serde(skip_serializing_if = "Option::is_none")]
313    pub name: Option<String>,
314    #[serde(skip_serializing_if = "Option::is_none")]
315    pub arguments: Option<String>,
316}
317
318fn serialize_tool_choice<S>(
319    value: &Option<ToolChoiceType>,
320    serializer: S,
321) -> Result<S::Ok, S::Error>
322where
323    S: Serializer,
324{
325    match value {
326        Some(ToolChoiceType::None) => serializer.serialize_str("none"),
327        Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
328        Some(ToolChoiceType::Required) => serializer.serialize_str("required"),
329        Some(ToolChoiceType::ToolChoice { tool }) => {
330            let mut map = serializer.serialize_map(Some(2))?;
331            map.serialize_entry("type", &tool.r#type)?;
332            map.serialize_entry("function", &tool.function)?;
333            map.end()
334        }
335        None => serializer.serialize_none(),
336    }
337}
338
339#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
340pub struct Tool {
341    pub r#type: ToolType,
342    pub function: types::Function,
343}
344
345#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
346#[serde(rename_all = "snake_case")]
347pub enum ToolType {
348    Function,
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use serde_json::json;
355
356    #[test]
357    fn test_reasoning_effort_serialization() {
358        let reasoning = Reasoning {
359            mode: Some(ReasoningMode::Effort {
360                effort: ReasoningEffort::High,
361            }),
362            exclude: Some(false),
363            enabled: None,
364        };
365
366        let serialized = serde_json::to_value(&reasoning).unwrap();
367        let expected = json!({
368            "effort": "high",
369            "exclude": false
370        });
371
372        assert_eq!(serialized, expected);
373    }
374
375    #[test]
376    fn test_reasoning_max_tokens_serialization() {
377        let reasoning = Reasoning {
378            mode: Some(ReasoningMode::MaxTokens { max_tokens: 2000 }),
379            exclude: None,
380            enabled: Some(true),
381        };
382
383        let serialized = serde_json::to_value(&reasoning).unwrap();
384        let expected = json!({
385            "max_tokens": 2000,
386            "enabled": true
387        });
388
389        assert_eq!(serialized, expected);
390    }
391
392    #[test]
393    fn test_reasoning_deserialization() {
394        let json_str = r#"{"effort": "medium", "exclude": true}"#;
395        let reasoning: Reasoning = serde_json::from_str(json_str).unwrap();
396
397        match reasoning.mode {
398            Some(ReasoningMode::Effort { effort }) => {
399                assert_eq!(effort, ReasoningEffort::Medium);
400            }
401            _ => panic!("Expected effort mode"),
402        }
403        assert_eq!(reasoning.exclude, Some(true));
404    }
405
406    #[test]
407    fn test_chat_completion_request_with_reasoning() {
408        let mut req = ChatCompletionRequest::new("gpt-4".to_string(), vec![]);
409
410        req.reasoning = Some(Reasoning {
411            mode: Some(ReasoningMode::Effort {
412                effort: ReasoningEffort::Low,
413            }),
414            exclude: None,
415            enabled: None,
416        });
417
418        let serialized = serde_json::to_value(&req).unwrap();
419        assert_eq!(serialized["reasoning"]["effort"], "low");
420    }
421}