openai_api_rs/v1/
chat_completion.rs

1use super::{common, types};
2use crate::impl_builder_methods;
3
4use serde::de::{self, MapAccess, SeqAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::fmt;
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
11pub enum ToolChoiceType {
12    None,
13    Auto,
14    Required,
15    ToolChoice { tool: Tool },
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct ChatCompletionRequest {
20    pub model: String,
21    pub messages: Vec<ChatCompletionMessage>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub temperature: Option<f64>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub top_p: Option<f64>,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub n: Option<i64>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub response_format: Option<Value>,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub stream: Option<bool>,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub stop: Option<Vec<String>>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub max_tokens: Option<i64>,
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub presence_penalty: Option<f64>,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub frequency_penalty: Option<f64>,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub logit_bias: Option<HashMap<String, i32>>,
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub user: Option<String>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub seed: Option<i64>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub tools: Option<Vec<Tool>>,
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub parallel_tool_calls: Option<bool>,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    #[serde(serialize_with = "serialize_tool_choice")]
52    pub tool_choice: Option<ToolChoiceType>,
53}
54
55impl ChatCompletionRequest {
56    pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
57        Self {
58            model,
59            messages,
60            temperature: None,
61            top_p: None,
62            stream: None,
63            n: None,
64            response_format: None,
65            stop: None,
66            max_tokens: None,
67            presence_penalty: None,
68            frequency_penalty: None,
69            logit_bias: None,
70            user: None,
71            seed: None,
72            tools: None,
73            parallel_tool_calls: None,
74            tool_choice: None,
75        }
76    }
77}
78
79impl_builder_methods!(
80    ChatCompletionRequest,
81    temperature: f64,
82    top_p: f64,
83    n: i64,
84    response_format: Value,
85    stream: bool,
86    stop: Vec<String>,
87    max_tokens: i64,
88    presence_penalty: f64,
89    frequency_penalty: f64,
90    logit_bias: HashMap<String, i32>,
91    user: String,
92    seed: i64,
93    tools: Vec<Tool>,
94    parallel_tool_calls: bool,
95    tool_choice: ToolChoiceType
96);
97
98#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99#[allow(non_camel_case_types)]
100pub enum MessageRole {
101    user,
102    system,
103    assistant,
104    function,
105    tool,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub enum Content {
110    Text(String),
111    ImageUrl(Vec<ImageUrl>),
112}
113
114impl serde::Serialize for Content {
115    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
116    where
117        S: serde::Serializer,
118    {
119        match *self {
120            Content::Text(ref text) => {
121                if text.is_empty() {
122                    serializer.serialize_none()
123                } else {
124                    serializer.serialize_str(text)
125                }
126            }
127            Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
128        }
129    }
130}
131
132impl<'de> Deserialize<'de> for Content {
133    fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
134    where
135        D: Deserializer<'de>,
136    {
137        struct ContentVisitor;
138
139        impl<'de> Visitor<'de> for ContentVisitor {
140            type Value = Content;
141
142            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
143                formatter.write_str("a valid content type")
144            }
145
146            fn visit_str<E>(self, value: &str) -> Result<Content, E>
147            where
148                E: de::Error,
149            {
150                Ok(Content::Text(value.to_string()))
151            }
152
153            fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
154            where
155                A: SeqAccess<'de>,
156            {
157                let image_urls: Vec<ImageUrl> =
158                    Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
159                Ok(Content::ImageUrl(image_urls))
160            }
161
162            fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
163            where
164                M: MapAccess<'de>,
165            {
166                let image_urls: Vec<ImageUrl> =
167                    Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
168                Ok(Content::ImageUrl(image_urls))
169            }
170
171            fn visit_none<E>(self) -> Result<Self::Value, E>
172            where
173                E: de::Error,
174            {
175                Ok(Content::Text(String::new()))
176            }
177
178            fn visit_unit<E>(self) -> Result<Self::Value, E>
179            where
180                E: de::Error,
181            {
182                Ok(Content::Text(String::new()))
183            }
184        }
185
186        deserializer.deserialize_any(ContentVisitor)
187    }
188}
189
190#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
191#[allow(non_camel_case_types)]
192pub enum ContentType {
193    text,
194    image_url,
195}
196
197#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
198#[allow(non_camel_case_types)]
199pub struct ImageUrlType {
200    pub url: String,
201}
202
203#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
204#[allow(non_camel_case_types)]
205pub struct ImageUrl {
206    pub r#type: ContentType,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub text: Option<String>,
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub image_url: Option<ImageUrlType>,
211}
212
213#[derive(Debug, Deserialize, Serialize, Clone)]
214pub struct ChatCompletionMessage {
215    pub role: MessageRole,
216    pub content: Content,
217    #[serde(skip_serializing_if = "Option::is_none")]
218    pub name: Option<String>,
219    #[serde(skip_serializing_if = "Option::is_none")]
220    pub tool_calls: Option<Vec<ToolCall>>,
221    #[serde(skip_serializing_if = "Option::is_none")]
222    pub tool_call_id: Option<String>,
223}
224
225#[derive(Debug, Deserialize, Serialize, Clone)]
226pub struct ChatCompletionMessageForResponse {
227    pub role: MessageRole,
228    #[serde(skip_serializing_if = "Option::is_none")]
229    pub content: Option<String>,
230    #[serde(skip_serializing_if = "Option::is_none")]
231    pub reasoning_content: Option<String>,
232    #[serde(skip_serializing_if = "Option::is_none")]
233    pub name: Option<String>,
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub tool_calls: Option<Vec<ToolCall>>,
236}
237
238#[derive(Debug, Deserialize, Serialize)]
239pub struct ChatCompletionChoice {
240    pub index: i64,
241    pub message: ChatCompletionMessageForResponse,
242    pub finish_reason: Option<FinishReason>,
243    pub finish_details: Option<FinishDetails>,
244}
245
246#[derive(Debug, Deserialize, Serialize)]
247pub struct ChatCompletionResponse {
248    pub id: Option<String>,
249    pub object: String,
250    pub created: i64,
251    pub model: String,
252    pub choices: Vec<ChatCompletionChoice>,
253    pub usage: common::Usage,
254    pub system_fingerprint: Option<String>,
255}
256
257#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
258#[allow(non_camel_case_types)]
259pub enum FinishReason {
260    stop,
261    length,
262    content_filter,
263    tool_calls,
264    null,
265}
266
267#[derive(Debug, Deserialize, Serialize)]
268#[allow(non_camel_case_types)]
269pub struct FinishDetails {
270    pub r#type: FinishReason,
271    pub stop: String,
272}
273
274#[derive(Debug, Deserialize, Serialize, Clone)]
275pub struct ToolCall {
276    pub id: String,
277    pub r#type: String,
278    pub function: ToolCallFunction,
279}
280
281#[derive(Debug, Deserialize, Serialize, Clone)]
282pub struct ToolCallFunction {
283    #[serde(skip_serializing_if = "Option::is_none")]
284    pub name: Option<String>,
285    #[serde(skip_serializing_if = "Option::is_none")]
286    pub arguments: Option<String>,
287}
288
289fn serialize_tool_choice<S>(
290    value: &Option<ToolChoiceType>,
291    serializer: S,
292) -> Result<S::Ok, S::Error>
293where
294    S: Serializer,
295{
296    match value {
297        Some(ToolChoiceType::None) => serializer.serialize_str("none"),
298        Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
299        Some(ToolChoiceType::Required) => serializer.serialize_str("required"),
300        Some(ToolChoiceType::ToolChoice { tool }) => {
301            let mut map = serializer.serialize_map(Some(2))?;
302            map.serialize_entry("type", &tool.r#type)?;
303            map.serialize_entry("function", &tool.function)?;
304            map.end()
305        }
306        None => serializer.serialize_none(),
307    }
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
311pub struct Tool {
312    pub r#type: ToolType,
313    pub function: types::Function,
314}
315
316#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
317#[serde(rename_all = "snake_case")]
318pub enum ToolType {
319    Function,
320}