1use super::{common, types};
2use crate::impl_builder_methods;
3
4use serde::de::{self, MapAccess, SeqAccess, Visitor};
5use serde::ser::SerializeMap;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7use serde_json::Value;
8use std::collections::HashMap;
9use std::fmt;
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
11pub enum ToolChoiceType {
12 None,
13 Auto,
14 Required,
15 ToolChoice { tool: Tool },
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct ChatCompletionRequest {
20 pub model: String,
21 pub messages: Vec<ChatCompletionMessage>,
22 #[serde(skip_serializing_if = "Option::is_none")]
23 pub temperature: Option<f64>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 pub top_p: Option<f64>,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub n: Option<i64>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub response_format: Option<Value>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub stream: Option<bool>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 pub stop: Option<Vec<String>>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 pub max_tokens: Option<i64>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 pub presence_penalty: Option<f64>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub frequency_penalty: Option<f64>,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub logit_bias: Option<HashMap<String, i32>>,
42 #[serde(skip_serializing_if = "Option::is_none")]
43 pub user: Option<String>,
44 #[serde(skip_serializing_if = "Option::is_none")]
45 pub seed: Option<i64>,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub tools: Option<Vec<Tool>>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 pub parallel_tool_calls: Option<bool>,
50 #[serde(skip_serializing_if = "Option::is_none")]
51 #[serde(serialize_with = "serialize_tool_choice")]
52 pub tool_choice: Option<ToolChoiceType>,
53}
54
55impl ChatCompletionRequest {
56 pub fn new(model: String, messages: Vec<ChatCompletionMessage>) -> Self {
57 Self {
58 model,
59 messages,
60 temperature: None,
61 top_p: None,
62 stream: None,
63 n: None,
64 response_format: None,
65 stop: None,
66 max_tokens: None,
67 presence_penalty: None,
68 frequency_penalty: None,
69 logit_bias: None,
70 user: None,
71 seed: None,
72 tools: None,
73 parallel_tool_calls: None,
74 tool_choice: None,
75 }
76 }
77}
78
79impl_builder_methods!(
80 ChatCompletionRequest,
81 temperature: f64,
82 top_p: f64,
83 n: i64,
84 response_format: Value,
85 stream: bool,
86 stop: Vec<String>,
87 max_tokens: i64,
88 presence_penalty: f64,
89 frequency_penalty: f64,
90 logit_bias: HashMap<String, i32>,
91 user: String,
92 seed: i64,
93 tools: Vec<Tool>,
94 parallel_tool_calls: bool,
95 tool_choice: ToolChoiceType
96);
97
98#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
99#[allow(non_camel_case_types)]
100pub enum MessageRole {
101 user,
102 system,
103 assistant,
104 function,
105 tool,
106}
107
108#[derive(Debug, Clone, PartialEq, Eq)]
109pub enum Content {
110 Text(String),
111 ImageUrl(Vec<ImageUrl>),
112}
113
114impl serde::Serialize for Content {
115 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
116 where
117 S: serde::Serializer,
118 {
119 match *self {
120 Content::Text(ref text) => {
121 if text.is_empty() {
122 serializer.serialize_none()
123 } else {
124 serializer.serialize_str(text)
125 }
126 }
127 Content::ImageUrl(ref image_url) => image_url.serialize(serializer),
128 }
129 }
130}
131
132impl<'de> Deserialize<'de> for Content {
133 fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
134 where
135 D: Deserializer<'de>,
136 {
137 struct ContentVisitor;
138
139 impl<'de> Visitor<'de> for ContentVisitor {
140 type Value = Content;
141
142 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
143 formatter.write_str("a valid content type")
144 }
145
146 fn visit_str<E>(self, value: &str) -> Result<Content, E>
147 where
148 E: de::Error,
149 {
150 Ok(Content::Text(value.to_string()))
151 }
152
153 fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
154 where
155 A: SeqAccess<'de>,
156 {
157 let image_urls: Vec<ImageUrl> =
158 Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
159 Ok(Content::ImageUrl(image_urls))
160 }
161
162 fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
163 where
164 M: MapAccess<'de>,
165 {
166 let image_urls: Vec<ImageUrl> =
167 Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
168 Ok(Content::ImageUrl(image_urls))
169 }
170
171 fn visit_none<E>(self) -> Result<Self::Value, E>
172 where
173 E: de::Error,
174 {
175 Ok(Content::Text(String::new()))
176 }
177
178 fn visit_unit<E>(self) -> Result<Self::Value, E>
179 where
180 E: de::Error,
181 {
182 Ok(Content::Text(String::new()))
183 }
184 }
185
186 deserializer.deserialize_any(ContentVisitor)
187 }
188}
189
190#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
191#[allow(non_camel_case_types)]
192pub enum ContentType {
193 text,
194 image_url,
195}
196
197#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
198#[allow(non_camel_case_types)]
199pub struct ImageUrlType {
200 pub url: String,
201}
202
203#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
204#[allow(non_camel_case_types)]
205pub struct ImageUrl {
206 pub r#type: ContentType,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub text: Option<String>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub image_url: Option<ImageUrlType>,
211}
212
213#[derive(Debug, Deserialize, Serialize, Clone)]
214pub struct ChatCompletionMessage {
215 pub role: MessageRole,
216 pub content: Content,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 pub name: Option<String>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub tool_calls: Option<Vec<ToolCall>>,
221 #[serde(skip_serializing_if = "Option::is_none")]
222 pub tool_call_id: Option<String>,
223}
224
225#[derive(Debug, Deserialize, Serialize, Clone)]
226pub struct ChatCompletionMessageForResponse {
227 pub role: MessageRole,
228 #[serde(skip_serializing_if = "Option::is_none")]
229 pub content: Option<String>,
230 #[serde(skip_serializing_if = "Option::is_none")]
231 pub reasoning_content: Option<String>,
232 #[serde(skip_serializing_if = "Option::is_none")]
233 pub name: Option<String>,
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub tool_calls: Option<Vec<ToolCall>>,
236}
237
238#[derive(Debug, Deserialize, Serialize)]
239pub struct ChatCompletionChoice {
240 pub index: i64,
241 pub message: ChatCompletionMessageForResponse,
242 pub finish_reason: Option<FinishReason>,
243 pub finish_details: Option<FinishDetails>,
244}
245
246#[derive(Debug, Deserialize, Serialize)]
247pub struct ChatCompletionResponse {
248 pub id: Option<String>,
249 pub object: String,
250 pub created: i64,
251 pub model: String,
252 pub choices: Vec<ChatCompletionChoice>,
253 pub usage: common::Usage,
254 pub system_fingerprint: Option<String>,
255}
256
257#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
258#[allow(non_camel_case_types)]
259pub enum FinishReason {
260 stop,
261 length,
262 content_filter,
263 tool_calls,
264 null,
265}
266
267#[derive(Debug, Deserialize, Serialize)]
268#[allow(non_camel_case_types)]
269pub struct FinishDetails {
270 pub r#type: FinishReason,
271 pub stop: String,
272}
273
274#[derive(Debug, Deserialize, Serialize, Clone)]
275pub struct ToolCall {
276 pub id: String,
277 pub r#type: String,
278 pub function: ToolCallFunction,
279}
280
281#[derive(Debug, Deserialize, Serialize, Clone)]
282pub struct ToolCallFunction {
283 #[serde(skip_serializing_if = "Option::is_none")]
284 pub name: Option<String>,
285 #[serde(skip_serializing_if = "Option::is_none")]
286 pub arguments: Option<String>,
287}
288
289fn serialize_tool_choice<S>(
290 value: &Option<ToolChoiceType>,
291 serializer: S,
292) -> Result<S::Ok, S::Error>
293where
294 S: Serializer,
295{
296 match value {
297 Some(ToolChoiceType::None) => serializer.serialize_str("none"),
298 Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
299 Some(ToolChoiceType::Required) => serializer.serialize_str("required"),
300 Some(ToolChoiceType::ToolChoice { tool }) => {
301 let mut map = serializer.serialize_map(Some(2))?;
302 map.serialize_entry("type", &tool.r#type)?;
303 map.serialize_entry("function", &tool.function)?;
304 map.end()
305 }
306 None => serializer.serialize_none(),
307 }
308}
309
310#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
311pub struct Tool {
312 pub r#type: ToolType,
313 pub function: types::Function,
314}
315
316#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
317#[serde(rename_all = "snake_case")]
318pub enum ToolType {
319 Function,
320}