1use serde::{Deserialize, Serialize};
5
6use super::common::{FinishReason, ResponseFormat, Role, Tool, ToolChoice, Usage};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(untagged)]
11pub enum MessageContent {
12 Text(String),
13 Parts(Vec<ContentPart>),
14}
15
16impl From<&str> for MessageContent {
17 fn from(text: &str) -> Self {
18 Self::Text(text.to_string())
19 }
20}
21
22impl From<String> for MessageContent {
23 fn from(text: String) -> Self {
24 Self::Text(text)
25 }
26}
27
28impl From<Vec<ContentPart>> for MessageContent {
29 fn from(parts: Vec<ContentPart>) -> Self {
30 Self::Parts(parts)
31 }
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum ContentPart {
39 Text { text: String },
40 ImageUrl { image_url: ImageUrl },
41 InputAudio { input_audio: InputAudio },
42}
43
44impl ContentPart {
45 pub fn text(text: impl Into<String>) -> Self {
46 Self::Text { text: text.into() }
47 }
48
49 pub fn image_url(url: impl Into<String>) -> Self {
51 Self::ImageUrl {
52 image_url: ImageUrl {
53 url: url.into(),
54 detail: None,
55 },
56 }
57 }
58
59 pub fn image_url_with_detail(url: impl Into<String>, detail: impl Into<String>) -> Self {
61 Self::ImageUrl {
62 image_url: ImageUrl {
63 url: url.into(),
64 detail: Some(detail.into()),
65 },
66 }
67 }
68
69 pub fn input_audio(data: impl Into<String>, format: impl Into<String>) -> Self {
71 Self::InputAudio {
72 input_audio: InputAudio {
73 data: data.into(),
74 format: format.into(),
75 },
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ImageUrl {
83 pub url: String,
84 #[serde(default, skip_serializing_if = "Option::is_none")]
85 pub detail: Option<String>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct InputAudio {
91 pub data: String,
93 pub format: String,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct Message {
99 pub role: Role,
100 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub content: Option<MessageContent>,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub name: Option<String>,
104 #[serde(default, skip_serializing_if = "Option::is_none")]
106 pub tool_calls: Option<Vec<ToolCall>>,
107 #[serde(default, skip_serializing_if = "Option::is_none")]
109 pub tool_call_id: Option<String>,
110}
111
112impl Message {
113 fn new(role: Role, content: impl Into<MessageContent>) -> Self {
114 Self {
115 role,
116 content: Some(content.into()),
117 name: None,
118 tool_calls: None,
119 tool_call_id: None,
120 }
121 }
122
123 pub fn system(content: impl Into<MessageContent>) -> Self {
124 Self::new(Role::System, content)
125 }
126
127 pub fn developer(content: impl Into<MessageContent>) -> Self {
128 Self::new(Role::Developer, content)
129 }
130
131 pub fn user(content: impl Into<MessageContent>) -> Self {
134 Self::new(Role::User, content)
135 }
136
137 pub fn assistant(content: impl Into<MessageContent>) -> Self {
138 Self::new(Role::Assistant, content)
139 }
140
141 pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
143 Self {
144 role: Role::Assistant,
145 content: None,
146 name: None,
147 tool_calls: Some(tool_calls),
148 tool_call_id: None,
149 }
150 }
151
152 pub fn tool(content: impl Into<MessageContent>, tool_call_id: impl Into<String>) -> Self {
154 Self {
155 role: Role::Tool,
156 content: Some(content.into()),
157 name: None,
158 tool_calls: None,
159 tool_call_id: Some(tool_call_id.into()),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize)]
166#[serde(untagged)]
167pub enum Stop {
168 One(String),
169 Many(Vec<String>),
170}
171
172#[derive(Debug, Clone, Default, Serialize, Deserialize)]
174pub struct StreamOptions {
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub include_usage: Option<bool>,
177}
178
179#[derive(Debug, Clone, Default, Serialize)]
182pub struct ChatCompletionRequest {
183 pub model: String,
184 pub messages: Vec<Message>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub frequency_penalty: Option<f64>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub logprobs: Option<bool>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub max_completion_tokens: Option<u64>,
191 #[serde(skip_serializing_if = "Option::is_none")]
194 pub max_tokens: Option<u64>,
195 #[serde(skip_serializing_if = "Option::is_none")]
196 pub n: Option<u32>,
197 #[serde(skip_serializing_if = "Option::is_none")]
198 pub presence_penalty: Option<f64>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 pub response_format: Option<ResponseFormat>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub seed: Option<i64>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub stop: Option<Stop>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub stream: Option<bool>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub stream_options: Option<StreamOptions>,
209 #[serde(skip_serializing_if = "Option::is_none")]
210 pub temperature: Option<f64>,
211 #[serde(skip_serializing_if = "Option::is_none")]
212 pub tool_choice: Option<ToolChoice>,
213 #[serde(skip_serializing_if = "Option::is_none")]
214 pub tools: Option<Vec<Tool>>,
215 #[serde(skip_serializing_if = "Option::is_none")]
216 pub top_logprobs: Option<u32>,
217 #[serde(skip_serializing_if = "Option::is_none")]
218 pub top_p: Option<f64>,
219 #[serde(skip_serializing_if = "Option::is_none")]
220 pub user: Option<String>,
221 #[serde(skip_serializing_if = "Option::is_none")]
224 pub reasoning_effort: Option<String>,
225 #[serde(skip_serializing_if = "Option::is_none")]
226 pub parallel_tool_calls: Option<bool>,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 pub logit_bias: Option<std::collections::HashMap<String, i32>>,
229 #[serde(skip_serializing_if = "Option::is_none")]
230 pub metadata: Option<std::collections::HashMap<String, String>>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 pub store: Option<bool>,
233 #[serde(skip_serializing_if = "Option::is_none")]
234 pub service_tier: Option<String>,
235 #[serde(skip_serializing_if = "Option::is_none")]
237 pub modalities: Option<Vec<String>>,
238 #[serde(skip_serializing_if = "Option::is_none")]
240 pub audio: Option<serde_json::Value>,
241 #[serde(flatten, skip_serializing_if = "std::collections::HashMap::is_empty")]
244 pub extra: std::collections::HashMap<String, serde_json::Value>,
245}
246
247impl ChatCompletionRequest {
248 pub fn new(model: impl Into<String>, messages: Vec<Message>) -> Self {
249 Self {
250 model: model.into(),
251 messages,
252 ..Self::default()
253 }
254 }
255
256 pub fn temperature(mut self, temperature: f64) -> Self {
257 self.temperature = Some(temperature);
258 self
259 }
260
261 pub fn max_completion_tokens(mut self, max: u64) -> Self {
262 self.max_completion_tokens = Some(max);
263 self
264 }
265
266 pub fn max_tokens(mut self, max: u64) -> Self {
267 self.max_tokens = Some(max);
268 self
269 }
270
271 pub fn top_p(mut self, top_p: f64) -> Self {
272 self.top_p = Some(top_p);
273 self
274 }
275
276 pub fn n(mut self, n: u32) -> Self {
277 self.n = Some(n);
278 self
279 }
280
281 pub fn seed(mut self, seed: i64) -> Self {
282 self.seed = Some(seed);
283 self
284 }
285
286 pub fn frequency_penalty(mut self, penalty: f64) -> Self {
287 self.frequency_penalty = Some(penalty);
288 self
289 }
290
291 pub fn presence_penalty(mut self, penalty: f64) -> Self {
292 self.presence_penalty = Some(penalty);
293 self
294 }
295
296 pub fn logprobs(mut self, logprobs: bool) -> Self {
297 self.logprobs = Some(logprobs);
298 self
299 }
300
301 pub fn top_logprobs(mut self, top_logprobs: u32) -> Self {
302 self.top_logprobs = Some(top_logprobs);
303 self
304 }
305
306 pub fn stop(mut self, stop: Stop) -> Self {
307 self.stop = Some(stop);
308 self
309 }
310
311 pub fn response_format(mut self, response_format: ResponseFormat) -> Self {
312 self.response_format = Some(response_format);
313 self
314 }
315
316 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
317 self.tools = Some(tools);
318 self
319 }
320
321 pub fn tool_choice(mut self, tool_choice: ToolChoice) -> Self {
322 self.tool_choice = Some(tool_choice);
323 self
324 }
325
326 pub fn stream_options(mut self, stream_options: StreamOptions) -> Self {
327 self.stream_options = Some(stream_options);
328 self
329 }
330
331 pub fn user(mut self, user: impl Into<String>) -> Self {
332 self.user = Some(user.into());
333 self
334 }
335
336 pub fn reasoning_effort(mut self, effort: impl Into<String>) -> Self {
339 self.reasoning_effort = Some(effort.into());
340 self
341 }
342
343 pub fn parallel_tool_calls(mut self, parallel: bool) -> Self {
344 self.parallel_tool_calls = Some(parallel);
345 self
346 }
347
348 pub fn metadata(mut self, metadata: std::collections::HashMap<String, String>) -> Self {
349 self.metadata = Some(metadata);
350 self
351 }
352
353 pub fn store(mut self, store: bool) -> Self {
354 self.store = Some(store);
355 self
356 }
357
358 pub fn param(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
361 self.extra.insert(key.into(), value.into());
362 self
363 }
364}
365
366#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct ToolCall {
369 pub id: String,
370 #[serde(rename = "type")]
371 pub call_type: String,
372 pub function: FunctionCall,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
377pub struct FunctionCall {
378 pub name: String,
379 pub arguments: String,
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize)]
386#[non_exhaustive]
387pub struct TokenLogprob {
388 pub token: String,
389 pub logprob: f64,
390 #[serde(default)]
392 pub bytes: Option<Vec<u8>>,
393 #[serde(default)]
394 pub top_logprobs: Vec<TopLogprob>,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
399#[non_exhaustive]
400pub struct TopLogprob {
401 pub token: String,
402 pub logprob: f64,
403 #[serde(default)]
404 pub bytes: Option<Vec<u8>>,
405}
406
407#[derive(Debug, Clone, Serialize, Deserialize)]
410#[non_exhaustive]
411pub struct ChatLogprobs {
412 #[serde(default)]
413 pub content: Option<Vec<TokenLogprob>>,
414 #[serde(default)]
415 pub refusal: Option<Vec<TokenLogprob>>,
416}
417
418#[derive(Debug, Clone, Serialize, Deserialize)]
420#[non_exhaustive]
421pub struct ChatCompletionMessage {
422 pub role: Role,
423 #[serde(default)]
424 pub content: Option<String>,
425 #[serde(default, skip_serializing_if = "Option::is_none")]
426 pub refusal: Option<String>,
427 #[serde(default, skip_serializing_if = "Option::is_none")]
428 pub tool_calls: Option<Vec<ToolCall>>,
429}
430
431#[derive(Debug, Clone, Serialize, Deserialize)]
433#[non_exhaustive]
434pub struct Choice {
435 pub index: u32,
436 pub message: ChatCompletionMessage,
437 #[serde(default)]
438 pub finish_reason: Option<FinishReason>,
439 #[serde(default, skip_serializing_if = "Option::is_none")]
440 pub logprobs: Option<ChatLogprobs>,
441}
442
443#[derive(Debug, Clone, Serialize, Deserialize)]
445#[non_exhaustive]
446pub struct ChatCompletion {
447 pub id: String,
448 pub choices: Vec<Choice>,
449 pub created: i64,
450 pub model: String,
451 #[serde(default)]
452 pub object: String,
453 #[serde(default, skip_serializing_if = "Option::is_none")]
454 pub service_tier: Option<String>,
455 #[serde(default, skip_serializing_if = "Option::is_none")]
456 pub system_fingerprint: Option<String>,
457 #[serde(default)]
458 pub usage: Option<Usage>,
459}
460
461impl ChatCompletion {
462 pub fn content(&self) -> Option<&str> {
464 self.choices.first()?.message.content.as_deref()
465 }
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize)]
470#[non_exhaustive]
471pub struct ChoiceDeltaToolCall {
472 pub index: u32,
473 #[serde(default, skip_serializing_if = "Option::is_none")]
474 pub id: Option<String>,
475 #[serde(default, rename = "type", skip_serializing_if = "Option::is_none")]
476 pub call_type: Option<String>,
477 #[serde(default, skip_serializing_if = "Option::is_none")]
478 pub function: Option<ChoiceDeltaFunction>,
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
483#[non_exhaustive]
484pub struct ChoiceDeltaFunction {
485 #[serde(default)]
486 pub name: Option<String>,
487 #[serde(default)]
488 pub arguments: Option<String>,
489}
490
491#[derive(Debug, Clone, Default, Serialize, Deserialize)]
493#[non_exhaustive]
494pub struct ChoiceDelta {
495 #[serde(default)]
496 pub role: Option<Role>,
497 #[serde(default)]
498 pub content: Option<String>,
499 #[serde(default, skip_serializing_if = "Option::is_none")]
500 pub refusal: Option<String>,
501 #[serde(default, skip_serializing_if = "Option::is_none")]
502 pub tool_calls: Option<Vec<ChoiceDeltaToolCall>>,
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize)]
507#[non_exhaustive]
508pub struct ChunkChoice {
509 pub index: u32,
510 #[serde(default)]
511 pub delta: ChoiceDelta,
512 #[serde(default)]
513 pub finish_reason: Option<FinishReason>,
514 #[serde(default, skip_serializing_if = "Option::is_none")]
515 pub logprobs: Option<ChatLogprobs>,
516}
517
518#[derive(Debug, Clone, Serialize, Deserialize)]
520#[non_exhaustive]
521pub struct ChatCompletionChunk {
522 pub id: String,
523 #[serde(default)]
524 pub choices: Vec<ChunkChoice>,
525 pub created: i64,
526 pub model: String,
527 #[serde(default)]
528 pub object: String,
529 #[serde(default, skip_serializing_if = "Option::is_none")]
530 pub service_tier: Option<String>,
531 #[serde(default, skip_serializing_if = "Option::is_none")]
532 pub system_fingerprint: Option<String>,
533 #[serde(default)]
535 pub usage: Option<Usage>,
536}
537
538impl ChatCompletionChunk {
539 pub fn content(&self) -> Option<&str> {
541 self.choices.first()?.delta.content.as_deref()
542 }
543}
544
545#[cfg(test)]
546mod tests {
547 use super::*;
548
549 #[test]
550 fn multimodal_content_parts_serialize() {
551 let message = Message::user(vec![
552 ContentPart::text("What is in this image?"),
553 ContentPart::image_url_with_detail("https://example.com/cat.png", "low"),
554 ContentPart::input_audio("aGVsbG8=", "wav"),
555 ]);
556 assert_eq!(
557 serde_json::to_value(&message).unwrap(),
558 serde_json::json!({
559 "role": "user",
560 "content": [
561 {"type": "text", "text": "What is in this image?"},
562 {"type": "image_url", "image_url": {"url": "https://example.com/cat.png", "detail": "low"}},
563 {"type": "input_audio", "input_audio": {"data": "aGVsbG8=", "format": "wav"}}
564 ]
565 })
566 );
567 }
568
569 #[test]
570 fn string_content_still_serializes_as_plain_string() {
571 let message = Message::user("hi");
572 assert_eq!(
573 serde_json::to_value(&message).unwrap(),
574 serde_json::json!({"role": "user", "content": "hi"})
575 );
576 }
577
578 #[test]
579 fn request_skips_none_fields() {
580 let request = ChatCompletionRequest::new("gpt-4o", vec![Message::user("hi")]);
581 let json = serde_json::to_value(&request).unwrap();
582 assert_eq!(
583 json,
584 serde_json::json!({
585 "model": "gpt-4o",
586 "messages": [{"role": "user", "content": "hi"}],
587 })
588 );
589 }
590
591 #[test]
592 fn deserializes_real_openai_response() {
593 let body = r#"{
594 "id": "chatcmpl-abc123",
595 "object": "chat.completion",
596 "created": 1728933352,
597 "model": "gpt-4o-2024-08-06",
598 "choices": [{
599 "index": 0,
600 "message": {
601 "role": "assistant",
602 "content": "Hello there!",
603 "refusal": null
604 },
605 "logprobs": null,
606 "finish_reason": "stop"
607 }],
608 "usage": {
609 "prompt_tokens": 19,
610 "completion_tokens": 10,
611 "total_tokens": 29,
612 "completion_tokens_details": {"reasoning_tokens": 0}
613 },
614 "system_fingerprint": "fp_6b68a8204b"
615 }"#;
616 let completion: ChatCompletion = serde_json::from_str(body).unwrap();
617 assert_eq!(completion.content(), Some("Hello there!"));
618 assert_eq!(
619 completion.choices[0].finish_reason,
620 Some(FinishReason::Stop)
621 );
622 assert_eq!(completion.usage.as_ref().unwrap().total_tokens, 29);
623 }
624
625 #[test]
626 fn deserializes_tool_call_response() {
627 let body = r#"{
628 "id": "chatcmpl-1",
629 "object": "chat.completion",
630 "created": 1,
631 "model": "gpt-4o",
632 "choices": [{
633 "index": 0,
634 "message": {
635 "role": "assistant",
636 "content": null,
637 "tool_calls": [{
638 "id": "call_1",
639 "type": "function",
640 "function": {"name": "get_weather", "arguments": "{\"city\":\"Hanoi\"}"}
641 }]
642 },
643 "finish_reason": "tool_calls"
644 }]
645 }"#;
646 let completion: ChatCompletion = serde_json::from_str(body).unwrap();
647 let calls = completion.choices[0].message.tool_calls.as_ref().unwrap();
648 assert_eq!(calls[0].function.name, "get_weather");
649 assert_eq!(
650 completion.choices[0].finish_reason,
651 Some(FinishReason::ToolCalls)
652 );
653 }
654
655 #[test]
656 fn deserializes_stream_chunk() {
657 let body = r#"{
658 "id": "chatcmpl-1",
659 "object": "chat.completion.chunk",
660 "created": 1,
661 "model": "gpt-4o",
662 "choices": [{"index": 0, "delta": {"content": "Hel"}, "finish_reason": null}]
663 }"#;
664 let chunk: ChatCompletionChunk = serde_json::from_str(body).unwrap();
665 assert_eq!(chunk.content(), Some("Hel"));
666 }
667}