agent_sdk/providers/
openai.rs

1//! `OpenAI` API provider implementation.
2//!
3//! This module provides an implementation of `LlmProvider` for the `OpenAI`
4//! Chat Completions API. It also supports `OpenAI`-compatible APIs (Ollama, vLLM, etc.)
5//! via the `with_base_url` constructor.
6//!
7//! Models that require the Responses API (like `gpt-5.2-codex`) are automatically
8//! routed to the correct endpoint.
9
10use crate::llm::{
11    ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, StopReason,
12    StreamBox, StreamDelta, Usage,
13};
14use anyhow::Result;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19
20use super::openai_responses::OpenAIResponsesProvider;
21
22const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
23
24/// Check if a model requires the Responses API instead of Chat Completions.
25fn requires_responses_api(model: &str) -> bool {
26    model.contains("codex")
27}
28
29// GPT-5.2 series (latest flagship, Dec 2025)
30pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
31pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
32pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
33pub const MODEL_GPT52_CODEX: &str = "gpt-5.2-codex";
34
35// GPT-5 series (400k context)
36pub const MODEL_GPT5: &str = "gpt-5";
37pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
38pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
39
40// o-series reasoning models
41pub const MODEL_O3: &str = "o3";
42pub const MODEL_O3_MINI: &str = "o3-mini";
43pub const MODEL_O4_MINI: &str = "o4-mini";
44pub const MODEL_O1: &str = "o1";
45pub const MODEL_O1_MINI: &str = "o1-mini";
46
47// GPT-4.1 series (improved instruction following, 1M context)
48pub const MODEL_GPT41: &str = "gpt-4.1";
49pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
50pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
51
52// GPT-4o series
53pub const MODEL_GPT4O: &str = "gpt-4o";
54pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
55
56/// `OpenAI` LLM provider using the Chat Completions API.
57///
58/// Also supports `OpenAI`-compatible APIs (Ollama, vLLM, Azure `OpenAI`, etc.)
59/// via the `with_base_url` constructor.
60#[derive(Clone)]
61pub struct OpenAIProvider {
62    client: reqwest::Client,
63    api_key: String,
64    model: String,
65    base_url: String,
66}
67
68impl OpenAIProvider {
69    /// Create a new `OpenAI` provider with the specified API key and model.
70    #[must_use]
71    pub fn new(api_key: String, model: String) -> Self {
72        Self {
73            client: reqwest::Client::new(),
74            api_key,
75            model,
76            base_url: DEFAULT_BASE_URL.to_owned(),
77        }
78    }
79
80    /// Create a new provider with a custom base URL for OpenAI-compatible APIs.
81    #[must_use]
82    pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
83        Self {
84            client: reqwest::Client::new(),
85            api_key,
86            model,
87            base_url,
88        }
89    }
90
91    /// Create a provider using GPT-5.2 Instant (speed-optimized for routine queries).
92    #[must_use]
93    pub fn gpt52_instant(api_key: String) -> Self {
94        Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
95    }
96
97    /// Create a provider using GPT-5.2 Thinking (complex reasoning, coding, analysis).
98    #[must_use]
99    pub fn gpt52_thinking(api_key: String) -> Self {
100        Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
101    }
102
103    /// Create a provider using GPT-5.2 Pro (maximum accuracy for difficult problems).
104    #[must_use]
105    pub fn gpt52_pro(api_key: String) -> Self {
106        Self::new(api_key, MODEL_GPT52_PRO.to_owned())
107    }
108
109    /// Create a provider using GPT-5.2 Codex (optimized for agentic coding).
110    ///
111    /// Note: This model uses the Responses API internally.
112    #[must_use]
113    pub fn codex(api_key: String) -> Self {
114        Self::new(api_key, MODEL_GPT52_CODEX.to_owned())
115    }
116
117    /// Create a provider using GPT-5 (400k context, coding and reasoning).
118    #[must_use]
119    pub fn gpt5(api_key: String) -> Self {
120        Self::new(api_key, MODEL_GPT5.to_owned())
121    }
122
123    /// Create a provider using GPT-5-mini (faster, cost-efficient GPT-5).
124    #[must_use]
125    pub fn gpt5_mini(api_key: String) -> Self {
126        Self::new(api_key, MODEL_GPT5_MINI.to_owned())
127    }
128
129    /// Create a provider using GPT-5-nano (fastest, cheapest GPT-5 variant).
130    #[must_use]
131    pub fn gpt5_nano(api_key: String) -> Self {
132        Self::new(api_key, MODEL_GPT5_NANO.to_owned())
133    }
134
135    /// Create a provider using o3 (most intelligent reasoning model).
136    #[must_use]
137    pub fn o3(api_key: String) -> Self {
138        Self::new(api_key, MODEL_O3.to_owned())
139    }
140
141    /// Create a provider using o3-mini (smaller o3 variant).
142    #[must_use]
143    pub fn o3_mini(api_key: String) -> Self {
144        Self::new(api_key, MODEL_O3_MINI.to_owned())
145    }
146
147    /// Create a provider using o4-mini (fast, cost-efficient reasoning).
148    #[must_use]
149    pub fn o4_mini(api_key: String) -> Self {
150        Self::new(api_key, MODEL_O4_MINI.to_owned())
151    }
152
153    /// Create a provider using o1 (reasoning model).
154    #[must_use]
155    pub fn o1(api_key: String) -> Self {
156        Self::new(api_key, MODEL_O1.to_owned())
157    }
158
159    /// Create a provider using o1-mini (fast reasoning model).
160    #[must_use]
161    pub fn o1_mini(api_key: String) -> Self {
162        Self::new(api_key, MODEL_O1_MINI.to_owned())
163    }
164
165    /// Create a provider using GPT-4.1 (improved instruction following, 1M context).
166    #[must_use]
167    pub fn gpt41(api_key: String) -> Self {
168        Self::new(api_key, MODEL_GPT41.to_owned())
169    }
170
171    /// Create a provider using GPT-4.1-mini (smaller, faster GPT-4.1).
172    #[must_use]
173    pub fn gpt41_mini(api_key: String) -> Self {
174        Self::new(api_key, MODEL_GPT41_MINI.to_owned())
175    }
176
177    /// Create a provider using GPT-4o.
178    #[must_use]
179    pub fn gpt4o(api_key: String) -> Self {
180        Self::new(api_key, MODEL_GPT4O.to_owned())
181    }
182
183    /// Create a provider using GPT-4o-mini (fast and cost-effective).
184    #[must_use]
185    pub fn gpt4o_mini(api_key: String) -> Self {
186        Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
187    }
188}
189
190#[async_trait]
191impl LlmProvider for OpenAIProvider {
192    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
193        // Route to Responses API for models that require it (e.g., gpt-5.2-codex)
194        if requires_responses_api(&self.model) {
195            let responses_provider =
196                OpenAIResponsesProvider::new(self.api_key.clone(), self.model.clone());
197            return responses_provider.chat(request).await;
198        }
199
200        let messages = build_api_messages(&request);
201        let tools: Option<Vec<ApiTool>> = request
202            .tools
203            .map(|ts| ts.into_iter().map(convert_tool).collect());
204
205        let api_request = ApiChatRequest {
206            model: &self.model,
207            messages: &messages,
208            max_completion_tokens: Some(request.max_tokens),
209            tools: tools.as_deref(),
210        };
211
212        tracing::debug!(
213            model = %self.model,
214            max_tokens = request.max_tokens,
215            "OpenAI LLM request"
216        );
217
218        let response = self
219            .client
220            .post(format!("{}/chat/completions", self.base_url))
221            .header("Content-Type", "application/json")
222            .header("Authorization", format!("Bearer {}", self.api_key))
223            .json(&api_request)
224            .send()
225            .await
226            .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
227
228        let status = response.status();
229        let bytes = response
230            .bytes()
231            .await
232            .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
233
234        tracing::debug!(
235            status = %status,
236            body_len = bytes.len(),
237            "OpenAI LLM response"
238        );
239
240        if status == StatusCode::TOO_MANY_REQUESTS {
241            return Ok(ChatOutcome::RateLimited);
242        }
243
244        if status.is_server_error() {
245            let body = String::from_utf8_lossy(&bytes);
246            tracing::error!(status = %status, body = %body, "OpenAI server error");
247            return Ok(ChatOutcome::ServerError(body.into_owned()));
248        }
249
250        if status.is_client_error() {
251            let body = String::from_utf8_lossy(&bytes);
252            tracing::warn!(status = %status, body = %body, "OpenAI client error");
253            return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
254        }
255
256        let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
257            .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
258
259        let choice = api_response
260            .choices
261            .into_iter()
262            .next()
263            .ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
264
265        let content = build_content_blocks(&choice.message);
266
267        let stop_reason = choice.finish_reason.map(|r| match r {
268            ApiFinishReason::Stop => StopReason::EndTurn,
269            ApiFinishReason::ToolCalls => StopReason::ToolUse,
270            ApiFinishReason::Length => StopReason::MaxTokens,
271            ApiFinishReason::ContentFilter => StopReason::StopSequence,
272        });
273
274        Ok(ChatOutcome::Success(ChatResponse {
275            id: api_response.id,
276            content,
277            model: api_response.model,
278            stop_reason,
279            usage: Usage {
280                input_tokens: api_response.usage.prompt_tokens,
281                output_tokens: api_response.usage.completion_tokens,
282            },
283        }))
284    }
285
286    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
287        // Route to Responses API for models that require it (e.g., gpt-5.2-codex)
288        if requires_responses_api(&self.model) {
289            let api_key = self.api_key.clone();
290            let model = self.model.clone();
291            return Box::pin(async_stream::stream! {
292                let responses_provider = OpenAIResponsesProvider::new(api_key, model);
293                let mut stream = std::pin::pin!(responses_provider.chat_stream(request));
294                while let Some(item) = futures::StreamExt::next(&mut stream).await {
295                    yield item;
296                }
297            });
298        }
299
300        Box::pin(async_stream::stream! {
301            let messages = build_api_messages(&request);
302            let tools: Option<Vec<ApiTool>> = request
303                .tools
304                .map(|ts| ts.into_iter().map(convert_tool).collect());
305
306            let api_request = ApiChatRequestStreaming { model: &self.model, messages: &messages, max_completion_tokens: Some(request.max_tokens), tools: tools.as_deref(), stream: true };
307
308            tracing::debug!(model = %self.model, max_tokens = request.max_tokens, "OpenAI streaming LLM request");
309
310            let Ok(response) = self.client
311                .post(format!("{}/chat/completions", self.base_url))
312                .header("Content-Type", "application/json")
313                .header("Authorization", format!("Bearer {}", self.api_key))
314                .json(&api_request)
315                .send()
316                .await
317            else {
318                yield Err(anyhow::anyhow!("request failed"));
319                return;
320            };
321
322            let status = response.status();
323
324            if !status.is_success() {
325                let body = response.text().await.unwrap_or_default();
326                let (recoverable, level) = if status == StatusCode::TOO_MANY_REQUESTS {
327                    (true, "rate_limit")
328                } else if status.is_server_error() {
329                    (true, "server_error")
330                } else {
331                    (false, "client_error")
332                };
333                tracing::warn!(status = %status, body = %body, kind = level, "OpenAI error");
334                yield Ok(StreamDelta::Error { message: body, recoverable });
335                return;
336            }
337
338            // Track tool call state across deltas
339            let mut tool_calls: std::collections::HashMap<usize, ToolCallAccumulator> =
340                std::collections::HashMap::new();
341            let mut usage: Option<Usage> = None;
342            let mut buffer = String::new();
343            let mut stream = response.bytes_stream();
344
345            while let Some(chunk_result) = stream.next().await {
346                let Ok(chunk) = chunk_result else {
347                    yield Err(anyhow::anyhow!("stream error: {}", chunk_result.unwrap_err()));
348                    return;
349                };
350                buffer.push_str(&String::from_utf8_lossy(&chunk));
351
352                while let Some(pos) = buffer.find('\n') {
353                    let line = buffer[..pos].trim().to_string();
354                    buffer = buffer[pos + 1..].to_string();
355                    if line.is_empty() { continue; }
356                    let Some(data) = line.strip_prefix("data: ") else { continue; };
357
358                    for result in process_sse_data(data) {
359                        match result {
360                            SseProcessResult::TextDelta(c) => yield Ok(StreamDelta::TextDelta { delta: c, block_index: 0 }),
361                            SseProcessResult::ToolCallUpdate { index, id, name, arguments } => apply_tool_call_update(&mut tool_calls, index, id, name, arguments),
362                            SseProcessResult::Usage(u) => usage = Some(u),
363                            SseProcessResult::Done(sr) => {
364                                for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
365                                return;
366                            }
367                            SseProcessResult::Sentinel => {
368                                let sr = if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
369                                for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
370                                return;
371                            }
372                        }
373                    }
374                }
375            }
376
377            // Stream ended without [DONE] - emit what we have
378            for delta in build_stream_end_deltas(&tool_calls, usage, StopReason::EndTurn) {
379                yield Ok(delta);
380            }
381        })
382    }
383
384    fn model(&self) -> &str {
385        &self.model
386    }
387
388    fn provider(&self) -> &'static str {
389        "openai"
390    }
391}
392
393/// Apply a tool call update to the accumulator.
394fn apply_tool_call_update(
395    tool_calls: &mut std::collections::HashMap<usize, ToolCallAccumulator>,
396    index: usize,
397    id: Option<String>,
398    name: Option<String>,
399    arguments: Option<String>,
400) {
401    let entry = tool_calls
402        .entry(index)
403        .or_insert_with(|| ToolCallAccumulator {
404            id: String::new(),
405            name: String::new(),
406            arguments: String::new(),
407        });
408    if let Some(id) = id {
409        entry.id = id;
410    }
411    if let Some(name) = name {
412        entry.name = name;
413    }
414    if let Some(args) = arguments {
415        entry.arguments.push_str(&args);
416    }
417}
418
419/// Helper to emit tool call deltas and done event.
420fn build_stream_end_deltas(
421    tool_calls: &std::collections::HashMap<usize, ToolCallAccumulator>,
422    usage: Option<Usage>,
423    stop_reason: StopReason,
424) -> Vec<StreamDelta> {
425    let mut deltas = Vec::new();
426
427    // Emit tool calls
428    for (idx, tool) in tool_calls {
429        deltas.push(StreamDelta::ToolUseStart {
430            id: tool.id.clone(),
431            name: tool.name.clone(),
432            block_index: *idx + 1,
433        });
434        deltas.push(StreamDelta::ToolInputDelta {
435            id: tool.id.clone(),
436            delta: tool.arguments.clone(),
437            block_index: *idx + 1,
438        });
439    }
440
441    // Emit usage
442    if let Some(u) = usage {
443        deltas.push(StreamDelta::Usage(u));
444    }
445
446    // Emit done
447    deltas.push(StreamDelta::Done {
448        stop_reason: Some(stop_reason),
449    });
450
451    deltas
452}
453
454/// Result of processing an SSE chunk.
455enum SseProcessResult {
456    /// Emit a text delta.
457    TextDelta(String),
458    /// Update tool call accumulator (index, optional id, optional name, optional args).
459    ToolCallUpdate {
460        index: usize,
461        id: Option<String>,
462        name: Option<String>,
463        arguments: Option<String>,
464    },
465    /// Usage information.
466    Usage(Usage),
467    /// Stream is done with a stop reason.
468    Done(StopReason),
469    /// Stream sentinel [DONE] was received.
470    Sentinel,
471}
472
473/// Process an SSE data line and return results to apply.
474fn process_sse_data(data: &str) -> Vec<SseProcessResult> {
475    if data == "[DONE]" {
476        return vec![SseProcessResult::Sentinel];
477    }
478
479    let Ok(chunk) = serde_json::from_str::<SseChunk>(data) else {
480        return vec![];
481    };
482
483    let mut results = Vec::new();
484
485    // Extract usage if present
486    if let Some(u) = chunk.usage {
487        results.push(SseProcessResult::Usage(Usage {
488            input_tokens: u.prompt_tokens,
489            output_tokens: u.completion_tokens,
490        }));
491    }
492
493    // Process choices
494    if let Some(choice) = chunk.choices.into_iter().next() {
495        // Handle text content delta
496        if let Some(content) = choice.delta.content
497            && !content.is_empty()
498        {
499            results.push(SseProcessResult::TextDelta(content));
500        }
501
502        // Handle tool call deltas
503        if let Some(tc_deltas) = choice.delta.tool_calls {
504            for tc in tc_deltas {
505                results.push(SseProcessResult::ToolCallUpdate {
506                    index: tc.index,
507                    id: tc.id,
508                    name: tc.function.as_ref().and_then(|f| f.name.clone()),
509                    arguments: tc.function.as_ref().and_then(|f| f.arguments.clone()),
510                });
511            }
512        }
513
514        // Check for finish reason
515        if let Some(finish_reason) = choice.finish_reason {
516            let stop_reason = match finish_reason {
517                SseFinishReason::Stop => StopReason::EndTurn,
518                SseFinishReason::ToolCalls => StopReason::ToolUse,
519                SseFinishReason::Length => StopReason::MaxTokens,
520                SseFinishReason::ContentFilter => StopReason::StopSequence,
521            };
522            results.push(SseProcessResult::Done(stop_reason));
523        }
524    }
525
526    results
527}
528
529fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
530    let mut messages = Vec::new();
531
532    // Add system message first (OpenAI uses a separate message for system prompt)
533    if !request.system.is_empty() {
534        messages.push(ApiMessage {
535            role: ApiRole::System,
536            content: Some(request.system.clone()),
537            tool_calls: None,
538            tool_call_id: None,
539        });
540    }
541
542    // Convert SDK messages to OpenAI format
543    for msg in &request.messages {
544        match &msg.content {
545            Content::Text(text) => {
546                messages.push(ApiMessage {
547                    role: match msg.role {
548                        crate::llm::Role::User => ApiRole::User,
549                        crate::llm::Role::Assistant => ApiRole::Assistant,
550                    },
551                    content: Some(text.clone()),
552                    tool_calls: None,
553                    tool_call_id: None,
554                });
555            }
556            Content::Blocks(blocks) => {
557                // Handle mixed content blocks
558                let mut text_parts = Vec::new();
559                let mut tool_calls = Vec::new();
560
561                for block in blocks {
562                    match block {
563                        ContentBlock::Text { text } => {
564                            text_parts.push(text.clone());
565                        }
566                        ContentBlock::ToolUse {
567                            id, name, input, ..
568                        } => {
569                            tool_calls.push(ApiToolCall {
570                                id: id.clone(),
571                                r#type: "function".to_owned(),
572                                function: ApiFunctionCall {
573                                    name: name.clone(),
574                                    arguments: serde_json::to_string(input)
575                                        .unwrap_or_else(|_| "{}".to_owned()),
576                                },
577                            });
578                        }
579                        ContentBlock::ToolResult {
580                            tool_use_id,
581                            content,
582                            ..
583                        } => {
584                            // Tool results are separate messages in OpenAI
585                            messages.push(ApiMessage {
586                                role: ApiRole::Tool,
587                                content: Some(content.clone()),
588                                tool_calls: None,
589                                tool_call_id: Some(tool_use_id.clone()),
590                            });
591                        }
592                    }
593                }
594
595                // Add assistant message with text and/or tool calls
596                if !text_parts.is_empty() || !tool_calls.is_empty() {
597                    let role = match msg.role {
598                        crate::llm::Role::User => ApiRole::User,
599                        crate::llm::Role::Assistant => ApiRole::Assistant,
600                    };
601
602                    // Only add if it's an assistant message or has text content
603                    if role == ApiRole::Assistant || !text_parts.is_empty() {
604                        messages.push(ApiMessage {
605                            role,
606                            content: if text_parts.is_empty() {
607                                None
608                            } else {
609                                Some(text_parts.join("\n"))
610                            },
611                            tool_calls: if tool_calls.is_empty() {
612                                None
613                            } else {
614                                Some(tool_calls)
615                            },
616                            tool_call_id: None,
617                        });
618                    }
619                }
620            }
621        }
622    }
623
624    messages
625}
626
627fn convert_tool(t: crate::llm::Tool) -> ApiTool {
628    ApiTool {
629        r#type: "function".to_owned(),
630        function: ApiFunction {
631            name: t.name,
632            description: t.description,
633            parameters: t.input_schema,
634        },
635    }
636}
637
638fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
639    let mut blocks = Vec::new();
640
641    // Add text content if present
642    if let Some(content) = &message.content
643        && !content.is_empty()
644    {
645        blocks.push(ContentBlock::Text {
646            text: content.clone(),
647        });
648    }
649
650    // Add tool calls if present
651    if let Some(tool_calls) = &message.tool_calls {
652        for tc in tool_calls {
653            let input: serde_json::Value =
654                serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null);
655            blocks.push(ContentBlock::ToolUse {
656                id: tc.id.clone(),
657                name: tc.function.name.clone(),
658                input,
659                thought_signature: None,
660            });
661        }
662    }
663
664    blocks
665}
666
667// ============================================================================
668// API Request Types
669// ============================================================================
670
671#[derive(Serialize)]
672struct ApiChatRequest<'a> {
673    model: &'a str,
674    messages: &'a [ApiMessage],
675    #[serde(skip_serializing_if = "Option::is_none")]
676    max_completion_tokens: Option<u32>,
677    #[serde(skip_serializing_if = "Option::is_none")]
678    tools: Option<&'a [ApiTool]>,
679}
680
681#[derive(Serialize)]
682struct ApiChatRequestStreaming<'a> {
683    model: &'a str,
684    messages: &'a [ApiMessage],
685    #[serde(skip_serializing_if = "Option::is_none")]
686    max_completion_tokens: Option<u32>,
687    #[serde(skip_serializing_if = "Option::is_none")]
688    tools: Option<&'a [ApiTool]>,
689    stream: bool,
690}
691
692#[derive(Serialize)]
693struct ApiMessage {
694    role: ApiRole,
695    #[serde(skip_serializing_if = "Option::is_none")]
696    content: Option<String>,
697    #[serde(skip_serializing_if = "Option::is_none")]
698    tool_calls: Option<Vec<ApiToolCall>>,
699    #[serde(skip_serializing_if = "Option::is_none")]
700    tool_call_id: Option<String>,
701}
702
703#[derive(Debug, Serialize, PartialEq, Eq)]
704#[serde(rename_all = "lowercase")]
705enum ApiRole {
706    System,
707    User,
708    Assistant,
709    Tool,
710}
711
712#[derive(Serialize)]
713struct ApiToolCall {
714    id: String,
715    r#type: String,
716    function: ApiFunctionCall,
717}
718
719#[derive(Serialize)]
720struct ApiFunctionCall {
721    name: String,
722    arguments: String,
723}
724
725#[derive(Serialize)]
726struct ApiTool {
727    r#type: String,
728    function: ApiFunction,
729}
730
731#[derive(Serialize)]
732struct ApiFunction {
733    name: String,
734    description: String,
735    parameters: serde_json::Value,
736}
737
738// ============================================================================
739// API Response Types
740// ============================================================================
741
742#[derive(Deserialize)]
743struct ApiChatResponse {
744    id: String,
745    choices: Vec<ApiChoice>,
746    model: String,
747    usage: ApiUsage,
748}
749
750#[derive(Deserialize)]
751struct ApiChoice {
752    message: ApiResponseMessage,
753    finish_reason: Option<ApiFinishReason>,
754}
755
756#[derive(Deserialize)]
757struct ApiResponseMessage {
758    content: Option<String>,
759    tool_calls: Option<Vec<ApiResponseToolCall>>,
760}
761
762#[derive(Deserialize)]
763struct ApiResponseToolCall {
764    id: String,
765    function: ApiResponseFunctionCall,
766}
767
768#[derive(Deserialize)]
769struct ApiResponseFunctionCall {
770    name: String,
771    arguments: String,
772}
773
774#[derive(Deserialize)]
775#[serde(rename_all = "snake_case")]
776enum ApiFinishReason {
777    Stop,
778    ToolCalls,
779    Length,
780    ContentFilter,
781}
782
783#[derive(Deserialize)]
784struct ApiUsage {
785    prompt_tokens: u32,
786    completion_tokens: u32,
787}
788
789// ============================================================================
790// SSE Streaming Types
791// ============================================================================
792
793/// Accumulator for tool call state across stream deltas.
794struct ToolCallAccumulator {
795    id: String,
796    name: String,
797    arguments: String,
798}
799
800/// A single chunk in `OpenAI`'s SSE stream.
801#[derive(Deserialize)]
802struct SseChunk {
803    choices: Vec<SseChoice>,
804    #[serde(default)]
805    usage: Option<SseUsage>,
806}
807
808#[derive(Deserialize)]
809struct SseChoice {
810    delta: SseDelta,
811    finish_reason: Option<SseFinishReason>,
812}
813
814#[derive(Deserialize)]
815struct SseDelta {
816    content: Option<String>,
817    tool_calls: Option<Vec<SseToolCallDelta>>,
818}
819
820#[derive(Deserialize)]
821struct SseToolCallDelta {
822    index: usize,
823    id: Option<String>,
824    function: Option<SseFunctionDelta>,
825}
826
827#[derive(Deserialize)]
828struct SseFunctionDelta {
829    name: Option<String>,
830    arguments: Option<String>,
831}
832
833#[derive(Deserialize)]
834#[serde(rename_all = "snake_case")]
835enum SseFinishReason {
836    Stop,
837    ToolCalls,
838    Length,
839    ContentFilter,
840}
841
842#[derive(Deserialize)]
843struct SseUsage {
844    prompt_tokens: u32,
845    completion_tokens: u32,
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851
852    // ===================
853    // Constructor Tests
854    // ===================
855
856    #[test]
857    fn test_new_creates_provider_with_custom_model() {
858        let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
859
860        assert_eq!(provider.model(), "custom-model");
861        assert_eq!(provider.provider(), "openai");
862        assert_eq!(provider.base_url, DEFAULT_BASE_URL);
863    }
864
865    #[test]
866    fn test_with_base_url_creates_provider_with_custom_url() {
867        let provider = OpenAIProvider::with_base_url(
868            "test-api-key".to_string(),
869            "llama3".to_string(),
870            "http://localhost:11434/v1".to_string(),
871        );
872
873        assert_eq!(provider.model(), "llama3");
874        assert_eq!(provider.base_url, "http://localhost:11434/v1");
875    }
876
877    #[test]
878    fn test_gpt4o_factory_creates_gpt4o_provider() {
879        let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
880
881        assert_eq!(provider.model(), MODEL_GPT4O);
882        assert_eq!(provider.provider(), "openai");
883    }
884
885    #[test]
886    fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
887        let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
888
889        assert_eq!(provider.model(), MODEL_GPT4O_MINI);
890        assert_eq!(provider.provider(), "openai");
891    }
892
893    #[test]
894    fn test_gpt52_thinking_factory_creates_provider() {
895        let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
896
897        assert_eq!(provider.model(), MODEL_GPT52_THINKING);
898        assert_eq!(provider.provider(), "openai");
899    }
900
901    #[test]
902    fn test_gpt5_factory_creates_gpt5_provider() {
903        let provider = OpenAIProvider::gpt5("test-api-key".to_string());
904
905        assert_eq!(provider.model(), MODEL_GPT5);
906        assert_eq!(provider.provider(), "openai");
907    }
908
909    #[test]
910    fn test_gpt5_mini_factory_creates_provider() {
911        let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
912
913        assert_eq!(provider.model(), MODEL_GPT5_MINI);
914        assert_eq!(provider.provider(), "openai");
915    }
916
917    #[test]
918    fn test_o3_factory_creates_o3_provider() {
919        let provider = OpenAIProvider::o3("test-api-key".to_string());
920
921        assert_eq!(provider.model(), MODEL_O3);
922        assert_eq!(provider.provider(), "openai");
923    }
924
925    #[test]
926    fn test_o4_mini_factory_creates_o4_mini_provider() {
927        let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
928
929        assert_eq!(provider.model(), MODEL_O4_MINI);
930        assert_eq!(provider.provider(), "openai");
931    }
932
933    #[test]
934    fn test_o1_factory_creates_o1_provider() {
935        let provider = OpenAIProvider::o1("test-api-key".to_string());
936
937        assert_eq!(provider.model(), MODEL_O1);
938        assert_eq!(provider.provider(), "openai");
939    }
940
941    #[test]
942    fn test_gpt41_factory_creates_gpt41_provider() {
943        let provider = OpenAIProvider::gpt41("test-api-key".to_string());
944
945        assert_eq!(provider.model(), MODEL_GPT41);
946        assert_eq!(provider.provider(), "openai");
947    }
948
949    // ===================
950    // Model Constants Tests
951    // ===================
952
953    #[test]
954    fn test_model_constants_have_expected_values() {
955        // GPT-5.2 series
956        assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
957        assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
958        assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
959        // GPT-5 series
960        assert_eq!(MODEL_GPT5, "gpt-5");
961        assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
962        assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
963        // o-series
964        assert_eq!(MODEL_O3, "o3");
965        assert_eq!(MODEL_O3_MINI, "o3-mini");
966        assert_eq!(MODEL_O4_MINI, "o4-mini");
967        assert_eq!(MODEL_O1, "o1");
968        assert_eq!(MODEL_O1_MINI, "o1-mini");
969        // GPT-4.1 series
970        assert_eq!(MODEL_GPT41, "gpt-4.1");
971        assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
972        assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
973        // GPT-4o series
974        assert_eq!(MODEL_GPT4O, "gpt-4o");
975        assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
976    }
977
978    // ===================
979    // Clone Tests
980    // ===================
981
982    #[test]
983    fn test_provider_is_cloneable() {
984        let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
985        let cloned = provider.clone();
986
987        assert_eq!(provider.model(), cloned.model());
988        assert_eq!(provider.provider(), cloned.provider());
989        assert_eq!(provider.base_url, cloned.base_url);
990    }
991
992    // ===================
993    // API Type Serialization Tests
994    // ===================
995
996    #[test]
997    fn test_api_role_serialization() {
998        let system_role = ApiRole::System;
999        let user_role = ApiRole::User;
1000        let assistant_role = ApiRole::Assistant;
1001        let tool_role = ApiRole::Tool;
1002
1003        assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
1004        assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
1005        assert_eq!(
1006            serde_json::to_string(&assistant_role).unwrap(),
1007            "\"assistant\""
1008        );
1009        assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
1010    }
1011
1012    #[test]
1013    fn test_api_message_serialization_simple() {
1014        let message = ApiMessage {
1015            role: ApiRole::User,
1016            content: Some("Hello, world!".to_string()),
1017            tool_calls: None,
1018            tool_call_id: None,
1019        };
1020
1021        let json = serde_json::to_string(&message).unwrap();
1022        assert!(json.contains("\"role\":\"user\""));
1023        assert!(json.contains("\"content\":\"Hello, world!\""));
1024        // Optional fields should be omitted
1025        assert!(!json.contains("tool_calls"));
1026        assert!(!json.contains("tool_call_id"));
1027    }
1028
1029    #[test]
1030    fn test_api_message_serialization_with_tool_calls() {
1031        let message = ApiMessage {
1032            role: ApiRole::Assistant,
1033            content: Some("Let me help.".to_string()),
1034            tool_calls: Some(vec![ApiToolCall {
1035                id: "call_123".to_string(),
1036                r#type: "function".to_string(),
1037                function: ApiFunctionCall {
1038                    name: "read_file".to_string(),
1039                    arguments: "{\"path\": \"/test.txt\"}".to_string(),
1040                },
1041            }]),
1042            tool_call_id: None,
1043        };
1044
1045        let json = serde_json::to_string(&message).unwrap();
1046        assert!(json.contains("\"role\":\"assistant\""));
1047        assert!(json.contains("\"tool_calls\""));
1048        assert!(json.contains("\"id\":\"call_123\""));
1049        assert!(json.contains("\"type\":\"function\""));
1050        assert!(json.contains("\"name\":\"read_file\""));
1051    }
1052
1053    #[test]
1054    fn test_api_tool_message_serialization() {
1055        let message = ApiMessage {
1056            role: ApiRole::Tool,
1057            content: Some("File contents here".to_string()),
1058            tool_calls: None,
1059            tool_call_id: Some("call_123".to_string()),
1060        };
1061
1062        let json = serde_json::to_string(&message).unwrap();
1063        assert!(json.contains("\"role\":\"tool\""));
1064        assert!(json.contains("\"tool_call_id\":\"call_123\""));
1065        assert!(json.contains("\"content\":\"File contents here\""));
1066    }
1067
1068    #[test]
1069    fn test_api_tool_serialization() {
1070        let tool = ApiTool {
1071            r#type: "function".to_string(),
1072            function: ApiFunction {
1073                name: "test_tool".to_string(),
1074                description: "A test tool".to_string(),
1075                parameters: serde_json::json!({
1076                    "type": "object",
1077                    "properties": {
1078                        "arg": {"type": "string"}
1079                    }
1080                }),
1081            },
1082        };
1083
1084        let json = serde_json::to_string(&tool).unwrap();
1085        assert!(json.contains("\"type\":\"function\""));
1086        assert!(json.contains("\"name\":\"test_tool\""));
1087        assert!(json.contains("\"description\":\"A test tool\""));
1088        assert!(json.contains("\"parameters\""));
1089    }
1090
1091    // ===================
1092    // API Type Deserialization Tests
1093    // ===================
1094
1095    #[test]
1096    fn test_api_response_deserialization() {
1097        let json = r#"{
1098            "id": "chatcmpl-123",
1099            "choices": [
1100                {
1101                    "message": {
1102                        "content": "Hello!"
1103                    },
1104                    "finish_reason": "stop"
1105                }
1106            ],
1107            "model": "gpt-4o",
1108            "usage": {
1109                "prompt_tokens": 100,
1110                "completion_tokens": 50
1111            }
1112        }"#;
1113
1114        let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1115        assert_eq!(response.id, "chatcmpl-123");
1116        assert_eq!(response.model, "gpt-4o");
1117        assert_eq!(response.usage.prompt_tokens, 100);
1118        assert_eq!(response.usage.completion_tokens, 50);
1119        assert_eq!(response.choices.len(), 1);
1120        assert_eq!(
1121            response.choices[0].message.content,
1122            Some("Hello!".to_string())
1123        );
1124    }
1125
1126    #[test]
1127    fn test_api_response_with_tool_calls_deserialization() {
1128        let json = r#"{
1129            "id": "chatcmpl-456",
1130            "choices": [
1131                {
1132                    "message": {
1133                        "content": null,
1134                        "tool_calls": [
1135                            {
1136                                "id": "call_abc",
1137                                "type": "function",
1138                                "function": {
1139                                    "name": "read_file",
1140                                    "arguments": "{\"path\": \"test.txt\"}"
1141                                }
1142                            }
1143                        ]
1144                    },
1145                    "finish_reason": "tool_calls"
1146                }
1147            ],
1148            "model": "gpt-4o",
1149            "usage": {
1150                "prompt_tokens": 150,
1151                "completion_tokens": 30
1152            }
1153        }"#;
1154
1155        let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1156        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1157        assert_eq!(tool_calls.len(), 1);
1158        assert_eq!(tool_calls[0].id, "call_abc");
1159        assert_eq!(tool_calls[0].function.name, "read_file");
1160    }
1161
1162    #[test]
1163    fn test_api_finish_reason_deserialization() {
1164        let stop: ApiFinishReason = serde_json::from_str("\"stop\"").unwrap();
1165        let tool_calls: ApiFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1166        let length: ApiFinishReason = serde_json::from_str("\"length\"").unwrap();
1167        let content_filter: ApiFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1168
1169        assert!(matches!(stop, ApiFinishReason::Stop));
1170        assert!(matches!(tool_calls, ApiFinishReason::ToolCalls));
1171        assert!(matches!(length, ApiFinishReason::Length));
1172        assert!(matches!(content_filter, ApiFinishReason::ContentFilter));
1173    }
1174
1175    // ===================
1176    // Message Conversion Tests
1177    // ===================
1178
1179    #[test]
1180    fn test_build_api_messages_with_system() {
1181        let request = ChatRequest {
1182            system: "You are helpful.".to_string(),
1183            messages: vec![crate::llm::Message::user("Hello")],
1184            tools: None,
1185            max_tokens: 1024,
1186        };
1187
1188        let api_messages = build_api_messages(&request);
1189        assert_eq!(api_messages.len(), 2);
1190        assert_eq!(api_messages[0].role, ApiRole::System);
1191        assert_eq!(
1192            api_messages[0].content,
1193            Some("You are helpful.".to_string())
1194        );
1195        assert_eq!(api_messages[1].role, ApiRole::User);
1196        assert_eq!(api_messages[1].content, Some("Hello".to_string()));
1197    }
1198
1199    #[test]
1200    fn test_build_api_messages_empty_system() {
1201        let request = ChatRequest {
1202            system: String::new(),
1203            messages: vec![crate::llm::Message::user("Hello")],
1204            tools: None,
1205            max_tokens: 1024,
1206        };
1207
1208        let api_messages = build_api_messages(&request);
1209        assert_eq!(api_messages.len(), 1);
1210        assert_eq!(api_messages[0].role, ApiRole::User);
1211    }
1212
1213    #[test]
1214    fn test_convert_tool() {
1215        let tool = crate::llm::Tool {
1216            name: "test_tool".to_string(),
1217            description: "A test tool".to_string(),
1218            input_schema: serde_json::json!({"type": "object"}),
1219        };
1220
1221        let api_tool = convert_tool(tool);
1222        assert_eq!(api_tool.r#type, "function");
1223        assert_eq!(api_tool.function.name, "test_tool");
1224        assert_eq!(api_tool.function.description, "A test tool");
1225    }
1226
1227    #[test]
1228    fn test_build_content_blocks_text_only() {
1229        let message = ApiResponseMessage {
1230            content: Some("Hello!".to_string()),
1231            tool_calls: None,
1232        };
1233
1234        let blocks = build_content_blocks(&message);
1235        assert_eq!(blocks.len(), 1);
1236        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
1237    }
1238
1239    #[test]
1240    fn test_build_content_blocks_with_tool_calls() {
1241        let message = ApiResponseMessage {
1242            content: Some("Let me help.".to_string()),
1243            tool_calls: Some(vec![ApiResponseToolCall {
1244                id: "call_123".to_string(),
1245                function: ApiResponseFunctionCall {
1246                    name: "read_file".to_string(),
1247                    arguments: "{\"path\": \"test.txt\"}".to_string(),
1248                },
1249            }]),
1250        };
1251
1252        let blocks = build_content_blocks(&message);
1253        assert_eq!(blocks.len(), 2);
1254        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
1255        assert!(
1256            matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
1257        );
1258    }
1259
1260    // ===================
1261    // SSE Streaming Type Tests
1262    // ===================
1263
1264    #[test]
1265    fn test_sse_chunk_text_delta_deserialization() {
1266        let json = r#"{
1267            "choices": [{
1268                "delta": {
1269                    "content": "Hello"
1270                },
1271                "finish_reason": null
1272            }]
1273        }"#;
1274
1275        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1276        assert_eq!(chunk.choices.len(), 1);
1277        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
1278        assert!(chunk.choices[0].finish_reason.is_none());
1279    }
1280
1281    #[test]
1282    fn test_sse_chunk_tool_call_delta_deserialization() {
1283        let json = r#"{
1284            "choices": [{
1285                "delta": {
1286                    "tool_calls": [{
1287                        "index": 0,
1288                        "id": "call_abc",
1289                        "function": {
1290                            "name": "read_file",
1291                            "arguments": ""
1292                        }
1293                    }]
1294                },
1295                "finish_reason": null
1296            }]
1297        }"#;
1298
1299        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1300        let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1301        assert_eq!(tool_calls.len(), 1);
1302        assert_eq!(tool_calls[0].index, 0);
1303        assert_eq!(tool_calls[0].id, Some("call_abc".to_string()));
1304        assert_eq!(
1305            tool_calls[0].function.as_ref().unwrap().name,
1306            Some("read_file".to_string())
1307        );
1308    }
1309
1310    #[test]
1311    fn test_sse_chunk_tool_call_arguments_delta_deserialization() {
1312        let json = r#"{
1313            "choices": [{
1314                "delta": {
1315                    "tool_calls": [{
1316                        "index": 0,
1317                        "function": {
1318                            "arguments": "{\"path\":"
1319                        }
1320                    }]
1321                },
1322                "finish_reason": null
1323            }]
1324        }"#;
1325
1326        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1327        let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1328        assert_eq!(tool_calls[0].id, None);
1329        assert_eq!(
1330            tool_calls[0].function.as_ref().unwrap().arguments,
1331            Some("{\"path\":".to_string())
1332        );
1333    }
1334
1335    #[test]
1336    fn test_sse_chunk_with_finish_reason_deserialization() {
1337        let json = r#"{
1338            "choices": [{
1339                "delta": {},
1340                "finish_reason": "stop"
1341            }]
1342        }"#;
1343
1344        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1345        assert!(matches!(
1346            chunk.choices[0].finish_reason,
1347            Some(SseFinishReason::Stop)
1348        ));
1349    }
1350
1351    #[test]
1352    fn test_sse_chunk_with_usage_deserialization() {
1353        let json = r#"{
1354            "choices": [{
1355                "delta": {},
1356                "finish_reason": "stop"
1357            }],
1358            "usage": {
1359                "prompt_tokens": 100,
1360                "completion_tokens": 50
1361            }
1362        }"#;
1363
1364        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1365        let usage = chunk.usage.unwrap();
1366        assert_eq!(usage.prompt_tokens, 100);
1367        assert_eq!(usage.completion_tokens, 50);
1368    }
1369
1370    #[test]
1371    fn test_sse_finish_reason_deserialization() {
1372        let stop: SseFinishReason = serde_json::from_str("\"stop\"").unwrap();
1373        let tool_calls: SseFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1374        let length: SseFinishReason = serde_json::from_str("\"length\"").unwrap();
1375        let content_filter: SseFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1376
1377        assert!(matches!(stop, SseFinishReason::Stop));
1378        assert!(matches!(tool_calls, SseFinishReason::ToolCalls));
1379        assert!(matches!(length, SseFinishReason::Length));
1380        assert!(matches!(content_filter, SseFinishReason::ContentFilter));
1381    }
1382
1383    #[test]
1384    fn test_streaming_request_serialization() {
1385        let messages = vec![ApiMessage {
1386            role: ApiRole::User,
1387            content: Some("Hello".to_string()),
1388            tool_calls: None,
1389            tool_call_id: None,
1390        }];
1391
1392        let request = ApiChatRequestStreaming {
1393            model: "gpt-4o",
1394            messages: &messages,
1395            max_completion_tokens: Some(1024),
1396            tools: None,
1397            stream: true,
1398        };
1399
1400        let json = serde_json::to_string(&request).unwrap();
1401        assert!(json.contains("\"stream\":true"));
1402        assert!(json.contains("\"model\":\"gpt-4o\""));
1403    }
1404}