1use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct ChatCompletionRequest {
16 pub model: String,
18
19 pub messages: Vec<Message>,
21
22 #[serde(default = "default_temperature")]
24 pub temperature: f32,
25
26 #[serde(default = "default_top_p")]
28 pub top_p: f32,
29
30 #[serde(default = "default_n")]
32 pub n: u32,
33
34 #[serde(default)]
36 pub stream: bool,
37
38 #[serde(default, skip_serializing_if = "Option::is_none")]
40 pub stop: Option<Vec<String>>,
41
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub max_tokens: Option<u32>,
45
46 #[serde(default)]
48 pub presence_penalty: f32,
49
50 #[serde(default)]
52 pub frequency_penalty: f32,
53
54 #[serde(skip_serializing_if = "Option::is_none")]
56 pub user: Option<String>,
57
58 #[serde(skip_serializing_if = "Option::is_none")]
60 pub response_format: Option<ResponseFormat>,
61
62 #[serde(skip_serializing_if = "Option::is_none")]
64 pub seed: Option<u64>,
65
66 #[serde(default, skip_serializing_if = "Vec::is_empty")]
68 pub tools: Vec<Tool>,
69
70 #[serde(skip_serializing_if = "Option::is_none")]
72 pub tool_choice: Option<ToolChoice>,
73}
74
75fn default_temperature() -> f32 {
76 1.0
77}
78fn default_top_p() -> f32 {
79 1.0
80}
81fn default_n() -> u32 {
82 1
83}
84
85impl Default for ChatCompletionRequest {
86 fn default() -> Self {
87 Self {
88 model: "gpt-4".to_string(),
89 messages: vec![],
90 temperature: default_temperature(),
91 top_p: default_top_p(),
92 n: default_n(),
93 stream: false,
94 stop: None,
95 max_tokens: None,
96 presence_penalty: 0.0,
97 frequency_penalty: 0.0,
98 user: None,
99 response_format: None,
100 seed: None,
101 tools: vec![],
102 tool_choice: None,
103 }
104 }
105}
106
107#[derive(Debug, Clone, Deserialize, Serialize)]
113pub struct Message {
114 pub role: Role,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 pub content: Option<MessageContent>,
117 #[serde(skip_serializing_if = "Option::is_none")]
118 pub name: Option<String>,
119 #[serde(skip_serializing_if = "Option::is_none")]
120 pub tool_calls: Option<Vec<ToolCall>>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 pub tool_call_id: Option<String>,
123}
124
125impl Message {
126 pub fn system(content: impl Into<String>) -> Self {
128 Self {
129 role: Role::System,
130 content: Some(MessageContent::Text(content.into())),
131 name: None,
132 tool_calls: None,
133 tool_call_id: None,
134 }
135 }
136
137 pub fn user(content: impl Into<String>) -> Self {
139 Self {
140 role: Role::User,
141 content: Some(MessageContent::Text(content.into())),
142 name: None,
143 tool_calls: None,
144 tool_call_id: None,
145 }
146 }
147
148 pub fn assistant(content: impl Into<String>) -> Self {
150 Self {
151 role: Role::Assistant,
152 content: Some(MessageContent::Text(content.into())),
153 name: None,
154 tool_calls: None,
155 tool_call_id: None,
156 }
157 }
158
159 pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
161 Self {
162 role: Role::Tool,
163 content: Some(MessageContent::Text(content.into())),
164 name: None,
165 tool_calls: None,
166 tool_call_id: Some(tool_call_id.into()),
167 }
168 }
169
170 pub fn text(&self) -> Option<&str> {
172 self.content.as_ref().and_then(|c| c.as_text())
173 }
174}
175
176#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
178#[serde(rename_all = "lowercase")]
179pub enum Role {
180 System,
181 User,
182 Assistant,
183 Tool,
184}
185
186#[derive(Debug, Clone, Deserialize, Serialize)]
188#[serde(untagged)]
189pub enum MessageContent {
190 Text(String),
191 Parts(Vec<ContentPart>),
192}
193
194impl MessageContent {
195 pub fn as_text(&self) -> Option<&str> {
197 match self {
198 MessageContent::Text(s) => Some(s),
199 MessageContent::Parts(parts) => parts.iter().find_map(|p| {
200 if let ContentPart::Text { text } = p {
201 Some(text.as_str())
202 } else {
203 None
204 }
205 }),
206 }
207 }
208
209 pub fn into_text(self) -> Option<String> {
211 match self {
212 MessageContent::Text(s) => Some(s),
213 MessageContent::Parts(parts) => parts.into_iter().find_map(|p| {
214 if let ContentPart::Text { text } = p {
215 Some(text)
216 } else {
217 None
218 }
219 }),
220 }
221 }
222}
223
224impl From<String> for MessageContent {
225 fn from(s: String) -> Self {
226 MessageContent::Text(s)
227 }
228}
229
230impl From<&str> for MessageContent {
231 fn from(s: &str) -> Self {
232 MessageContent::Text(s.to_string())
233 }
234}
235
236#[derive(Debug, Clone, Deserialize, Serialize)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum ContentPart {
240 Text { text: String },
241 ImageUrl { image_url: ImageUrl },
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize)]
245pub struct ImageUrl {
246 pub url: String,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub detail: Option<String>,
249}
250
251#[derive(Debug, Clone, Deserialize, Serialize)]
257pub struct ResponseFormat {
258 #[serde(rename = "type")]
259 pub format_type: String,
260}
261
262impl ResponseFormat {
263 pub fn json() -> Self {
264 Self {
265 format_type: "json_object".to_string(),
266 }
267 }
268
269 pub fn text() -> Self {
270 Self {
271 format_type: "text".to_string(),
272 }
273 }
274}
275
276#[derive(Debug, Clone, Deserialize, Serialize)]
278pub struct Tool {
279 #[serde(rename = "type")]
280 pub tool_type: String,
281 pub function: FunctionDefinition,
282}
283
284impl Tool {
285 pub fn function(
287 name: impl Into<String>,
288 description: Option<String>,
289 parameters: Option<serde_json::Value>,
290 ) -> Self {
291 Self {
292 tool_type: "function".to_string(),
293 function: FunctionDefinition {
294 name: name.into(),
295 description,
296 parameters,
297 },
298 }
299 }
300}
301
302#[derive(Debug, Clone, Deserialize, Serialize)]
303pub struct FunctionDefinition {
304 pub name: String,
305 #[serde(skip_serializing_if = "Option::is_none")]
306 pub description: Option<String>,
307 #[serde(skip_serializing_if = "Option::is_none")]
308 pub parameters: Option<serde_json::Value>,
309}
310
311#[derive(Debug, Clone, Deserialize, Serialize)]
313pub struct ToolCall {
314 pub id: String,
315 #[serde(rename = "type")]
316 pub tool_type: String,
317 pub function: FunctionCall,
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize)]
321pub struct FunctionCall {
322 pub name: String,
323 pub arguments: String,
324}
325
326#[derive(Debug, Clone, Deserialize, Serialize)]
328#[serde(untagged)]
329pub enum ToolChoice {
330 Mode(String), Specific {
332 #[serde(rename = "type")]
333 tool_type: String,
334 function: FunctionName,
335 },
336}
337
338impl ToolChoice {
339 pub fn none() -> Self {
340 ToolChoice::Mode("none".to_string())
341 }
342
343 pub fn auto() -> Self {
344 ToolChoice::Mode("auto".to_string())
345 }
346
347 pub fn required() -> Self {
348 ToolChoice::Mode("required".to_string())
349 }
350
351 pub fn function(name: impl Into<String>) -> Self {
352 ToolChoice::Specific {
353 tool_type: "function".to_string(),
354 function: FunctionName { name: name.into() },
355 }
356 }
357}
358
359#[derive(Debug, Clone, Deserialize, Serialize)]
360pub struct FunctionName {
361 pub name: String,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ChatCompletionResponse {
371 pub id: String,
372 pub object: String,
373 pub created: i64,
374 pub model: String,
375 pub choices: Vec<Choice>,
376 pub usage: Usage,
377 #[serde(skip_serializing_if = "Option::is_none")]
378 pub system_fingerprint: Option<String>,
379}
380
381impl ChatCompletionResponse {
382 pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
384 Self {
385 id: format!("chatcmpl-{}", Uuid::new_v4()),
386 object: "chat.completion".to_string(),
387 created: chrono::Utc::now().timestamp(),
388 model: model.into(),
389 choices: vec![Choice {
390 index: 0,
391 message: Message::assistant(content),
392 finish_reason: Some("stop".to_string()),
393 logprobs: None,
394 }],
395 usage: Usage::default(),
396 system_fingerprint: None,
397 }
398 }
399
400 pub fn text(&self) -> Option<&str> {
402 self.choices.first().and_then(|c| c.message.text())
403 }
404}
405
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct Choice {
408 pub index: u32,
409 pub message: Message,
410 #[serde(skip_serializing_if = "Option::is_none")]
411 pub finish_reason: Option<String>,
412 #[serde(skip_serializing_if = "Option::is_none")]
413 pub logprobs: Option<serde_json::Value>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize, Default)]
417pub struct Usage {
418 pub prompt_tokens: u32,
419 pub completion_tokens: u32,
420 pub total_tokens: u32,
421}
422
423#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct ChatCompletionChunk {
426 pub id: String,
427 pub object: String,
428 pub created: i64,
429 pub model: String,
430 pub choices: Vec<ChunkChoice>,
431 #[serde(skip_serializing_if = "Option::is_none")]
432 pub system_fingerprint: Option<String>,
433}
434
435impl ChatCompletionChunk {
436 pub fn new(
437 id: &str,
438 model: &str,
439 delta: ChunkDelta,
440 finish_reason: Option<String>,
441 ) -> Self {
442 Self {
443 id: id.to_string(),
444 object: "chat.completion.chunk".to_string(),
445 created: chrono::Utc::now().timestamp(),
446 model: model.to_string(),
447 choices: vec![ChunkChoice {
448 index: 0,
449 delta,
450 finish_reason,
451 logprobs: None,
452 }],
453 system_fingerprint: None,
454 }
455 }
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct ChunkChoice {
460 pub index: u32,
461 pub delta: ChunkDelta,
462 #[serde(skip_serializing_if = "Option::is_none")]
463 pub finish_reason: Option<String>,
464 #[serde(skip_serializing_if = "Option::is_none")]
465 pub logprobs: Option<serde_json::Value>,
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize, Default)]
469pub struct ChunkDelta {
470 #[serde(skip_serializing_if = "Option::is_none")]
471 pub role: Option<Role>,
472 #[serde(skip_serializing_if = "Option::is_none")]
473 pub content: Option<String>,
474 #[serde(skip_serializing_if = "Option::is_none")]
475 pub tool_calls: Option<Vec<ToolCallChunk>>,
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct ToolCallChunk {
480 pub index: u32,
481 #[serde(skip_serializing_if = "Option::is_none")]
482 pub id: Option<String>,
483 #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
484 pub tool_type: Option<String>,
485 #[serde(skip_serializing_if = "Option::is_none")]
486 pub function: Option<FunctionCallChunk>,
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct FunctionCallChunk {
491 #[serde(skip_serializing_if = "Option::is_none")]
492 pub name: Option<String>,
493 #[serde(skip_serializing_if = "Option::is_none")]
494 pub arguments: Option<String>,
495}
496
497#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct ModelsResponse {
503 pub object: String,
504 pub data: Vec<ModelInfo>,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct ModelInfo {
509 pub id: String,
510 pub object: String,
511 pub created: i64,
512 pub owned_by: String,
513 #[serde(skip_serializing_if = "Option::is_none")]
514 pub context_window: Option<u32>,
515 #[serde(skip_serializing_if = "Option::is_none")]
516 pub max_completion_tokens: Option<u32>,
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct ErrorResponse {
525 pub error: ErrorDetail,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct ErrorDetail {
530 pub message: String,
531 #[serde(rename = "type")]
532 pub error_type: String,
533 #[serde(skip_serializing_if = "Option::is_none")]
534 pub param: Option<String>,
535 #[serde(skip_serializing_if = "Option::is_none")]
536 pub code: Option<String>,
537}
538
539impl ErrorResponse {
540 pub fn new(message: impl Into<String>, error_type: impl Into<String>) -> Self {
541 Self {
542 error: ErrorDetail {
543 message: message.into(),
544 error_type: error_type.into(),
545 param: None,
546 code: None,
547 },
548 }
549 }
550}
551
552#[cfg(test)]
557mod tests {
558 use super::*;
559
560 #[test]
561 fn test_message_constructors() {
562 let sys = Message::system("You are helpful");
563 assert_eq!(sys.role, Role::System);
564 assert_eq!(sys.text(), Some("You are helpful"));
565
566 let user = Message::user("Hello");
567 assert_eq!(user.role, Role::User);
568
569 let asst = Message::assistant("Hi there!");
570 assert_eq!(asst.role, Role::Assistant);
571
572 let tool = Message::tool("call_123", r#"{"result": 42}"#);
573 assert_eq!(tool.role, Role::Tool);
574 assert_eq!(tool.tool_call_id, Some("call_123".to_string()));
575 }
576
577 #[test]
578 fn test_request_serialization() {
579 let request = ChatCompletionRequest {
580 model: "gpt-4".to_string(),
581 messages: vec![Message::user("Hello")],
582 ..Default::default()
583 };
584
585 let json = serde_json::to_value(&request).unwrap();
586 assert_eq!(json["model"], "gpt-4");
587 assert_eq!(json["messages"][0]["role"], "user");
588 }
589
590 #[test]
591 fn test_response_creation() {
592 let response = ChatCompletionResponse::new("gpt-4", "Hello!");
593
594 assert!(response.id.starts_with("chatcmpl-"));
595 assert_eq!(response.object, "chat.completion");
596 assert_eq!(response.text(), Some("Hello!"));
597 }
598
599 #[test]
600 fn test_message_content_variants() {
601 let text = MessageContent::Text("hello".to_string());
602 assert_eq!(text.as_text(), Some("hello"));
603
604 let parts = MessageContent::Parts(vec![
605 ContentPart::Text { text: "world".to_string() },
606 ]);
607 assert_eq!(parts.as_text(), Some("world"));
608 }
609
610 #[test]
611 fn test_tool_choice() {
612 let auto = ToolChoice::auto();
613 let json = serde_json::to_value(&auto).unwrap();
614 assert_eq!(json, "auto");
615
616 let specific = ToolChoice::function("get_weather");
617 let json = serde_json::to_value(&specific).unwrap();
618 assert_eq!(json["function"]["name"], "get_weather");
619 }
620}