Skip to main content

agent_sdk/providers/
openai.rs

1//! `OpenAI` API provider implementation.
2//!
3//! This module provides an implementation of `LlmProvider` for the `OpenAI`
4//! Chat Completions API. It also supports `OpenAI`-compatible APIs (Ollama, vLLM, etc.)
5//! via the `with_base_url` constructor.
6//!
7//! Models that require the Responses API (like `gpt-5.2-codex`) are automatically
8//! routed to the correct endpoint.
9
10use crate::llm::{
11    ChatOutcome, ChatRequest, ChatResponse, Content, ContentBlock, LlmProvider, StopReason,
12    StreamBox, StreamDelta, Usage,
13};
14use anyhow::Result;
15use async_trait::async_trait;
16use futures::StreamExt;
17use reqwest::StatusCode;
18use serde::{Deserialize, Serialize};
19
20use super::openai_responses::OpenAIResponsesProvider;
21
22const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
23
24/// Check if a model requires the Responses API instead of Chat Completions.
25fn requires_responses_api(model: &str) -> bool {
26    model.contains("codex")
27}
28
29// GPT-5.2 series (latest flagship, Dec 2025)
30pub const MODEL_GPT52_INSTANT: &str = "gpt-5.2-instant";
31pub const MODEL_GPT52_THINKING: &str = "gpt-5.2-thinking";
32pub const MODEL_GPT52_PRO: &str = "gpt-5.2-pro";
33pub const MODEL_GPT52_CODEX: &str = "gpt-5.2-codex";
34
35// GPT-5 series (400k context)
36pub const MODEL_GPT5: &str = "gpt-5";
37pub const MODEL_GPT5_MINI: &str = "gpt-5-mini";
38pub const MODEL_GPT5_NANO: &str = "gpt-5-nano";
39
40// o-series reasoning models
41pub const MODEL_O3: &str = "o3";
42pub const MODEL_O3_MINI: &str = "o3-mini";
43pub const MODEL_O4_MINI: &str = "o4-mini";
44pub const MODEL_O1: &str = "o1";
45pub const MODEL_O1_MINI: &str = "o1-mini";
46
47// GPT-4.1 series (improved instruction following, 1M context)
48pub const MODEL_GPT41: &str = "gpt-4.1";
49pub const MODEL_GPT41_MINI: &str = "gpt-4.1-mini";
50pub const MODEL_GPT41_NANO: &str = "gpt-4.1-nano";
51
52// GPT-4o series
53pub const MODEL_GPT4O: &str = "gpt-4o";
54pub const MODEL_GPT4O_MINI: &str = "gpt-4o-mini";
55
56/// `OpenAI` LLM provider using the Chat Completions API.
57///
58/// Also supports `OpenAI`-compatible APIs (Ollama, vLLM, Azure `OpenAI`, etc.)
59/// via the `with_base_url` constructor.
60#[derive(Clone)]
61pub struct OpenAIProvider {
62    client: reqwest::Client,
63    api_key: String,
64    model: String,
65    base_url: String,
66}
67
68impl OpenAIProvider {
69    /// Create a new `OpenAI` provider with the specified API key and model.
70    #[must_use]
71    pub fn new(api_key: String, model: String) -> Self {
72        Self {
73            client: reqwest::Client::new(),
74            api_key,
75            model,
76            base_url: DEFAULT_BASE_URL.to_owned(),
77        }
78    }
79
80    /// Create a new provider with a custom base URL for OpenAI-compatible APIs.
81    #[must_use]
82    pub fn with_base_url(api_key: String, model: String, base_url: String) -> Self {
83        Self {
84            client: reqwest::Client::new(),
85            api_key,
86            model,
87            base_url,
88        }
89    }
90
91    /// Create a provider using GPT-5.2 Instant (speed-optimized for routine queries).
92    #[must_use]
93    pub fn gpt52_instant(api_key: String) -> Self {
94        Self::new(api_key, MODEL_GPT52_INSTANT.to_owned())
95    }
96
97    /// Create a provider using GPT-5.2 Thinking (complex reasoning, coding, analysis).
98    #[must_use]
99    pub fn gpt52_thinking(api_key: String) -> Self {
100        Self::new(api_key, MODEL_GPT52_THINKING.to_owned())
101    }
102
103    /// Create a provider using GPT-5.2 Pro (maximum accuracy for difficult problems).
104    #[must_use]
105    pub fn gpt52_pro(api_key: String) -> Self {
106        Self::new(api_key, MODEL_GPT52_PRO.to_owned())
107    }
108
109    /// Create a provider using GPT-5.2 Codex (optimized for agentic coding).
110    ///
111    /// Note: This model uses the Responses API internally.
112    #[must_use]
113    pub fn codex(api_key: String) -> Self {
114        Self::new(api_key, MODEL_GPT52_CODEX.to_owned())
115    }
116
117    /// Create a provider using GPT-5 (400k context, coding and reasoning).
118    #[must_use]
119    pub fn gpt5(api_key: String) -> Self {
120        Self::new(api_key, MODEL_GPT5.to_owned())
121    }
122
123    /// Create a provider using GPT-5-mini (faster, cost-efficient GPT-5).
124    #[must_use]
125    pub fn gpt5_mini(api_key: String) -> Self {
126        Self::new(api_key, MODEL_GPT5_MINI.to_owned())
127    }
128
129    /// Create a provider using GPT-5-nano (fastest, cheapest GPT-5 variant).
130    #[must_use]
131    pub fn gpt5_nano(api_key: String) -> Self {
132        Self::new(api_key, MODEL_GPT5_NANO.to_owned())
133    }
134
135    /// Create a provider using o3 (most intelligent reasoning model).
136    #[must_use]
137    pub fn o3(api_key: String) -> Self {
138        Self::new(api_key, MODEL_O3.to_owned())
139    }
140
141    /// Create a provider using o3-mini (smaller o3 variant).
142    #[must_use]
143    pub fn o3_mini(api_key: String) -> Self {
144        Self::new(api_key, MODEL_O3_MINI.to_owned())
145    }
146
147    /// Create a provider using o4-mini (fast, cost-efficient reasoning).
148    #[must_use]
149    pub fn o4_mini(api_key: String) -> Self {
150        Self::new(api_key, MODEL_O4_MINI.to_owned())
151    }
152
153    /// Create a provider using o1 (reasoning model).
154    #[must_use]
155    pub fn o1(api_key: String) -> Self {
156        Self::new(api_key, MODEL_O1.to_owned())
157    }
158
159    /// Create a provider using o1-mini (fast reasoning model).
160    #[must_use]
161    pub fn o1_mini(api_key: String) -> Self {
162        Self::new(api_key, MODEL_O1_MINI.to_owned())
163    }
164
165    /// Create a provider using GPT-4.1 (improved instruction following, 1M context).
166    #[must_use]
167    pub fn gpt41(api_key: String) -> Self {
168        Self::new(api_key, MODEL_GPT41.to_owned())
169    }
170
171    /// Create a provider using GPT-4.1-mini (smaller, faster GPT-4.1).
172    #[must_use]
173    pub fn gpt41_mini(api_key: String) -> Self {
174        Self::new(api_key, MODEL_GPT41_MINI.to_owned())
175    }
176
177    /// Create a provider using GPT-4o.
178    #[must_use]
179    pub fn gpt4o(api_key: String) -> Self {
180        Self::new(api_key, MODEL_GPT4O.to_owned())
181    }
182
183    /// Create a provider using GPT-4o-mini (fast and cost-effective).
184    #[must_use]
185    pub fn gpt4o_mini(api_key: String) -> Self {
186        Self::new(api_key, MODEL_GPT4O_MINI.to_owned())
187    }
188}
189
190#[async_trait]
191impl LlmProvider for OpenAIProvider {
192    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome> {
193        // Route to Responses API for models that require it (e.g., gpt-5.2-codex)
194        if requires_responses_api(&self.model) {
195            let responses_provider =
196                OpenAIResponsesProvider::new(self.api_key.clone(), self.model.clone());
197            return responses_provider.chat(request).await;
198        }
199
200        let messages = build_api_messages(&request);
201        let tools: Option<Vec<ApiTool>> = request
202            .tools
203            .map(|ts| ts.into_iter().map(convert_tool).collect());
204
205        let api_request = ApiChatRequest {
206            model: &self.model,
207            messages: &messages,
208            max_completion_tokens: Some(request.max_tokens),
209            tools: tools.as_deref(),
210        };
211
212        tracing::debug!(
213            model = %self.model,
214            max_tokens = request.max_tokens,
215            "OpenAI LLM request"
216        );
217
218        let response = self
219            .client
220            .post(format!("{}/chat/completions", self.base_url))
221            .header("Content-Type", "application/json")
222            .header("Authorization", format!("Bearer {}", self.api_key))
223            .json(&api_request)
224            .send()
225            .await
226            .map_err(|e| anyhow::anyhow!("request failed: {e}"))?;
227
228        let status = response.status();
229        let bytes = response
230            .bytes()
231            .await
232            .map_err(|e| anyhow::anyhow!("failed to read response body: {e}"))?;
233
234        tracing::debug!(
235            status = %status,
236            body_len = bytes.len(),
237            "OpenAI LLM response"
238        );
239
240        if status == StatusCode::TOO_MANY_REQUESTS {
241            return Ok(ChatOutcome::RateLimited);
242        }
243
244        if status.is_server_error() {
245            let body = String::from_utf8_lossy(&bytes);
246            tracing::error!(status = %status, body = %body, "OpenAI server error");
247            return Ok(ChatOutcome::ServerError(body.into_owned()));
248        }
249
250        if status.is_client_error() {
251            let body = String::from_utf8_lossy(&bytes);
252            tracing::warn!(status = %status, body = %body, "OpenAI client error");
253            return Ok(ChatOutcome::InvalidRequest(body.into_owned()));
254        }
255
256        let api_response: ApiChatResponse = serde_json::from_slice(&bytes)
257            .map_err(|e| anyhow::anyhow!("failed to parse response: {e}"))?;
258
259        let choice = api_response
260            .choices
261            .into_iter()
262            .next()
263            .ok_or_else(|| anyhow::anyhow!("no choices in response"))?;
264
265        let content = build_content_blocks(&choice.message);
266
267        let stop_reason = choice.finish_reason.map(|r| match r {
268            ApiFinishReason::Stop => StopReason::EndTurn,
269            ApiFinishReason::ToolCalls => StopReason::ToolUse,
270            ApiFinishReason::Length => StopReason::MaxTokens,
271            ApiFinishReason::ContentFilter => StopReason::StopSequence,
272        });
273
274        Ok(ChatOutcome::Success(ChatResponse {
275            id: api_response.id,
276            content,
277            model: api_response.model,
278            stop_reason,
279            usage: Usage {
280                input_tokens: api_response.usage.prompt_tokens,
281                output_tokens: api_response.usage.completion_tokens,
282            },
283        }))
284    }
285
286    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
287        // Route to Responses API for models that require it (e.g., gpt-5.2-codex)
288        if requires_responses_api(&self.model) {
289            let api_key = self.api_key.clone();
290            let model = self.model.clone();
291            return Box::pin(async_stream::stream! {
292                let responses_provider = OpenAIResponsesProvider::new(api_key, model);
293                let mut stream = std::pin::pin!(responses_provider.chat_stream(request));
294                while let Some(item) = futures::StreamExt::next(&mut stream).await {
295                    yield item;
296                }
297            });
298        }
299
300        Box::pin(async_stream::stream! {
301            let messages = build_api_messages(&request);
302            let tools: Option<Vec<ApiTool>> = request
303                .tools
304                .map(|ts| ts.into_iter().map(convert_tool).collect());
305
306            let api_request = ApiChatRequestStreaming { model: &self.model, messages: &messages, max_completion_tokens: Some(request.max_tokens), tools: tools.as_deref(), stream: true };
307
308            tracing::debug!(model = %self.model, max_tokens = request.max_tokens, "OpenAI streaming LLM request");
309
310            let Ok(response) = self.client
311                .post(format!("{}/chat/completions", self.base_url))
312                .header("Content-Type", "application/json")
313                .header("Authorization", format!("Bearer {}", self.api_key))
314                .json(&api_request)
315                .send()
316                .await
317            else {
318                yield Err(anyhow::anyhow!("request failed"));
319                return;
320            };
321
322            let status = response.status();
323
324            if !status.is_success() {
325                let body = response.text().await.unwrap_or_default();
326                let (recoverable, level) = if status == StatusCode::TOO_MANY_REQUESTS {
327                    (true, "rate_limit")
328                } else if status.is_server_error() {
329                    (true, "server_error")
330                } else {
331                    (false, "client_error")
332                };
333                tracing::warn!(status = %status, body = %body, kind = level, "OpenAI error");
334                yield Ok(StreamDelta::Error { message: body, recoverable });
335                return;
336            }
337
338            // Track tool call state across deltas
339            let mut tool_calls: std::collections::HashMap<usize, ToolCallAccumulator> =
340                std::collections::HashMap::new();
341            let mut usage: Option<Usage> = None;
342            let mut buffer = String::new();
343            let mut stream = response.bytes_stream();
344
345            while let Some(chunk_result) = stream.next().await {
346                let Ok(chunk) = chunk_result else {
347                    yield Err(anyhow::anyhow!("stream error: {}", chunk_result.unwrap_err()));
348                    return;
349                };
350                buffer.push_str(&String::from_utf8_lossy(&chunk));
351
352                while let Some(pos) = buffer.find('\n') {
353                    let line = buffer[..pos].trim().to_string();
354                    buffer = buffer[pos + 1..].to_string();
355                    if line.is_empty() { continue; }
356                    let Some(data) = line.strip_prefix("data: ") else { continue; };
357
358                    for result in process_sse_data(data) {
359                        match result {
360                            SseProcessResult::TextDelta(c) => yield Ok(StreamDelta::TextDelta { delta: c, block_index: 0 }),
361                            SseProcessResult::ToolCallUpdate { index, id, name, arguments } => apply_tool_call_update(&mut tool_calls, index, id, name, arguments),
362                            SseProcessResult::Usage(u) => usage = Some(u),
363                            SseProcessResult::Done(sr) => {
364                                for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
365                                return;
366                            }
367                            SseProcessResult::Sentinel => {
368                                let sr = if tool_calls.is_empty() { StopReason::EndTurn } else { StopReason::ToolUse };
369                                for d in build_stream_end_deltas(&tool_calls, usage.take(), sr) { yield Ok(d); }
370                                return;
371                            }
372                        }
373                    }
374                }
375            }
376
377            // Stream ended without [DONE] - emit what we have
378            for delta in build_stream_end_deltas(&tool_calls, usage, StopReason::EndTurn) {
379                yield Ok(delta);
380            }
381        })
382    }
383
384    fn model(&self) -> &str {
385        &self.model
386    }
387
388    fn provider(&self) -> &'static str {
389        "openai"
390    }
391}
392
393/// Apply a tool call update to the accumulator.
394fn apply_tool_call_update(
395    tool_calls: &mut std::collections::HashMap<usize, ToolCallAccumulator>,
396    index: usize,
397    id: Option<String>,
398    name: Option<String>,
399    arguments: Option<String>,
400) {
401    let entry = tool_calls
402        .entry(index)
403        .or_insert_with(|| ToolCallAccumulator {
404            id: String::new(),
405            name: String::new(),
406            arguments: String::new(),
407        });
408    if let Some(id) = id {
409        entry.id = id;
410    }
411    if let Some(name) = name {
412        entry.name = name;
413    }
414    if let Some(args) = arguments {
415        entry.arguments.push_str(&args);
416    }
417}
418
419/// Helper to emit tool call deltas and done event.
420fn build_stream_end_deltas(
421    tool_calls: &std::collections::HashMap<usize, ToolCallAccumulator>,
422    usage: Option<Usage>,
423    stop_reason: StopReason,
424) -> Vec<StreamDelta> {
425    let mut deltas = Vec::new();
426
427    // Emit tool calls
428    for (idx, tool) in tool_calls {
429        deltas.push(StreamDelta::ToolUseStart {
430            id: tool.id.clone(),
431            name: tool.name.clone(),
432            block_index: *idx + 1,
433        });
434        deltas.push(StreamDelta::ToolInputDelta {
435            id: tool.id.clone(),
436            delta: tool.arguments.clone(),
437            block_index: *idx + 1,
438        });
439    }
440
441    // Emit usage
442    if let Some(u) = usage {
443        deltas.push(StreamDelta::Usage(u));
444    }
445
446    // Emit done
447    deltas.push(StreamDelta::Done {
448        stop_reason: Some(stop_reason),
449    });
450
451    deltas
452}
453
454/// Result of processing an SSE chunk.
455enum SseProcessResult {
456    /// Emit a text delta.
457    TextDelta(String),
458    /// Update tool call accumulator (index, optional id, optional name, optional args).
459    ToolCallUpdate {
460        index: usize,
461        id: Option<String>,
462        name: Option<String>,
463        arguments: Option<String>,
464    },
465    /// Usage information.
466    Usage(Usage),
467    /// Stream is done with a stop reason.
468    Done(StopReason),
469    /// Stream sentinel [DONE] was received.
470    Sentinel,
471}
472
473/// Process an SSE data line and return results to apply.
474fn process_sse_data(data: &str) -> Vec<SseProcessResult> {
475    if data == "[DONE]" {
476        return vec![SseProcessResult::Sentinel];
477    }
478
479    let Ok(chunk) = serde_json::from_str::<SseChunk>(data) else {
480        return vec![];
481    };
482
483    let mut results = Vec::new();
484
485    // Extract usage if present
486    if let Some(u) = chunk.usage {
487        results.push(SseProcessResult::Usage(Usage {
488            input_tokens: u.prompt_tokens,
489            output_tokens: u.completion_tokens,
490        }));
491    }
492
493    // Process choices
494    if let Some(choice) = chunk.choices.into_iter().next() {
495        // Handle text content delta
496        if let Some(content) = choice.delta.content
497            && !content.is_empty()
498        {
499            results.push(SseProcessResult::TextDelta(content));
500        }
501
502        // Handle tool call deltas
503        if let Some(tc_deltas) = choice.delta.tool_calls {
504            for tc in tc_deltas {
505                results.push(SseProcessResult::ToolCallUpdate {
506                    index: tc.index,
507                    id: tc.id,
508                    name: tc.function.as_ref().and_then(|f| f.name.clone()),
509                    arguments: tc.function.as_ref().and_then(|f| f.arguments.clone()),
510                });
511            }
512        }
513
514        // Check for finish reason
515        if let Some(finish_reason) = choice.finish_reason {
516            let stop_reason = match finish_reason {
517                SseFinishReason::Stop => StopReason::EndTurn,
518                SseFinishReason::ToolCalls => StopReason::ToolUse,
519                SseFinishReason::Length => StopReason::MaxTokens,
520                SseFinishReason::ContentFilter => StopReason::StopSequence,
521            };
522            results.push(SseProcessResult::Done(stop_reason));
523        }
524    }
525
526    results
527}
528
529fn build_api_messages(request: &ChatRequest) -> Vec<ApiMessage> {
530    let mut messages = Vec::new();
531
532    // Add system message first (OpenAI uses a separate message for system prompt)
533    if !request.system.is_empty() {
534        messages.push(ApiMessage {
535            role: ApiRole::System,
536            content: Some(request.system.clone()),
537            tool_calls: None,
538            tool_call_id: None,
539        });
540    }
541
542    // Convert SDK messages to OpenAI format
543    for msg in &request.messages {
544        match &msg.content {
545            Content::Text(text) => {
546                messages.push(ApiMessage {
547                    role: match msg.role {
548                        crate::llm::Role::User => ApiRole::User,
549                        crate::llm::Role::Assistant => ApiRole::Assistant,
550                    },
551                    content: Some(text.clone()),
552                    tool_calls: None,
553                    tool_call_id: None,
554                });
555            }
556            Content::Blocks(blocks) => {
557                // Handle mixed content blocks
558                let mut text_parts = Vec::new();
559                let mut tool_calls = Vec::new();
560
561                for block in blocks {
562                    match block {
563                        ContentBlock::Text { text } => {
564                            text_parts.push(text.clone());
565                        }
566                        ContentBlock::Thinking { .. } => {
567                            // Thinking blocks are ephemeral - not sent back to API
568                        }
569                        ContentBlock::ToolUse {
570                            id, name, input, ..
571                        } => {
572                            tool_calls.push(ApiToolCall {
573                                id: id.clone(),
574                                r#type: "function".to_owned(),
575                                function: ApiFunctionCall {
576                                    name: name.clone(),
577                                    arguments: serde_json::to_string(input)
578                                        .unwrap_or_else(|_| "{}".to_owned()),
579                                },
580                            });
581                        }
582                        ContentBlock::ToolResult {
583                            tool_use_id,
584                            content,
585                            ..
586                        } => {
587                            // Tool results are separate messages in OpenAI
588                            messages.push(ApiMessage {
589                                role: ApiRole::Tool,
590                                content: Some(content.clone()),
591                                tool_calls: None,
592                                tool_call_id: Some(tool_use_id.clone()),
593                            });
594                        }
595                    }
596                }
597
598                // Add assistant message with text and/or tool calls
599                if !text_parts.is_empty() || !tool_calls.is_empty() {
600                    let role = match msg.role {
601                        crate::llm::Role::User => ApiRole::User,
602                        crate::llm::Role::Assistant => ApiRole::Assistant,
603                    };
604
605                    // Only add if it's an assistant message or has text content
606                    if role == ApiRole::Assistant || !text_parts.is_empty() {
607                        messages.push(ApiMessage {
608                            role,
609                            content: if text_parts.is_empty() {
610                                None
611                            } else {
612                                Some(text_parts.join("\n"))
613                            },
614                            tool_calls: if tool_calls.is_empty() {
615                                None
616                            } else {
617                                Some(tool_calls)
618                            },
619                            tool_call_id: None,
620                        });
621                    }
622                }
623            }
624        }
625    }
626
627    messages
628}
629
630fn convert_tool(t: crate::llm::Tool) -> ApiTool {
631    ApiTool {
632        r#type: "function".to_owned(),
633        function: ApiFunction {
634            name: t.name,
635            description: t.description,
636            parameters: t.input_schema,
637        },
638    }
639}
640
641fn build_content_blocks(message: &ApiResponseMessage) -> Vec<ContentBlock> {
642    let mut blocks = Vec::new();
643
644    // Add text content if present
645    if let Some(content) = &message.content
646        && !content.is_empty()
647    {
648        blocks.push(ContentBlock::Text {
649            text: content.clone(),
650        });
651    }
652
653    // Add tool calls if present
654    if let Some(tool_calls) = &message.tool_calls {
655        for tc in tool_calls {
656            let input: serde_json::Value =
657                serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null);
658            blocks.push(ContentBlock::ToolUse {
659                id: tc.id.clone(),
660                name: tc.function.name.clone(),
661                input,
662                thought_signature: None,
663            });
664        }
665    }
666
667    blocks
668}
669
670// ============================================================================
671// API Request Types
672// ============================================================================
673
674#[derive(Serialize)]
675struct ApiChatRequest<'a> {
676    model: &'a str,
677    messages: &'a [ApiMessage],
678    #[serde(skip_serializing_if = "Option::is_none")]
679    max_completion_tokens: Option<u32>,
680    #[serde(skip_serializing_if = "Option::is_none")]
681    tools: Option<&'a [ApiTool]>,
682}
683
684#[derive(Serialize)]
685struct ApiChatRequestStreaming<'a> {
686    model: &'a str,
687    messages: &'a [ApiMessage],
688    #[serde(skip_serializing_if = "Option::is_none")]
689    max_completion_tokens: Option<u32>,
690    #[serde(skip_serializing_if = "Option::is_none")]
691    tools: Option<&'a [ApiTool]>,
692    stream: bool,
693}
694
695#[derive(Serialize)]
696struct ApiMessage {
697    role: ApiRole,
698    #[serde(skip_serializing_if = "Option::is_none")]
699    content: Option<String>,
700    #[serde(skip_serializing_if = "Option::is_none")]
701    tool_calls: Option<Vec<ApiToolCall>>,
702    #[serde(skip_serializing_if = "Option::is_none")]
703    tool_call_id: Option<String>,
704}
705
706#[derive(Debug, Serialize, PartialEq, Eq)]
707#[serde(rename_all = "lowercase")]
708enum ApiRole {
709    System,
710    User,
711    Assistant,
712    Tool,
713}
714
715#[derive(Serialize)]
716struct ApiToolCall {
717    id: String,
718    r#type: String,
719    function: ApiFunctionCall,
720}
721
722#[derive(Serialize)]
723struct ApiFunctionCall {
724    name: String,
725    arguments: String,
726}
727
728#[derive(Serialize)]
729struct ApiTool {
730    r#type: String,
731    function: ApiFunction,
732}
733
734#[derive(Serialize)]
735struct ApiFunction {
736    name: String,
737    description: String,
738    parameters: serde_json::Value,
739}
740
741// ============================================================================
742// API Response Types
743// ============================================================================
744
745#[derive(Deserialize)]
746struct ApiChatResponse {
747    id: String,
748    choices: Vec<ApiChoice>,
749    model: String,
750    usage: ApiUsage,
751}
752
753#[derive(Deserialize)]
754struct ApiChoice {
755    message: ApiResponseMessage,
756    finish_reason: Option<ApiFinishReason>,
757}
758
759#[derive(Deserialize)]
760struct ApiResponseMessage {
761    content: Option<String>,
762    tool_calls: Option<Vec<ApiResponseToolCall>>,
763}
764
765#[derive(Deserialize)]
766struct ApiResponseToolCall {
767    id: String,
768    function: ApiResponseFunctionCall,
769}
770
771#[derive(Deserialize)]
772struct ApiResponseFunctionCall {
773    name: String,
774    arguments: String,
775}
776
777#[derive(Deserialize)]
778#[serde(rename_all = "snake_case")]
779enum ApiFinishReason {
780    Stop,
781    ToolCalls,
782    Length,
783    ContentFilter,
784}
785
786#[derive(Deserialize)]
787struct ApiUsage {
788    prompt_tokens: u32,
789    completion_tokens: u32,
790}
791
792// ============================================================================
793// SSE Streaming Types
794// ============================================================================
795
796/// Accumulator for tool call state across stream deltas.
797struct ToolCallAccumulator {
798    id: String,
799    name: String,
800    arguments: String,
801}
802
803/// A single chunk in `OpenAI`'s SSE stream.
804#[derive(Deserialize)]
805struct SseChunk {
806    choices: Vec<SseChoice>,
807    #[serde(default)]
808    usage: Option<SseUsage>,
809}
810
811#[derive(Deserialize)]
812struct SseChoice {
813    delta: SseDelta,
814    finish_reason: Option<SseFinishReason>,
815}
816
817#[derive(Deserialize)]
818struct SseDelta {
819    content: Option<String>,
820    tool_calls: Option<Vec<SseToolCallDelta>>,
821}
822
823#[derive(Deserialize)]
824struct SseToolCallDelta {
825    index: usize,
826    id: Option<String>,
827    function: Option<SseFunctionDelta>,
828}
829
830#[derive(Deserialize)]
831struct SseFunctionDelta {
832    name: Option<String>,
833    arguments: Option<String>,
834}
835
836#[derive(Deserialize)]
837#[serde(rename_all = "snake_case")]
838enum SseFinishReason {
839    Stop,
840    ToolCalls,
841    Length,
842    ContentFilter,
843}
844
845#[derive(Deserialize)]
846struct SseUsage {
847    prompt_tokens: u32,
848    completion_tokens: u32,
849}
850
851#[cfg(test)]
852mod tests {
853    use super::*;
854
855    // ===================
856    // Constructor Tests
857    // ===================
858
859    #[test]
860    fn test_new_creates_provider_with_custom_model() {
861        let provider = OpenAIProvider::new("test-api-key".to_string(), "custom-model".to_string());
862
863        assert_eq!(provider.model(), "custom-model");
864        assert_eq!(provider.provider(), "openai");
865        assert_eq!(provider.base_url, DEFAULT_BASE_URL);
866    }
867
868    #[test]
869    fn test_with_base_url_creates_provider_with_custom_url() {
870        let provider = OpenAIProvider::with_base_url(
871            "test-api-key".to_string(),
872            "llama3".to_string(),
873            "http://localhost:11434/v1".to_string(),
874        );
875
876        assert_eq!(provider.model(), "llama3");
877        assert_eq!(provider.base_url, "http://localhost:11434/v1");
878    }
879
880    #[test]
881    fn test_gpt4o_factory_creates_gpt4o_provider() {
882        let provider = OpenAIProvider::gpt4o("test-api-key".to_string());
883
884        assert_eq!(provider.model(), MODEL_GPT4O);
885        assert_eq!(provider.provider(), "openai");
886    }
887
888    #[test]
889    fn test_gpt4o_mini_factory_creates_gpt4o_mini_provider() {
890        let provider = OpenAIProvider::gpt4o_mini("test-api-key".to_string());
891
892        assert_eq!(provider.model(), MODEL_GPT4O_MINI);
893        assert_eq!(provider.provider(), "openai");
894    }
895
896    #[test]
897    fn test_gpt52_thinking_factory_creates_provider() {
898        let provider = OpenAIProvider::gpt52_thinking("test-api-key".to_string());
899
900        assert_eq!(provider.model(), MODEL_GPT52_THINKING);
901        assert_eq!(provider.provider(), "openai");
902    }
903
904    #[test]
905    fn test_gpt5_factory_creates_gpt5_provider() {
906        let provider = OpenAIProvider::gpt5("test-api-key".to_string());
907
908        assert_eq!(provider.model(), MODEL_GPT5);
909        assert_eq!(provider.provider(), "openai");
910    }
911
912    #[test]
913    fn test_gpt5_mini_factory_creates_provider() {
914        let provider = OpenAIProvider::gpt5_mini("test-api-key".to_string());
915
916        assert_eq!(provider.model(), MODEL_GPT5_MINI);
917        assert_eq!(provider.provider(), "openai");
918    }
919
920    #[test]
921    fn test_o3_factory_creates_o3_provider() {
922        let provider = OpenAIProvider::o3("test-api-key".to_string());
923
924        assert_eq!(provider.model(), MODEL_O3);
925        assert_eq!(provider.provider(), "openai");
926    }
927
928    #[test]
929    fn test_o4_mini_factory_creates_o4_mini_provider() {
930        let provider = OpenAIProvider::o4_mini("test-api-key".to_string());
931
932        assert_eq!(provider.model(), MODEL_O4_MINI);
933        assert_eq!(provider.provider(), "openai");
934    }
935
936    #[test]
937    fn test_o1_factory_creates_o1_provider() {
938        let provider = OpenAIProvider::o1("test-api-key".to_string());
939
940        assert_eq!(provider.model(), MODEL_O1);
941        assert_eq!(provider.provider(), "openai");
942    }
943
944    #[test]
945    fn test_gpt41_factory_creates_gpt41_provider() {
946        let provider = OpenAIProvider::gpt41("test-api-key".to_string());
947
948        assert_eq!(provider.model(), MODEL_GPT41);
949        assert_eq!(provider.provider(), "openai");
950    }
951
952    // ===================
953    // Model Constants Tests
954    // ===================
955
956    #[test]
957    fn test_model_constants_have_expected_values() {
958        // GPT-5.2 series
959        assert_eq!(MODEL_GPT52_INSTANT, "gpt-5.2-instant");
960        assert_eq!(MODEL_GPT52_THINKING, "gpt-5.2-thinking");
961        assert_eq!(MODEL_GPT52_PRO, "gpt-5.2-pro");
962        // GPT-5 series
963        assert_eq!(MODEL_GPT5, "gpt-5");
964        assert_eq!(MODEL_GPT5_MINI, "gpt-5-mini");
965        assert_eq!(MODEL_GPT5_NANO, "gpt-5-nano");
966        // o-series
967        assert_eq!(MODEL_O3, "o3");
968        assert_eq!(MODEL_O3_MINI, "o3-mini");
969        assert_eq!(MODEL_O4_MINI, "o4-mini");
970        assert_eq!(MODEL_O1, "o1");
971        assert_eq!(MODEL_O1_MINI, "o1-mini");
972        // GPT-4.1 series
973        assert_eq!(MODEL_GPT41, "gpt-4.1");
974        assert_eq!(MODEL_GPT41_MINI, "gpt-4.1-mini");
975        assert_eq!(MODEL_GPT41_NANO, "gpt-4.1-nano");
976        // GPT-4o series
977        assert_eq!(MODEL_GPT4O, "gpt-4o");
978        assert_eq!(MODEL_GPT4O_MINI, "gpt-4o-mini");
979    }
980
981    // ===================
982    // Clone Tests
983    // ===================
984
985    #[test]
986    fn test_provider_is_cloneable() {
987        let provider = OpenAIProvider::new("test-api-key".to_string(), "test-model".to_string());
988        let cloned = provider.clone();
989
990        assert_eq!(provider.model(), cloned.model());
991        assert_eq!(provider.provider(), cloned.provider());
992        assert_eq!(provider.base_url, cloned.base_url);
993    }
994
995    // ===================
996    // API Type Serialization Tests
997    // ===================
998
999    #[test]
1000    fn test_api_role_serialization() {
1001        let system_role = ApiRole::System;
1002        let user_role = ApiRole::User;
1003        let assistant_role = ApiRole::Assistant;
1004        let tool_role = ApiRole::Tool;
1005
1006        assert_eq!(serde_json::to_string(&system_role).unwrap(), "\"system\"");
1007        assert_eq!(serde_json::to_string(&user_role).unwrap(), "\"user\"");
1008        assert_eq!(
1009            serde_json::to_string(&assistant_role).unwrap(),
1010            "\"assistant\""
1011        );
1012        assert_eq!(serde_json::to_string(&tool_role).unwrap(), "\"tool\"");
1013    }
1014
1015    #[test]
1016    fn test_api_message_serialization_simple() {
1017        let message = ApiMessage {
1018            role: ApiRole::User,
1019            content: Some("Hello, world!".to_string()),
1020            tool_calls: None,
1021            tool_call_id: None,
1022        };
1023
1024        let json = serde_json::to_string(&message).unwrap();
1025        assert!(json.contains("\"role\":\"user\""));
1026        assert!(json.contains("\"content\":\"Hello, world!\""));
1027        // Optional fields should be omitted
1028        assert!(!json.contains("tool_calls"));
1029        assert!(!json.contains("tool_call_id"));
1030    }
1031
1032    #[test]
1033    fn test_api_message_serialization_with_tool_calls() {
1034        let message = ApiMessage {
1035            role: ApiRole::Assistant,
1036            content: Some("Let me help.".to_string()),
1037            tool_calls: Some(vec![ApiToolCall {
1038                id: "call_123".to_string(),
1039                r#type: "function".to_string(),
1040                function: ApiFunctionCall {
1041                    name: "read_file".to_string(),
1042                    arguments: "{\"path\": \"/test.txt\"}".to_string(),
1043                },
1044            }]),
1045            tool_call_id: None,
1046        };
1047
1048        let json = serde_json::to_string(&message).unwrap();
1049        assert!(json.contains("\"role\":\"assistant\""));
1050        assert!(json.contains("\"tool_calls\""));
1051        assert!(json.contains("\"id\":\"call_123\""));
1052        assert!(json.contains("\"type\":\"function\""));
1053        assert!(json.contains("\"name\":\"read_file\""));
1054    }
1055
1056    #[test]
1057    fn test_api_tool_message_serialization() {
1058        let message = ApiMessage {
1059            role: ApiRole::Tool,
1060            content: Some("File contents here".to_string()),
1061            tool_calls: None,
1062            tool_call_id: Some("call_123".to_string()),
1063        };
1064
1065        let json = serde_json::to_string(&message).unwrap();
1066        assert!(json.contains("\"role\":\"tool\""));
1067        assert!(json.contains("\"tool_call_id\":\"call_123\""));
1068        assert!(json.contains("\"content\":\"File contents here\""));
1069    }
1070
1071    #[test]
1072    fn test_api_tool_serialization() {
1073        let tool = ApiTool {
1074            r#type: "function".to_string(),
1075            function: ApiFunction {
1076                name: "test_tool".to_string(),
1077                description: "A test tool".to_string(),
1078                parameters: serde_json::json!({
1079                    "type": "object",
1080                    "properties": {
1081                        "arg": {"type": "string"}
1082                    }
1083                }),
1084            },
1085        };
1086
1087        let json = serde_json::to_string(&tool).unwrap();
1088        assert!(json.contains("\"type\":\"function\""));
1089        assert!(json.contains("\"name\":\"test_tool\""));
1090        assert!(json.contains("\"description\":\"A test tool\""));
1091        assert!(json.contains("\"parameters\""));
1092    }
1093
1094    // ===================
1095    // API Type Deserialization Tests
1096    // ===================
1097
1098    #[test]
1099    fn test_api_response_deserialization() {
1100        let json = r#"{
1101            "id": "chatcmpl-123",
1102            "choices": [
1103                {
1104                    "message": {
1105                        "content": "Hello!"
1106                    },
1107                    "finish_reason": "stop"
1108                }
1109            ],
1110            "model": "gpt-4o",
1111            "usage": {
1112                "prompt_tokens": 100,
1113                "completion_tokens": 50
1114            }
1115        }"#;
1116
1117        let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1118        assert_eq!(response.id, "chatcmpl-123");
1119        assert_eq!(response.model, "gpt-4o");
1120        assert_eq!(response.usage.prompt_tokens, 100);
1121        assert_eq!(response.usage.completion_tokens, 50);
1122        assert_eq!(response.choices.len(), 1);
1123        assert_eq!(
1124            response.choices[0].message.content,
1125            Some("Hello!".to_string())
1126        );
1127    }
1128
1129    #[test]
1130    fn test_api_response_with_tool_calls_deserialization() {
1131        let json = r#"{
1132            "id": "chatcmpl-456",
1133            "choices": [
1134                {
1135                    "message": {
1136                        "content": null,
1137                        "tool_calls": [
1138                            {
1139                                "id": "call_abc",
1140                                "type": "function",
1141                                "function": {
1142                                    "name": "read_file",
1143                                    "arguments": "{\"path\": \"test.txt\"}"
1144                                }
1145                            }
1146                        ]
1147                    },
1148                    "finish_reason": "tool_calls"
1149                }
1150            ],
1151            "model": "gpt-4o",
1152            "usage": {
1153                "prompt_tokens": 150,
1154                "completion_tokens": 30
1155            }
1156        }"#;
1157
1158        let response: ApiChatResponse = serde_json::from_str(json).unwrap();
1159        let tool_calls = response.choices[0].message.tool_calls.as_ref().unwrap();
1160        assert_eq!(tool_calls.len(), 1);
1161        assert_eq!(tool_calls[0].id, "call_abc");
1162        assert_eq!(tool_calls[0].function.name, "read_file");
1163    }
1164
1165    #[test]
1166    fn test_api_finish_reason_deserialization() {
1167        let stop: ApiFinishReason = serde_json::from_str("\"stop\"").unwrap();
1168        let tool_calls: ApiFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1169        let length: ApiFinishReason = serde_json::from_str("\"length\"").unwrap();
1170        let content_filter: ApiFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1171
1172        assert!(matches!(stop, ApiFinishReason::Stop));
1173        assert!(matches!(tool_calls, ApiFinishReason::ToolCalls));
1174        assert!(matches!(length, ApiFinishReason::Length));
1175        assert!(matches!(content_filter, ApiFinishReason::ContentFilter));
1176    }
1177
1178    // ===================
1179    // Message Conversion Tests
1180    // ===================
1181
1182    #[test]
1183    fn test_build_api_messages_with_system() {
1184        let request = ChatRequest {
1185            system: "You are helpful.".to_string(),
1186            messages: vec![crate::llm::Message::user("Hello")],
1187            tools: None,
1188            max_tokens: 1024,
1189            thinking: None,
1190        };
1191
1192        let api_messages = build_api_messages(&request);
1193        assert_eq!(api_messages.len(), 2);
1194        assert_eq!(api_messages[0].role, ApiRole::System);
1195        assert_eq!(
1196            api_messages[0].content,
1197            Some("You are helpful.".to_string())
1198        );
1199        assert_eq!(api_messages[1].role, ApiRole::User);
1200        assert_eq!(api_messages[1].content, Some("Hello".to_string()));
1201    }
1202
1203    #[test]
1204    fn test_build_api_messages_empty_system() {
1205        let request = ChatRequest {
1206            system: String::new(),
1207            messages: vec![crate::llm::Message::user("Hello")],
1208            tools: None,
1209            max_tokens: 1024,
1210            thinking: None,
1211        };
1212
1213        let api_messages = build_api_messages(&request);
1214        assert_eq!(api_messages.len(), 1);
1215        assert_eq!(api_messages[0].role, ApiRole::User);
1216    }
1217
1218    #[test]
1219    fn test_convert_tool() {
1220        let tool = crate::llm::Tool {
1221            name: "test_tool".to_string(),
1222            description: "A test tool".to_string(),
1223            input_schema: serde_json::json!({"type": "object"}),
1224        };
1225
1226        let api_tool = convert_tool(tool);
1227        assert_eq!(api_tool.r#type, "function");
1228        assert_eq!(api_tool.function.name, "test_tool");
1229        assert_eq!(api_tool.function.description, "A test tool");
1230    }
1231
1232    #[test]
1233    fn test_build_content_blocks_text_only() {
1234        let message = ApiResponseMessage {
1235            content: Some("Hello!".to_string()),
1236            tool_calls: None,
1237        };
1238
1239        let blocks = build_content_blocks(&message);
1240        assert_eq!(blocks.len(), 1);
1241        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Hello!"));
1242    }
1243
1244    #[test]
1245    fn test_build_content_blocks_with_tool_calls() {
1246        let message = ApiResponseMessage {
1247            content: Some("Let me help.".to_string()),
1248            tool_calls: Some(vec![ApiResponseToolCall {
1249                id: "call_123".to_string(),
1250                function: ApiResponseFunctionCall {
1251                    name: "read_file".to_string(),
1252                    arguments: "{\"path\": \"test.txt\"}".to_string(),
1253                },
1254            }]),
1255        };
1256
1257        let blocks = build_content_blocks(&message);
1258        assert_eq!(blocks.len(), 2);
1259        assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me help."));
1260        assert!(
1261            matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "call_123" && name == "read_file")
1262        );
1263    }
1264
1265    // ===================
1266    // SSE Streaming Type Tests
1267    // ===================
1268
1269    #[test]
1270    fn test_sse_chunk_text_delta_deserialization() {
1271        let json = r#"{
1272            "choices": [{
1273                "delta": {
1274                    "content": "Hello"
1275                },
1276                "finish_reason": null
1277            }]
1278        }"#;
1279
1280        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1281        assert_eq!(chunk.choices.len(), 1);
1282        assert_eq!(chunk.choices[0].delta.content, Some("Hello".to_string()));
1283        assert!(chunk.choices[0].finish_reason.is_none());
1284    }
1285
1286    #[test]
1287    fn test_sse_chunk_tool_call_delta_deserialization() {
1288        let json = r#"{
1289            "choices": [{
1290                "delta": {
1291                    "tool_calls": [{
1292                        "index": 0,
1293                        "id": "call_abc",
1294                        "function": {
1295                            "name": "read_file",
1296                            "arguments": ""
1297                        }
1298                    }]
1299                },
1300                "finish_reason": null
1301            }]
1302        }"#;
1303
1304        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1305        let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1306        assert_eq!(tool_calls.len(), 1);
1307        assert_eq!(tool_calls[0].index, 0);
1308        assert_eq!(tool_calls[0].id, Some("call_abc".to_string()));
1309        assert_eq!(
1310            tool_calls[0].function.as_ref().unwrap().name,
1311            Some("read_file".to_string())
1312        );
1313    }
1314
1315    #[test]
1316    fn test_sse_chunk_tool_call_arguments_delta_deserialization() {
1317        let json = r#"{
1318            "choices": [{
1319                "delta": {
1320                    "tool_calls": [{
1321                        "index": 0,
1322                        "function": {
1323                            "arguments": "{\"path\":"
1324                        }
1325                    }]
1326                },
1327                "finish_reason": null
1328            }]
1329        }"#;
1330
1331        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1332        let tool_calls = chunk.choices[0].delta.tool_calls.as_ref().unwrap();
1333        assert_eq!(tool_calls[0].id, None);
1334        assert_eq!(
1335            tool_calls[0].function.as_ref().unwrap().arguments,
1336            Some("{\"path\":".to_string())
1337        );
1338    }
1339
1340    #[test]
1341    fn test_sse_chunk_with_finish_reason_deserialization() {
1342        let json = r#"{
1343            "choices": [{
1344                "delta": {},
1345                "finish_reason": "stop"
1346            }]
1347        }"#;
1348
1349        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1350        assert!(matches!(
1351            chunk.choices[0].finish_reason,
1352            Some(SseFinishReason::Stop)
1353        ));
1354    }
1355
1356    #[test]
1357    fn test_sse_chunk_with_usage_deserialization() {
1358        let json = r#"{
1359            "choices": [{
1360                "delta": {},
1361                "finish_reason": "stop"
1362            }],
1363            "usage": {
1364                "prompt_tokens": 100,
1365                "completion_tokens": 50
1366            }
1367        }"#;
1368
1369        let chunk: SseChunk = serde_json::from_str(json).unwrap();
1370        let usage = chunk.usage.unwrap();
1371        assert_eq!(usage.prompt_tokens, 100);
1372        assert_eq!(usage.completion_tokens, 50);
1373    }
1374
1375    #[test]
1376    fn test_sse_finish_reason_deserialization() {
1377        let stop: SseFinishReason = serde_json::from_str("\"stop\"").unwrap();
1378        let tool_calls: SseFinishReason = serde_json::from_str("\"tool_calls\"").unwrap();
1379        let length: SseFinishReason = serde_json::from_str("\"length\"").unwrap();
1380        let content_filter: SseFinishReason = serde_json::from_str("\"content_filter\"").unwrap();
1381
1382        assert!(matches!(stop, SseFinishReason::Stop));
1383        assert!(matches!(tool_calls, SseFinishReason::ToolCalls));
1384        assert!(matches!(length, SseFinishReason::Length));
1385        assert!(matches!(content_filter, SseFinishReason::ContentFilter));
1386    }
1387
1388    #[test]
1389    fn test_streaming_request_serialization() {
1390        let messages = vec![ApiMessage {
1391            role: ApiRole::User,
1392            content: Some("Hello".to_string()),
1393            tool_calls: None,
1394            tool_call_id: None,
1395        }];
1396
1397        let request = ApiChatRequestStreaming {
1398            model: "gpt-4o",
1399            messages: &messages,
1400            max_completion_tokens: Some(1024),
1401            tools: None,
1402            stream: true,
1403        };
1404
1405        let json = serde_json::to_string(&request).unwrap();
1406        assert!(json.contains("\"stream\":true"));
1407        assert!(json.contains("\"model\":\"gpt-4o\""));
1408    }
1409}