agent_sdk/providers/
openai.rs

1//! `OpenAI` API provider implementation.
2//!
3//! This module provides an implementation of `LlmProvider` for the `OpenAI`
4//! Chat Completions API. It also supports `OpenAI`-compatible APIs (Ollama, vLLM, etc.)
5//! via the `with_base_url` constructor.
6
7use crate::llm::{
8    ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, StopReason, Usage,
9};
10use anyhow::Result;
11use async_trait::async_trait;
12use reqwest::StatusCode;
13use serde::{Deserialize, Serialize};
14
15const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
16
17// GPT-5.2 series (latest flagship, Dec 2025)
18pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
19pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
20pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
21
22// GPT-5 series (400k context)
23pub const MODEL_GPT5: &str = "gpt-5";
24pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
25pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
26
27// o-series reasoning models
28pub const MODEL_O3: &str = "o3";
29pub const MODEL_O3_MINI: &str = "o3-mini";
30pub const MODEL_O4_MINI: &str = "o4-mini";
31pub const MODEL_O1: &str = "o1";
32pub const MODEL_O1_MINI: &str = "o1-mini";
33
34// GPT-4.1 series (improved instruction following, 1M context)
35pub const MODEL_GPT41: &str = "gpt-4.1";
36pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
37pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
38
39// GPT-4o series
40pub const MODEL_GPT4O: &str = "gpt-4o";
41pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
42
43/// `OpenAI` LLM provider using the Chat Completions API.
44///
45/// Also supports `OpenAI`-compatible APIs (Ollama, vLLM, Azure `OpenAI`, etc.)
46/// via the `with_base_url` constructor.
47#[derive(Clone)]
48pub struct OpenAIProvider {
49    client: reqwest::Client,
50    api_key: String,
51    model: String,
52    base_url: String,
53}
54
55impl OpenAIProvider {
56    /// Create a new `OpenAI` provider with the specified API key and model.
57    #[must_use]
58    pub fn new(api_key: String, model: String) -> Self {
59        Self {
60            client: reqwest::Client::new(),
61            api_key,
62            model,
63            base_url: DEFAULT_BASE_URL.to_owned(),
64        }
65    }
66
67    /// Create a new provider with a custom base URL for OpenAI-compatible APIs.
68    #[must_use]
69    pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
70        Self {
71            client: reqwest::Client::new(),
72            api_key,
73            model,
74            base_url,
75        }
76    }
77
78    /// Create a provider using GPT-5.2 Instant (speed-optimized for routine queries).
79    #[must_use]
80    pub fn gpt52_instant(api_key: String) -> Self {
81        Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
82    }
83
84    /// Create a provider using GPT-5.2 Thinking (complex reasoning, coding, analysis).
85    #[must_use]
86    pub fn gpt52_thinking(api_key: String) -> Self {
87        Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
88    }
89
90    /// Create a provider using GPT-5.2 Pro (maximum accuracy for difficult problems).
91    #[must_use]
92    pub fn gpt52_pro(api_key: String) -> Self {
93        Self::new(api_key, MODEL_GPT52_PRO.to_owned())
94    }
95
96    /// Create a provider using GPT-5 (400k context, coding and reasoning).
97    #[must_use]
98    pub fn gpt5(api_key: String) -> Self {
99        Self::new(api_key, MODEL_GPT5.to_owned())
100    }
101
102    /// Create a provider using GPT-5-mini (faster, cost-efficient GPT-5).
103    #[must_use]
104    pub fn gpt5_mini(api_key: String) -> Self {
105        Self::new(api_key, MODEL_GPT5_MINI.to_owned())
106    }
107
108    /// Create a provider using GPT-5-nano (fastest, cheapest GPT-5 variant).
109    #[must_use]
110    pub fn gpt5_nano(api_key: String) -> Self {
111        Self::new(api_key, MODEL_GPT5_NANO.to_owned())
112    }
113
114    /// Create a provider using o3 (most intelligent reasoning model).
115    #[must_use]
116    pub fn o3(api_key: String) -> Self {
117        Self::new(api_key, MODEL_O3.to_owned())
118    }
119
120    /// Create a provider using o3-mini (smaller o3 variant).
121    #[must_use]
122    pub fn o3_mini(api_key: String) -> Self {
123        Self::new(api_key, MODEL_O3_MINI.to_owned())
124    }
125
126    /// Create a provider using o4-mini (fast, cost-efficient reasoning).
127    #[must_use]
128    pub fn o4_mini(api_key: String) -> Self {
129        Self::new(api_key, MODEL_O4_MINI.to_owned())
130    }
131
132    /// Create a provider using o1 (reasoning model).
133    #[must_use]
134    pub fn o1(api_key: String) -> Self {
135        Self::new(api_key, MODEL_O1.to_owned())
136    }
137
138    /// Create a provider using o1-mini (fast reasoning model).
139    #[must_use]
140    pub fn o1_mini(api_key: String) -> Self {
141        Self::new(api_key, MODEL_O1_MINI.to_owned())
142    }
143
144    /// Create a provider using GPT-4.1 (improved instruction following, 1M context).
145    #[must_use]
146    pub fn gpt41(api_key: String) -> Self {
147        Self::new(api_key, MODEL_GPT41.to_owned())
148    }
149
150    /// Create a provider using GPT-4.1-mini (smaller, faster GPT-4.1).
151    #[must_use]
152    pub fn gpt41_mini(api_key: String) -> Self {
153        Self::new(api_key, MODEL_GPT41_MINI.to_owned())
154    }
155
156    /// Create a provider using GPT-4o.
157    #[must_use]
158    pub fn gpt4o(api_key: String) -> Self {
159        Self::new(api_key, MODEL_GPT4O.to_owned())
160    }
161
162    /// Create a provider using GPT-4o-mini (fast and cost-effective).
163    #[must_use]
164    pub fn gpt4o_mini(api_key: String) -> Self {
165        Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
166    }
167}
168
169#[async_trait]
170impl LlmProvider for OpenAIProvider {
171    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
172        let messages = build_api_messages(&request);
173        let tools: Option<Vec<ApiTool>> = request
174            .tools
175            .map(|ts| ts.into_iter().map(convert_tool).collect());
176
177        let api_request = ApiChatRequest {
178            model: &self.model,
179            messages: &messages,
180            max_completion_tokens: Some(request.max_tokens),
181            tools: tools.as_deref(),
182        };
183
184        tracing::debug!(
185            model = %self.model,
186            max_tokens = request.max_tokens,
187            "OpenAI LLM request"
188        );
189
190        let response = self
191            .client
192            .post(format!("{}/chat/completions", self.base_url))
193            .header("Content-Type", "application/json")
194            .header("Authorization", format!("Bearer {}", self.api_key))
195            .json(&api_request)
196            .send()
197            .await
198            .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
199
200        let status = response.status();
201        let bytes = response
202            .bytes()
203            .await
204            .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
205
206        tracing::debug!(
207            status = %status,
208            body_len = bytes.len(),
209            "OpenAI LLM response"
210        );
211
212        if status == StatusCode::TOO_MANY_REQUESTS {
213            return Ok(ChatOutcome::RateLimited);
214        }
215
216        if status.is_server_error() {
217            let body = String::from_utf8_lossy(&bytes);
218            tracing::error!(status = %status, body = %body, "OpenAI server error");
219            return Ok(ChatOutcome::ServerError(body.into_owned()));
220        }
221
222        if status.is_client_error() {
223            let body = String::from_utf8_lossy(&bytes);
224            tracing::warn!(status = %status, body = %body, "OpenAI client error");
225            return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
226        }
227
228        let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
229            .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
230
231        let choice = api_response
232            .choices
233            .into_iter()
234            .next()
235            .ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
236
237        let content = build_content_blocks(&choice.message);
238
239        let stop_reason = choice.finish_reason.map(|r| match r {
240            ApiFinishReason::Stop => StopReason::EndTurn,
241            ApiFinishReason::ToolCalls => StopReason::ToolUse,
242            ApiFinishReason::Length => StopReason::MaxTokens,
243            ApiFinishReason::ContentFilter => StopReason::StopSequence,
244        });
245
246        Ok(ChatOutcome::Success(ChatResponse {
247            id: api_response.id,
248            content,
249            model: api_response.model,
250            stop_reason,
251            usage: Usage {
252                input_tokens: api_response.usage.prompt_tokens,
253                output_tokens: api_response.usage.completion_tokens,
254            },
255        }))
256    }
257
258    fn model(&self) -> &str {
259        &self.model
260    }
261
262    fn provider(&self) -> &'static str {
263        "openai"
264    }
265}
266
267fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
268    let mut messages = Vec::new();
269
270    // Add system message first (OpenAI uses a separate message for system prompt)
271    if !request.system.is_empty() {
272        messages.push(ApiMessage {
273            role: ApiRole::System,
274            content: Some(request.system.clone()),
275            tool_calls: None,
276            tool_call_id: None,
277        });
278    }
279
280    // Convert SDK messages to OpenAI format
281    for msg in &request.messages {
282        match &msg.content {
283            Content::Text(text) => {
284                messages.push(ApiMessage {
285                    role: match msg.role {
286                        crate::llm::Role::User => ApiRole::User,
287                        crate::llm::Role::Assistant => ApiRole::Assistant,
288                    },
289                    content: Some(text.clone()),
290                    tool_calls: None,
291                    tool_call_id: None,
292                });
293            }
294            Content::Blocks(blocks) => {
295                // Handle mixed content blocks
296                let mut text_parts = Vec::new();
297                let mut tool_calls = Vec::new();
298
299                for block in blocks {
300                    match block {
301                        ContentBlock::Text { text } => {
302                            text_parts.push(text.clone());
303                        }
304                        ContentBlock::ToolUse { id, name, input } => {
305                            tool_calls.push(ApiToolCall {
306                                id: id.clone(),
307                                r#type: "function".to_owned(),
308                                function: ApiFunctionCall {
309                                    name: name.clone(),
310                                    arguments: serde_json::to_string(input)
311                                        .unwrap_or_else(|_| "{}".to_owned()),
312                                },
313                            });
314                        }
315                        ContentBlock::ToolResult {
316                            tool_use_id,
317                            content,
318                            ..
319                        } => {
320                            // Tool results are separate messages in OpenAI
321                            messages.push(ApiMessage {
322                                role: ApiRole::Tool,
323                                content: Some(content.clone()),
324                                tool_calls: None,
325                                tool_call_id: Some(tool_use_id.clone()),
326                            });
327                        }
328                    }
329                }
330
331                // Add assistant message with text and/or tool calls
332                if !text_parts.is_empty() || !tool_calls.is_empty() {
333                    let role = match msg.role {
334                        crate::llm::Role::User => ApiRole::User,
335                        crate::llm::Role::Assistant => ApiRole::Assistant,
336                    };
337
338                    // Only add if it's an assistant message or has text content
339                    if role == ApiRole::Assistant || !text_parts.is_empty() {
340                        messages.push(ApiMessage {
341                            role,
342                            content: if text_parts.is_empty() {
343                                None
344                            } else {
345                                Some(text_parts.join("\n"))
346                            },
347                            tool_calls: if tool_calls.is_empty() {
348                                None
349                            } else {
350                                Some(tool_calls)
351                            },
352                            tool_call_id: None,
353                        });
354                    }
355                }
356            }
357        }
358    }
359
360    messages
361}
362
363fn convert_tool(t: crate::llm::Tool) -> ApiTool {
364    ApiTool {
365        r#type: "function".to_owned(),
366        function: ApiFunction {
367            name: t.name,
368            description: t.description,
369            parameters: t.input_schema,
370        },
371    }
372}
373
374fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
375    let mut blocks = Vec::new();
376
377    // Add text content if present
378    if let Some(content) = &message.content
379        && !content.is_empty()
380    {
381        blocks.push(ContentBlock::Text {
382            text: content.clone(),
383        });
384    }
385
386    // Add tool calls if present
387    if let Some(tool_calls) = &message.tool_calls {
388        for tc in tool_calls {
389            let input: serde_json::Value =
390                serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null);
391            blocks.push(ContentBlock::ToolUse {
392                id: tc.id.clone(),
393                name: tc.function.name.clone(),
394                input,
395            });
396        }
397    }
398
399    blocks
400}
401
402// ============================================================================
403// API Request Types
404// ============================================================================
405
406#[derive(Serialize)]
407struct ApiChatRequest<'a> {
408    model: &'a str,
409    messages: &'a [ApiMessage],
410    #[serde(skip_serializing_if = "Option::is_none")]
411    max_completion_tokens: Option<u32>,
412    #[serde(skip_serializing_if = "Option::is_none")]
413    tools: Option<&'a [ApiTool]>,
414}
415
416#[derive(Serialize)]
417struct ApiMessage {
418    role: ApiRole,
419    #[serde(skip_serializing_if = "Option::is_none")]
420    content: Option<String>,
421    #[serde(skip_serializing_if = "Option::is_none")]
422    tool_calls: Option<Vec<ApiToolCall>>,
423    #[serde(skip_serializing_if = "Option::is_none")]
424    tool_call_id: Option<String>,
425}
426
427#[derive(Debug, Serialize, PartialEq, Eq)]
428#[serde(rename_all = "lowercase")]
429enum ApiRole {
430    System,
431    User,
432    Assistant,
433    Tool,
434}
435
436#[derive(Serialize)]
437struct ApiToolCall {
438    id: String,
439    r#type: String,
440    function: ApiFunctionCall,
441}
442
443#[derive(Serialize)]
444struct ApiFunctionCall {
445    name: String,
446    arguments: String,
447}
448
449#[derive(Serialize)]
450struct ApiTool {
451    r#type: String,
452    function: ApiFunction,
453}
454
455#[derive(Serialize)]
456struct ApiFunction {
457    name: String,
458    description: String,
459    parameters: serde_json::Value,
460}
461
462// ============================================================================
463// API Response Types
464// ============================================================================
465
466#[derive(Deserialize)]
467struct ApiChatResponse {
468    id: String,
469    choices: Vec<ApiChoice>,
470    model: String,
471    usage: ApiUsage,
472}
473
474#[derive(Deserialize)]
475struct ApiChoice {
476    message: ApiResponseMessage,
477    finish_reason: Option<ApiFinishReason>,
478}
479
480#[derive(Deserialize)]
481struct ApiResponseMessage {
482    content: Option<String>,
483    tool_calls: Option<Vec<ApiResponseToolCall>>,
484}
485
486#[derive(Deserialize)]
487struct ApiResponseToolCall {
488    id: String,
489    function: ApiResponseFunctionCall,
490}
491
492#[derive(Deserialize)]
493struct ApiResponseFunctionCall {
494    name: String,
495    arguments: String,
496}
497
498#[derive(Deserialize)]
499#[serde(rename_all = "snake_case")]
500enum ApiFinishReason {
501    Stop,
502    ToolCalls,
503    Length,
504    ContentFilter,
505}
506
507#[derive(Deserialize)]
508struct ApiUsage {
509    prompt_tokens: u32,
510    completion_tokens: u32,
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    // ===================
518    // Constructor Tests
519    // ===================
520
521    #[test]
522    fn test_new_creates_provider_with_custom_model() {
523        let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
524
525        assert_eq!(provider.model(), "custom-model");
526        assert_eq!(provider.provider(), "openai");
527        assert_eq!(provider.base_url, DEFAULT_BASE_URL);
528    }
529
530    #[test]
531    fn test_with_base_url_creates_provider_with_custom_url() {
532        let provider = OpenAIProvider::with_base_url(
533            "test-api-key".to_string(),
534            "llama3".to_string(),
535            "http://localhost:11434/v1".to_string(),
536        );
537
538        assert_eq!(provider.model(), "llama3");
539        assert_eq!(provider.base_url, "http://localhost:11434/v1");
540    }
541
542    #[test]
543    fn test_gpt4o_factory_creates_gpt4o_provider() {
544        let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
545
546        assert_eq!(provider.model(), MODEL_GPT4O);
547        assert_eq!(provider.provider(), "openai");
548    }
549
550    #[test]
551    fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
552        let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
553
554        assert_eq!(provider.model(), MODEL_GPT4O_MINI);
555        assert_eq!(provider.provider(), "openai");
556    }
557
558    #[test]
559    fn test_gpt52_thinking_factory_creates_provider() {
560        let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
561
562        assert_eq!(provider.model(), MODEL_GPT52_THINKING);
563        assert_eq!(provider.provider(), "openai");
564    }
565
566    #[test]
567    fn test_gpt5_factory_creates_gpt5_provider() {
568        let provider = OpenAIProvider::gpt5("test-api-key".to_string());
569
570        assert_eq!(provider.model(), MODEL_GPT5);
571        assert_eq!(provider.provider(), "openai");
572    }
573
574    #[test]
575    fn test_gpt5_mini_factory_creates_provider() {
576        let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
577
578        assert_eq!(provider.model(), MODEL_GPT5_MINI);
579        assert_eq!(provider.provider(), "openai");
580    }
581
582    #[test]
583    fn test_o3_factory_creates_o3_provider() {
584        let provider = OpenAIProvider::o3("test-api-key".to_string());
585
586        assert_eq!(provider.model(), MODEL_O3);
587        assert_eq!(provider.provider(), "openai");
588    }
589
590    #[test]
591    fn test_o4_mini_factory_creates_o4_mini_provider() {
592        let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
593
594        assert_eq!(provider.model(), MODEL_O4_MINI);
595        assert_eq!(provider.provider(), "openai");
596    }
597
598    #[test]
599    fn test_o1_factory_creates_o1_provider() {
600        let provider = OpenAIProvider::o1("test-api-key".to_string());
601
602        assert_eq!(provider.model(), MODEL_O1);
603        assert_eq!(provider.provider(), "openai");
604    }
605
606    #[test]
607    fn test_gpt41_factory_creates_gpt41_provider() {
608        let provider = OpenAIProvider::gpt41("test-api-key".to_string());
609
610        assert_eq!(provider.model(), MODEL_GPT41);
611        assert_eq!(provider.provider(), "openai");
612    }
613
614    // ===================
615    // Model Constants Tests
616    // ===================
617
618    #[test]
619    fn test_model_constants_have_expected_values() {
620        // GPT-5.2 series
621        assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
622        assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
623        assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
624        // GPT-5 series
625        assert_eq!(MODEL_GPT5, "gpt-5");
626        assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
627        assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
628        // o-series
629        assert_eq!(MODEL_O3, "o3");
630        assert_eq!(MODEL_O3_MINI, "o3-mini");
631        assert_eq!(MODEL_O4_MINI, "o4-mini");
632        assert_eq!(MODEL_O1, "o1");
633        assert_eq!(MODEL_O1_MINI, "o1-mini");
634        // GPT-4.1 series
635        assert_eq!(MODEL_GPT41, "gpt-4.1");
636        assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
637        assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
638        // GPT-4o series
639        assert_eq!(MODEL_GPT4O, "gpt-4o");
640        assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
641    }
642
643    // ===================
644    // Clone Tests
645    // ===================
646
647    #[test]
648    fn test_provider_is_cloneable() {
649        let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
650        let cloned = provider.clone();
651
652        assert_eq!(provider.model(), cloned.model());
653        assert_eq!(provider.provider(), cloned.provider());
654        assert_eq!(provider.base_url, cloned.base_url);
655    }
656
657    // ===================
658    // API Type Serialization Tests
659    // ===================
660
661    #[test]
662    fn test_api_role_serialization() {
663        let system_role = ApiRole::System;
664        let user_role = ApiRole::User;
665        let assistant_role = ApiRole::Assistant;
666        let tool_role = ApiRole::Tool;
667
668        assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
669        assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
670        assert_eq!(
671            serde_json::to_string(&assistant_role).unwrap(),
672            "\"assistant\""
673        );
674        assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
675    }
676
677    #[test]
678    fn test_api_message_serialization_simple() {
679        let message = ApiMessage {
680            role: ApiRole::User,
681            content: Some("Hello, world!".to_string()),
682            tool_calls: None,
683            tool_call_id: None,
684        };
685
686        let json = serde_json::to_string(&message).unwrap();
687        assert!(json.contains("\"role\":\"user\""));
688        assert!(json.contains("\"content\":\"Hello, world!\""));
689        // Optional fields should be omitted
690        assert!(!json.contains("tool_calls"));
691        assert!(!json.contains("tool_call_id"));
692    }
693
694    #[test]
695    fn test_api_message_serialization_with_tool_calls() {
696        let message = ApiMessage {
697            role: ApiRole::Assistant,
698            content: Some("Let me help.".to_string()),
699            tool_calls: Some(vec![ApiToolCall {
700                id: "call_123".to_string(),
701                r#type: "function".to_string(),
702                function: ApiFunctionCall {
703                    name: "read_file".to_string(),
704                    arguments: "{\"path\": \"/test.txt\"}".to_string(),
705                },
706            }]),
707            tool_call_id: None,
708        };
709
710        let json = serde_json::to_string(&message).unwrap();
711        assert!(json.contains("\"role\":\"assistant\""));
712        assert!(json.contains("\"tool_calls\""));
713        assert!(json.contains("\"id\":\"call_123\""));
714        assert!(json.contains("\"type\":\"function\""));
715        assert!(json.contains("\"name\":\"read_file\""));
716    }
717
718    #[test]
719    fn test_api_tool_message_serialization() {
720        let message = ApiMessage {
721            role: ApiRole::Tool,
722            content: Some("File contents here".to_string()),
723            tool_calls: None,
724            tool_call_id: Some("call_123".to_string()),
725        };
726
727        let json = serde_json::to_string(&message).unwrap();
728        assert!(json.contains("\"role\":\"tool\""));
729        assert!(json.contains("\"tool_call_id\":\"call_123\""));
730        assert!(json.contains("\"content\":\"File contents here\""));
731    }
732
733    #[test]
734    fn test_api_tool_serialization() {
735        let tool = ApiTool {
736            r#type: "function".to_string(),
737            function: ApiFunction {
738                name: "test_tool".to_string(),
739                description: "A test tool".to_string(),
740                parameters: serde_json::json!({
741                    "type": "object",
742                    "properties": {
743                        "arg": {"type": "string"}
744                    }
745                }),
746            },
747        };
748
749        let json = serde_json::to_string(&tool).unwrap();
750        assert!(json.contains("\"type\":\"function\""));
751        assert!(json.contains("\"name\":\"test_tool\""));
752        assert!(json.contains("\"description\":\"A test tool\""));
753        assert!(json.contains("\"parameters\""));
754    }
755
756    // ===================
757    // API Type Deserialization Tests
758    // ===================
759
760    #[test]
761    fn test_api_response_deserialization() {
762        let json = r#"{
763            "id": "chatcmpl-123",
764            "choices": [
765                {
766                    "message": {
767                        "content": "Hello!"
768                    },
769                    "finish_reason": "stop"
770                }
771            ],
772            "model": "gpt-4o",
773            "usage": {
774                "prompt_tokens": 100,
775                "completion_tokens": 50
776            }
777        }"#;
778
779        let response: ApiChatResponse = serde_json::from_str(json).unwrap();
780        assert_eq!(response.id, "chatcmpl-123");
781        assert_eq!(response.model, "gpt-4o");
782        assert_eq!(response.usage.prompt_tokens, 100);
783        assert_eq!(response.usage.completion_tokens, 50);
784        assert_eq!(response.choices.len(), 1);
785        assert_eq!(
786            response.choices[0].message.content,
787            Some("Hello!".to_string())
788        );
789    }
790
791    #[test]
792    fn test_api_response_with_tool_calls_deserialization() {
793        let json = r#"{
794            "id": "chatcmpl-456",
795            "choices": [
796                {
797                    "message": {
798                        "content": null,
799                        "tool_calls": [
800                            {
801                                "id": "call_abc",
802                                "type": "function",
803                                "function": {
804                                    "name": "read_file",
805                                    "arguments": "{\"path\": \"test.txt\"}"
806                                }
807                            }
808                        ]
809                    },
810                    "finish_reason": "tool_calls"
811                }
812            ],
813            "model": "gpt-4o",
814            "usage": {
815                "prompt_tokens": 150,
816                "completion_tokens": 30
817            }
818        }"#;
819
820        let response: ApiChatResponse = serde_json::from_str(json).unwrap();
821        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
822        assert_eq!(tool_calls.len(), 1);
823        assert_eq!(tool_calls[0].id, "call_abc");
824        assert_eq!(tool_calls[0].function.name, "read_file");
825    }
826
827    #[test]
828    fn test_api_finish_reason_deserialization() {
829        let stop: ApiFinishReason = serde_json::from_str("\"stop\"").unwrap();
830        let tool_calls: ApiFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
831        let length: ApiFinishReason = serde_json::from_str("\"length\"").unwrap();
832        let content_filter: ApiFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
833
834        assert!(matches!(stop, ApiFinishReason::Stop));
835        assert!(matches!(tool_calls, ApiFinishReason::ToolCalls));
836        assert!(matches!(length, ApiFinishReason::Length));
837        assert!(matches!(content_filter, ApiFinishReason::ContentFilter));
838    }
839
840    // ===================
841    // Message Conversion Tests
842    // ===================
843
844    #[test]
845    fn test_build_api_messages_with_system() {
846        let request = ChatRequest {
847            system: "You are helpful.".to_string(),
848            messages: vec![crate::llm::Message::user("Hello")],
849            tools: None,
850            max_tokens: 1024,
851        };
852
853        let api_messages = build_api_messages(&request);
854        assert_eq!(api_messages.len(), 2);
855        assert_eq!(api_messages[0].role, ApiRole::System);
856        assert_eq!(
857            api_messages[0].content,
858            Some("You are helpful.".to_string())
859        );
860        assert_eq!(api_messages[1].role, ApiRole::User);
861        assert_eq!(api_messages[1].content, Some("Hello".to_string()));
862    }
863
864    #[test]
865    fn test_build_api_messages_empty_system() {
866        let request = ChatRequest {
867            system: String::new(),
868            messages: vec![crate::llm::Message::user("Hello")],
869            tools: None,
870            max_tokens: 1024,
871        };
872
873        let api_messages = build_api_messages(&request);
874        assert_eq!(api_messages.len(), 1);
875        assert_eq!(api_messages[0].role, ApiRole::User);
876    }
877
878    #[test]
879    fn test_convert_tool() {
880        let tool = crate::llm::Tool {
881            name: "test_tool".to_string(),
882            description: "A test tool".to_string(),
883            input_schema: serde_json::json!({"type": "object"}),
884        };
885
886        let api_tool = convert_tool(tool);
887        assert_eq!(api_tool.r#type, "function");
888        assert_eq!(api_tool.function.name, "test_tool");
889        assert_eq!(api_tool.function.description, "A test tool");
890    }
891
892    #[test]
893    fn test_build_content_blocks_text_only() {
894        let message = ApiResponseMessage {
895            content: Some("Hello!".to_string()),
896            tool_calls: None,
897        };
898
899        let blocks = build_content_blocks(&message);
900        assert_eq!(blocks.len(), 1);
901        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
902    }
903
904    #[test]
905    fn test_build_content_blocks_with_tool_calls() {
906        let message = ApiResponseMessage {
907            content: Some("Let me help.".to_string()),
908            tool_calls: Some(vec![ApiResponseToolCall {
909                id: "call_123".to_string(),
910                function: ApiResponseFunctionCall {
911                    name: "read_file".to_string(),
912                    arguments: "{\"path\": \"test.txt\"}".to_string(),
913                },
914            }]),
915        };
916
917        let blocks = build_content_blocks(&message);
918        assert_eq!(blocks.len(), 2);
919        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
920        assert!(
921            matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
922        );
923    }
924}