Skip to main content

everruns_core/
openai_protocol.rs

1// OpenAI Protocol LLM Driver
2//
3// Base implementation of the OpenAI chat completion protocol.
4// This driver can be used with any OpenAI-compatible API endpoint.
5//
6// Rate limit handling: On 429 errors, the driver automatically retries with
7// exponential backoff, respecting x-ratelimit-reset-* and retry-after headers.
8// Retry metadata is included in the response for observability.
9//
10// This is the base protocol implementation used in examples.
11// For production use with OpenAI-specific features, use OpenAILlmDriver from everruns-openai.
12//
13// Note: OTel instrumentation is handled via the event-listener pattern.
14// llm.generation events are emitted by ReasonAtom, and OtelEventListener
15// creates the appropriate gen-ai spans. No direct tracing in drivers.
16
17use async_trait::async_trait;
18use eventsource_stream::Eventsource;
19use futures::StreamExt;
20use reqwest::{Client, RequestBuilder, Url};
21use serde::{Deserialize, Serialize};
22use serde_json::{Value, json};
23use std::sync::{Arc, Mutex};
24
25use crate::error::{AgentLoopError, Result};
26use crate::llm_driver_registry::{
27    LlmCallConfig, LlmCompletionMetadata, LlmContentPart, LlmDriver, LlmMessage, LlmMessageContent,
28    LlmMessageRole, LlmResponseStream, LlmStreamEvent,
29};
30use crate::llm_retry::{
31    LlmRetryConfig, RateLimitInfo, RetryMetadata, is_rate_limit_status, is_transient_error,
32};
33use crate::tool_types::{ToolCall, ToolDefinition};
34
35const DEFAULT_API_URL: &str = "https://api.openai.com/v1/chat/completions";
36
37pub(crate) fn apply_openai_api_auth(
38    request: RequestBuilder,
39    api_url: &str,
40    api_key: &str,
41) -> RequestBuilder {
42    if is_azure_openai_api_url(api_url) {
43        request.header("api-key", api_key)
44    } else {
45        request.header("Authorization", format!("Bearer {}", api_key))
46    }
47}
48
49pub fn is_azure_openai_api_url(api_url: &str) -> bool {
50    Url::parse(api_url)
51        .ok()
52        .and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
53        .is_some_and(|host| {
54            host.ends_with(".openai.azure.com") || host.ends_with(".services.ai.azure.com")
55        })
56}
57
58/// Whether `api_url` points at OpenAI's hosted API (`api.openai.com`).
59///
60/// Host-based (not prefix-based) so it tolerates ports and trailing paths.
61pub fn is_openai_api_url(api_url: &str) -> bool {
62    Url::parse(api_url)
63        .ok()
64        .and_then(|url| url.host_str().map(|host| host.to_ascii_lowercase()))
65        .is_some_and(|host| host == "api.openai.com")
66}
67
68/// OpenAI Protocol LLM Driver
69///
70/// Base implementation of `LlmDriver` for OpenAI-compatible APIs.
71/// Supports streaming responses and tool calls.
72///
73/// Rate limit handling: On 429 errors, automatically retries with exponential
74/// backoff, respecting `x-ratelimit-reset-*` and `retry-after` headers.
75///
76/// This is the base protocol driver used in examples and for OpenAI-compatible endpoints.
77/// For production use with OpenAI, consider using `OpenAILlmDriver` from the `everruns-openai` crate.
78///
79/// # Example
80///
81/// ```ignore
82/// use everruns_core::OpenAIProtocolLlmDriver;
83///
84/// let driver = OpenAIProtocolLlmDriver::from_env()?;
85/// // or
86/// let driver = OpenAIProtocolLlmDriver::new("your-api-key");
87/// // or with custom endpoint
88/// let driver = OpenAIProtocolLlmDriver::with_base_url("your-api-key", "https://api.example.com/v1/chat/completions");
89/// // or with custom retry config
90/// let driver = OpenAIProtocolLlmDriver::new("your-api-key")
91///     .with_retry_config(LlmRetryConfig::aggressive());
92/// ```
93#[derive(Clone)]
94pub struct OpenAIProtocolLlmDriver {
95    client: Client,
96    api_key: String,
97    api_url: String,
98    /// Retry configuration for rate limit errors
99    retry_config: LlmRetryConfig,
100}
101
102impl OpenAIProtocolLlmDriver {
103    /// Create a new driver with the given API key
104    pub fn new(api_key: impl Into<String>) -> Self {
105        Self {
106            client: Client::new(),
107            api_key: api_key.into(),
108            api_url: DEFAULT_API_URL.to_string(),
109            retry_config: LlmRetryConfig::default(),
110        }
111    }
112
113    /// Create a new driver from the OPENAI_API_KEY environment variable
114    pub fn from_env() -> Result<Self> {
115        let api_key = std::env::var("OPENAI_API_KEY")
116            .map_err(|_| AgentLoopError::llm("OPENAI_API_KEY environment variable not set"))?;
117        Ok(Self::new(api_key))
118    }
119
120    /// Create a new driver with a custom API URL (for OpenAI-compatible APIs)
121    pub fn with_base_url(api_key: impl Into<String>, api_url: impl Into<String>) -> Self {
122        Self {
123            client: Client::new(),
124            api_key: api_key.into(),
125            api_url: api_url.into(),
126            retry_config: LlmRetryConfig::default(),
127        }
128    }
129
130    /// Configure retry behavior for rate limit errors
131    pub fn with_retry_config(mut self, config: LlmRetryConfig) -> Self {
132        self.retry_config = config;
133        self
134    }
135
136    /// Get the API URL
137    pub fn api_url(&self) -> &str {
138        &self.api_url
139    }
140
141    /// Get the API key (for subclass access)
142    pub fn api_key(&self) -> &str {
143        &self.api_key
144    }
145
146    /// Get the HTTP client (for subclass access)
147    pub fn client(&self) -> &Client {
148        &self.client
149    }
150
151    fn convert_role(role: &LlmMessageRole) -> &'static str {
152        match role {
153            LlmMessageRole::System => "system",
154            LlmMessageRole::User => "user",
155            LlmMessageRole::Assistant => "assistant",
156            LlmMessageRole::Tool => "tool",
157        }
158    }
159
160    fn convert_message(msg: &LlmMessage) -> OpenAiMessage {
161        let content = match &msg.content {
162            LlmMessageContent::Text(text) => OpenAiContent::Text(text.clone()),
163            LlmMessageContent::Parts(parts) => {
164                let openai_parts: Vec<OpenAiContentPart> = parts
165                    .iter()
166                    .map(|part| match part {
167                        LlmContentPart::Text { text } => OpenAiContentPart::Text {
168                            r#type: "text".to_string(),
169                            text: text.clone(),
170                        },
171                        LlmContentPart::Image { url } => OpenAiContentPart::ImageUrl {
172                            r#type: "image_url".to_string(),
173                            image_url: OpenAiImageUrl { url: url.clone() },
174                        },
175                        LlmContentPart::Audio { url } => OpenAiContentPart::InputAudio {
176                            r#type: "input_audio".to_string(),
177                            input_audio: OpenAiInputAudio {
178                                data: url.clone(),
179                                format: "wav".to_string(),
180                            },
181                        },
182                    })
183                    .collect();
184                OpenAiContent::Parts(openai_parts)
185            }
186        };
187
188        // OpenAI only accepts tool_calls on assistant messages
189        let tool_calls = if msg.role == LlmMessageRole::Assistant {
190            msg.tool_calls.as_ref().map(|calls| {
191                calls
192                    .iter()
193                    .map(|tc| OpenAiToolCall {
194                        id: tc.id.clone(),
195                        r#type: "function".to_string(),
196                        function: OpenAiFunctionCall {
197                            name: tc.name.clone(),
198                            arguments: serde_json::to_string(&tc.arguments).unwrap_or_default(),
199                        },
200                    })
201                    .collect()
202            })
203        } else {
204            None
205        };
206
207        OpenAiMessage {
208            role: Self::convert_role(&msg.role).to_string(),
209            content: Some(content),
210            tool_calls,
211            tool_call_id: msg.tool_call_id.clone(),
212        }
213    }
214
215    fn convert_tools(tools: &[ToolDefinition]) -> Vec<OpenAiTool> {
216        tools
217            .iter()
218            .map(|tool| OpenAiTool {
219                r#type: "function".to_string(),
220                function: OpenAiFunction {
221                    name: tool.name().to_string(),
222                    description: tool.description().to_string(),
223                    parameters: tool.parameters().clone(),
224                },
225            })
226            .collect()
227    }
228}
229
230/// Drop Tool-role messages whose tool_call_id has no matching assistant tool call in the
231/// visible window. Chat Completions rejects payloads where a `tool`-role message references
232/// a call that is absent from the conversation.
233fn drop_orphaned_tool_messages(messages: &[LlmMessage]) -> Vec<LlmMessage> {
234    use std::collections::HashSet;
235
236    let visible_call_ids: HashSet<&str> = messages
237        .iter()
238        .filter(|m| m.role == LlmMessageRole::Assistant)
239        .flat_map(|m| m.tool_calls.iter().flatten())
240        .map(|tc| tc.id.as_str())
241        .collect();
242
243    if visible_call_ids.is_empty() {
244        return messages
245            .iter()
246            .filter(|m| m.role != LlmMessageRole::Tool)
247            .cloned()
248            .collect();
249    }
250
251    messages
252        .iter()
253        .filter(|m| {
254            if m.role == LlmMessageRole::Tool {
255                return m
256                    .tool_call_id
257                    .as_deref()
258                    .is_none_or(|id| visible_call_ids.contains(id));
259            }
260            true
261        })
262        .cloned()
263        .collect()
264}
265
266#[async_trait]
267impl LlmDriver for OpenAIProtocolLlmDriver {
268    async fn chat_completion_stream(
269        &self,
270        messages: Vec<LlmMessage>,
271        config: &LlmCallConfig,
272    ) -> Result<LlmResponseStream> {
273        // Note: OTel instrumentation is handled via event listeners.
274        // ReasonAtom emits llm.generation events, and OtelEventListener
275        // creates gen-ai spans from those events.
276        let messages = drop_orphaned_tool_messages(&messages);
277        let openai_messages: Vec<OpenAiMessage> =
278            messages.iter().map(Self::convert_message).collect();
279
280        let tools = if config.tools.is_empty() {
281            None
282        } else {
283            Some(Self::convert_tools(&config.tools))
284        };
285
286        // Build metadata for request tracking
287        let metadata = if config.metadata.is_empty() {
288            None
289        } else {
290            Some(config.metadata.clone())
291        };
292
293        let request = OpenAiRequest {
294            model: config.model.clone(),
295            messages: openai_messages,
296            temperature: config.temperature,
297            max_tokens: config.max_tokens,
298            stream: true,
299            stream_options: Some(OpenAiStreamOptions {
300                include_usage: true,
301            }),
302            tools,
303            // Skip "none" — sending reasoning_effort to non-thinking models causes API errors
304            reasoning_effort: config
305                .reasoning_effort
306                .as_ref()
307                .filter(|e| !e.eq_ignore_ascii_case("none"))
308                .cloned(),
309            metadata,
310        };
311
312        // Retry loop for rate limit (429) and transient errors
313        let mut retry_metadata = RetryMetadata::default();
314        let mut last_error: Option<String> = None;
315
316        let response = loop {
317            let response = apply_openai_api_auth(
318                self.client.post(&self.api_url),
319                &self.api_url,
320                &self.api_key,
321            )
322            .header("Content-Type", "application/json")
323            .json(&request)
324            .send()
325            .await
326            .map_err(|e| AgentLoopError::llm(format!("Failed to send request: {}", e)))?;
327
328            let status = response.status();
329
330            if status.is_success() {
331                // Success - exit retry loop
332                break response;
333            }
334
335            // Check if this is a retryable error
336            if is_transient_error(status) && retry_metadata.attempts < self.retry_config.max_retries
337            {
338                // Parse rate limit info from headers before consuming response body
339                let rate_limit_info = if is_rate_limit_status(status) {
340                    Some(RateLimitInfo::from_openai_headers(response.headers()))
341                } else {
342                    None
343                };
344
345                let error_text = response.text().await.unwrap_or_default();
346
347                // Don't retry if this is a request-too-large error (not transient)
348                if is_openai_request_too_large(status, &error_text) {
349                    return Err(AgentLoopError::request_too_large(format!(
350                        "OpenAI API error ({}): {}",
351                        status, error_text
352                    )));
353                }
354
355                // Calculate wait duration
356                let wait_duration = rate_limit_info
357                    .as_ref()
358                    .map(|info| info.recommended_wait(&self.retry_config, retry_metadata.attempts))
359                    .unwrap_or_else(|| {
360                        self.retry_config.calculate_backoff(retry_metadata.attempts)
361                    });
362
363                tracing::warn!(
364                    status = %status,
365                    attempt = retry_metadata.attempts + 1,
366                    max_retries = self.retry_config.max_retries,
367                    wait_secs = wait_duration.as_secs_f64(),
368                    retry_after = ?rate_limit_info.as_ref().and_then(|i| i.retry_after_secs),
369                    "OpenAIProtocolDriver: rate limit or transient error, retrying"
370                );
371
372                // Record retry attempt
373                retry_metadata.record_retry(wait_duration, rate_limit_info);
374                last_error = Some(error_text);
375
376                // Wait before retry
377                tokio::time::sleep(wait_duration).await;
378                continue;
379            }
380
381            // Non-retryable error or max retries exceeded
382            let error_text = response.text().await.unwrap_or_default();
383            let error_msg = format!("OpenAI API error ({}): {}", status, error_text);
384
385            // Check if this is a model-not-found error
386            if is_openai_model_not_found(status, &error_text) {
387                return Err(AgentLoopError::model_not_available(config.model.clone()));
388            }
389
390            // Check if this is a request-too-large error
391            if is_openai_request_too_large(status, &error_text) {
392                return Err(AgentLoopError::request_too_large(error_msg));
393            }
394
395            // If we exhausted retries, include that in the error message
396            if retry_metadata.attempts > 0 {
397                return Err(AgentLoopError::llm(format!(
398                    "{} (after {} retries, last error: {})",
399                    error_msg,
400                    retry_metadata.attempts,
401                    last_error.unwrap_or_default()
402                )));
403            }
404
405            return Err(AgentLoopError::llm(error_msg));
406        };
407
408        // Log successful retry recovery
409        if retry_metadata.had_retries() {
410            tracing::info!(
411                attempts = retry_metadata.attempts,
412                total_wait_secs = retry_metadata.total_retry_wait.as_secs_f64(),
413                "OpenAIProtocolDriver: request succeeded after retries"
414            );
415        }
416
417        let byte_stream = response.bytes_stream();
418        let event_stream = byte_stream.eventsource();
419
420        let model = config.model.clone();
421        let total_tokens = Arc::new(Mutex::new(0u32));
422        let prompt_tokens = Arc::new(Mutex::new(0u32));
423        let cache_read_tokens = Arc::new(Mutex::new(Option::<u32>::None));
424        // OpenAI-compatible gateways (e.g. OpenRouter) report an authoritative
425        // per-request cost in `usage.cost`; direct OpenAI leaves it absent.
426        let provider_cost_usd = Arc::new(Mutex::new(Option::<f64>::None));
427        let accumulated_tool_calls = Arc::new(Mutex::new(Vec::<ToolCall>::new()));
428        let finish_reason = Arc::new(Mutex::new(Option::<String>::None));
429        // Share retry metadata with stream closure (only set if retries occurred)
430        let shared_retry_metadata = if retry_metadata.had_retries() {
431            Some(Arc::new(retry_metadata))
432        } else {
433            None
434        };
435
436        // Each SSE event maps to zero-or-more stream events (the [DONE] marker can
437        // emit a flushed ToolCalls plus Done), so the closure yields a Vec that is
438        // flattened back into the stream.
439        let converted_stream: LlmResponseStream = Box::pin(
440            event_stream
441                .then(move |result| {
442                    let model = model.clone();
443                    let total_tokens = Arc::clone(&total_tokens);
444                    let prompt_tokens = Arc::clone(&prompt_tokens);
445                    let cache_read_tokens = Arc::clone(&cache_read_tokens);
446                    let provider_cost_usd = Arc::clone(&provider_cost_usd);
447                    let accumulated_tool_calls = Arc::clone(&accumulated_tool_calls);
448                    let finish_reason = Arc::clone(&finish_reason);
449                    let retry_metadata_for_done = shared_retry_metadata.clone();
450
451                    async move {
452                        let event = match result {
453                            Ok(event) => event,
454                            Err(e) => {
455                                return vec![Ok(LlmStreamEvent::Error(format!(
456                                    "Stream error: {}",
457                                    e
458                                )))];
459                            }
460                        };
461
462                        if event.data == "[DONE]" {
463                            let output_tokens = *total_tokens.lock().unwrap();
464                            let input_tokens = *prompt_tokens.lock().unwrap();
465                            let cached = *cache_read_tokens.lock().unwrap();
466                            let cost = *provider_cost_usd.lock().unwrap();
467                            let mut reason = finish_reason.lock().unwrap().clone();
468
469                            let mut events = Vec::new();
470
471                            // Defense in depth (EVE-522): flush any tool calls that
472                            // were accumulated but never emitted before Done, so they
473                            // are never silently dropped. The normal path drains the
474                            // accumulator at the finish chunk, so this only fires as a
475                            // fallback — e.g. a provider that ends the stream with
476                            // [DONE] without a tool_calls finish chunk reaching the
477                            // handler. When it fires, reflect the tool-call completion
478                            // in the reported finish_reason.
479                            {
480                                let mut acc = accumulated_tool_calls.lock().unwrap();
481                                if let Some(event) = take_pending_tool_calls(&mut acc) {
482                                    events.push(Ok(event));
483                                    reason.get_or_insert_with(|| "tool_calls".to_string());
484                                }
485                            }
486
487                            events.push(Ok(LlmStreamEvent::Done(Box::new(
488                                LlmCompletionMetadata {
489                                    total_tokens: Some(input_tokens + output_tokens),
490                                    prompt_tokens: Some(input_tokens),
491                                    completion_tokens: Some(output_tokens),
492                                    cache_read_tokens: cached,
493                                    cache_creation_tokens: None,
494                                    provider_cost_usd: cost,
495                                    model: Some(model),
496                                    finish_reason: reason.or_else(|| Some("stop".to_string())),
497                                    retry_metadata: retry_metadata_for_done
498                                        .map(|arc| (*arc).clone()),
499                                    response_id: None,
500                                    phase: None,
501                                },
502                            ))));
503
504                            return events;
505                        }
506
507                        match serde_json::from_str::<OpenAiStreamChunk>(&event.data) {
508                            Ok(chunk) => {
509                                // Capture usage from chunk if available
510                                if let Some(usage) = &chunk.usage {
511                                    if let Some(pt) = usage.prompt_tokens {
512                                        *prompt_tokens.lock().unwrap() = pt;
513                                    }
514                                    if let Some(ct) = usage.completion_tokens {
515                                        *total_tokens.lock().unwrap() = ct;
516                                    }
517                                    // Capture cached tokens from prompt_tokens_details
518                                    if let Some(details) = &usage.prompt_tokens_details
519                                        && details.cached_tokens.is_some()
520                                    {
521                                        *cache_read_tokens.lock().unwrap() = details.cached_tokens;
522                                    }
523                                    // Authoritative cost from OpenAI-compatible gateways
524                                    // (e.g. OpenRouter `usage.cost`, in USD credits).
525                                    if usage.cost.is_some() {
526                                        *provider_cost_usd.lock().unwrap() = usage.cost;
527                                    }
528                                }
529
530                                if let Some(choice) = chunk.choices.first() {
531                                    let mut tt = total_tokens.lock().unwrap();
532                                    let mut acc = accumulated_tool_calls.lock().unwrap();
533                                    let mut fr = finish_reason.lock().unwrap();
534                                    let stream_event =
535                                        process_stream_choice(choice, &mut tt, &mut acc, &mut fr);
536                                    return vec![Ok(stream_event)];
537                                }
538                                vec![Ok(LlmStreamEvent::TextDelta(String::new()))]
539                            }
540                            Err(e) => vec![Ok(LlmStreamEvent::Error(format!(
541                                "Failed to parse chunk: {}",
542                                e
543                            )))],
544                        }
545                    }
546                })
547                .flat_map(futures::stream::iter),
548        );
549
550        Ok(converted_stream)
551    }
552}
553
554impl std::fmt::Debug for OpenAIProtocolLlmDriver {
555    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
556        f.debug_struct("OpenAIProtocolLlmDriver")
557            .field("api_url", &self.api_url)
558            .field("api_key", &"[REDACTED]")
559            .finish()
560    }
561}
562
563// ============================================================================
564// Error Detection Helpers
565// ============================================================================
566
567/// Check if the error indicates the model was not found.
568///
569/// OpenAI returns 404 or 400 with `"model_not_found"` code or `"does not exist"` message.
570/// OpenAI can also return 403 with `"model_not_found"` for tier-gated models — these must
571/// be classified as model_unavailable rather than provider_misconfigured.
572/// Also handles Gemini/OpenAI-compatible endpoints with similar patterns.
573pub fn is_openai_model_not_found(status: reqwest::StatusCode, error_text: &str) -> bool {
574    let error_lower = error_text.to_lowercase();
575
576    // OpenAI can return 404, 400, or 403 (tier-gated access) for nonexistent/inaccessible models
577    if status == reqwest::StatusCode::NOT_FOUND
578        || status == reqwest::StatusCode::BAD_REQUEST
579        || status == reqwest::StatusCode::FORBIDDEN
580    {
581        // OpenAI: {"error":{"code":"model_not_found","message":"The model 'x' does not exist"}}
582        if error_lower.contains("model_not_found") {
583            return true;
584        }
585    }
586
587    // 404 with generic model-not-found patterns
588    if status == reqwest::StatusCode::NOT_FOUND {
589        if error_lower.contains("does not exist") {
590            return true;
591        }
592        if error_lower.contains("model") && error_lower.contains("not found") {
593            return true;
594        }
595    }
596
597    false
598}
599
600/// Check if an OpenAI API error indicates the request is too large.
601///
602/// Detects:
603/// - 429 with "Request too large" or token limit messages
604/// - 400 with "context_length_exceeded" code
605/// - Any message about maximum context length being exceeded
606pub fn is_openai_request_too_large(status: reqwest::StatusCode, error_text: &str) -> bool {
607    let error_lower = error_text.to_lowercase();
608
609    // HTTP 429 with token-related errors
610    if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
611        // "Request too large for gpt-4" pattern
612        if error_lower.contains("request too large") {
613            return true;
614        }
615        // Token limit errors: "tokens per min (TPM): Limit X, Requested Y"
616        if error_lower.contains("tokens") && error_lower.contains("limit") {
617            return true;
618        }
619    }
620
621    // HTTP 400 with context length errors
622    if status == reqwest::StatusCode::BAD_REQUEST {
623        // "context_length_exceeded" error code
624        if error_lower.contains("context_length_exceeded") {
625            return true;
626        }
627        // "maximum context length" message
628        if error_lower.contains("maximum context length") {
629            return true;
630        }
631    }
632
633    // Generic patterns that could appear with various status codes
634    if error_lower.contains("tokens must be reduced")
635        || error_lower.contains("reduce the length")
636        || error_lower.contains("input is too long")
637    {
638        return true;
639    }
640
641    false
642}
643
644// ============================================================================
645// OpenAI API Types
646// ============================================================================
647
648#[derive(Debug, Serialize)]
649struct OpenAiRequest {
650    model: String,
651    messages: Vec<OpenAiMessage>,
652    #[serde(skip_serializing_if = "Option::is_none")]
653    temperature: Option<f32>,
654    #[serde(skip_serializing_if = "Option::is_none")]
655    max_tokens: Option<u32>,
656    stream: bool,
657    /// Request usage info in streaming response (required for token counts)
658    #[serde(skip_serializing_if = "Option::is_none")]
659    stream_options: Option<OpenAiStreamOptions>,
660    #[serde(skip_serializing_if = "Option::is_none")]
661    tools: Option<Vec<OpenAiTool>>,
662    #[serde(skip_serializing_if = "Option::is_none")]
663    reasoning_effort: Option<String>,
664    /// Metadata for tracking API usage (up to 16 key-value pairs).
665    /// Useful for correlating requests with session_id, agent_id, org_id, etc.
666    #[serde(skip_serializing_if = "Option::is_none")]
667    metadata: Option<std::collections::HashMap<String, String>>,
668}
669
670#[derive(Debug, Serialize)]
671struct OpenAiStreamOptions {
672    include_usage: bool,
673}
674
675#[derive(Debug, Serialize, Deserialize)]
676#[serde(untagged)]
677enum OpenAiContent {
678    Text(String),
679    Parts(Vec<OpenAiContentPart>),
680}
681
682#[derive(Debug, Serialize, Deserialize)]
683#[serde(untagged)]
684enum OpenAiContentPart {
685    Text {
686        r#type: String,
687        text: String,
688    },
689    ImageUrl {
690        r#type: String,
691        image_url: OpenAiImageUrl,
692    },
693    InputAudio {
694        r#type: String,
695        input_audio: OpenAiInputAudio,
696    },
697}
698
699#[derive(Debug, Serialize, Deserialize)]
700struct OpenAiImageUrl {
701    url: String,
702}
703
704#[derive(Debug, Serialize, Deserialize)]
705struct OpenAiInputAudio {
706    data: String,
707    format: String,
708}
709
710#[derive(Debug, Serialize, Deserialize)]
711struct OpenAiMessage {
712    role: String,
713    #[serde(skip_serializing_if = "Option::is_none")]
714    content: Option<OpenAiContent>,
715    #[serde(skip_serializing_if = "Option::is_none")]
716    tool_calls: Option<Vec<OpenAiToolCall>>,
717    #[serde(skip_serializing_if = "Option::is_none")]
718    tool_call_id: Option<String>,
719}
720
721#[derive(Debug, Serialize, Deserialize)]
722struct OpenAiTool {
723    r#type: String,
724    function: OpenAiFunction,
725}
726
727#[derive(Debug, Serialize, Deserialize)]
728struct OpenAiFunction {
729    name: String,
730    description: String,
731    parameters: Value,
732}
733
734#[derive(Debug, Serialize, Deserialize)]
735struct OpenAiToolCall {
736    id: String,
737    r#type: String,
738    function: OpenAiFunctionCall,
739}
740
741#[derive(Debug, Serialize, Deserialize)]
742struct OpenAiFunctionCall {
743    name: String,
744    arguments: String,
745}
746
747#[derive(Debug, Deserialize)]
748#[allow(dead_code)] // id and model are deserialized but used by event listeners, not directly
749struct OpenAiStreamChunk {
750    /// Unique identifier for this completion
751    #[serde(default)]
752    id: Option<String>,
753    /// Model used for completion (may differ from requested)
754    #[serde(default)]
755    model: Option<String>,
756    choices: Vec<OpenAiStreamChoice>,
757    #[serde(default)]
758    usage: Option<OpenAiUsage>,
759}
760
761#[derive(Debug, Deserialize)]
762struct OpenAiUsage {
763    prompt_tokens: Option<u32>,
764    completion_tokens: Option<u32>,
765    /// Detailed breakdown of prompt tokens (includes cached tokens)
766    #[serde(default)]
767    prompt_tokens_details: Option<OpenAiPromptTokensDetails>,
768    /// Authoritative per-request cost in USD credits, returned by
769    /// OpenAI-compatible gateways such as OpenRouter. Absent for direct OpenAI.
770    #[serde(default)]
771    cost: Option<f64>,
772}
773
774#[derive(Debug, Deserialize, Default)]
775struct OpenAiPromptTokensDetails {
776    /// Number of tokens retrieved from cache
777    #[serde(default)]
778    cached_tokens: Option<u32>,
779}
780
781#[derive(Debug, Deserialize)]
782struct OpenAiStreamChoice {
783    delta: OpenAiDelta,
784    #[serde(default)]
785    finish_reason: Option<String>,
786}
787
788#[derive(Debug, Deserialize)]
789struct OpenAiDelta {
790    #[serde(default)]
791    content: Option<String>,
792    #[serde(default)]
793    tool_calls: Option<Vec<OpenAiStreamToolCall>>,
794}
795
796#[derive(Debug, Deserialize)]
797struct OpenAiStreamToolCall {
798    index: u32,
799    id: Option<String>,
800    function: Option<OpenAiStreamFunction>,
801}
802
803#[derive(Debug, Deserialize)]
804struct OpenAiStreamFunction {
805    name: Option<String>,
806    arguments: Option<String>,
807}
808
809/// Parses each accumulated tool call's argument string (assembled from streamed
810/// fragments) into JSON, falling back to an empty object on parse failure.
811fn finalize_tool_calls(tool_calls: Vec<ToolCall>) -> Vec<ToolCall> {
812    tool_calls
813        .into_iter()
814        .map(|mut tc| {
815            if let Some(args_str) = tc.arguments.as_str() {
816                tc.arguments = serde_json::from_str(args_str).unwrap_or(json!({}));
817            }
818            tc
819        })
820        .collect()
821}
822
823/// Drains tool calls that were accumulated but not yet emitted, returning a
824/// final `ToolCalls` event for the `[DONE]` handler. Returns `None` when nothing
825/// is pending (the common case, since the finish chunk normally drains them).
826fn take_pending_tool_calls(accumulated_tool_calls: &mut Vec<ToolCall>) -> Option<LlmStreamEvent> {
827    if accumulated_tool_calls.is_empty() {
828        return None;
829    }
830    let calls = std::mem::take(accumulated_tool_calls);
831    Some(LlmStreamEvent::ToolCalls(finalize_tool_calls(calls)))
832}
833
834/// Processes a single chat-completion stream choice, updating the running
835/// accumulators and returning the event to emit.
836///
837/// EVE-522: some OpenAI-compatible providers (OpenRouter/DeepInfra) send an
838/// empty `content: ""` delta in the *same* chunk that carries
839/// `finish_reason: "tool_calls"`. The content branch must therefore ignore
840/// empty content, otherwise it short-circuits before the finish handler and the
841/// accumulated tool calls are silently dropped. Emitting drains the accumulator
842/// so a repeated finish chunk does not re-emit the same calls.
843fn process_stream_choice(
844    choice: &OpenAiStreamChoice,
845    total_tokens: &mut u32,
846    accumulated_tool_calls: &mut Vec<ToolCall>,
847    finish_reason: &mut Option<String>,
848) -> LlmStreamEvent {
849    // Accumulate streamed tool-call fragments.
850    if let Some(tool_calls) = &choice.delta.tool_calls {
851        for tc in tool_calls {
852            let idx = tc.index as usize;
853            while accumulated_tool_calls.len() <= idx {
854                accumulated_tool_calls.push(ToolCall {
855                    id: String::new(),
856                    name: String::new(),
857                    arguments: json!(""),
858                });
859            }
860
861            if let Some(id) = &tc.id {
862                accumulated_tool_calls[idx].id = id.clone();
863            }
864            if let Some(function) = &tc.function {
865                if let Some(name) = &function.name {
866                    accumulated_tool_calls[idx].name = name.clone();
867                }
868                if let Some(args) = &function.arguments {
869                    let current = accumulated_tool_calls[idx].arguments.as_str().unwrap_or("");
870                    let combined = format!("{}{}", current, args);
871                    accumulated_tool_calls[idx].arguments = json!(combined);
872                }
873            }
874        }
875        return LlmStreamEvent::TextDelta(String::new());
876    }
877
878    // Content delta. Guard on non-empty: an empty-content delta that rides along
879    // with finish_reason must not short-circuit the finish handler below.
880    if let Some(content) = &choice.delta.content
881        && !content.is_empty()
882    {
883        *total_tokens += 1;
884        return LlmStreamEvent::TextDelta(content.clone());
885    }
886
887    // Finish reason. Store it for the [DONE] handler; for tool_calls, emit the
888    // accumulated calls immediately so the agent can start working. Draining the
889    // accumulator prevents a second finish chunk from re-emitting the calls.
890    if let Some(fr) = &choice.finish_reason {
891        *finish_reason = Some(fr.clone());
892
893        if fr == "tool_calls" && !accumulated_tool_calls.is_empty() {
894            let calls = std::mem::take(accumulated_tool_calls);
895            return LlmStreamEvent::ToolCalls(finalize_tool_calls(calls));
896        }
897    }
898
899    LlmStreamEvent::TextDelta(String::new())
900}
901
902// ============================================================================
903// Tests
904// ============================================================================
905
906#[cfg(test)]
907mod tests {
908    use super::*;
909
910    #[test]
911    fn test_driver_with_api_key() {
912        let driver = OpenAIProtocolLlmDriver::new("test-key");
913        assert!(format!("{:?}", driver).contains("OpenAIProtocolLlmDriver"));
914    }
915
916    #[test]
917    fn test_driver_with_base_url() {
918        let driver = OpenAIProtocolLlmDriver::with_base_url(
919            "test-key",
920            "https://custom.api.com/v1/completions",
921        );
922        assert!(format!("{:?}", driver).contains("OpenAIProtocolLlmDriver"));
923        assert_eq!(driver.api_url(), "https://custom.api.com/v1/completions");
924    }
925
926    #[test]
927    fn test_is_azure_openai_api_url() {
928        assert!(is_azure_openai_api_url(
929            "https://example.openai.azure.com/openai/v1/chat/completions"
930        ));
931        assert!(is_azure_openai_api_url(
932            "https://example.services.ai.azure.com/openai/v1/responses"
933        ));
934        assert!(!is_azure_openai_api_url(
935            "https://api.openai.com/v1/chat/completions"
936        ));
937    }
938
939    #[test]
940    fn test_request_includes_stream_options_for_usage() {
941        // OpenAI streaming API requires stream_options.include_usage=true
942        // to return token usage in the response
943        let request = OpenAiRequest {
944            model: "gpt-4o".to_string(),
945            messages: vec![OpenAiMessage {
946                role: "user".to_string(),
947                content: Some(OpenAiContent::Text("Hello".to_string())),
948                tool_calls: None,
949                tool_call_id: None,
950            }],
951            temperature: None,
952            max_tokens: None,
953            stream: true,
954            stream_options: Some(OpenAiStreamOptions {
955                include_usage: true,
956            }),
957            tools: None,
958            reasoning_effort: None,
959            metadata: None,
960        };
961
962        let json = serde_json::to_value(&request).unwrap();
963        assert_eq!(json["stream"], true);
964        assert_eq!(json["stream_options"]["include_usage"], true);
965    }
966
967    #[test]
968    fn test_request_includes_metadata() {
969        // Metadata should be included when provided
970        let mut metadata = std::collections::HashMap::new();
971        metadata.insert("session_id".to_string(), "session_abc123".to_string());
972        metadata.insert("agent_id".to_string(), "agent_xyz789".to_string());
973
974        let request = OpenAiRequest {
975            model: "gpt-4o".to_string(),
976            messages: vec![OpenAiMessage {
977                role: "user".to_string(),
978                content: Some(OpenAiContent::Text("Hello".to_string())),
979                tool_calls: None,
980                tool_call_id: None,
981            }],
982            temperature: None,
983            max_tokens: None,
984            stream: true,
985            stream_options: None,
986            tools: None,
987            reasoning_effort: None,
988            metadata: Some(metadata),
989        };
990
991        let json = serde_json::to_value(&request).unwrap();
992        assert_eq!(json["metadata"]["session_id"], "session_abc123");
993        assert_eq!(json["metadata"]["agent_id"], "agent_xyz789");
994    }
995
996    #[test]
997    fn test_usage_chunk_parsing() {
998        // OpenAI sends usage in a separate chunk after finish_reason
999        // This test verifies we can parse it correctly
1000        let usage_chunk = r#"{
1001            "id": "chatcmpl-123",
1002            "object": "chat.completion.chunk",
1003            "created": 1234567890,
1004            "model": "gpt-4o",
1005            "choices": [],
1006            "usage": {
1007                "prompt_tokens": 150,
1008                "completion_tokens": 42,
1009                "total_tokens": 192
1010            }
1011        }"#;
1012
1013        let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1014        assert!(chunk.usage.is_some());
1015        let usage = chunk.usage.unwrap();
1016        assert_eq!(usage.prompt_tokens, Some(150));
1017        assert_eq!(usage.completion_tokens, Some(42));
1018    }
1019
1020    #[test]
1021    fn test_usage_chunk_with_cached_tokens() {
1022        // OpenAI includes cached_tokens in prompt_tokens_details
1023        let usage_chunk = r#"{
1024            "id": "chatcmpl-123",
1025            "choices": [],
1026            "usage": {
1027                "prompt_tokens": 150,
1028                "completion_tokens": 42,
1029                "prompt_tokens_details": {
1030                    "cached_tokens": 100
1031                }
1032            }
1033        }"#;
1034
1035        let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1036        let usage = chunk.usage.unwrap();
1037        assert_eq!(usage.prompt_tokens, Some(150));
1038        assert_eq!(usage.completion_tokens, Some(42));
1039        assert!(usage.prompt_tokens_details.is_some());
1040        assert_eq!(
1041            usage.prompt_tokens_details.unwrap().cached_tokens,
1042            Some(100)
1043        );
1044    }
1045
1046    #[test]
1047    fn test_usage_chunk_with_openrouter_cost() {
1048        // OpenAI-compatible gateways like OpenRouter add `usage.cost` (USD credits).
1049        let usage_chunk = r#"{
1050            "id": "gen-123",
1051            "choices": [],
1052            "usage": {
1053                "prompt_tokens": 194,
1054                "completion_tokens": 2,
1055                "total_tokens": 196,
1056                "cost": 0.00095
1057            }
1058        }"#;
1059
1060        let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1061        let usage = chunk.usage.unwrap();
1062        assert_eq!(usage.cost, Some(0.00095));
1063    }
1064
1065    #[test]
1066    fn test_usage_chunk_without_cost_defaults_none() {
1067        // Direct OpenAI omits `cost`; it must deserialize to None, not error.
1068        let usage_chunk = r#"{
1069            "id": "chatcmpl-123",
1070            "choices": [],
1071            "usage": { "prompt_tokens": 10, "completion_tokens": 5 }
1072        }"#;
1073
1074        let chunk: OpenAiStreamChunk = serde_json::from_str(usage_chunk).unwrap();
1075        assert_eq!(chunk.usage.unwrap().cost, None);
1076    }
1077
1078    #[test]
1079    fn test_finish_reason_chunk_parsing() {
1080        // Finish reason comes in a chunk BEFORE the usage chunk
1081        let finish_chunk = r#"{
1082            "id": "chatcmpl-123",
1083            "choices": [{
1084                "index": 0,
1085                "delta": {},
1086                "finish_reason": "stop"
1087            }]
1088        }"#;
1089
1090        let chunk: OpenAiStreamChunk = serde_json::from_str(finish_chunk).unwrap();
1091        assert!(chunk.usage.is_none()); // No usage in finish_reason chunk
1092        assert_eq!(chunk.choices.len(), 1);
1093        assert_eq!(chunk.choices[0].finish_reason, Some("stop".to_string()));
1094    }
1095
1096    // ========================================================================
1097    // Request-too-large detection tests
1098    // ========================================================================
1099
1100    #[test]
1101    fn test_is_openai_request_too_large_429_request_too_large() {
1102        let error = r#"{"error":{"message":"Request too large for gpt-4o in organization org-xxx on tokens per min (TPM): Limit 500000, Requested 538772."}}"#;
1103        assert!(is_openai_request_too_large(
1104            reqwest::StatusCode::TOO_MANY_REQUESTS,
1105            error
1106        ));
1107    }
1108
1109    #[test]
1110    fn test_is_openai_request_too_large_429_token_limit() {
1111        let error =
1112            r#"{"error":{"message":"tokens per min (TPM): Limit 500000, Requested 600000"}}"#;
1113        assert!(is_openai_request_too_large(
1114            reqwest::StatusCode::TOO_MANY_REQUESTS,
1115            error
1116        ));
1117    }
1118
1119    #[test]
1120    fn test_is_openai_request_too_large_400_context_length() {
1121        let error = r#"{"error":{"code":"context_length_exceeded","message":"This model's maximum context length is 128000 tokens."}}"#;
1122        assert!(is_openai_request_too_large(
1123            reqwest::StatusCode::BAD_REQUEST,
1124            error
1125        ));
1126    }
1127
1128    #[test]
1129    fn test_is_openai_request_too_large_400_max_context() {
1130        let error =
1131            r#"{"error":{"message":"This model's maximum context length is 128000 tokens"}}"#;
1132        assert!(is_openai_request_too_large(
1133            reqwest::StatusCode::BAD_REQUEST,
1134            error
1135        ));
1136    }
1137
1138    #[test]
1139    fn test_is_openai_request_too_large_tokens_must_be_reduced() {
1140        let error = r#"{"error":{"message":"The input or output tokens must be reduced"}}"#;
1141        assert!(is_openai_request_too_large(
1142            reqwest::StatusCode::BAD_REQUEST,
1143            error
1144        ));
1145    }
1146
1147    #[test]
1148    fn test_is_openai_request_too_large_false_for_other_errors() {
1149        // Regular rate limit (not token-related)
1150        let error = r#"{"error":{"message":"Rate limit exceeded: too many requests per minute"}}"#;
1151        assert!(!is_openai_request_too_large(
1152            reqwest::StatusCode::TOO_MANY_REQUESTS,
1153            error
1154        ));
1155
1156        // Internal server error
1157        let error = r#"{"error":{"message":"Internal server error"}}"#;
1158        assert!(!is_openai_request_too_large(
1159            reqwest::StatusCode::INTERNAL_SERVER_ERROR,
1160            error
1161        ));
1162
1163        // Generic 400 error
1164        let error = r#"{"error":{"message":"Invalid request"}}"#;
1165        assert!(!is_openai_request_too_large(
1166            reqwest::StatusCode::BAD_REQUEST,
1167            error
1168        ));
1169    }
1170
1171    // ========================================================================
1172    // Model-not-found detection tests
1173    // ========================================================================
1174
1175    #[test]
1176    fn test_is_openai_model_not_found_real_error() {
1177        // Real OpenAI 404 response for nonexistent model
1178        let error = r#"{"error":{"code":"model_not_found","message":"The model 'gpt-99' does not exist or you do not have access to it.","type":"invalid_request_error","param":null}}"#;
1179        assert!(is_openai_model_not_found(
1180            reqwest::StatusCode::NOT_FOUND,
1181            error
1182        ));
1183    }
1184
1185    #[test]
1186    fn test_is_openai_model_not_found_does_not_exist() {
1187        let error = r#"{"error":{"message":"The model 'fake-model' does not exist"}}"#;
1188        assert!(is_openai_model_not_found(
1189            reqwest::StatusCode::NOT_FOUND,
1190            error
1191        ));
1192    }
1193
1194    #[test]
1195    fn test_is_openai_model_not_found_generic_not_found() {
1196        let error = r#"{"error":{"message":"Model not found"}}"#;
1197        assert!(is_openai_model_not_found(
1198            reqwest::StatusCode::NOT_FOUND,
1199            error
1200        ));
1201    }
1202
1203    #[test]
1204    fn test_is_openai_model_not_found_400_with_model_not_found_code() {
1205        // OpenAI Responses API returns 400 (not 404) for nonexistent models
1206        let error = r#"{"error":{"code":"model_not_found","message":"The requested model 'gpt-99' does not exist.","type":"invalid_request_error","param":"model"}}"#;
1207        assert!(is_openai_model_not_found(
1208            reqwest::StatusCode::BAD_REQUEST,
1209            error
1210        ));
1211    }
1212
1213    #[test]
1214    fn test_is_openai_model_not_found_false_for_non_model_error() {
1215        // 400 without model_not_found code should not match
1216        let error = r#"{"error":{"code":"invalid_request","message":"Some other error"}}"#;
1217        assert!(!is_openai_model_not_found(
1218            reqwest::StatusCode::BAD_REQUEST,
1219            error
1220        ));
1221    }
1222
1223    #[test]
1224    fn test_is_openai_model_not_found_false_for_other_404() {
1225        // 404 without model-related message
1226        let error = r#"{"error":{"message":"Endpoint not found"}}"#;
1227        assert!(!is_openai_model_not_found(
1228            reqwest::StatusCode::NOT_FOUND,
1229            error
1230        ));
1231    }
1232
1233    #[test]
1234    fn test_is_openai_model_not_found_403_tier_gated_model() {
1235        // OpenAI returns 403 for models that exist but require a higher API tier;
1236        // these must classify as model_unavailable, not provider_misconfigured.
1237        let error = r#"{"error":{"code":"model_not_found","message":"The model 'gpt-5.4-mini' does not exist or you do not have access to it.","type":"invalid_request_error","param":null}}"#;
1238        assert!(is_openai_model_not_found(
1239            reqwest::StatusCode::FORBIDDEN,
1240            error
1241        ));
1242    }
1243
1244    #[test]
1245    fn test_is_openai_model_not_found_403_plain_auth_error_is_not_model_not_found() {
1246        // A plain 403 without model_not_found code is a real auth error and must
1247        // NOT be classified as model_unavailable.
1248        let error = r#"{"error":{"message":"Invalid authentication credentials","type":"authentication_error"}}"#;
1249        assert!(!is_openai_model_not_found(
1250            reqwest::StatusCode::FORBIDDEN,
1251            error
1252        ));
1253    }
1254
1255    // ========================================================================
1256    // Reasoning effort guard tests
1257    // ========================================================================
1258
1259    #[test]
1260    fn test_reasoning_effort_none_is_omitted() {
1261        // When reasoning_effort is "none", it should be filtered out
1262        // to avoid "Unrecognized request argument" errors on non-thinking models
1263        let request = OpenAiRequest {
1264            model: "gpt-4o-mini".to_string(),
1265            messages: vec![OpenAiMessage {
1266                role: "user".to_string(),
1267                content: Some(OpenAiContent::Text("Hello".to_string())),
1268                tool_calls: None,
1269                tool_call_id: None,
1270            }],
1271            temperature: None,
1272            max_tokens: None,
1273            stream: true,
1274            stream_options: None,
1275            tools: None,
1276            reasoning_effort: Some("none".to_string())
1277                .as_ref()
1278                .filter(|e| !e.eq_ignore_ascii_case("none"))
1279                .cloned(),
1280            metadata: None,
1281        };
1282
1283        let json = serde_json::to_value(&request).unwrap();
1284        assert!(
1285            json.get("reasoning_effort").is_none(),
1286            "reasoning_effort should be omitted when effort is 'none'"
1287        );
1288    }
1289
1290    #[test]
1291    fn test_reasoning_effort_high_is_included() {
1292        let request = OpenAiRequest {
1293            model: "o3-mini".to_string(),
1294            messages: vec![OpenAiMessage {
1295                role: "user".to_string(),
1296                content: Some(OpenAiContent::Text("Hello".to_string())),
1297                tool_calls: None,
1298                tool_call_id: None,
1299            }],
1300            temperature: None,
1301            max_tokens: None,
1302            stream: true,
1303            stream_options: None,
1304            tools: None,
1305            reasoning_effort: Some("high".to_string())
1306                .as_ref()
1307                .filter(|e| !e.eq_ignore_ascii_case("none"))
1308                .cloned(),
1309            metadata: None,
1310        };
1311
1312        let json = serde_json::to_value(&request).unwrap();
1313        assert_eq!(json["reasoning_effort"], "high");
1314    }
1315
1316    // ------------------------------------------------------------------
1317    // EVE-522: streaming chunk handling (process_stream_choice)
1318    // ------------------------------------------------------------------
1319
1320    fn choice(json_str: &str) -> OpenAiStreamChoice {
1321        serde_json::from_str(json_str).unwrap()
1322    }
1323
1324    /// EVE-522 regression: providers such as OpenRouter/DeepInfra send an empty
1325    /// `content: ""` in the same chunk that carries `finish_reason: "tool_calls"`.
1326    /// The accumulated tool calls must still be emitted exactly once.
1327    #[test]
1328    fn test_empty_content_finish_chunk_still_emits_tool_calls() {
1329        let mut total_tokens = 0u32;
1330        let mut acc: Vec<ToolCall> = Vec::new();
1331        let mut finish_reason: Option<String> = None;
1332
1333        // Chunk 2: tool_calls delta opens the call (id + name).
1334        let e = process_stream_choice(
1335            &choice(
1336                r#"{"delta":{"content":null,"tool_calls":[{"index":0,"id":"call_1","function":{"name":"read_file","arguments":""}}]},"finish_reason":null}"#,
1337            ),
1338            &mut total_tokens,
1339            &mut acc,
1340            &mut finish_reason,
1341        );
1342        assert!(matches!(e, LlmStreamEvent::TextDelta(s) if s.is_empty()));
1343
1344        // Chunk 3: tool_calls delta streams the arguments.
1345        let e = process_stream_choice(
1346            &choice(
1347                r#"{"delta":{"content":null,"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":\"Cargo.toml\"}"}}]},"finish_reason":null}"#,
1348            ),
1349            &mut total_tokens,
1350            &mut acc,
1351            &mut finish_reason,
1352        );
1353        assert!(matches!(e, LlmStreamEvent::TextDelta(s) if s.is_empty()));
1354
1355        // Chunk 4: content:"" alongside finish_reason:"tool_calls" — must NOT
1356        // short-circuit; emits the accumulated call with parsed JSON arguments.
1357        let e = process_stream_choice(
1358            &choice(r#"{"delta":{"content":""},"finish_reason":"tool_calls"}"#),
1359            &mut total_tokens,
1360            &mut acc,
1361            &mut finish_reason,
1362        );
1363        match e {
1364            LlmStreamEvent::ToolCalls(calls) => {
1365                assert_eq!(calls.len(), 1);
1366                assert_eq!(calls[0].id, "call_1");
1367                assert_eq!(calls[0].name, "read_file");
1368                assert_eq!(calls[0].arguments, json!({"path": "Cargo.toml"}));
1369            }
1370            other => panic!("expected ToolCalls, got {:?}", other),
1371        }
1372        assert_eq!(finish_reason.as_deref(), Some("tool_calls"));
1373
1374        // Chunk 5: second finish chunk with content:"" — the accumulator was
1375        // drained, so the same call must not be emitted again.
1376        let e = process_stream_choice(
1377            &choice(r#"{"delta":{"content":""},"finish_reason":"tool_calls"}"#),
1378            &mut total_tokens,
1379            &mut acc,
1380            &mut finish_reason,
1381        );
1382        assert!(
1383            matches!(e, LlmStreamEvent::TextDelta(s) if s.is_empty()),
1384            "tool calls must only be emitted once"
1385        );
1386    }
1387
1388    /// Non-empty content deltas are still emitted and counted as output tokens.
1389    #[test]
1390    fn test_non_empty_content_is_emitted() {
1391        let mut total_tokens = 0u32;
1392        let mut acc: Vec<ToolCall> = Vec::new();
1393        let mut finish_reason: Option<String> = None;
1394
1395        let e = process_stream_choice(
1396            &choice(r#"{"delta":{"content":"hello"},"finish_reason":null}"#),
1397            &mut total_tokens,
1398            &mut acc,
1399            &mut finish_reason,
1400        );
1401        assert!(matches!(e, LlmStreamEvent::TextDelta(s) if s == "hello"));
1402        assert_eq!(total_tokens, 1);
1403    }
1404
1405    /// OpenAI's native path sends `delta: {}` (no content key) in the finish
1406    /// chunk; the existing behavior of emitting tool calls there is preserved.
1407    #[test]
1408    fn test_finish_chunk_without_content_emits_tool_calls() {
1409        let mut total_tokens = 0u32;
1410        let mut acc: Vec<ToolCall> = Vec::new();
1411        let mut finish_reason: Option<String> = None;
1412
1413        process_stream_choice(
1414            &choice(
1415                r#"{"delta":{"tool_calls":[{"index":0,"id":"call_9","function":{"name":"list_dir","arguments":"{}"}}]},"finish_reason":null}"#,
1416            ),
1417            &mut total_tokens,
1418            &mut acc,
1419            &mut finish_reason,
1420        );
1421
1422        let e = process_stream_choice(
1423            &choice(r#"{"delta":{},"finish_reason":"tool_calls"}"#),
1424            &mut total_tokens,
1425            &mut acc,
1426            &mut finish_reason,
1427        );
1428        match e {
1429            LlmStreamEvent::ToolCalls(calls) => {
1430                assert_eq!(calls.len(), 1);
1431                assert_eq!(calls[0].name, "list_dir");
1432            }
1433            other => panic!("expected ToolCalls, got {:?}", other),
1434        }
1435    }
1436
1437    /// The [DONE] fallback flushes accumulated-but-unemitted tool calls and
1438    /// drains the accumulator; once drained it returns None.
1439    #[test]
1440    fn test_take_pending_tool_calls_flushes_then_drains() {
1441        let mut acc = vec![ToolCall {
1442            id: "call_1".to_string(),
1443            name: "read_file".to_string(),
1444            arguments: json!(r#"{"path":"Cargo.toml"}"#),
1445        }];
1446
1447        match take_pending_tool_calls(&mut acc) {
1448            Some(LlmStreamEvent::ToolCalls(calls)) => {
1449                assert_eq!(calls.len(), 1);
1450                assert_eq!(calls[0].name, "read_file");
1451                assert_eq!(calls[0].arguments, json!({"path": "Cargo.toml"}));
1452            }
1453            other => panic!("expected ToolCalls, got {:?}", other),
1454        }
1455        assert!(acc.is_empty(), "accumulator must be drained after flush");
1456        assert!(take_pending_tool_calls(&mut acc).is_none());
1457    }
1458
1459    #[test]
1460    fn test_finalize_tool_calls_parses_arguments() {
1461        let calls = vec![ToolCall {
1462            id: "call_1".to_string(),
1463            name: "read_file".to_string(),
1464            arguments: json!(r#"{"path":"src/main.rs"}"#),
1465        }];
1466        let finalized = finalize_tool_calls(calls);
1467        assert_eq!(finalized[0].arguments, json!({"path": "src/main.rs"}));
1468    }
1469
1470    #[test]
1471    fn drop_orphaned_tool_messages_removes_unmatched_tool_results() {
1472        use crate::llm_driver_registry::LlmMessageContent;
1473
1474        let messages = vec![
1475            LlmMessage::text(LlmMessageRole::User, "hello"),
1476            LlmMessage {
1477                role: LlmMessageRole::Tool,
1478                content: LlmMessageContent::Text("result".to_string()),
1479                tool_calls: None,
1480                tool_call_id: Some("call_trimmed".to_string()),
1481                phase: None,
1482                thinking: None,
1483                thinking_signature: None,
1484            },
1485        ];
1486        let filtered = drop_orphaned_tool_messages(&messages);
1487        assert_eq!(filtered.len(), 1);
1488        assert_eq!(filtered[0].role, LlmMessageRole::User);
1489    }
1490
1491    #[test]
1492    fn drop_orphaned_tool_messages_keeps_matched_tool_results() {
1493        use crate::llm_driver_registry::LlmMessageContent;
1494        use crate::tool_types::ToolCall;
1495
1496        let messages = vec![
1497            LlmMessage {
1498                role: LlmMessageRole::Assistant,
1499                content: LlmMessageContent::Text(String::new()),
1500                tool_calls: Some(vec![ToolCall {
1501                    id: "call_1".to_string(),
1502                    name: "read_file".to_string(),
1503                    arguments: json!({}),
1504                }]),
1505                tool_call_id: None,
1506                phase: None,
1507                thinking: None,
1508                thinking_signature: None,
1509            },
1510            LlmMessage {
1511                role: LlmMessageRole::Tool,
1512                content: LlmMessageContent::Text("file content".to_string()),
1513                tool_calls: None,
1514                tool_call_id: Some("call_1".to_string()),
1515                phase: None,
1516                thinking: None,
1517                thinking_signature: None,
1518            },
1519        ];
1520        let filtered = drop_orphaned_tool_messages(&messages);
1521        assert_eq!(filtered.len(), 2);
1522    }
1523}