Skip to main content

pw_core/
chat.rs

1//! OpenAI-compatible Chat API Types
2//!
3//! These types match the OpenAI API schema for SDK compatibility.
4//! Works with OpenAI, Anthropic (via adapter), Groq, OpenRouter, etc.
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9// ============================================================================
10// Request Types
11// ============================================================================
12
13/// Chat completion request - matches OpenAI schema exactly
14#[derive(Debug, Clone, Deserialize, Serialize)]
15pub struct ChatCompletionRequest {
16    /// Model to use (e.g., "gpt-4", "llama-3.1-70b")
17    pub model: String,
18
19    /// Messages in the conversation
20    pub messages: Vec<Message>,
21
22    /// Sampling temperature (0-2)
23    #[serde(default = "default_temperature")]
24    pub temperature: f32,
25
26    /// Top-p nucleus sampling
27    #[serde(default = "default_top_p")]
28    pub top_p: f32,
29
30    /// Number of completions to generate
31    #[serde(default = "default_n")]
32    pub n: u32,
33
34    /// Whether to stream responses
35    #[serde(default)]
36    pub stream: bool,
37
38    /// Stop sequences
39    #[serde(default, skip_serializing_if = "Option::is_none")]
40    pub stop: Option<Vec<String>>,
41
42    /// Maximum tokens to generate
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub max_tokens: Option<u32>,
45
46    /// Presence penalty (-2 to 2)
47    #[serde(default)]
48    pub presence_penalty: f32,
49
50    /// Frequency penalty (-2 to 2)
51    #[serde(default)]
52    pub frequency_penalty: f32,
53
54    /// User identifier for abuse detection
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub user: Option<String>,
57
58    /// Response format
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub response_format: Option<ResponseFormat>,
61
62    /// Seed for deterministic outputs
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub seed: Option<u64>,
65
66    /// Tool/function definitions
67    #[serde(default, skip_serializing_if = "Vec::is_empty")]
68    pub tools: Vec<Tool>,
69
70    /// How to handle tool calls
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub tool_choice: Option<ToolChoice>,
73}
74
75fn default_temperature() -> f32 {
76    1.0
77}
78fn default_top_p() -> f32 {
79    1.0
80}
81fn default_n() -> u32 {
82    1
83}
84
85impl Default for ChatCompletionRequest {
86    fn default() -> Self {
87        Self {
88            model: "gpt-4".to_string(),
89            messages: vec![],
90            temperature: default_temperature(),
91            top_p: default_top_p(),
92            n: default_n(),
93            stream: false,
94            stop: None,
95            max_tokens: None,
96            presence_penalty: 0.0,
97            frequency_penalty: 0.0,
98            user: None,
99            response_format: None,
100            seed: None,
101            tools: vec![],
102            tool_choice: None,
103        }
104    }
105}
106
107// ============================================================================
108// Message Types
109// ============================================================================
110
111/// Message in a chat conversation
112#[derive(Debug, Clone, Deserialize, Serialize)]
113pub struct Message {
114    pub role: Role,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub content: Option<MessageContent>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub name: Option<String>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub tool_calls: Option<Vec<ToolCall>>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub tool_call_id: Option<String>,
123}
124
125impl Message {
126    /// Create a system message
127    pub fn system(content: impl Into<String>) -> Self {
128        Self {
129            role: Role::System,
130            content: Some(MessageContent::Text(content.into())),
131            name: None,
132            tool_calls: None,
133            tool_call_id: None,
134        }
135    }
136
137    /// Create a user message
138    pub fn user(content: impl Into<String>) -> Self {
139        Self {
140            role: Role::User,
141            content: Some(MessageContent::Text(content.into())),
142            name: None,
143            tool_calls: None,
144            tool_call_id: None,
145        }
146    }
147
148    /// Create an assistant message
149    pub fn assistant(content: impl Into<String>) -> Self {
150        Self {
151            role: Role::Assistant,
152            content: Some(MessageContent::Text(content.into())),
153            name: None,
154            tool_calls: None,
155            tool_call_id: None,
156        }
157    }
158
159    /// Create a tool result message
160    pub fn tool(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
161        Self {
162            role: Role::Tool,
163            content: Some(MessageContent::Text(content.into())),
164            name: None,
165            tool_calls: None,
166            tool_call_id: Some(tool_call_id.into()),
167        }
168    }
169
170    /// Get content as text (extracts from Parts if needed)
171    pub fn text(&self) -> Option<&str> {
172        self.content.as_ref().and_then(|c| c.as_text())
173    }
174}
175
176/// Role of a message sender
177#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq, Hash)]
178#[serde(rename_all = "lowercase")]
179pub enum Role {
180    System,
181    User,
182    Assistant,
183    Tool,
184}
185
186/// Message content - can be string or array of content parts
187#[derive(Debug, Clone, Deserialize, Serialize)]
188#[serde(untagged)]
189pub enum MessageContent {
190    Text(String),
191    Parts(Vec<ContentPart>),
192}
193
194impl MessageContent {
195    /// Get content as text (extracts first text from Parts if needed)
196    pub fn as_text(&self) -> Option<&str> {
197        match self {
198            MessageContent::Text(s) => Some(s),
199            MessageContent::Parts(parts) => parts.iter().find_map(|p| {
200                if let ContentPart::Text { text } = p {
201                    Some(text.as_str())
202                } else {
203                    None
204                }
205            }),
206        }
207    }
208
209    /// Convert to owned String
210    pub fn into_text(self) -> Option<String> {
211        match self {
212            MessageContent::Text(s) => Some(s),
213            MessageContent::Parts(parts) => parts.into_iter().find_map(|p| {
214                if let ContentPart::Text { text } = p {
215                    Some(text)
216                } else {
217                    None
218                }
219            }),
220        }
221    }
222}
223
224impl From<String> for MessageContent {
225    fn from(s: String) -> Self {
226        MessageContent::Text(s)
227    }
228}
229
230impl From<&str> for MessageContent {
231    fn from(s: &str) -> Self {
232        MessageContent::Text(s.to_string())
233    }
234}
235
236/// Content part for multimodal messages
237#[derive(Debug, Clone, Deserialize, Serialize)]
238#[serde(tag = "type", rename_all = "snake_case")]
239pub enum ContentPart {
240    Text { text: String },
241    ImageUrl { image_url: ImageUrl },
242}
243
244#[derive(Debug, Clone, Deserialize, Serialize)]
245pub struct ImageUrl {
246    pub url: String,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub detail: Option<String>,
249}
250
251// ============================================================================
252// Tool Types
253// ============================================================================
254
255/// Response format specification
256#[derive(Debug, Clone, Deserialize, Serialize)]
257pub struct ResponseFormat {
258    #[serde(rename = "type")]
259    pub format_type: String,
260}
261
262impl ResponseFormat {
263    pub fn json() -> Self {
264        Self {
265            format_type: "json_object".to_string(),
266        }
267    }
268
269    pub fn text() -> Self {
270        Self {
271            format_type: "text".to_string(),
272        }
273    }
274}
275
276/// Tool definition
277#[derive(Debug, Clone, Deserialize, Serialize)]
278pub struct Tool {
279    #[serde(rename = "type")]
280    pub tool_type: String,
281    pub function: FunctionDefinition,
282}
283
284impl Tool {
285    /// Create a function tool
286    pub fn function(
287        name: impl Into<String>,
288        description: Option<String>,
289        parameters: Option<serde_json::Value>,
290    ) -> Self {
291        Self {
292            tool_type: "function".to_string(),
293            function: FunctionDefinition {
294                name: name.into(),
295                description,
296                parameters,
297            },
298        }
299    }
300}
301
302#[derive(Debug, Clone, Deserialize, Serialize)]
303pub struct FunctionDefinition {
304    pub name: String,
305    #[serde(skip_serializing_if = "Option::is_none")]
306    pub description: Option<String>,
307    #[serde(skip_serializing_if = "Option::is_none")]
308    pub parameters: Option<serde_json::Value>,
309}
310
311/// Tool call in assistant message
312#[derive(Debug, Clone, Deserialize, Serialize)]
313pub struct ToolCall {
314    pub id: String,
315    #[serde(rename = "type")]
316    pub tool_type: String,
317    pub function: FunctionCall,
318}
319
320#[derive(Debug, Clone, Deserialize, Serialize)]
321pub struct FunctionCall {
322    pub name: String,
323    pub arguments: String,
324}
325
326/// How to select tools
327#[derive(Debug, Clone, Deserialize, Serialize)]
328#[serde(untagged)]
329pub enum ToolChoice {
330    Mode(String), // "none", "auto", "required"
331    Specific {
332        #[serde(rename = "type")]
333        tool_type: String,
334        function: FunctionName,
335    },
336}
337
338impl ToolChoice {
339    pub fn none() -> Self {
340        ToolChoice::Mode("none".to_string())
341    }
342
343    pub fn auto() -> Self {
344        ToolChoice::Mode("auto".to_string())
345    }
346
347    pub fn required() -> Self {
348        ToolChoice::Mode("required".to_string())
349    }
350
351    pub fn function(name: impl Into<String>) -> Self {
352        ToolChoice::Specific {
353            tool_type: "function".to_string(),
354            function: FunctionName { name: name.into() },
355        }
356    }
357}
358
359#[derive(Debug, Clone, Deserialize, Serialize)]
360pub struct FunctionName {
361    pub name: String,
362}
363
364// ============================================================================
365// Response Types
366// ============================================================================
367
368/// Chat completion response - non-streaming
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct ChatCompletionResponse {
371    pub id: String,
372    pub object: String,
373    pub created: i64,
374    pub model: String,
375    pub choices: Vec<Choice>,
376    pub usage: Usage,
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub system_fingerprint: Option<String>,
379}
380
381impl ChatCompletionResponse {
382    /// Create a simple response with text content
383    pub fn new(model: impl Into<String>, content: impl Into<String>) -> Self {
384        Self {
385            id: format!("chatcmpl-{}", Uuid::new_v4()),
386            object: "chat.completion".to_string(),
387            created: chrono::Utc::now().timestamp(),
388            model: model.into(),
389            choices: vec![Choice {
390                index: 0,
391                message: Message::assistant(content),
392                finish_reason: Some("stop".to_string()),
393                logprobs: None,
394            }],
395            usage: Usage::default(),
396            system_fingerprint: None,
397        }
398    }
399
400    /// Get the first choice's message content as text
401    pub fn text(&self) -> Option<&str> {
402        self.choices.first().and_then(|c| c.message.text())
403    }
404}
405
406#[derive(Debug, Clone, Serialize, Deserialize)]
407pub struct Choice {
408    pub index: u32,
409    pub message: Message,
410    #[serde(skip_serializing_if = "Option::is_none")]
411    pub finish_reason: Option<String>,
412    #[serde(skip_serializing_if = "Option::is_none")]
413    pub logprobs: Option<serde_json::Value>,
414}
415
416#[derive(Debug, Clone, Serialize, Deserialize, Default)]
417pub struct Usage {
418    pub prompt_tokens: u32,
419    pub completion_tokens: u32,
420    pub total_tokens: u32,
421}
422
423/// Streaming chunk response
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct ChatCompletionChunk {
426    pub id: String,
427    pub object: String,
428    pub created: i64,
429    pub model: String,
430    pub choices: Vec<ChunkChoice>,
431    #[serde(skip_serializing_if = "Option::is_none")]
432    pub system_fingerprint: Option<String>,
433}
434
435impl ChatCompletionChunk {
436    pub fn new(
437        id: &str,
438        model: &str,
439        delta: ChunkDelta,
440        finish_reason: Option<String>,
441    ) -> Self {
442        Self {
443            id: id.to_string(),
444            object: "chat.completion.chunk".to_string(),
445            created: chrono::Utc::now().timestamp(),
446            model: model.to_string(),
447            choices: vec![ChunkChoice {
448                index: 0,
449                delta,
450                finish_reason,
451                logprobs: None,
452            }],
453            system_fingerprint: None,
454        }
455    }
456}
457
458#[derive(Debug, Clone, Serialize, Deserialize)]
459pub struct ChunkChoice {
460    pub index: u32,
461    pub delta: ChunkDelta,
462    #[serde(skip_serializing_if = "Option::is_none")]
463    pub finish_reason: Option<String>,
464    #[serde(skip_serializing_if = "Option::is_none")]
465    pub logprobs: Option<serde_json::Value>,
466}
467
468#[derive(Debug, Clone, Serialize, Deserialize, Default)]
469pub struct ChunkDelta {
470    #[serde(skip_serializing_if = "Option::is_none")]
471    pub role: Option<Role>,
472    #[serde(skip_serializing_if = "Option::is_none")]
473    pub content: Option<String>,
474    #[serde(skip_serializing_if = "Option::is_none")]
475    pub tool_calls: Option<Vec<ToolCallChunk>>,
476}
477
478#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct ToolCallChunk {
480    pub index: u32,
481    #[serde(skip_serializing_if = "Option::is_none")]
482    pub id: Option<String>,
483    #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
484    pub tool_type: Option<String>,
485    #[serde(skip_serializing_if = "Option::is_none")]
486    pub function: Option<FunctionCallChunk>,
487}
488
489#[derive(Debug, Clone, Serialize, Deserialize)]
490pub struct FunctionCallChunk {
491    #[serde(skip_serializing_if = "Option::is_none")]
492    pub name: Option<String>,
493    #[serde(skip_serializing_if = "Option::is_none")]
494    pub arguments: Option<String>,
495}
496
497// ============================================================================
498// Models Endpoint
499// ============================================================================
500
501#[derive(Debug, Clone, Serialize, Deserialize)]
502pub struct ModelsResponse {
503    pub object: String,
504    pub data: Vec<ModelInfo>,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
508pub struct ModelInfo {
509    pub id: String,
510    pub object: String,
511    pub created: i64,
512    pub owned_by: String,
513    #[serde(skip_serializing_if = "Option::is_none")]
514    pub context_window: Option<u32>,
515    #[serde(skip_serializing_if = "Option::is_none")]
516    pub max_completion_tokens: Option<u32>,
517}
518
519// ============================================================================
520// Error Response
521// ============================================================================
522
523#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct ErrorResponse {
525    pub error: ErrorDetail,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub struct ErrorDetail {
530    pub message: String,
531    #[serde(rename = "type")]
532    pub error_type: String,
533    #[serde(skip_serializing_if = "Option::is_none")]
534    pub param: Option<String>,
535    #[serde(skip_serializing_if = "Option::is_none")]
536    pub code: Option<String>,
537}
538
539impl ErrorResponse {
540    pub fn new(message: impl Into<String>, error_type: impl Into<String>) -> Self {
541        Self {
542            error: ErrorDetail {
543                message: message.into(),
544                error_type: error_type.into(),
545                param: None,
546                code: None,
547            },
548        }
549    }
550}
551
552// ============================================================================
553// Tests
554// ============================================================================
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559
560    #[test]
561    fn test_message_constructors() {
562        let sys = Message::system("You are helpful");
563        assert_eq!(sys.role, Role::System);
564        assert_eq!(sys.text(), Some("You are helpful"));
565
566        let user = Message::user("Hello");
567        assert_eq!(user.role, Role::User);
568
569        let asst = Message::assistant("Hi there!");
570        assert_eq!(asst.role, Role::Assistant);
571
572        let tool = Message::tool("call_123", r#"{"result": 42}"#);
573        assert_eq!(tool.role, Role::Tool);
574        assert_eq!(tool.tool_call_id, Some("call_123".to_string()));
575    }
576
577    #[test]
578    fn test_request_serialization() {
579        let request = ChatCompletionRequest {
580            model: "gpt-4".to_string(),
581            messages: vec![Message::user("Hello")],
582            ..Default::default()
583        };
584
585        let json = serde_json::to_value(&request).unwrap();
586        assert_eq!(json["model"], "gpt-4");
587        assert_eq!(json["messages"][0]["role"], "user");
588    }
589
590    #[test]
591    fn test_response_creation() {
592        let response = ChatCompletionResponse::new("gpt-4", "Hello!");
593        
594        assert!(response.id.starts_with("chatcmpl-"));
595        assert_eq!(response.object, "chat.completion");
596        assert_eq!(response.text(), Some("Hello!"));
597    }
598
599    #[test]
600    fn test_message_content_variants() {
601        let text = MessageContent::Text("hello".to_string());
602        assert_eq!(text.as_text(), Some("hello"));
603
604        let parts = MessageContent::Parts(vec![
605            ContentPart::Text { text: "world".to_string() },
606        ]);
607        assert_eq!(parts.as_text(), Some("world"));
608    }
609
610    #[test]
611    fn test_tool_choice() {
612        let auto = ToolChoice::auto();
613        let json = serde_json::to_value(&auto).unwrap();
614        assert_eq!(json, "auto");
615
616        let specific = ToolChoice::function("get_weather");
617        let json = serde_json::to_value(&specific).unwrap();
618        assert_eq!(json["function"]["name"], "get_weather");
619    }
620}