Skip to main content

edgequake_llm/providers/
openai.rs

1//! OpenAI provider implementation.
2//!
3//! Supports OpenAI and OpenAI-compatible APIs (Ollama, LM Studio, etc.)
4
5use async_openai::{
6    config::OpenAIConfig,
7    types::chat::{
8        ChatCompletionMessageToolCall, ChatCompletionMessageToolCalls,
9        ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessageArgs,
10        ChatCompletionRequestMessage, ChatCompletionRequestMessageContentPartImage,
11        ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessageArgs,
12        ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
13        ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart,
14        ChatCompletionStreamOptions, ChatCompletionTool, ChatCompletionToolChoiceOption,
15        ChatCompletionTools, CompletionUsage, CreateChatCompletionRequestArgs, FinishReason,
16        FunctionCall, FunctionName, FunctionObjectArgs, ImageDetail, ImageUrl, ToolChoiceOptions,
17    },
18    Client,
19};
20use async_trait::async_trait;
21use futures::StreamExt;
22use std::collections::HashMap;
23use tracing::debug;
24
25use crate::error::{LlmError, Result};
26use crate::traits::FunctionCall as TraitFunctionCall;
27use crate::traits::ImageData;
28use crate::traits::ToolCall;
29use crate::traits::{
30    ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, LLMProvider, LLMResponse,
31    StreamChunk, StreamUsage, ToolChoice, ToolDefinition,
32};
33
34/// OpenAI provider for text completion and embeddings.
35///
36/// # Issue #164: Lenient Embedding Deserialization
37///
38/// Uses a manual HTTP client for embeddings instead of async-openai's strict
39/// types. This supports HuggingFace TEI and other OpenAI-compatible servers
40/// that omit the cosmetic `object` field in embedding responses.
41pub struct OpenAIProvider {
42    client: Client<OpenAIConfig>,
43    model: String,
44    embedding_model: String,
45    max_context_length: usize,
46    embedding_dimension: usize,
47    /// Raw API key for manual HTTP calls (embedding fallback).
48    raw_api_key: String,
49    /// Base URL for the API (empty = default OpenAI).
50    raw_base_url: String,
51}
52
53impl OpenAIProvider {
54    /// Create a new OpenAI provider with the given API key.
55    pub fn new(api_key: impl Into<String>) -> Self {
56        let key = api_key.into();
57        let config = OpenAIConfig::new().with_api_key(&key);
58        Self::with_config_and_key(config, key, String::new())
59    }
60
61    /// Create a provider with custom configuration.
62    /// Defaults to GPT-5-mini for best balance of performance and cost.
63    pub fn with_config(config: OpenAIConfig) -> Self {
64        // Extract key/url from config — not directly accessible, so use empty defaults.
65        // Callers that need lenient embedding should use with_config_and_key().
66        Self::with_config_and_key(config, String::new(), String::new())
67    }
68
69    /// Create a provider with custom configuration and explicit key/base_url for embedding fallback.
70    fn with_config_and_key(config: OpenAIConfig, api_key: String, base_url: String) -> Self {
71        Self {
72            client: Client::with_config(config),
73            model: "gpt-5-mini".to_string(),
74            embedding_model: "text-embedding-3-small".to_string(),
75            max_context_length: 200000,
76            embedding_dimension: 1536,
77            raw_api_key: api_key,
78            raw_base_url: base_url,
79        }
80    }
81
82    /// Create a provider for an OpenAI-compatible API.
83    pub fn compatible(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
84        let key = api_key.into();
85        let url = base_url.into();
86        let config = OpenAIConfig::new().with_api_key(&key).with_api_base(&url);
87        Self::with_config_and_key(config, key, url)
88    }
89
90    /// Create from environment variables.
91    ///
92    /// Loads `.env` first (dotenvy). Then reads:
93    /// - **Required:** `OPENAI_API_KEY`
94    /// - **Optional:** `OPENAI_MODEL` (default: `gpt-5-mini`)
95    /// - **Optional:** `OPENAI_BASE_URL` — for compatible APIs
96    ///
97    /// ```no_run
98    /// use edgequake_llm::OpenAIProvider;
99    /// let provider = OpenAIProvider::from_env().unwrap();
100    /// ```
101    pub fn from_env() -> crate::error::Result<Self> {
102        let _ = dotenvy::dotenv();
103        let api_key = std::env::var("OPENAI_API_KEY")
104            .map_err(|_| crate::error::LlmError::ConfigError("OPENAI_API_KEY not set".into()))?;
105        let base_url = std::env::var("OPENAI_BASE_URL").unwrap_or_default();
106        let mut config = OpenAIConfig::new().with_api_key(&api_key);
107        if !base_url.is_empty() {
108            config = config.with_api_base(&base_url);
109        }
110        let mut provider = Self::with_config_and_key(config, api_key, base_url);
111        if let Ok(model) = std::env::var("OPENAI_MODEL") {
112            provider = provider.with_model(model);
113        }
114        Ok(provider)
115    }
116
117    /// Set the completion model.
118    pub fn with_model(mut self, model: impl Into<String>) -> Self {
119        self.model = model.into();
120        self.max_context_length = Self::context_length_for_model(&self.model);
121        self
122    }
123
124    /// Set the embedding model.
125    pub fn with_embedding_model(mut self, model: impl Into<String>) -> Self {
126        self.embedding_model = model.into();
127        self.embedding_dimension = Self::dimension_for_model(&self.embedding_model);
128        self
129    }
130
131    /// Get the context length for a model.
132    fn context_length_for_model(model: &str) -> usize {
133        match model {
134            // GPT-5 series (2026 models)
135            m if m.contains("gpt-5.2") || m.contains("gpt-5.1") => 200000,
136            m if m.contains("gpt-5-nano") => 128000,
137            m if m.contains("gpt-5-mini") || m.contains("gpt-5") => 200000,
138
139            // GPT-4.1 series
140            m if m.contains("gpt-4.1") => 128000,
141
142            // O-series reasoning models
143            m if m.contains("o4") || m.contains("o3") => 200000,
144            m if m.contains("o1") => 200000,
145
146            // GPT-4 series
147            m if m.contains("gpt-4o") => 128000,
148            m if m.contains("gpt-4-turbo") => 128000,
149            m if m.contains("gpt-4-32k") => 32768,
150            m if m.contains("gpt-4") => 8192,
151
152            // GPT-3.5 series
153            m if m.contains("gpt-3.5-turbo-16k") => 16384,
154            m if m.contains("gpt-3.5") => 4096,
155
156            // Codex models
157            m if m.contains("codex") => 200000,
158
159            // Realtime and audio models
160            m if m.contains("gpt-realtime") || m.contains("gpt-audio") => 128000,
161
162            _ => 128000, // Updated default for newer models
163        }
164    }
165
166    /// Get the embedding dimension for a model.
167    fn dimension_for_model(model: &str) -> usize {
168        match model {
169            m if m.contains("text-embedding-3-large") => 3072,
170            m if m.contains("text-embedding-3-small") => 1536,
171            m if m.contains("text-embedding-ada") => 1536,
172            _ => 1536, // Default
173        }
174    }
175
176    fn extract_usage(
177        usage: Option<CompletionUsage>,
178    ) -> (usize, usize, usize, Option<usize>, Option<usize>) {
179        let usage = usage.unwrap_or(CompletionUsage {
180            prompt_tokens: 0,
181            completion_tokens: 0,
182            total_tokens: 0,
183            prompt_tokens_details: None,
184            completion_tokens_details: None,
185        });
186
187        let cache_hit_tokens = usage
188            .prompt_tokens_details
189            .as_ref()
190            .and_then(|d| d.cached_tokens)
191            .map(|t| t as usize);
192        let thinking_tokens = usage
193            .completion_tokens_details
194            .as_ref()
195            .and_then(|d| d.reasoning_tokens)
196            .map(|t| t as usize);
197
198        (
199            usage.prompt_tokens as usize,
200            usage.completion_tokens as usize,
201            usage.total_tokens as usize,
202            cache_hit_tokens,
203            thinking_tokens,
204        )
205    }
206
207    fn extract_stream_usage(usage: Option<CompletionUsage>) -> Option<StreamUsage> {
208        let (prompt_tokens, completion_tokens, _total_tokens, cache_hit_tokens, thinking_tokens) =
209            Self::extract_usage(usage);
210
211        if prompt_tokens == 0
212            && completion_tokens == 0
213            && cache_hit_tokens.is_none()
214            && thinking_tokens.is_none()
215        {
216            return None;
217        }
218
219        let mut usage = StreamUsage::new(prompt_tokens, completion_tokens);
220        if let Some(tokens) = cache_hit_tokens {
221            usage = usage.with_cache_hit_tokens(tokens);
222        }
223        if let Some(tokens) = thinking_tokens {
224            usage = usage.with_thinking_tokens(tokens);
225        }
226        Some(usage)
227    }
228
229    /// Convert chat messages to OpenAI format.
230    ///
231    /// Per the OpenAI API spec:
232    /// - `system` → `ChatCompletionRequestSystemMessage`
233    /// - `user`   → `ChatCompletionRequestUserMessage` (supports multimodal content)
234    /// - `assistant` → `ChatCompletionRequestAssistantMessage`; must carry `tool_calls`
235    ///   when the model previously requested function calls so the conversation history
236    ///   is replayed correctly.
237    /// - `tool`   → `ChatCompletionRequestToolMessage` with `tool_call_id`; sending
238    ///   tool results as `user` messages is an API contract violation that yields 400.
239    /// - `function` → legacy `ChatCompletionRequestFunctionMessage` (deprecated by OpenAI).
240    fn convert_messages(messages: &[ChatMessage]) -> Result<Vec<ChatCompletionRequestMessage>> {
241        messages
242            .iter()
243            .map(|msg| {
244                match msg.role {
245                    ChatRole::System => ChatCompletionRequestSystemMessageArgs::default()
246                        .content(msg.content.as_str())
247                        .build()
248                        .map(Into::into)
249                        .map_err(|e| LlmError::InvalidRequest(e.to_string())),
250
251                    ChatRole::User => {
252                        let content = Self::build_user_content(msg);
253                        ChatCompletionRequestUserMessageArgs::default()
254                            .content(content)
255                            .build()
256                            .map(Into::into)
257                            .map_err(|e| LlmError::InvalidRequest(e.to_string()))
258                    }
259
260                    ChatRole::Assistant => {
261                        let mut builder = ChatCompletionRequestAssistantMessageArgs::default();
262                        // Content may be empty when the assistant only emits tool calls.
263                        if !msg.content.is_empty() {
264                            builder.content(msg.content.clone());
265                        }
266                        // Propagate tool_calls so the replayed conversation includes
267                        // the model's previous function-calling requests.
268                        if let Some(ref tool_calls) = msg.tool_calls {
269                            let openai_calls: Vec<ChatCompletionMessageToolCalls> = tool_calls
270                                .iter()
271                                .map(|tc| {
272                                    ChatCompletionMessageToolCalls::Function(
273                                        ChatCompletionMessageToolCall {
274                                            id: tc.id.clone(),
275                                            function: FunctionCall {
276                                                name: tc.function.name.clone(),
277                                                arguments: tc.function.arguments.clone(),
278                                            },
279                                        },
280                                    )
281                                })
282                                .collect();
283                            builder.tool_calls(openai_calls);
284                        }
285                        builder
286                            .build()
287                            .map(Into::into)
288                            .map_err(|e| LlmError::InvalidRequest(e.to_string()))
289                    }
290
291                    ChatRole::Tool => {
292                        // The API requires role=tool with the matching tool_call_id.
293                        // Sending as a user message loses the correlation identifier and
294                        // causes the model to error or misinterpret its own history.
295                        let tool_call_id = msg.tool_call_id.clone().ok_or_else(|| {
296                            LlmError::InvalidRequest(
297                                "Tool message missing required tool_call_id".into(),
298                            )
299                        })?;
300                        ChatCompletionRequestToolMessageArgs::default()
301                            .content(msg.content.clone())
302                            .tool_call_id(tool_call_id)
303                            .build()
304                            .map(Into::into)
305                            .map_err(|e| LlmError::InvalidRequest(e.to_string()))
306                    }
307
308                    ChatRole::Function => {
309                        // Deprecated by OpenAI in favour of the `tool` role.
310                        // Keep as user message for backward compatibility with older callers.
311                        ChatCompletionRequestUserMessageArgs::default()
312                            .content(msg.content.as_str())
313                            .build()
314                            .map(Into::into)
315                            .map_err(|e| LlmError::InvalidRequest(e.to_string()))
316                    }
317                }
318            })
319            .collect()
320    }
321
322    /// Build user message content, supporting multimodal (text + images).
323    ///
324    /// WHY: Vision-capable OpenAI models (gpt-4o, gpt-4-vision-preview, etc.) require
325    /// content to be an array of typed parts when images are present. This function
326    /// detects image presence and builds the appropriate content representation.
327    fn build_user_content(msg: &ChatMessage) -> ChatCompletionRequestUserMessageContent {
328        if msg.has_images() {
329            let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
330
331            // Add text part first (if non-empty)
332            if !msg.content.is_empty() {
333                parts.push(ChatCompletionRequestUserMessageContentPart::Text(
334                    ChatCompletionRequestMessageContentPartText {
335                        text: msg.content.clone(),
336                    },
337                ));
338            }
339
340            // Add image parts
341            if let Some(ref images) = msg.images {
342                for img in images {
343                    let detail = Self::parse_image_detail(img);
344                    parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
345                        ChatCompletionRequestMessageContentPartImage {
346                            image_url: ImageUrl {
347                                // Pass URL directly for URL images; wrap base64 in data URI otherwise.
348                                url: img.to_api_url(),
349                                detail,
350                            },
351                        },
352                    ));
353                }
354            }
355
356            ChatCompletionRequestUserMessageContent::Array(parts)
357        } else {
358            ChatCompletionRequestUserMessageContent::Text(msg.content.clone())
359        }
360    }
361
362    /// Parse image detail level from ImageData.
363    fn parse_image_detail(img: &ImageData) -> Option<ImageDetail> {
364        match img.detail.as_deref() {
365            Some("low") => Some(ImageDetail::Low),
366            Some("high") => Some(ImageDetail::High),
367            Some("auto") => Some(ImageDetail::Auto),
368            _ => None,
369        }
370    }
371}
372
373#[async_trait]
374impl LLMProvider for OpenAIProvider {
375    fn name(&self) -> &str {
376        "openai"
377    }
378
379    fn model(&self) -> &str {
380        &self.model
381    }
382
383    fn max_context_length(&self) -> usize {
384        self.max_context_length
385    }
386
387    async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
388        self.complete_with_options(prompt, &CompletionOptions::default())
389            .await
390    }
391
392    async fn complete_with_options(
393        &self,
394        prompt: &str,
395        options: &CompletionOptions,
396    ) -> Result<LLMResponse> {
397        let mut messages = Vec::new();
398
399        if let Some(system) = &options.system_prompt {
400            messages.push(ChatMessage::system(system));
401        }
402        messages.push(ChatMessage::user(prompt));
403
404        self.chat(&messages, Some(options)).await
405    }
406
407    async fn chat(
408        &self,
409        messages: &[ChatMessage],
410        options: Option<&CompletionOptions>,
411    ) -> Result<LLMResponse> {
412        let openai_messages = Self::convert_messages(messages)?;
413        let options = options.cloned().unwrap_or_default();
414
415        let mut request_builder = CreateChatCompletionRequestArgs::default();
416        request_builder.model(&self.model).messages(openai_messages);
417
418        if let Some(max_tokens) = options.max_tokens {
419            // Use max_completion_tokens (the modern, universal parameter).
420            // async-openai 0.33 supports this natively for all models including
421            // o1/o3/o4 and gpt-4.1 families that previously rejected max_tokens.
422            request_builder.max_completion_tokens(max_tokens as u32);
423        }
424
425        if let Some(temp) = options.temperature {
426            // Bug #15: gpt-4.1-nano, o1, o4-mini only accept the default temperature (1.0).
427            // Skip setting temperature when it equals the model default to avoid 400 errors.
428            if (temp - 1.0_f32).abs() > f32::EPSILON {
429                request_builder.temperature(temp);
430            }
431        }
432
433        if let Some(top_p) = options.top_p {
434            request_builder.top_p(top_p);
435        }
436
437        if let Some(stop) = options.stop {
438            request_builder.stop(stop);
439        }
440
441        if let Some(freq_penalty) = options.frequency_penalty {
442            request_builder.frequency_penalty(freq_penalty);
443        }
444
445        if let Some(pres_penalty) = options.presence_penalty {
446            request_builder.presence_penalty(pres_penalty);
447        }
448
449        let request = request_builder
450            .build()
451            .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
452
453        let response = self.client.chat().create(request).await?;
454
455        // Debug logging for token tracking
456        debug!(
457            "OpenAI response - usage: {:?}, model: {}",
458            response.usage, response.model
459        );
460
461        let choice = response
462            .choices
463            .first()
464            .ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
465
466        // Guardrail: surface content-filter as an explicit error.
467        if let Some(FinishReason::ContentFilter) = choice.finish_reason {
468            return Err(LlmError::ApiError(
469                "Response blocked by OpenAI content filter (finish_reason=content_filter)".into(),
470            ));
471        }
472
473        let content = choice.message.content.clone().unwrap_or_default();
474
475        let (prompt_tokens, completion_tokens, total_tokens, cache_hit_tokens, thinking_tokens) =
476            Self::extract_usage(response.usage.clone());
477
478        // Log extracted token counts
479        debug!(
480            "OpenAI token usage - prompt: {}, completion: {}, total: {}, cached: {:?}, reasoning: {:?}",
481            prompt_tokens, completion_tokens, total_tokens,
482            cache_hit_tokens, thinking_tokens
483        );
484
485        let mut metadata = HashMap::new();
486        metadata.insert("response_id".to_string(), serde_json::json!(response.id));
487
488        Ok(LLMResponse {
489            content,
490            prompt_tokens,
491            completion_tokens,
492            total_tokens,
493            model: response.model,
494            finish_reason: choice.finish_reason.map(|r| format!("{:?}", r)),
495            tool_calls: Vec::new(),
496            metadata,
497            cache_hit_tokens,
498            cache_write_tokens: None,
499            thinking_tokens,
500            thinking_content: None,
501        })
502    }
503
504    /// Non-streaming chat with tool/function-calling support.
505    ///
506    /// Implements the full OpenAI tool-use contract:
507    /// 1. Sends tools + tool_choice in the request.
508    /// 2. Extracts `tool_calls` from the assistant response (if any).
509    /// 3. Returns them in `LLMResponse::tool_calls` so the caller can execute
510    ///    them and replay the history via `ChatMessage::tool_result(...)`.
511    async fn chat_with_tools(
512        &self,
513        messages: &[ChatMessage],
514        tools: &[ToolDefinition],
515        tool_choice: Option<ToolChoice>,
516        options: Option<&CompletionOptions>,
517    ) -> Result<LLMResponse> {
518        let openai_messages = Self::convert_messages(messages)?;
519        let opts = options.cloned().unwrap_or_default();
520
521        let openai_tools: Vec<ChatCompletionTools> = tools
522            .iter()
523            .map(|t| {
524                ChatCompletionTools::Function(ChatCompletionTool {
525                    function: FunctionObjectArgs::default()
526                        .name(&t.function.name)
527                        .description(&t.function.description)
528                        .parameters(t.function.parameters.clone())
529                        .build()
530                        .expect("Invalid tool definition"),
531                })
532            })
533            .collect();
534
535        let mut request_builder = CreateChatCompletionRequestArgs::default();
536        request_builder
537            .model(&self.model)
538            .messages(openai_messages)
539            .tools(openai_tools);
540
541        if let Some(tc) = tool_choice {
542            match tc {
543                ToolChoice::Auto(_) => {
544                    request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
545                        ToolChoiceOptions::Auto,
546                    ));
547                }
548                ToolChoice::Required(_) => {
549                    request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
550                        ToolChoiceOptions::Required,
551                    ));
552                }
553                ToolChoice::Function { ref function, .. } => {
554                    request_builder.tool_choice(ChatCompletionToolChoiceOption::Function(
555                        ChatCompletionNamedToolChoice {
556                            function: FunctionName {
557                                name: function.name.clone(),
558                            },
559                        },
560                    ));
561                }
562            }
563        }
564
565        if let Some(max_tokens) = opts.max_tokens {
566            request_builder.max_completion_tokens(max_tokens as u32);
567        }
568
569        if let Some(temp) = opts.temperature {
570            // Skip temperature=1.0 for strict-mode models (o1/o3/o4) which reject it.
571            if (temp - 1.0_f32).abs() > f32::EPSILON {
572                request_builder.temperature(temp);
573            }
574        }
575
576        let request = request_builder
577            .build()
578            .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
579
580        let response = self.client.chat().create(request).await?;
581
582        debug!(
583            "OpenAI chat_with_tools response id={} model={}",
584            response.id, response.model
585        );
586
587        let choice = response
588            .choices
589            .first()
590            .ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
591
592        if let Some(FinishReason::ContentFilter) = choice.finish_reason {
593            return Err(LlmError::ApiError(
594                "Response blocked by OpenAI content filter (finish_reason=content_filter)".into(),
595            ));
596        }
597
598        // Extract tool calls from the assistant message.
599        let tool_calls: Vec<ToolCall> = choice
600            .message
601            .tool_calls
602            .as_deref()
603            .unwrap_or_default()
604            .iter()
605            .filter_map(|tc| {
606                if let ChatCompletionMessageToolCalls::Function(f) = tc {
607                    Some(ToolCall {
608                        id: f.id.clone(),
609                        call_type: "function".to_string(),
610                        function: TraitFunctionCall {
611                            name: f.function.name.clone(),
612                            arguments: f.function.arguments.clone(),
613                        },
614                        thought_signature: None,
615                    })
616                } else {
617                    None
618                }
619            })
620            .collect();
621
622        let content = choice.message.content.clone().unwrap_or_default();
623
624        let (prompt_tokens, completion_tokens, total_tokens, cache_hit_tokens, thinking_tokens) =
625            Self::extract_usage(response.usage.clone());
626
627        let mut metadata = HashMap::new();
628        metadata.insert("response_id".to_string(), serde_json::json!(response.id));
629
630        Ok(LLMResponse {
631            content,
632            prompt_tokens,
633            completion_tokens,
634            total_tokens,
635            model: response.model,
636            finish_reason: choice.finish_reason.map(|r| format!("{:?}", r)),
637            tool_calls,
638            metadata,
639            cache_hit_tokens,
640            cache_write_tokens: None,
641            thinking_tokens,
642            thinking_content: None,
643        })
644    }
645
646    fn supports_function_calling(&self) -> bool {
647        true
648    }
649
650    async fn stream(
651        &self,
652        prompt: &str,
653    ) -> Result<futures::stream::BoxStream<'static, Result<String>>> {
654        let request = ChatCompletionRequestUserMessageArgs::default()
655            .content(prompt)
656            .build()
657            .map(Into::into)
658            .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
659
660        let request = CreateChatCompletionRequestArgs::default()
661            .model(&self.model)
662            .messages(vec![request])
663            .stream(true)
664            .build()
665            .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
666
667        let stream = self.client.chat().create_stream(request).await?;
668
669        let mapped_stream = stream.map(|res| match res {
670            Ok(response) => {
671                let content = response
672                    .choices
673                    .first()
674                    .and_then(|c| c.delta.content.clone())
675                    .unwrap_or_default();
676                Ok(content)
677            }
678            // IMPORTANT: Use LlmError::from(e) instead of LlmError::ApiError(e.to_string()).
679            // e.to_string() uses ApiError's Display which formats as "{type}: {message}"
680            // (e.g. "tokens: Rate limit reached..."), masking the structured code/type
681            // fields needed for correct RateLimited classification. From<OpenAIError>
682            // inspects api_err.code and api_err.message to produce the right variant.
683            Err(e) => Err(LlmError::from(e)),
684        });
685
686        Ok(mapped_stream.boxed())
687    }
688
689    fn supports_streaming(&self) -> bool {
690        true
691    }
692
693    async fn chat_with_tools_stream(
694        &self,
695        messages: &[ChatMessage],
696        tools: &[ToolDefinition],
697        tool_choice: Option<ToolChoice>,
698        options: Option<&CompletionOptions>,
699    ) -> Result<futures::stream::BoxStream<'static, Result<StreamChunk>>> {
700        let openai_messages = Self::convert_messages(messages)?;
701        let options = options.cloned().unwrap_or_default();
702
703        // Convert tools to OpenAI 0.33 format.
704        // ChatCompletionTools is now an enum: Function(ChatCompletionTool) or Custom(...)
705        let openai_tools: Vec<ChatCompletionTools> = tools
706            .iter()
707            .map(|tool| {
708                ChatCompletionTools::Function(ChatCompletionTool {
709                    function: FunctionObjectArgs::default()
710                        .name(&tool.function.name)
711                        .description(&tool.function.description)
712                        .parameters(tool.function.parameters.clone())
713                        .build()
714                        .expect("Invalid tool definition"),
715                })
716            })
717            .collect();
718
719        // Build request
720        let mut request_builder = CreateChatCompletionRequestArgs::default();
721        request_builder
722            .model(&self.model)
723            .messages(openai_messages)
724            .tools(openai_tools)
725            .stream(true)
726            .stream_options(ChatCompletionStreamOptions {
727                include_usage: Some(true),
728                include_obfuscation: None,
729            }); // Enable streaming and final usage
730
731        // Set tool choice if specified.
732        // In async-openai 0.33, Auto/Required are ChatCompletionToolChoiceOption::Mode(...)
733        if let Some(tc) = tool_choice {
734            match tc {
735                ToolChoice::Auto(_) => {
736                    request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
737                        ToolChoiceOptions::Auto,
738                    ));
739                }
740                ToolChoice::Required(_) => {
741                    request_builder.tool_choice(ChatCompletionToolChoiceOption::Mode(
742                        ToolChoiceOptions::Required,
743                    ));
744                }
745                ToolChoice::Function { ref function, .. } => {
746                    // Force a specific function: {"type":"function","function":{"name":"..."}}
747                    request_builder.tool_choice(ChatCompletionToolChoiceOption::Function(
748                        ChatCompletionNamedToolChoice {
749                            function: FunctionName {
750                                name: function.name.clone(),
751                            },
752                        },
753                    ));
754                }
755            }
756        }
757
758        if let Some(temp) = options.temperature {
759            // Bug #15: gpt-4.1-nano, o1, o4-mini only accept the default temperature (1.0).
760            // Skip setting temperature when it equals the model default to avoid 400 errors.
761            if (temp - 1.0_f32).abs() > f32::EPSILON {
762                request_builder.temperature(temp);
763            }
764        }
765
766        if let Some(max_tokens) = options.max_tokens {
767            // Use max_completion_tokens (the modern, universal parameter).
768            request_builder.max_completion_tokens(max_tokens as u32);
769        }
770
771        let request = request_builder
772            .build()
773            .map_err(|e| LlmError::InvalidRequest(e.to_string()))?;
774
775        let stream = self.client.chat().create_stream(request).await?;
776
777        // Map OpenAI stream to our StreamChunk format
778        let mapped_stream = stream.map(|result| {
779            match result {
780                Ok(response) => {
781                    let stream_usage = Self::extract_stream_usage(response.usage.clone());
782
783                    let choice = response.choices.first();
784                    if let Some(choice) = choice {
785                        // Stream content chunks immediately
786                        if let Some(content) = &choice.delta.content {
787                            return Ok(StreamChunk::Content(content.clone()));
788                        }
789
790                        // Return tool call delta (agent will accumulate)
791                        if let Some(tool_call_chunks) = &choice.delta.tool_calls {
792                            if let Some(chunk) = tool_call_chunks.first() {
793                                return Ok(StreamChunk::ToolCallDelta {
794                                    index: chunk.index as usize,
795                                    id: chunk.id.clone(),
796                                    function_name: chunk
797                                        .function
798                                        .as_ref()
799                                        .and_then(|f| f.name.clone()),
800                                    function_arguments: chunk
801                                        .function
802                                        .as_ref()
803                                        .and_then(|f| f.arguments.clone()),
804                                    thought_signature: None,
805                                });
806                            }
807                        }
808
809                        // Check finish reason
810                        if let Some(finish_reason) = &choice.finish_reason {
811                            let reason = match finish_reason {
812                                FinishReason::Stop => "stop",
813                                FinishReason::Length => "length",
814                                FinishReason::ToolCalls => "tool_calls",
815                                FinishReason::ContentFilter => "content_filter",
816                                FinishReason::FunctionCall => "function_call",
817                            };
818                            return Ok(StreamChunk::Finished {
819                                reason: reason.to_string(),
820                                ttft_ms: None,
821                                usage: stream_usage,
822                            });
823                        }
824                    }
825                    if stream_usage.is_some() {
826                        return Ok(StreamChunk::Finished {
827                            reason: "stop".to_string(),
828                            ttft_ms: None,
829                            usage: stream_usage,
830                        });
831                    }
832                    // Empty chunk (no content or tool calls)
833                    Ok(StreamChunk::Content(String::new()))
834                }
835                // IMPORTANT: Propagate through From<OpenAIError> so that rate-limit
836                // errors (code: rate_limit_exceeded) are classified as RateLimited,
837                // not ApiError. Using e.to_string() loses the structured code field.
838                Err(e) => Err(LlmError::from(e)),
839            }
840        });
841
842        Ok(mapped_stream.boxed())
843    }
844
845    fn supports_tool_streaming(&self) -> bool {
846        true
847    }
848
849    fn supports_json_mode(&self) -> bool {
850        // All modern OpenAI chat models support JSON object response format.
851        // Exclude only legacy completions models (text-davinci etc).
852        let m = &self.model;
853        m.contains("gpt-4")
854            || m.contains("gpt-3.5-turbo")
855            || m.contains("gpt-5")
856            || m.starts_with("o1")
857            || m.starts_with("o3")
858            || m.starts_with("o4")
859    }
860}
861
862#[async_trait]
863impl EmbeddingProvider for OpenAIProvider {
864    fn name(&self) -> &str {
865        "openai"
866    }
867
868    /// Returns the embedding model name (not completion model).
869    ///
870    /// Note: `model` field refers to completion, `embedding_model` is for embeddings.
871    #[allow(clippy::misnamed_getters)]
872    fn model(&self) -> &str {
873        &self.embedding_model
874    }
875
876    fn dimension(&self) -> usize {
877        self.embedding_dimension
878    }
879
880    fn max_tokens(&self) -> usize {
881        8191 // OpenAI embedding models support 8191 tokens
882    }
883
884    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
885        if texts.is_empty() {
886            return Ok(Vec::new());
887        }
888
889        // Issue #164: Use lenient HTTP-based embedding to support HuggingFace TEI
890        // and other OpenAI-compatible servers that omit the cosmetic `object` field.
891        let base_url = if self.raw_base_url.is_empty() {
892            "https://api.openai.com/v1".to_string()
893        } else {
894            self.raw_base_url.trim_end_matches('/').to_string()
895        };
896        let url = format!("{}/embeddings", base_url);
897
898        let request_body = serde_json::json!({
899            "model": self.embedding_model,
900            "input": texts,
901            "encoding_format": "float"
902        });
903
904        let http_client = reqwest::Client::new();
905        let mut req = http_client.post(&url).json(&request_body);
906        if !self.raw_api_key.is_empty() {
907            req = req.header("Authorization", format!("Bearer {}", self.raw_api_key));
908        }
909
910        let response = req
911            .send()
912            .await
913            .map_err(|e| LlmError::NetworkError(format!("Embedding request failed: {}", e)))?;
914
915        let status = response.status();
916        let body = response.text().await.map_err(|e| {
917            LlmError::NetworkError(format!("Failed to read embedding response: {}", e))
918        })?;
919
920        if !status.is_success() {
921            return Err(LlmError::ApiError(format!(
922                "Embedding API returned {} {}: {}",
923                status.as_u16(),
924                status.canonical_reason().unwrap_or(""),
925                &body[..body.len().min(500)]
926            )));
927        }
928
929        // Lenient deserialization: `object` fields are optional (HuggingFace TEI omits them)
930        #[derive(serde::Deserialize)]
931        struct LenientEmbeddingResponse {
932            data: Vec<LenientEmbeddingObject>,
933        }
934        #[derive(serde::Deserialize)]
935        struct LenientEmbeddingObject {
936            embedding: Vec<f32>,
937        }
938
939        let parsed: LenientEmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
940            LlmError::InvalidRequest(format!(
941                "Failed to parse embedding response: {} – body: {}",
942                e,
943                &body[..body.len().min(500)]
944            ))
945        })?;
946
947        Ok(parsed.data.into_iter().map(|o| o.embedding).collect())
948    }
949}
950
951#[cfg(test)]
952mod tests {
953    use super::*;
954
955    #[test]
956    fn test_context_length_detection() {
957        assert_eq!(OpenAIProvider::context_length_for_model("gpt-4o"), 128000);
958        assert_eq!(OpenAIProvider::context_length_for_model("gpt-4"), 8192);
959        assert_eq!(
960            OpenAIProvider::context_length_for_model("gpt-3.5-turbo"),
961            4096
962        );
963    }
964
965    #[test]
966    fn test_embedding_dimension_detection() {
967        assert_eq!(
968            OpenAIProvider::dimension_for_model("text-embedding-3-large"),
969            3072
970        );
971        assert_eq!(
972            OpenAIProvider::dimension_for_model("text-embedding-3-small"),
973            1536
974        );
975    }
976
977    #[test]
978    fn test_provider_builder() {
979        let provider = OpenAIProvider::new("test-key")
980            .with_model("gpt-4")
981            .with_embedding_model("text-embedding-3-large");
982
983        assert_eq!(LLMProvider::model(&provider), "gpt-4");
984        assert_eq!(provider.dimension(), 3072);
985    }
986
987    #[test]
988    fn test_message_conversion() {
989        let messages = vec![
990            ChatMessage::system("You are helpful"),
991            ChatMessage::user("Hello"),
992            ChatMessage::assistant("Hi there!"),
993        ];
994
995        let converted = OpenAIProvider::convert_messages(&messages).unwrap();
996        assert_eq!(converted.len(), 3);
997    }
998
999    // ---- Iteration 27: Additional OpenAI tests ----
1000
1001    #[test]
1002    fn test_context_length_gpt5_series() {
1003        assert_eq!(
1004            OpenAIProvider::context_length_for_model("gpt-5.2-turbo"),
1005            200000
1006        );
1007        assert_eq!(
1008            OpenAIProvider::context_length_for_model("gpt-5.1-preview"),
1009            200000
1010        );
1011        assert_eq!(
1012            OpenAIProvider::context_length_for_model("gpt-5-nano"),
1013            128000
1014        );
1015        assert_eq!(
1016            OpenAIProvider::context_length_for_model("gpt-5-mini"),
1017            200000
1018        );
1019        assert_eq!(OpenAIProvider::context_length_for_model("gpt-5"), 200000);
1020    }
1021
1022    #[test]
1023    fn test_context_length_o_series() {
1024        assert_eq!(OpenAIProvider::context_length_for_model("o4-mini"), 200000);
1025        assert_eq!(
1026            OpenAIProvider::context_length_for_model("o3-preview"),
1027            200000
1028        );
1029        assert_eq!(
1030            OpenAIProvider::context_length_for_model("o1-preview"),
1031            200000
1032        );
1033    }
1034
1035    #[test]
1036    fn test_context_length_gpt4_variants() {
1037        assert_eq!(
1038            OpenAIProvider::context_length_for_model("gpt-4-turbo-preview"),
1039            128000
1040        );
1041        assert_eq!(
1042            OpenAIProvider::context_length_for_model("gpt-4-32k-0613"),
1043            32768
1044        );
1045        assert_eq!(OpenAIProvider::context_length_for_model("gpt-4-0613"), 8192);
1046    }
1047
1048    #[test]
1049    fn test_context_length_gpt35_variants() {
1050        assert_eq!(
1051            OpenAIProvider::context_length_for_model("gpt-3.5-turbo-16k"),
1052            16384
1053        );
1054        assert_eq!(
1055            OpenAIProvider::context_length_for_model("gpt-3.5-turbo-1106"),
1056            4096
1057        );
1058    }
1059
1060    #[test]
1061    fn test_context_length_unknown_defaults_high() {
1062        // Unknown models default to 128K (newer default)
1063        assert_eq!(
1064            OpenAIProvider::context_length_for_model("unknown-future-model"),
1065            128000
1066        );
1067    }
1068
1069    #[test]
1070    fn test_dimension_ada_model() {
1071        assert_eq!(
1072            OpenAIProvider::dimension_for_model("text-embedding-ada-002"),
1073            1536
1074        );
1075    }
1076
1077    #[test]
1078    fn test_dimension_unknown_defaults() {
1079        assert_eq!(
1080            OpenAIProvider::dimension_for_model("unknown-embedding"),
1081            1536
1082        );
1083    }
1084
1085    #[test]
1086    fn test_provider_name() {
1087        let provider = OpenAIProvider::new("test-key");
1088        assert_eq!(LLMProvider::name(&provider), "openai");
1089    }
1090
1091    #[test]
1092    fn test_provider_max_context_length() {
1093        let provider = OpenAIProvider::new("test-key").with_model("gpt-4");
1094        assert_eq!(provider.max_context_length(), 8192);
1095    }
1096
1097    #[test]
1098    fn test_provider_dimension() {
1099        let provider =
1100            OpenAIProvider::new("test-key").with_embedding_model("text-embedding-3-large");
1101        assert_eq!(provider.dimension(), 3072);
1102    }
1103
1104    #[test]
1105    fn test_provider_embedding_model() {
1106        let provider =
1107            OpenAIProvider::new("test-key").with_embedding_model("text-embedding-3-small");
1108        assert_eq!(
1109            EmbeddingProvider::model(&provider),
1110            "text-embedding-3-small"
1111        );
1112    }
1113
1114    #[test]
1115    fn test_message_conversion_tool_role() {
1116        // Bug #1 regression: Tool messages must use role=tool + tool_call_id, not role=user.
1117        let messages = vec![ChatMessage::tool_result("call_abc", "result data")];
1118        let converted = OpenAIProvider::convert_messages(&messages).unwrap();
1119        assert_eq!(converted.len(), 1);
1120        match &converted[0] {
1121            ChatCompletionRequestMessage::Tool(m) => {
1122                assert_eq!(m.tool_call_id, "call_abc");
1123            }
1124            other => panic!("Expected Tool message, got {:?}", other),
1125        }
1126    }
1127
1128    #[test]
1129    fn test_tool_message_missing_id_returns_err() {
1130        let mut msg = ChatMessage::user("orphan");
1131        msg.role = ChatRole::Tool;
1132        msg.tool_call_id = None;
1133        let r = OpenAIProvider::convert_messages(&[msg]);
1134        assert!(
1135            r.is_err(),
1136            "Expected Err for tool message without tool_call_id"
1137        );
1138    }
1139
1140    #[test]
1141    fn test_assistant_with_tool_calls_serialized() {
1142        // Bug #2 regression: assistant tool_calls must be propagated.
1143        let calls = vec![ToolCall {
1144            id: "call_xyz".to_string(),
1145            call_type: "function".to_string(),
1146            function: TraitFunctionCall {
1147                name: "get_weather".to_string(),
1148                arguments: r#"{"city":"Paris"}"#.to_string(),
1149            },
1150            thought_signature: None,
1151        }];
1152        let msg = ChatMessage::assistant_with_tools("", calls);
1153        let converted = OpenAIProvider::convert_messages(&[msg]).unwrap();
1154        assert_eq!(converted.len(), 1);
1155        match &converted[0] {
1156            ChatCompletionRequestMessage::Assistant(m) => {
1157                let tcs = m.tool_calls.as_ref().expect("tool_calls must be present");
1158                assert_eq!(tcs.len(), 1);
1159                if let ChatCompletionMessageToolCalls::Function(f) = &tcs[0] {
1160                    assert_eq!(f.id, "call_xyz");
1161                    assert_eq!(f.function.name, "get_weather");
1162                } else {
1163                    panic!("Expected Function tool call");
1164                }
1165            }
1166            other => panic!("Expected Assistant message, got {:?}", other),
1167        }
1168    }
1169
1170    #[test]
1171    fn test_supports_streaming() {
1172        let provider = OpenAIProvider::new("test-key");
1173        assert!(provider.supports_streaming());
1174    }
1175
1176    #[test]
1177    fn test_supports_json_mode_gpt4() {
1178        let provider = OpenAIProvider::new("test-key").with_model("gpt-4o");
1179        assert!(provider.supports_json_mode());
1180    }
1181
1182    #[test]
1183    fn test_supports_json_mode_gpt35() {
1184        let provider = OpenAIProvider::new("test-key").with_model("gpt-3.5-turbo");
1185        assert!(provider.supports_json_mode());
1186    }
1187
1188    #[test]
1189    fn test_supports_json_mode_default_is_false() {
1190        // Old completion-only models do not support JSON mode
1191        let provider = OpenAIProvider::new("test-key").with_model("davinci-002");
1192        assert!(!provider.supports_json_mode());
1193    }
1194
1195    // ---- Vision / multimodal message tests ----
1196
1197    #[test]
1198    fn test_build_user_content_text_only() {
1199        let msg = ChatMessage::user("Hello");
1200        let content = OpenAIProvider::build_user_content(&msg);
1201        match content {
1202            ChatCompletionRequestUserMessageContent::Text(t) => assert_eq!(t, "Hello"),
1203            _ => panic!("Expected text content"),
1204        }
1205    }
1206
1207    #[test]
1208    fn test_build_user_content_with_image() {
1209        use crate::traits::ImageData;
1210        let img = ImageData::new("base64data", "image/png");
1211        let msg = ChatMessage::user_with_images("Describe this", vec![img]);
1212        let content = OpenAIProvider::build_user_content(&msg);
1213        match content {
1214            ChatCompletionRequestUserMessageContent::Array(parts) => {
1215                assert_eq!(parts.len(), 2, "Should have text + image parts");
1216                assert!(
1217                    matches!(
1218                        parts[0],
1219                        ChatCompletionRequestUserMessageContentPart::Text(_)
1220                    ),
1221                    "First part should be text"
1222                );
1223                assert!(
1224                    matches!(
1225                        parts[1],
1226                        ChatCompletionRequestUserMessageContentPart::ImageUrl(_)
1227                    ),
1228                    "Second part should be image_url"
1229                );
1230            }
1231            _ => panic!("Expected array content for vision message"),
1232        }
1233    }
1234
1235    #[test]
1236    fn test_build_user_content_image_data_uri() {
1237        use crate::traits::ImageData;
1238        let img = ImageData::new("abc123", "image/jpeg");
1239        let msg = ChatMessage::user_with_images("What's here?", vec![img]);
1240        let content = OpenAIProvider::build_user_content(&msg);
1241        if let ChatCompletionRequestUserMessageContent::Array(parts) = content {
1242            if let ChatCompletionRequestUserMessageContentPart::ImageUrl(img_part) = &parts[1] {
1243                assert_eq!(
1244                    img_part.image_url.url, "data:image/jpeg;base64,abc123",
1245                    "Data URI should be correct"
1246                );
1247            } else {
1248                panic!("Expected ImageUrl part");
1249            }
1250        } else {
1251            panic!("Expected array content");
1252        }
1253    }
1254
1255    #[test]
1256    fn test_build_user_content_image_with_detail() {
1257        use crate::traits::ImageData;
1258        let img = ImageData::new("data", "image/png").with_detail("high");
1259        let _msg = ChatMessage::user_with_images("Analyze", vec![img]);
1260        let detail = OpenAIProvider::parse_image_detail(
1261            &ImageData::new("x", "image/png").with_detail("high"),
1262        );
1263        assert!(matches!(detail, Some(ImageDetail::High)));
1264    }
1265
1266    #[test]
1267    fn test_parse_image_detail_low() {
1268        use crate::traits::ImageData;
1269        let img = ImageData::new("x", "image/png").with_detail("low");
1270        let d = OpenAIProvider::parse_image_detail(&img);
1271        assert!(matches!(d, Some(ImageDetail::Low)));
1272    }
1273
1274    #[test]
1275    fn test_parse_image_detail_auto() {
1276        use crate::traits::ImageData;
1277        let img = ImageData::new("x", "image/png").with_detail("auto");
1278        let d = OpenAIProvider::parse_image_detail(&img);
1279        assert!(matches!(d, Some(ImageDetail::Auto)));
1280    }
1281
1282    #[test]
1283    fn test_parse_image_detail_none() {
1284        use crate::traits::ImageData;
1285        let img = ImageData::new("x", "image/png");
1286        let d = OpenAIProvider::parse_image_detail(&img);
1287        assert!(d.is_none());
1288    }
1289
1290    #[test]
1291    fn test_convert_messages_with_image_produces_array_content() {
1292        use crate::traits::ImageData;
1293        let img = ImageData::new("iVBORw0KGgo", "image/png");
1294        let messages = vec![
1295            ChatMessage::system("You are a vision assistant"),
1296            ChatMessage::user_with_images("What is in this image?", vec![img]),
1297        ];
1298        let converted = OpenAIProvider::convert_messages(&messages).unwrap();
1299        assert_eq!(converted.len(), 2);
1300        // Verify the user message is a ChatCompletionRequestMessage::User with Array content
1301        // We can validate via JSON serialization
1302        let json = serde_json::to_value(&converted[1]).unwrap();
1303        let content = &json["content"];
1304        assert!(
1305            content.is_array(),
1306            "Vision user message content must be a JSON array, got: {:?}",
1307            content
1308        );
1309        let parts = content.as_array().unwrap();
1310        assert_eq!(parts.len(), 2, "Should have text + image parts");
1311        assert_eq!(parts[0]["type"], "text");
1312        assert_eq!(parts[1]["type"], "image_url");
1313        assert!(parts[1]["image_url"]["url"]
1314            .as_str()
1315            .unwrap()
1316            .starts_with("data:image/png;base64,"));
1317    }
1318
1319    #[test]
1320    fn test_convert_messages_without_image_produces_text_content() {
1321        let messages = vec![ChatMessage::user("Just text")];
1322        let converted = OpenAIProvider::convert_messages(&messages).unwrap();
1323        let json = serde_json::to_value(&converted[0]).unwrap();
1324        let content = &json["content"];
1325        assert!(
1326            content.is_string(),
1327            "Plain text user message content must be a JSON string"
1328        );
1329        assert_eq!(content.as_str().unwrap(), "Just text");
1330    }
1331
1332    // ---- async-openai 0.33 upgrade tests ----
1333
1334    /// Test that tools are correctly wrapped in ChatCompletionTools::Function.
1335    /// In 0.33, the tools parameter takes Vec<ChatCompletionTools> (enum) not
1336    /// Vec<ChatCompletionTool> (struct).
1337    #[test]
1338    fn test_chat_completion_tools_function_wrapping() {
1339        use crate::traits::FunctionDefinition;
1340        let tool_def = ToolDefinition {
1341            tool_type: "function".to_string(),
1342            function: FunctionDefinition {
1343                name: "get_weather".to_string(),
1344                description: "Get the current weather".to_string(),
1345                parameters: serde_json::json!({
1346                    "type": "object",
1347                    "properties": {
1348                        "location": { "type": "string" }
1349                    },
1350                    "required": ["location"]
1351                }),
1352                strict: None,
1353            },
1354        };
1355
1356        // Build the tool using the new 0.33 pattern
1357        let openai_tool = ChatCompletionTools::Function(ChatCompletionTool {
1358            function: FunctionObjectArgs::default()
1359                .name(&tool_def.function.name)
1360                .description(&tool_def.function.description)
1361                .parameters(tool_def.function.parameters.clone())
1362                .build()
1363                .unwrap(),
1364        });
1365
1366        // Verify it serializes correctly
1367        let json = serde_json::to_value(&openai_tool).unwrap();
1368        assert_eq!(json["type"], "function");
1369        assert_eq!(json["function"]["name"], "get_weather");
1370        assert_eq!(json["function"]["description"], "Get the current weather");
1371    }
1372
1373    /// Test that ToolChoiceOptions::Auto/Required serialize correctly.
1374    /// In 0.33, these are behind ChatCompletionToolChoiceOption::Mode(...)
1375    #[test]
1376    fn test_tool_choice_auto_serialization() {
1377        let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::Auto);
1378        let json = serde_json::to_value(&choice).unwrap();
1379        assert_eq!(json, "auto");
1380    }
1381
1382    #[test]
1383    fn test_tool_choice_required_serialization() {
1384        let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::Required);
1385        let json = serde_json::to_value(&choice).unwrap();
1386        assert_eq!(json, "required");
1387    }
1388
1389    #[test]
1390    fn test_tool_choice_none_serialization() {
1391        let choice = ChatCompletionToolChoiceOption::Mode(ToolChoiceOptions::None);
1392        let json = serde_json::to_value(&choice).unwrap();
1393        assert_eq!(json, "none");
1394    }
1395
1396    /// Test that max_completion_tokens is correctly serialized in the request.
1397    /// This verifies the fix for issue #13 — models like o1/o3/o4 and gpt-4.1
1398    /// require max_completion_tokens, not the deprecated max_tokens.
1399    #[test]
1400    fn test_max_completion_tokens_in_request_serialization() {
1401        let request = CreateChatCompletionRequestArgs::default()
1402            .model("o3-mini")
1403            .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1404                .content("Hello")
1405                .build()
1406                .unwrap()
1407                .into()])
1408            .max_completion_tokens(1024u32)
1409            .build()
1410            .unwrap();
1411
1412        let json = serde_json::to_value(&request).unwrap();
1413        assert_eq!(
1414            json["max_completion_tokens"], 1024,
1415            "max_completion_tokens should be set in request"
1416        );
1417        assert!(
1418            json["max_tokens"].is_null(),
1419            "deprecated max_tokens should NOT be set"
1420        );
1421    }
1422
1423    /// Test that max_completion_tokens works for legacy models too.
1424    /// In 0.33, all models accept max_completion_tokens — old max_tokens is deprecated.
1425    #[test]
1426    fn test_max_completion_tokens_works_for_all_models() {
1427        for model in &[
1428            "gpt-4o",
1429            "gpt-3.5-turbo",
1430            "o1-preview",
1431            "o3-mini",
1432            "gpt-4.1-nano",
1433        ] {
1434            let request = CreateChatCompletionRequestArgs::default()
1435                .model(*model)
1436                .messages(vec![ChatCompletionRequestUserMessageArgs::default()
1437                    .content("Test")
1438                    .build()
1439                    .unwrap()
1440                    .into()])
1441                .max_completion_tokens(512u32)
1442                .build()
1443                .unwrap();
1444
1445            let json = serde_json::to_value(&request).unwrap();
1446            assert_eq!(
1447                json["max_completion_tokens"], 512,
1448                "max_completion_tokens should be set for model {}",
1449                model
1450            );
1451        }
1452    }
1453
1454    /// Test cache hit token extraction from PromptTokensDetails.
1455    /// This validates the new async-openai 0.33 feature used in chat() method.
1456    #[test]
1457    fn test_cache_hit_token_extraction() {
1458        use async_openai::types::chat::PromptTokensDetails;
1459
1460        let usage = CompletionUsage {
1461            prompt_tokens: 100,
1462            completion_tokens: 50,
1463            total_tokens: 150,
1464            prompt_tokens_details: Some(PromptTokensDetails {
1465                cached_tokens: Some(80),
1466                audio_tokens: None,
1467            }),
1468            completion_tokens_details: None,
1469        };
1470
1471        let cache_hit_tokens = usage
1472            .prompt_tokens_details
1473            .as_ref()
1474            .and_then(|d| d.cached_tokens)
1475            .map(|t| t as usize);
1476
1477        assert_eq!(cache_hit_tokens, Some(80));
1478    }
1479
1480    /// Test reasoning token extraction from CompletionTokensDetails.
1481    /// This validates the new async-openai 0.33 feature for o-series models.
1482    #[test]
1483    fn test_reasoning_token_extraction() {
1484        use async_openai::types::chat::CompletionTokensDetails;
1485
1486        let usage = CompletionUsage {
1487            prompt_tokens: 50,
1488            completion_tokens: 200,
1489            total_tokens: 250,
1490            prompt_tokens_details: None,
1491            completion_tokens_details: Some(CompletionTokensDetails {
1492                reasoning_tokens: Some(150),
1493                audio_tokens: None,
1494                accepted_prediction_tokens: None,
1495                rejected_prediction_tokens: None,
1496            }),
1497        };
1498
1499        let thinking_tokens = usage
1500            .completion_tokens_details
1501            .as_ref()
1502            .and_then(|d| d.reasoning_tokens)
1503            .map(|t| t as usize);
1504
1505        assert_eq!(thinking_tokens, Some(150));
1506    }
1507
1508    /// Test that missing token details returns None (no panic).
1509    #[test]
1510    fn test_token_details_none_is_safe() {
1511        let usage = CompletionUsage {
1512            prompt_tokens: 10,
1513            completion_tokens: 20,
1514            total_tokens: 30,
1515            prompt_tokens_details: None,
1516            completion_tokens_details: None,
1517        };
1518
1519        let cache_hit = usage
1520            .prompt_tokens_details
1521            .as_ref()
1522            .and_then(|d| d.cached_tokens)
1523            .map(|t| t as usize);
1524
1525        let reasoning = usage
1526            .completion_tokens_details
1527            .as_ref()
1528            .and_then(|d| d.reasoning_tokens)
1529            .map(|t| t as usize);
1530
1531        assert_eq!(cache_hit, None);
1532        assert_eq!(reasoning, None);
1533    }
1534
1535    /// Test that FinishReason variants still exist and format correctly.
1536    /// Validates no regression in finish reason handling after 0.33 upgrade.
1537    #[test]
1538    fn test_finish_reason_variants() {
1539        let cases = vec![
1540            (FinishReason::Stop, "Stop"),
1541            (FinishReason::Length, "Length"),
1542            (FinishReason::ToolCalls, "ToolCalls"),
1543            (FinishReason::ContentFilter, "ContentFilter"),
1544            (FinishReason::FunctionCall, "FunctionCall"),
1545        ];
1546
1547        for (reason, expected_debug) in cases {
1548            let formatted = format!("{:?}", reason);
1549            assert_eq!(
1550                formatted, expected_debug,
1551                "FinishReason::{} should format as {:?}",
1552                expected_debug, expected_debug
1553            );
1554        }
1555    }
1556
1557    /// Test OpenAIError::JSONDeserialize now takes 2 args in 0.33.
1558    /// Validates that our error conversion handles the new signature.
1559    #[test]
1560    fn test_json_deserialize_error_conversion() {
1561        use crate::error::LlmError;
1562        // Simulate the 2-arg variant matching
1563        let serde_err = serde_json::from_str::<serde_json::Value>("invalid json {{").unwrap_err();
1564        let openai_err = async_openai::error::OpenAIError::JSONDeserialize(
1565            serde_err,
1566            "invalid json {{".to_string(),
1567        );
1568        let llm_err = LlmError::from(openai_err);
1569        assert!(
1570            matches!(llm_err, LlmError::SerializationError(_)),
1571            "JSONDeserialize error should convert to SerializationError"
1572        );
1573    }
1574
1575    /// Test that ChatCompletionTool no longer has a type field in 0.33.
1576    /// In 0.24 it had `r#type: ChatCompletionToolType::Function` but
1577    /// in 0.33 the type is encoded in the ChatCompletionTools enum variant.
1578    #[test]
1579    fn test_chat_completion_tool_serialization() {
1580        let tool = ChatCompletionTool {
1581            function: FunctionObjectArgs::default()
1582                .name("my_func")
1583                .description("A test function")
1584                .parameters(serde_json::json!({"type": "object"}))
1585                .build()
1586                .unwrap(),
1587        };
1588        let wrapped = ChatCompletionTools::Function(tool);
1589        let json = serde_json::to_value(&wrapped).unwrap();
1590
1591        // In 0.33, type is determined by the enum variant tag
1592        assert_eq!(json["type"], "function");
1593        assert_eq!(json["function"]["name"], "my_func");
1594    }
1595}