Skip to main content

pi/providers/
openai.rs

1//! OpenAI Chat Completions API provider implementation.
2//!
3//! This module implements the Provider trait for the OpenAI Chat Completions API,
4//! supporting streaming responses and tool use. Compatible with:
5//! - OpenAI direct API (api.openai.com)
6//! - Azure OpenAI
7//! - Any OpenAI-compatible API (Groq, Together, etc.)
8
9use std::borrow::Cow;
10
11use crate::error::{Error, Result};
12use crate::http::client::Client;
13use crate::model::{
14    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
15    ToolCall, Usage, UserContent,
16};
17use crate::models::CompatConfig;
18use crate::provider::{Context, Provider, StreamOptions, ToolDef};
19use crate::sse::SseStream;
20use async_trait::async_trait;
21use futures::StreamExt;
22use futures::stream::{self, Stream};
23use serde::{Deserialize, Serialize};
24use std::collections::VecDeque;
25use std::pin::Pin;
26
27// ============================================================================
28// Constants
29// ============================================================================
30
31const OPENAI_API_URL: &str = "https://api.openai.com/v1/chat/completions";
32const DEFAULT_MAX_TOKENS: u32 = 4096;
33const OPENROUTER_DEFAULT_HTTP_REFERER: &str = "https://github.com/Dicklesworthstone/pi_agent_rust";
34const OPENROUTER_DEFAULT_X_TITLE: &str = "Pi Agent Rust";
35
36/// Map a role string (which may come from compat config at runtime) to a `&'static str`.
37///
38/// The OpenAI API uses a small, well-known set of role names.  When the value
39/// matches one of these we return the corresponding string literal (zero
40/// allocation).  For an unknown role name (extremely rare – only possible via
41/// exotic compat overrides) we leak a heap copy so that callers can always
42/// work with `&'static str`.
43fn to_cow_role(role: &str) -> Cow<'_, str> {
44    match role {
45        "system" => Cow::Borrowed("system"),
46        "developer" => Cow::Borrowed("developer"),
47        "user" => Cow::Borrowed("user"),
48        "assistant" => Cow::Borrowed("assistant"),
49        "tool" => Cow::Borrowed("tool"),
50        "function" => Cow::Borrowed("function"),
51        other => Cow::Owned(other.to_string()),
52    }
53}
54
55fn map_has_any_header(headers: &std::collections::HashMap<String, String>, names: &[&str]) -> bool {
56    headers
57        .keys()
58        .any(|key| names.iter().any(|name| key.eq_ignore_ascii_case(name)))
59}
60
61fn first_non_empty_env(keys: &[&str]) -> Option<String> {
62    keys.iter().find_map(|key| {
63        std::env::var(key)
64            .ok()
65            .map(|value| value.trim().to_string())
66            .filter(|value| !value.is_empty())
67    })
68}
69
70fn openrouter_default_http_referer() -> String {
71    first_non_empty_env(&["OPENROUTER_HTTP_REFERER", "PI_OPENROUTER_HTTP_REFERER"])
72        .unwrap_or_else(|| OPENROUTER_DEFAULT_HTTP_REFERER.to_string())
73}
74
75fn openrouter_default_x_title() -> String {
76    first_non_empty_env(&["OPENROUTER_X_TITLE", "PI_OPENROUTER_X_TITLE"])
77        .unwrap_or_else(|| OPENROUTER_DEFAULT_X_TITLE.to_string())
78}
79
80// ============================================================================
81// OpenAI Provider
82// ============================================================================
83
84/// OpenAI Chat Completions API provider.
85pub struct OpenAIProvider {
86    client: Client,
87    model: String,
88    base_url: String,
89    provider: String,
90    compat: Option<CompatConfig>,
91}
92
93impl OpenAIProvider {
94    /// Create a new OpenAI provider.
95    pub fn new(model: impl Into<String>) -> Self {
96        Self {
97            client: Client::new(),
98            model: model.into(),
99            base_url: OPENAI_API_URL.to_string(),
100            provider: "openai".to_string(),
101            compat: None,
102        }
103    }
104
105    /// Override the provider name reported in streamed events.
106    ///
107    /// This is useful for OpenAI-compatible backends (Groq, Together, etc.) that use this
108    /// implementation but should still surface their own provider identifier in session logs.
109    #[must_use]
110    pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
111        self.provider = provider.into();
112        self
113    }
114
115    /// Create with a custom base URL (for Azure, Groq, etc.).
116    #[must_use]
117    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
118        self.base_url = base_url.into();
119        self
120    }
121
122    /// Create with a custom HTTP client (VCR, test harness, etc.).
123    #[must_use]
124    pub fn with_client(mut self, client: Client) -> Self {
125        self.client = client;
126        self
127    }
128
129    /// Attach provider-specific compatibility overrides.
130    ///
131    /// Overrides are applied during request building (field names, headers,
132    /// capability flags) and response parsing (stop-reason mapping).
133    #[must_use]
134    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
135        self.compat = compat;
136        self
137    }
138
139    /// Build the request body for the OpenAI API.
140    pub fn build_request<'a>(
141        &'a self,
142        context: &'a Context<'_>,
143        options: &StreamOptions,
144    ) -> OpenAIRequest<'a> {
145        let system_role = self
146            .compat
147            .as_ref()
148            .and_then(|c| c.system_role_name.as_deref())
149            .unwrap_or("system");
150        let messages = Self::build_messages_with_role(context, system_role);
151
152        let tools_supported = self
153            .compat
154            .as_ref()
155            .and_then(|c| c.supports_tools)
156            .unwrap_or(true);
157
158        let tools: Option<Vec<OpenAITool<'a>>> = if context.tools.is_empty() || !tools_supported {
159            None
160        } else {
161            Some(context.tools.iter().map(convert_tool_to_openai).collect())
162        };
163
164        // Determine which max-tokens field to populate based on compat config.
165        let use_alt_field = self
166            .compat
167            .as_ref()
168            .and_then(|c| c.max_tokens_field.as_deref())
169            .is_some_and(|f| f == "max_completion_tokens");
170
171        let token_limit = options.max_tokens.or(Some(DEFAULT_MAX_TOKENS));
172        let (max_tokens, max_completion_tokens) = if use_alt_field {
173            (None, token_limit)
174        } else {
175            (token_limit, None)
176        };
177
178        let include_usage = self
179            .compat
180            .as_ref()
181            .and_then(|c| c.supports_usage_in_streaming)
182            .unwrap_or(true);
183
184        let stream_options = Some(OpenAIStreamOptions { include_usage });
185
186        OpenAIRequest {
187            model: &self.model,
188            messages,
189            max_tokens,
190            max_completion_tokens,
191            temperature: options.temperature,
192            tools,
193            stream: true,
194            stream_options,
195        }
196    }
197
198    fn build_request_json(
199        &self,
200        context: &Context<'_>,
201        options: &StreamOptions,
202    ) -> Result<serde_json::Value> {
203        let request = self.build_request(context, options);
204        let mut value = serde_json::to_value(request)
205            .map_err(|e| Error::api(format!("Failed to serialize OpenAI request: {e}")))?;
206        self.apply_openrouter_routing_overrides(&mut value)?;
207        Ok(value)
208    }
209
210    fn apply_openrouter_routing_overrides(&self, request: &mut serde_json::Value) -> Result<()> {
211        if !self.provider.eq_ignore_ascii_case("openrouter") {
212            return Ok(());
213        }
214
215        let Some(routing) = self
216            .compat
217            .as_ref()
218            .and_then(|compat| compat.open_router_routing.as_ref())
219        else {
220            return Ok(());
221        };
222
223        let Some(request_obj) = request.as_object_mut() else {
224            return Err(Error::api(
225                "OpenAI request body must serialize to a JSON object",
226            ));
227        };
228        let Some(routing_obj) = routing.as_object() else {
229            return Err(Error::config(
230                "openRouterRouting must be a JSON object when configured",
231            ));
232        };
233
234        for (key, value) in routing_obj {
235            request_obj.insert(key.clone(), value.clone());
236        }
237        Ok(())
238    }
239
240    /// Build the messages array with system prompt prepended using the given role name.
241    fn build_messages_with_role<'a>(
242        context: &'a Context<'_>,
243        system_role: &'a str,
244    ) -> Vec<OpenAIMessage<'a>> {
245        let mut messages = Vec::with_capacity(context.messages.len() + 1);
246
247        // Add system prompt as first message
248        if let Some(system) = &context.system_prompt {
249            messages.push(OpenAIMessage {
250                role: to_cow_role(system_role),
251                content: Some(OpenAIContent::Text(Cow::Borrowed(system))),
252                tool_calls: None,
253                tool_call_id: None,
254            });
255        }
256
257        // Convert conversation messages
258        for message in context.messages.iter() {
259            messages.extend(convert_message_to_openai(message));
260        }
261
262        messages
263    }
264}
265
266#[async_trait]
267impl Provider for OpenAIProvider {
268    fn name(&self) -> &str {
269        &self.provider
270    }
271
272    fn api(&self) -> &'static str {
273        "openai-completions"
274    }
275
276    fn model_id(&self) -> &str {
277        &self.model
278    }
279
280    #[allow(clippy::too_many_lines)]
281    async fn stream(
282        &self,
283        context: &Context<'_>,
284        options: &StreamOptions,
285    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
286        let has_authorization_header = options
287            .headers
288            .keys()
289            .any(|key| key.eq_ignore_ascii_case("authorization"));
290
291        let auth_value = if has_authorization_header {
292            None
293        } else {
294            Some(
295                options
296                    .api_key
297                    .clone()
298                    .or_else(|| std::env::var("OPENAI_API_KEY").ok())
299                    .ok_or_else(|| {
300                        Error::provider(
301                            &self.provider,
302                            "Missing API key for OpenAI. Set OPENAI_API_KEY or configure in settings.",
303                        )
304                    })?,
305            )
306        };
307
308        let request_body = self.build_request_json(context, options)?;
309
310        // Note: Content-Type is set by .json() below; setting it here too
311        // produces a duplicate header that OpenAI's server rejects.
312        let mut request = self
313            .client
314            .post(&self.base_url)
315            .header("Accept", "text/event-stream");
316
317        if let Some(auth_value) = auth_value {
318            request = request.header("Authorization", format!("Bearer {auth_value}"));
319        }
320
321        if self.provider.eq_ignore_ascii_case("openrouter") {
322            let compat_headers = self
323                .compat
324                .as_ref()
325                .and_then(|compat| compat.custom_headers.as_ref());
326            let has_referer = map_has_any_header(&options.headers, &["http-referer", "referer"])
327                || compat_headers.is_some_and(|headers| {
328                    map_has_any_header(headers, &["http-referer", "referer"])
329                });
330            if !has_referer {
331                request = request.header("HTTP-Referer", openrouter_default_http_referer());
332            }
333
334            let has_title = map_has_any_header(&options.headers, &["x-title"])
335                || compat_headers.is_some_and(|headers| map_has_any_header(headers, &["x-title"]));
336            if !has_title {
337                request = request.header("X-Title", openrouter_default_x_title());
338            }
339        }
340
341        // Apply provider-specific custom headers from compat config.
342        if let Some(compat) = &self.compat {
343            if let Some(custom_headers) = &compat.custom_headers {
344                for (key, value) in custom_headers {
345                    request = request.header(key, value);
346                }
347            }
348        }
349
350        // Per-request headers from StreamOptions (highest priority).
351        for (key, value) in &options.headers {
352            request = request.header(key, value);
353        }
354
355        let request = request.json(&request_body)?;
356
357        let response = Box::pin(request.send()).await?;
358        let status = response.status();
359        if !(200..300).contains(&status) {
360            let body = response
361                .text()
362                .await
363                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
364            return Err(Error::provider(
365                &self.provider,
366                format!("OpenAI API error (HTTP {status}): {body}"),
367            ));
368        }
369
370        let content_type = response
371            .headers()
372            .iter()
373            .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
374            .map(|(_, value)| value.to_ascii_lowercase());
375        if !content_type
376            .as_deref()
377            .is_some_and(|value| value.contains("text/event-stream"))
378        {
379            let message = content_type.map_or_else(
380                || {
381                    format!(
382                        "OpenAI API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
383                    )
384                },
385                |value| {
386                    format!(
387                        "OpenAI API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
388                    )
389                },
390            );
391            return Err(Error::api(message));
392        }
393
394        // Create SSE stream for streaming responses.
395        let event_source = SseStream::new(response.bytes_stream());
396
397        // Create stream state
398        let model = self.model.clone();
399        let api = self.api().to_string();
400        let provider = self.name().to_string();
401
402        let stream = stream::unfold(
403            StreamState::new(event_source, model, api, provider),
404            |mut state| async move {
405                if state.done {
406                    return None;
407                }
408                loop {
409                    if let Some(event) = state.pending_events.pop_front() {
410                        return Some((Ok(event), state));
411                    }
412
413                    match state.event_source.next().await {
414                        Some(Ok(msg)) => {
415                            // OpenAI sends "[DONE]" as final message
416                            if msg.data == "[DONE]" {
417                                state.done = true;
418                                let reason = state.partial.stop_reason;
419                                let message = std::mem::take(&mut state.partial);
420                                return Some((Ok(StreamEvent::Done { reason, message }), state));
421                            }
422
423                            if let Err(e) = state.process_event(&msg.data) {
424                                state.done = true;
425                                return Some((Err(e), state));
426                            }
427                        }
428                        Some(Err(e)) => {
429                            state.done = true;
430                            let err = Error::api(format!("SSE error: {e}"));
431                            return Some((Err(err), state));
432                        }
433                        // Stream ended without [DONE] sentinel (e.g.
434                        // premature server disconnect).  Emit a Done event
435                        // so the agent loop receives the accumulated partial
436                        // instead of silently losing it.
437                        None => {
438                            state.done = true;
439                            let reason = state.partial.stop_reason;
440                            let message = std::mem::take(&mut state.partial);
441                            return Some((Ok(StreamEvent::Done { reason, message }), state));
442                        }
443                    }
444                }
445            },
446        );
447
448        Ok(Box::pin(stream))
449    }
450}
451
452// ============================================================================
453// Stream State
454// ============================================================================
455
456struct StreamState<S>
457where
458    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
459{
460    event_source: SseStream<S>,
461    partial: AssistantMessage,
462    tool_calls: Vec<ToolCallState>,
463    pending_events: VecDeque<StreamEvent>,
464    started: bool,
465    done: bool,
466}
467
468struct ToolCallState {
469    index: usize,
470    content_index: usize,
471    id: String,
472    name: String,
473    arguments: String,
474}
475
476impl<S> StreamState<S>
477where
478    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
479{
480    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
481        Self {
482            event_source,
483            partial: AssistantMessage {
484                content: Vec::new(),
485                api,
486                provider,
487                model,
488                usage: Usage::default(),
489                stop_reason: StopReason::Stop,
490                error_message: None,
491                timestamp: chrono::Utc::now().timestamp_millis(),
492            },
493            tool_calls: Vec::new(),
494            pending_events: VecDeque::new(),
495            started: false,
496            done: false,
497        }
498    }
499
500    fn ensure_started(&mut self) {
501        if !self.started {
502            self.started = true;
503            self.pending_events.push_back(StreamEvent::Start {
504                partial: self.partial.clone(),
505            });
506        }
507    }
508
509    fn process_event(&mut self, data: &str) -> Result<()> {
510        let chunk: OpenAIStreamChunk =
511            serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
512
513        // Handle usage in final chunk
514        if let Some(usage) = chunk.usage {
515            self.partial.usage.input = usage.prompt_tokens;
516            self.partial.usage.output = usage.completion_tokens.unwrap_or(0);
517            self.partial.usage.total_tokens = usage.total_tokens;
518        }
519
520        if let Some(error) = chunk.error {
521            self.partial.stop_reason = StopReason::Error;
522            if let Some(message) = error.message {
523                let message = message.trim();
524                if !message.is_empty() {
525                    self.partial.error_message = Some(message.to_string());
526                }
527            }
528        }
529
530        // Process choices
531        if let Some(choice) = chunk.choices.into_iter().next() {
532            if !self.started
533                && choice.finish_reason.is_none()
534                && choice.delta.content.is_none()
535                && choice.delta.tool_calls.is_none()
536            {
537                self.ensure_started();
538                return Ok(());
539            }
540
541            self.process_choice(choice);
542        }
543
544        Ok(())
545    }
546
547    fn finalize_tool_call_arguments(&mut self) {
548        for tc in &self.tool_calls {
549            let arguments: serde_json::Value = match serde_json::from_str(&tc.arguments) {
550                Ok(args) => args,
551                Err(e) => {
552                    tracing::warn!(
553                        error = %e,
554                        raw = %tc.arguments,
555                        "Failed to parse tool arguments as JSON"
556                    );
557                    serde_json::Value::Null
558                }
559            };
560
561            if let Some(ContentBlock::ToolCall(block)) =
562                self.partial.content.get_mut(tc.content_index)
563            {
564                block.arguments = arguments;
565            }
566        }
567    }
568
569    #[allow(clippy::too_many_lines)]
570    fn process_choice(&mut self, choice: OpenAIChoice) {
571        let delta = choice.delta;
572        if delta.content.is_some()
573            || delta.tool_calls.is_some()
574            || delta.reasoning_content.is_some()
575        {
576            self.ensure_started();
577        }
578
579        // Handle finish reason - may arrive in empty delta without content/tool_calls
580        // Ensure we emit Start before processing finish_reason
581        if choice.finish_reason.is_some() {
582            self.ensure_started();
583        }
584
585        // Handle reasoning content (e.g. DeepSeek R1)
586        if let Some(reasoning) = delta.reasoning_content {
587            // Update partial content
588            let last_is_thinking =
589                matches!(self.partial.content.last(), Some(ContentBlock::Thinking(_)));
590
591            let content_index = if last_is_thinking {
592                self.partial.content.len() - 1
593            } else {
594                let idx = self.partial.content.len();
595                self.partial
596                    .content
597                    .push(ContentBlock::Thinking(ThinkingContent {
598                        thinking: String::new(),
599                        thinking_signature: None,
600                    }));
601
602                self.pending_events
603                    .push_back(StreamEvent::ThinkingStart { content_index: idx });
604
605                idx
606            };
607
608            if let Some(ContentBlock::Thinking(t)) = self.partial.content.get_mut(content_index) {
609                t.thinking.push_str(&reasoning);
610            }
611
612            self.pending_events.push_back(StreamEvent::ThinkingDelta {
613                content_index,
614                delta: reasoning,
615            });
616        }
617
618        // Handle text content
619
620        if let Some(content) = delta.content {
621            // Update partial content
622
623            let last_is_text = matches!(self.partial.content.last(), Some(ContentBlock::Text(_)));
624
625            let content_index = if last_is_text {
626                self.partial.content.len() - 1
627            } else {
628                let idx = self.partial.content.len();
629
630                self.partial
631                    .content
632                    .push(ContentBlock::Text(TextContent::new("")));
633
634                self.pending_events
635                    .push_back(StreamEvent::TextStart { content_index: idx });
636
637                idx
638            };
639
640            if let Some(ContentBlock::Text(t)) = self.partial.content.get_mut(content_index) {
641                t.text.push_str(&content);
642            }
643
644            self.pending_events.push_back(StreamEvent::TextDelta {
645                content_index,
646
647                delta: content,
648            });
649        }
650
651        // Handle tool calls
652
653        if let Some(tool_calls) = delta.tool_calls {
654            for tc_delta in tool_calls {
655                let index = tc_delta.index as usize;
656
657                // OpenAI may emit sparse tool-call indices. Match by logical index
658
659                // instead of assuming contiguous 0..N ordering in arrival order.
660
661                let tool_state_idx = if let Some(existing_idx) =
662                    self.tool_calls.iter().position(|tc| tc.index == index)
663                {
664                    existing_idx
665                } else {
666                    let content_index = self.partial.content.len();
667
668                    self.tool_calls.push(ToolCallState {
669                        index,
670
671                        content_index,
672
673                        id: String::new(),
674
675                        name: String::new(),
676
677                        arguments: String::new(),
678                    });
679
680                    // Initialize the tool call block in partial content
681
682                    self.partial.content.push(ContentBlock::ToolCall(ToolCall {
683                        id: String::new(),
684
685                        name: String::new(),
686
687                        arguments: serde_json::Value::Null,
688
689                        thought_signature: None,
690                    }));
691
692                    self.pending_events
693                        .push_back(StreamEvent::ToolCallStart { content_index });
694
695                    self.tool_calls.len() - 1
696                };
697
698                let tc = &mut self.tool_calls[tool_state_idx];
699
700                let content_index = tc.content_index;
701
702                // Update ID if present
703
704                if let Some(id) = tc_delta.id {
705                    tc.id = id;
706
707                    if let Some(ContentBlock::ToolCall(block)) =
708                        self.partial.content.get_mut(content_index)
709                    {
710                        block.id.clone_from(&tc.id);
711                    }
712                }
713
714                // Update function name if present
715
716                if let Some(function) = tc_delta.function {
717                    if let Some(name) = function.name {
718                        tc.name = name;
719
720                        if let Some(ContentBlock::ToolCall(block)) =
721                            self.partial.content.get_mut(content_index)
722                        {
723                            block.name.clone_from(&tc.name);
724                        }
725                    }
726
727                    if let Some(args) = function.arguments {
728                        tc.arguments.push_str(&args);
729
730                        // Update arguments in partial (best effort parse, or just raw string if we supported it)
731
732                        // Note: We don't update partial.arguments here because it requires valid JSON.
733
734                        // We only update it at the end or if we switched to storing raw string args.
735
736                        // But we MUST emit the delta.
737
738                        self.pending_events.push_back(StreamEvent::ToolCallDelta {
739                            content_index,
740
741                            delta: args,
742                        });
743                    }
744                }
745            }
746        }
747
748        // Handle finish reason (MUST happen after delta processing to capture final chunks)
749
750        if let Some(reason) = choice.finish_reason {
751            self.partial.stop_reason = match reason.as_str() {
752                "length" => StopReason::Length,
753
754                "tool_calls" => StopReason::ToolUse,
755
756                "content_filter" | "error" => StopReason::Error,
757
758                _ => StopReason::Stop,
759            };
760
761            // Emit TextEnd/ThinkingEnd for all open text/thinking blocks (not just the last one,
762            // since text/thinking may precede tool calls).
763
764            for (content_index, block) in self.partial.content.iter().enumerate() {
765                if let ContentBlock::Text(t) = block {
766                    self.pending_events.push_back(StreamEvent::TextEnd {
767                        content_index,
768                        content: t.text.clone(),
769                    });
770                } else if let ContentBlock::Thinking(t) = block {
771                    self.pending_events.push_back(StreamEvent::ThinkingEnd {
772                        content_index,
773                        content: t.thinking.clone(),
774                    });
775                }
776            }
777
778            // Finalize tool call arguments
779
780            self.finalize_tool_call_arguments();
781
782            // Emit ToolCallEnd for each accumulated tool call
783
784            for tc in &self.tool_calls {
785                if let Some(ContentBlock::ToolCall(tool_call)) =
786                    self.partial.content.get(tc.content_index)
787                {
788                    self.pending_events.push_back(StreamEvent::ToolCallEnd {
789                        content_index: tc.content_index,
790
791                        tool_call: tool_call.clone(),
792                    });
793                }
794            }
795        }
796    }
797}
798
799// ============================================================================
800// OpenAI API Types
801// ============================================================================
802
803#[derive(Debug, Serialize)]
804pub struct OpenAIRequest<'a> {
805    model: &'a str,
806    messages: Vec<OpenAIMessage<'a>>,
807    #[serde(skip_serializing_if = "Option::is_none")]
808    max_tokens: Option<u32>,
809    /// Some providers (e.g., o1-series) use `max_completion_tokens` instead of `max_tokens`.
810    #[serde(skip_serializing_if = "Option::is_none")]
811    max_completion_tokens: Option<u32>,
812    #[serde(skip_serializing_if = "Option::is_none")]
813    temperature: Option<f32>,
814    #[serde(skip_serializing_if = "Option::is_none")]
815    tools: Option<Vec<OpenAITool<'a>>>,
816    stream: bool,
817    #[serde(skip_serializing_if = "Option::is_none")]
818    stream_options: Option<OpenAIStreamOptions>,
819}
820
821#[derive(Debug, Serialize)]
822struct OpenAIStreamOptions {
823    include_usage: bool,
824}
825
826#[derive(Debug, Serialize)]
827struct OpenAIMessage<'a> {
828    role: Cow<'a, str>,
829    #[serde(skip_serializing_if = "Option::is_none")]
830    content: Option<OpenAIContent<'a>>,
831    #[serde(skip_serializing_if = "Option::is_none")]
832    tool_calls: Option<Vec<OpenAIToolCallRef<'a>>>,
833    #[serde(skip_serializing_if = "Option::is_none")]
834    tool_call_id: Option<&'a str>,
835}
836
837#[derive(Debug, Serialize)]
838#[serde(untagged)]
839enum OpenAIContent<'a> {
840    Text(Cow<'a, str>),
841    Parts(Vec<OpenAIContentPart<'a>>),
842}
843
844#[derive(Debug, Serialize)]
845#[serde(tag = "type", rename_all = "snake_case")]
846enum OpenAIContentPart<'a> {
847    Text { text: Cow<'a, str> },
848    ImageUrl { image_url: OpenAIImageUrl<'a> },
849}
850
851#[derive(Debug, Serialize)]
852struct OpenAIImageUrl<'a> {
853    url: String,
854    #[serde(skip)]
855    // Phantom data for lifetime if needed, but url is String here as constructed from format!
856    _phantom: std::marker::PhantomData<&'a ()>,
857}
858
859#[derive(Debug, Serialize)]
860struct OpenAIToolCallRef<'a> {
861    id: &'a str,
862    r#type: &'static str,
863    function: OpenAIFunctionRef<'a>,
864}
865
866#[derive(Debug, Serialize)]
867struct OpenAIFunctionRef<'a> {
868    name: &'a str,
869    arguments: String,
870}
871
872#[derive(Debug, Serialize)]
873struct OpenAITool<'a> {
874    r#type: &'static str,
875    function: OpenAIFunction<'a>,
876}
877
878#[derive(Debug, Serialize)]
879struct OpenAIFunction<'a> {
880    name: &'a str,
881    description: &'a str,
882    parameters: &'a serde_json::Value,
883}
884
885// ============================================================================
886// Streaming Response Types
887// ============================================================================
888
889#[derive(Debug, Deserialize)]
890struct OpenAIStreamChunk {
891    #[serde(default)]
892    choices: Vec<OpenAIChoice>,
893    #[serde(default)]
894    usage: Option<OpenAIUsage>,
895    #[serde(default)]
896    error: Option<OpenAIChunkError>,
897}
898
899#[derive(Debug, Deserialize)]
900struct OpenAIChoice {
901    delta: OpenAIDelta,
902    #[serde(default)]
903    finish_reason: Option<String>,
904}
905
906#[derive(Debug, Deserialize)]
907struct OpenAIDelta {
908    #[serde(default)]
909    content: Option<String>,
910    #[serde(default)]
911    reasoning_content: Option<String>,
912    #[serde(default)]
913    tool_calls: Option<Vec<OpenAIToolCallDelta>>,
914}
915
916#[derive(Debug, Deserialize)]
917struct OpenAIToolCallDelta {
918    index: u32,
919    #[serde(default)]
920    id: Option<String>,
921    #[serde(default)]
922    function: Option<OpenAIFunctionDelta>,
923}
924
925#[derive(Debug, Deserialize)]
926struct OpenAIFunctionDelta {
927    #[serde(default)]
928    name: Option<String>,
929    #[serde(default)]
930    arguments: Option<String>,
931}
932
933#[derive(Debug, Deserialize)]
934#[allow(clippy::struct_field_names)]
935struct OpenAIUsage {
936    prompt_tokens: u64,
937    #[serde(default)]
938    completion_tokens: Option<u64>,
939    total_tokens: u64,
940}
941
942#[derive(Debug, Deserialize)]
943struct OpenAIChunkError {
944    #[serde(default)]
945    message: Option<String>,
946}
947
948// ============================================================================
949// Conversion Functions
950// ============================================================================
951
952fn convert_message_to_openai(message: &Message) -> Vec<OpenAIMessage<'_>> {
953    match message {
954        Message::User(user) => vec![OpenAIMessage {
955            role: Cow::Borrowed("user"),
956            content: Some(convert_user_content(&user.content)),
957            tool_calls: None,
958            tool_call_id: None,
959        }],
960        Message::Custom(custom) => vec![OpenAIMessage {
961            role: Cow::Borrowed("user"),
962            content: Some(OpenAIContent::Text(Cow::Borrowed(&custom.content))),
963            tool_calls: None,
964            tool_call_id: None,
965        }],
966        Message::Assistant(assistant) => {
967            let mut messages = Vec::new();
968
969            // Collect text content
970            let text: String = assistant
971                .content
972                .iter()
973                .filter_map(|b| match b {
974                    ContentBlock::Text(t) => Some(t.text.as_str()),
975                    _ => None,
976                })
977                .collect::<String>();
978
979            // Collect tool calls
980            let tool_calls: Vec<OpenAIToolCallRef<'_>> = assistant
981                .content
982                .iter()
983                .filter_map(|b| match b {
984                    ContentBlock::ToolCall(tc) => Some(OpenAIToolCallRef {
985                        id: &tc.id,
986                        r#type: "function",
987                        function: OpenAIFunctionRef {
988                            name: &tc.name,
989                            arguments: tc.arguments.to_string(),
990                        },
991                    }),
992                    _ => None,
993                })
994                .collect();
995
996            let content = if text.is_empty() {
997                None
998            } else {
999                Some(OpenAIContent::Text(Cow::Owned(text)))
1000            };
1001
1002            let tool_calls = if tool_calls.is_empty() {
1003                None
1004            } else {
1005                Some(tool_calls)
1006            };
1007
1008            messages.push(OpenAIMessage {
1009                role: Cow::Borrowed("assistant"),
1010                content,
1011                tool_calls,
1012                tool_call_id: None,
1013            });
1014
1015            messages
1016        }
1017        Message::ToolResult(result) => {
1018            // OpenAI expects tool results as separate messages with role "tool"
1019            let parts: Vec<OpenAIContentPart<'_>> = result
1020                .content
1021                .iter()
1022                .filter_map(|block| match block {
1023                    ContentBlock::Text(t) => Some(OpenAIContentPart::Text {
1024                        text: Cow::Borrowed(&t.text),
1025                    }),
1026                    ContentBlock::Image(img) => {
1027                        let url = format!("data:{};base64,{}", img.mime_type, img.data);
1028                        Some(OpenAIContentPart::ImageUrl {
1029                            image_url: OpenAIImageUrl {
1030                                url,
1031                                _phantom: std::marker::PhantomData,
1032                            },
1033                        })
1034                    }
1035                    _ => None,
1036                })
1037                .collect();
1038
1039            let content = if parts.is_empty() {
1040                None
1041            } else if parts.len() == 1 && matches!(parts[0], OpenAIContentPart::Text { .. }) {
1042                // Optimization: use simple text content if possible
1043                if let OpenAIContentPart::Text { text } = &parts[0] {
1044                    Some(OpenAIContent::Text(text.clone()))
1045                } else {
1046                    Some(OpenAIContent::Parts(parts))
1047                }
1048            } else {
1049                Some(OpenAIContent::Parts(parts))
1050            };
1051
1052            vec![OpenAIMessage {
1053                role: Cow::Borrowed("tool"),
1054                content,
1055                tool_calls: None,
1056                tool_call_id: Some(&result.tool_call_id),
1057            }]
1058        }
1059    }
1060}
1061
1062fn convert_user_content(content: &UserContent) -> OpenAIContent<'_> {
1063    match content {
1064        UserContent::Text(text) => OpenAIContent::Text(Cow::Borrowed(text)),
1065        UserContent::Blocks(blocks) => {
1066            let parts: Vec<OpenAIContentPart<'_>> = blocks
1067                .iter()
1068                .filter_map(|block| match block {
1069                    ContentBlock::Text(t) => Some(OpenAIContentPart::Text {
1070                        text: Cow::Borrowed(&t.text),
1071                    }),
1072                    ContentBlock::Image(img) => {
1073                        // Convert to data URL for OpenAI
1074                        let url = format!("data:{};base64,{}", img.mime_type, img.data);
1075                        Some(OpenAIContentPart::ImageUrl {
1076                            image_url: OpenAIImageUrl {
1077                                url,
1078                                _phantom: std::marker::PhantomData,
1079                            },
1080                        })
1081                    }
1082                    _ => None,
1083                })
1084                .collect();
1085            OpenAIContent::Parts(parts)
1086        }
1087    }
1088}
1089
1090fn convert_tool_to_openai(tool: &ToolDef) -> OpenAITool<'_> {
1091    OpenAITool {
1092        r#type: "function",
1093        function: OpenAIFunction {
1094            name: &tool.name,
1095            description: &tool.description,
1096            parameters: &tool.parameters,
1097        },
1098    }
1099}
1100
1101// ============================================================================
1102// Tests
1103// ============================================================================
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use asupersync::runtime::RuntimeBuilder;
1109    use futures::{StreamExt, stream};
1110    use serde::{Deserialize, Serialize};
1111    use serde_json::{Value, json};
1112    use std::collections::HashMap;
1113    use std::io::{Read, Write};
1114    use std::net::TcpListener;
1115    use std::path::PathBuf;
1116    use std::sync::mpsc;
1117    use std::time::Duration;
1118
1119    #[test]
1120    fn test_convert_user_text_message() {
1121        let message = Message::User(crate::model::UserMessage {
1122            content: UserContent::Text("Hello".to_string()),
1123            timestamp: 0,
1124        });
1125
1126        let converted = convert_message_to_openai(&message);
1127        assert_eq!(converted.len(), 1);
1128        assert_eq!(converted[0].role, "user");
1129    }
1130
1131    #[test]
1132    fn test_tool_conversion() {
1133        let tool = ToolDef {
1134            name: "test_tool".to_string(),
1135            description: "A test tool".to_string(),
1136            parameters: serde_json::json!({
1137                "type": "object",
1138                "properties": {
1139                    "arg": {"type": "string"}
1140                }
1141            }),
1142        };
1143
1144        let converted = convert_tool_to_openai(&tool);
1145        assert_eq!(converted.r#type, "function");
1146        assert_eq!(converted.function.name, "test_tool");
1147        assert_eq!(converted.function.description, "A test tool");
1148        assert_eq!(
1149            converted.function.parameters,
1150            &serde_json::json!({
1151                "type": "object",
1152                "properties": {
1153                    "arg": {"type": "string"}
1154                }
1155            })
1156        );
1157    }
1158
1159    #[test]
1160    fn test_provider_info() {
1161        let provider = OpenAIProvider::new("gpt-4o");
1162        assert_eq!(provider.name(), "openai");
1163        assert_eq!(provider.api(), "openai-completions");
1164    }
1165
1166    #[test]
1167    fn test_build_request_includes_system_tools_and_stream_options() {
1168        let provider = OpenAIProvider::new("gpt-4o");
1169        let context = Context {
1170            system_prompt: Some("You are concise.".to_string().into()),
1171            messages: vec![Message::User(crate::model::UserMessage {
1172                content: UserContent::Text("Ping".to_string()),
1173                timestamp: 0,
1174            })]
1175            .into(),
1176            tools: vec![ToolDef {
1177                name: "search".to_string(),
1178                description: "Search docs".to_string(),
1179                parameters: json!({
1180                    "type": "object",
1181                    "properties": {
1182                        "q": { "type": "string" }
1183                    },
1184                    "required": ["q"]
1185                }),
1186            }]
1187            .into(),
1188        };
1189        let options = StreamOptions {
1190            temperature: Some(0.2),
1191            max_tokens: Some(123),
1192            ..Default::default()
1193        };
1194
1195        let request = provider.build_request(&context, &options);
1196        let value = serde_json::to_value(&request).expect("serialize request");
1197        assert_eq!(value["model"], "gpt-4o");
1198        assert_eq!(value["messages"][0]["role"], "system");
1199        assert_eq!(value["messages"][0]["content"], "You are concise.");
1200        assert_eq!(value["messages"][1]["role"], "user");
1201        assert_eq!(value["messages"][1]["content"], "Ping");
1202        let temperature = value["temperature"]
1203            .as_f64()
1204            .expect("temperature should serialize as number");
1205        assert!((temperature - 0.2).abs() < 1e-6);
1206        assert_eq!(value["max_tokens"], 123);
1207        assert_eq!(value["stream"], true);
1208        assert_eq!(value["stream_options"]["include_usage"], true);
1209        assert_eq!(value["tools"][0]["type"], "function");
1210        assert_eq!(value["tools"][0]["function"]["name"], "search");
1211        assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1212        assert_eq!(
1213            value["tools"][0]["function"]["parameters"],
1214            json!({
1215                "type": "object",
1216                "properties": {
1217                    "q": { "type": "string" }
1218                },
1219                "required": ["q"]
1220            })
1221        );
1222    }
1223
1224    #[test]
1225    fn test_stream_accumulates_tool_call_argument_deltas() {
1226        let events = vec![
1227            json!({ "choices": [{ "delta": {} }] }),
1228            json!({
1229                "choices": [{
1230                    "delta": {
1231                        "tool_calls": [{
1232                            "index": 0,
1233                            "id": "call_1",
1234                            "function": {
1235                                "name": "search",
1236                                "arguments": "{\"q\":\"ru"
1237                            }
1238                        }]
1239                    }
1240                }]
1241            }),
1242            json!({
1243                "choices": [{
1244                    "delta": {
1245                        "tool_calls": [{
1246                            "index": 0,
1247                            "function": {
1248                                "arguments": "st\"}"
1249                            }
1250                        }]
1251                    }
1252                }]
1253            }),
1254            json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1255            Value::String("[DONE]".to_string()),
1256        ];
1257
1258        let out = collect_events(&events);
1259        assert!(
1260            out.iter()
1261                .any(|e| matches!(e, StreamEvent::ToolCallStart { .. }))
1262        );
1263        assert!(out.iter().any(
1264            |e| matches!(e, StreamEvent::ToolCallDelta { delta, .. } if delta == "{\"q\":\"ru")
1265        ));
1266        assert!(
1267            out.iter()
1268                .any(|e| matches!(e, StreamEvent::ToolCallDelta { delta, .. } if delta == "st\"}"))
1269        );
1270        let done = out
1271            .iter()
1272            .find_map(|event| match event {
1273                StreamEvent::Done { message, .. } => Some(message),
1274                _ => None,
1275            })
1276            .expect("done event");
1277        let tool_call = done
1278            .content
1279            .iter()
1280            .find_map(|block| match block {
1281                ContentBlock::ToolCall(tc) => Some(tc),
1282                _ => None,
1283            })
1284            .expect("assembled tool call content");
1285        assert_eq!(tool_call.id, "call_1");
1286        assert_eq!(tool_call.name, "search");
1287        assert_eq!(tool_call.arguments, json!({ "q": "rust" }));
1288        assert!(out.iter().any(|e| matches!(
1289            e,
1290            StreamEvent::Done {
1291                reason: StopReason::ToolUse,
1292                ..
1293            }
1294        )));
1295    }
1296
1297    #[test]
1298    fn test_stream_handles_sparse_tool_call_index_without_panic() {
1299        let events = vec![
1300            json!({ "choices": [{ "delta": {} }] }),
1301            json!({
1302                "choices": [{
1303                    "delta": {
1304                        "tool_calls": [{
1305                            "index": 2,
1306                            "id": "call_sparse",
1307                            "function": {
1308                                "name": "lookup",
1309                                "arguments": "{\"q\":\"sparse\"}"
1310                            }
1311                        }]
1312                    }
1313                }]
1314            }),
1315            json!({ "choices": [{ "delta": {}, "finish_reason": "tool_calls" }] }),
1316            Value::String("[DONE]".to_string()),
1317        ];
1318
1319        let out = collect_events(&events);
1320        let done = out
1321            .iter()
1322            .find_map(|event| match event {
1323                StreamEvent::Done { message, .. } => Some(message),
1324                _ => None,
1325            })
1326            .expect("done event");
1327        let tool_calls: Vec<&ToolCall> = done
1328            .content
1329            .iter()
1330            .filter_map(|block| match block {
1331                ContentBlock::ToolCall(tc) => Some(tc),
1332                _ => None,
1333            })
1334            .collect();
1335        assert_eq!(tool_calls.len(), 1);
1336        assert_eq!(tool_calls[0].id, "call_sparse");
1337        assert_eq!(tool_calls[0].name, "lookup");
1338        assert_eq!(tool_calls[0].arguments, json!({ "q": "sparse" }));
1339        assert!(
1340            out.iter()
1341                .any(|event| matches!(event, StreamEvent::ToolCallStart { .. })),
1342            "expected tool call start event"
1343        );
1344    }
1345
1346    #[test]
1347    fn test_stream_maps_finish_reason_error_to_stop_reason_error() {
1348        let events = vec![
1349            json!({
1350                "choices": [{ "delta": {}, "finish_reason": "error" }],
1351                "error": { "message": "upstream provider timeout" }
1352            }),
1353            Value::String("[DONE]".to_string()),
1354        ];
1355
1356        let out = collect_events(&events);
1357        let done = out
1358            .iter()
1359            .find_map(|event| match event {
1360                StreamEvent::Done { reason, message } => Some((reason, message)),
1361                _ => None,
1362            })
1363            .expect("done event");
1364        assert_eq!(*done.0, StopReason::Error);
1365        assert_eq!(
1366            done.1.error_message.as_deref(),
1367            Some("upstream provider timeout")
1368        );
1369    }
1370
1371    #[test]
1372    fn test_finish_reason_without_prior_content_emits_start() {
1373        let events = vec![
1374            json!({ "choices": [{ "delta": {}, "finish_reason": "stop" }] }),
1375            Value::String("[DONE]".to_string()),
1376        ];
1377
1378        let out = collect_events(&events);
1379
1380        // Should have: Start, Done
1381        // First event must be Start (bug would skip this)
1382        assert!(!out.is_empty(), "expected at least one event");
1383        assert!(
1384            matches!(out[0], StreamEvent::Start { .. }),
1385            "First event should be Start, got {:?}",
1386            out[0]
1387        );
1388    }
1389
1390    #[test]
1391    fn test_stream_emits_all_events_in_correct_order() {
1392        let events = vec![
1393            json!({ "choices": [{ "delta": { "content": "Hello" } }] }),
1394            json!({ "choices": [{ "delta": { "content": " world" } }] }),
1395            json!({ "choices": [{ "delta": {}, "finish_reason": "stop" }] }),
1396            Value::String("[DONE]".to_string()),
1397        ];
1398
1399        let out = collect_events(&events);
1400
1401        // Verify sequence: Start, TextStart, TextDelta, TextDelta, TextEnd, Done
1402        assert_eq!(out.len(), 6, "Expected 6 events, got {}", out.len());
1403
1404        assert!(
1405            matches!(out[0], StreamEvent::Start { .. }),
1406            "Event 0 should be Start, got {:?}",
1407            out[0]
1408        );
1409
1410        assert!(
1411            matches!(
1412                out[1],
1413                StreamEvent::TextStart {
1414                    content_index: 0,
1415                    ..
1416                }
1417            ),
1418            "Event 1 should be TextStart at index 0, got {:?}",
1419            out[1]
1420        );
1421
1422        assert!(
1423            matches!(&out[2], StreamEvent::TextDelta { content_index: 0, delta, .. } if delta == "Hello"),
1424            "Event 2 should be TextDelta 'Hello' at index 0, got {:?}",
1425            out[2]
1426        );
1427
1428        assert!(
1429            matches!(&out[3], StreamEvent::TextDelta { content_index: 0, delta, .. } if delta == " world"),
1430            "Event 3 should be TextDelta ' world' at index 0, got {:?}",
1431            out[3]
1432        );
1433
1434        assert!(
1435            matches!(&out[4], StreamEvent::TextEnd { content_index: 0, content, .. } if content == "Hello world"),
1436            "Event 4 should be TextEnd 'Hello world' at index 0, got {:?}",
1437            out[4]
1438        );
1439
1440        assert!(
1441            matches!(
1442                out[5],
1443                StreamEvent::Done {
1444                    reason: StopReason::Stop,
1445                    ..
1446                }
1447            ),
1448            "Event 5 should be Done with Stop reason, got {:?}",
1449            out[5]
1450        );
1451    }
1452
1453    #[test]
1454    fn test_build_request_applies_openrouter_routing_overrides() {
1455        let provider = OpenAIProvider::new("openai/gpt-4o-mini")
1456            .with_provider_name("openrouter")
1457            .with_compat(Some(CompatConfig {
1458                open_router_routing: Some(json!({
1459                    "models": ["openai/gpt-4o-mini", "anthropic/claude-3.5-sonnet"],
1460                    "provider": {
1461                        "order": ["openai", "anthropic"],
1462                        "allow_fallbacks": false
1463                    },
1464                    "route": "fallback"
1465                })),
1466                ..CompatConfig::default()
1467            }));
1468        let context = Context {
1469            system_prompt: None,
1470            messages: vec![Message::User(crate::model::UserMessage {
1471                content: UserContent::Text("Ping".to_string()),
1472                timestamp: 0,
1473            })]
1474            .into(),
1475            tools: Vec::new().into(),
1476        };
1477        let options = StreamOptions::default();
1478
1479        let request = provider
1480            .build_request_json(&context, &options)
1481            .expect("request json");
1482        assert_eq!(request["model"], "openai/gpt-4o-mini");
1483        assert_eq!(request["route"], "fallback");
1484        assert_eq!(request["provider"]["allow_fallbacks"], false);
1485        assert_eq!(request["models"][0], "openai/gpt-4o-mini");
1486        assert_eq!(request["models"][1], "anthropic/claude-3.5-sonnet");
1487    }
1488
1489    #[test]
1490    fn test_stream_sets_bearer_auth_header() {
1491        let captured = run_stream_and_capture_headers().expect("captured request");
1492        assert_eq!(
1493            captured.headers.get("authorization").map(String::as_str),
1494            Some("Bearer test-openai-key")
1495        );
1496        assert_eq!(
1497            captured.headers.get("accept").map(String::as_str),
1498            Some("text/event-stream")
1499        );
1500
1501        let body: Value = serde_json::from_str(&captured.body).expect("request body json");
1502        assert_eq!(body["stream"], true);
1503        assert_eq!(body["stream_options"]["include_usage"], true);
1504    }
1505
1506    #[test]
1507    fn test_stream_openrouter_injects_default_attribution_headers() {
1508        let options = StreamOptions {
1509            api_key: Some("test-openrouter-key".to_string()),
1510            ..Default::default()
1511        };
1512        let captured = run_stream_and_capture_headers_with(
1513            OpenAIProvider::new("openai/gpt-4o-mini").with_provider_name("openrouter"),
1514            &options,
1515        )
1516        .expect("captured request");
1517
1518        assert_eq!(
1519            captured.headers.get("http-referer").map(String::as_str),
1520            Some(OPENROUTER_DEFAULT_HTTP_REFERER)
1521        );
1522        assert_eq!(
1523            captured.headers.get("x-title").map(String::as_str),
1524            Some(OPENROUTER_DEFAULT_X_TITLE)
1525        );
1526    }
1527
1528    #[test]
1529    fn test_stream_openrouter_respects_explicit_attribution_headers() {
1530        let options = StreamOptions {
1531            api_key: Some("test-openrouter-key".to_string()),
1532            headers: HashMap::from([
1533                (
1534                    "HTTP-Referer".to_string(),
1535                    "https://example.test/app".to_string(),
1536                ),
1537                (
1538                    "X-Title".to_string(),
1539                    "Custom OpenRouter Client".to_string(),
1540                ),
1541            ]),
1542            ..Default::default()
1543        };
1544        let captured = run_stream_and_capture_headers_with(
1545            OpenAIProvider::new("openai/gpt-4o-mini").with_provider_name("openrouter"),
1546            &options,
1547        )
1548        .expect("captured request");
1549
1550        assert_eq!(
1551            captured.headers.get("http-referer").map(String::as_str),
1552            Some("https://example.test/app")
1553        );
1554        assert_eq!(
1555            captured.headers.get("x-title").map(String::as_str),
1556            Some("Custom OpenRouter Client")
1557        );
1558    }
1559
1560    #[derive(Debug, Deserialize)]
1561    struct ProviderFixture {
1562        cases: Vec<ProviderCase>,
1563    }
1564
1565    #[derive(Debug, Deserialize)]
1566    struct ProviderCase {
1567        name: String,
1568        events: Vec<Value>,
1569        expected: Vec<EventSummary>,
1570    }
1571
1572    #[derive(Debug, Deserialize, Serialize, PartialEq)]
1573    struct EventSummary {
1574        kind: String,
1575        #[serde(default)]
1576        content_index: Option<usize>,
1577        #[serde(default)]
1578        delta: Option<String>,
1579        #[serde(default)]
1580        content: Option<String>,
1581        #[serde(default)]
1582        reason: Option<String>,
1583    }
1584
1585    #[test]
1586    fn test_stream_fixtures() {
1587        let fixture = load_fixture("openai_stream.json");
1588        for case in fixture.cases {
1589            let events = collect_events(&case.events);
1590            let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1591            assert_eq!(summaries, case.expected, "case {}", case.name);
1592        }
1593    }
1594
1595    fn load_fixture(file_name: &str) -> ProviderFixture {
1596        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1597            .join("tests/fixtures/provider_responses")
1598            .join(file_name);
1599        let raw = std::fs::read_to_string(path).expect("fixture read");
1600        serde_json::from_str(&raw).expect("fixture parse")
1601    }
1602
1603    #[derive(Debug)]
1604    struct CapturedRequest {
1605        headers: HashMap<String, String>,
1606        body: String,
1607    }
1608
1609    fn run_stream_and_capture_headers() -> Option<CapturedRequest> {
1610        let options = StreamOptions {
1611            api_key: Some("test-openai-key".to_string()),
1612            ..Default::default()
1613        };
1614        run_stream_and_capture_headers_with(OpenAIProvider::new("gpt-4o"), &options)
1615    }
1616
1617    fn run_stream_and_capture_headers_with(
1618        provider: OpenAIProvider,
1619        options: &StreamOptions,
1620    ) -> Option<CapturedRequest> {
1621        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1622        let provider = provider.with_base_url(base_url);
1623        let context = Context {
1624            system_prompt: None,
1625            messages: vec![Message::User(crate::model::UserMessage {
1626                content: UserContent::Text("ping".to_string()),
1627                timestamp: 0,
1628            })]
1629            .into(),
1630            tools: Vec::new().into(),
1631        };
1632
1633        let runtime = RuntimeBuilder::current_thread()
1634            .build()
1635            .expect("runtime build");
1636        runtime.block_on(async {
1637            let mut stream = provider.stream(&context, options).await.expect("stream");
1638            while let Some(event) = stream.next().await {
1639                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1640                    break;
1641                }
1642            }
1643        });
1644
1645        rx.recv_timeout(Duration::from_secs(2)).ok()
1646    }
1647
1648    fn success_sse_body() -> String {
1649        [
1650            r#"data: {"choices":[{"delta":{}}]}"#,
1651            "",
1652            r#"data: {"choices":[{"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}"#,
1653            "",
1654            "data: [DONE]",
1655            "",
1656        ]
1657        .join("\n")
1658    }
1659
1660    fn spawn_test_server(
1661        status_code: u16,
1662        content_type: &str,
1663        body: &str,
1664    ) -> (String, mpsc::Receiver<CapturedRequest>) {
1665        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1666        let addr = listener.local_addr().expect("local addr");
1667        let (tx, rx) = mpsc::channel();
1668        let body = body.to_string();
1669        let content_type = content_type.to_string();
1670
1671        std::thread::spawn(move || {
1672            let (mut socket, _) = listener.accept().expect("accept");
1673            socket
1674                .set_read_timeout(Some(Duration::from_secs(2)))
1675                .expect("set read timeout");
1676
1677            let mut bytes = Vec::new();
1678            let mut chunk = [0_u8; 4096];
1679            loop {
1680                match socket.read(&mut chunk) {
1681                    Ok(0) => break,
1682                    Ok(n) => {
1683                        bytes.extend_from_slice(&chunk[..n]);
1684                        if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1685                            break;
1686                        }
1687                    }
1688                    Err(err)
1689                        if err.kind() == std::io::ErrorKind::WouldBlock
1690                            || err.kind() == std::io::ErrorKind::TimedOut =>
1691                    {
1692                        break;
1693                    }
1694                    Err(err) => panic!("read request failed: {err}"),
1695                }
1696            }
1697
1698            let header_end = bytes
1699                .windows(4)
1700                .position(|window| window == b"\r\n\r\n")
1701                .expect("request header boundary");
1702            let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1703            let headers = parse_headers(&header_text);
1704            let mut request_body = bytes[header_end + 4..].to_vec();
1705
1706            let content_length = headers
1707                .get("content-length")
1708                .and_then(|value| value.parse::<usize>().ok())
1709                .unwrap_or(0);
1710            while request_body.len() < content_length {
1711                match socket.read(&mut chunk) {
1712                    Ok(0) => break,
1713                    Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1714                    Err(err)
1715                        if err.kind() == std::io::ErrorKind::WouldBlock
1716                            || err.kind() == std::io::ErrorKind::TimedOut =>
1717                    {
1718                        break;
1719                    }
1720                    Err(err) => panic!("read request body failed: {err}"),
1721                }
1722            }
1723
1724            let captured = CapturedRequest {
1725                headers,
1726                body: String::from_utf8_lossy(&request_body).to_string(),
1727            };
1728            tx.send(captured).expect("send captured request");
1729
1730            let reason = match status_code {
1731                401 => "Unauthorized",
1732                500 => "Internal Server Error",
1733                _ => "OK",
1734            };
1735            let response = format!(
1736                "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1737                body.len()
1738            );
1739            socket
1740                .write_all(response.as_bytes())
1741                .expect("write response");
1742            socket.flush().expect("flush response");
1743        });
1744
1745        (format!("http://{addr}/chat/completions"), rx)
1746    }
1747
1748    fn parse_headers(header_text: &str) -> HashMap<String, String> {
1749        let mut headers = HashMap::new();
1750        for line in header_text.lines().skip(1) {
1751            if let Some((name, value)) = line.split_once(':') {
1752                headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1753            }
1754        }
1755        headers
1756    }
1757
1758    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1759        let runtime = RuntimeBuilder::current_thread()
1760            .build()
1761            .expect("runtime build");
1762        runtime.block_on(async move {
1763            let byte_stream = stream::iter(
1764                events
1765                    .iter()
1766                    .map(|event| {
1767                        let data = match event {
1768                            Value::String(text) => text.clone(),
1769                            _ => serde_json::to_string(event).expect("serialize event"),
1770                        };
1771                        format!("data: {data}\n\n").into_bytes()
1772                    })
1773                    .map(Ok),
1774            );
1775            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1776            let mut state = StreamState::new(
1777                event_source,
1778                "gpt-test".to_string(),
1779                "openai".to_string(),
1780                "openai".to_string(),
1781            );
1782            let mut out = Vec::new();
1783
1784            while let Some(item) = state.event_source.next().await {
1785                let msg = item.expect("SSE event");
1786                if msg.data == "[DONE]" {
1787                    out.extend(state.pending_events.drain(..));
1788                    let reason = state.partial.stop_reason;
1789                    out.push(StreamEvent::Done {
1790                        reason,
1791                        message: std::mem::take(&mut state.partial),
1792                    });
1793                    break;
1794                }
1795                state.process_event(&msg.data).expect("process_event");
1796                out.extend(state.pending_events.drain(..));
1797            }
1798
1799            out
1800        })
1801    }
1802
1803    fn summarize_event(event: &StreamEvent) -> EventSummary {
1804        match event {
1805            StreamEvent::Start { .. } => EventSummary {
1806                kind: "start".to_string(),
1807                content_index: None,
1808                delta: None,
1809                content: None,
1810                reason: None,
1811            },
1812            StreamEvent::TextDelta {
1813                content_index,
1814                delta,
1815                ..
1816            } => EventSummary {
1817                kind: "text_delta".to_string(),
1818                content_index: Some(*content_index),
1819                delta: Some(delta.clone()),
1820                content: None,
1821                reason: None,
1822            },
1823            StreamEvent::Done { reason, .. } => EventSummary {
1824                kind: "done".to_string(),
1825                content_index: None,
1826                delta: None,
1827                content: None,
1828                reason: Some(reason_to_string(*reason)),
1829            },
1830            StreamEvent::Error { reason, .. } => EventSummary {
1831                kind: "error".to_string(),
1832                content_index: None,
1833                delta: None,
1834                content: None,
1835                reason: Some(reason_to_string(*reason)),
1836            },
1837            StreamEvent::TextStart { content_index, .. } => EventSummary {
1838                kind: "text_start".to_string(),
1839                content_index: Some(*content_index),
1840                delta: None,
1841                content: None,
1842                reason: None,
1843            },
1844            StreamEvent::TextEnd {
1845                content_index,
1846                content,
1847                ..
1848            } => EventSummary {
1849                kind: "text_end".to_string(),
1850                content_index: Some(*content_index),
1851                delta: None,
1852                content: Some(content.clone()),
1853                reason: None,
1854            },
1855            _ => EventSummary {
1856                kind: "other".to_string(),
1857                content_index: None,
1858                delta: None,
1859                content: None,
1860                reason: None,
1861            },
1862        }
1863    }
1864
1865    fn reason_to_string(reason: StopReason) -> String {
1866        match reason {
1867            StopReason::Stop => "stop",
1868            StopReason::Length => "length",
1869            StopReason::ToolUse => "tool_use",
1870            StopReason::Error => "error",
1871            StopReason::Aborted => "aborted",
1872        }
1873        .to_string()
1874    }
1875
1876    // ── bd-3uqg.2.4: compat override behavior ──────────────────────
1877
1878    fn context_with_tools() -> Context<'static> {
1879        Context {
1880            system_prompt: Some("You are helpful.".to_string().into()),
1881            messages: vec![Message::User(crate::model::UserMessage {
1882                content: UserContent::Text("Hi".to_string()),
1883                timestamp: 0,
1884            })]
1885            .into(),
1886            tools: vec![ToolDef {
1887                name: "search".to_string(),
1888                description: "Search".to_string(),
1889                parameters: json!({"type": "object", "properties": {}}),
1890            }]
1891            .into(),
1892        }
1893    }
1894
1895    fn default_stream_options() -> StreamOptions {
1896        StreamOptions {
1897            max_tokens: Some(1024),
1898            ..Default::default()
1899        }
1900    }
1901
1902    #[test]
1903    fn compat_system_role_name_overrides_default() {
1904        let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1905            system_role_name: Some("developer".to_string()),
1906            ..Default::default()
1907        }));
1908        let context = context_with_tools();
1909        let options = default_stream_options();
1910        let req = provider.build_request(&context, &options);
1911        let value = serde_json::to_value(&req).expect("serialize");
1912        assert_eq!(
1913            value["messages"][0]["role"], "developer",
1914            "system message should use overridden role name"
1915        );
1916    }
1917
1918    #[test]
1919    fn compat_none_uses_default_system_role() {
1920        let provider = OpenAIProvider::new("gpt-4o");
1921        let context = context_with_tools();
1922        let options = default_stream_options();
1923        let req = provider.build_request(&context, &options);
1924        let value = serde_json::to_value(&req).expect("serialize");
1925        assert_eq!(
1926            value["messages"][0]["role"], "system",
1927            "default system role should be 'system'"
1928        );
1929    }
1930
1931    #[test]
1932    fn compat_supports_tools_false_omits_tools() {
1933        let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1934            supports_tools: Some(false),
1935            ..Default::default()
1936        }));
1937        let context = context_with_tools();
1938        let options = default_stream_options();
1939        let req = provider.build_request(&context, &options);
1940        let value = serde_json::to_value(&req).expect("serialize");
1941        assert!(
1942            value["tools"].is_null(),
1943            "tools should be omitted when supports_tools=false"
1944        );
1945    }
1946
1947    #[test]
1948    fn compat_supports_tools_true_includes_tools() {
1949        let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
1950            supports_tools: Some(true),
1951            ..Default::default()
1952        }));
1953        let context = context_with_tools();
1954        let options = default_stream_options();
1955        let req = provider.build_request(&context, &options);
1956        let value = serde_json::to_value(&req).expect("serialize");
1957        assert!(
1958            value["tools"].is_array(),
1959            "tools should be included when supports_tools=true"
1960        );
1961    }
1962
1963    #[test]
1964    fn compat_max_tokens_field_routes_to_max_completion_tokens() {
1965        let provider = OpenAIProvider::new("o1").with_compat(Some(CompatConfig {
1966            max_tokens_field: Some("max_completion_tokens".to_string()),
1967            ..Default::default()
1968        }));
1969        let context = context_with_tools();
1970        let options = default_stream_options();
1971        let req = provider.build_request(&context, &options);
1972        let value = serde_json::to_value(&req).expect("serialize");
1973        assert!(
1974            value["max_tokens"].is_null(),
1975            "max_tokens should be absent when routed to max_completion_tokens"
1976        );
1977        assert_eq!(
1978            value["max_completion_tokens"], 1024,
1979            "max_completion_tokens should carry the token limit"
1980        );
1981    }
1982
1983    #[test]
1984    fn compat_default_routes_to_max_tokens() {
1985        let provider = OpenAIProvider::new("gpt-4o");
1986        let context = context_with_tools();
1987        let options = default_stream_options();
1988        let req = provider.build_request(&context, &options);
1989        let value = serde_json::to_value(&req).expect("serialize");
1990        assert_eq!(
1991            value["max_tokens"], 1024,
1992            "default should use max_tokens field"
1993        );
1994        assert!(
1995            value["max_completion_tokens"].is_null(),
1996            "max_completion_tokens should be absent by default"
1997        );
1998    }
1999
2000    #[test]
2001    fn compat_supports_usage_in_streaming_false() {
2002        let provider = OpenAIProvider::new("gpt-4o").with_compat(Some(CompatConfig {
2003            supports_usage_in_streaming: Some(false),
2004            ..Default::default()
2005        }));
2006        let context = context_with_tools();
2007        let options = default_stream_options();
2008        let req = provider.build_request(&context, &options);
2009        let value = serde_json::to_value(&req).expect("serialize");
2010        assert_eq!(
2011            value["stream_options"]["include_usage"], false,
2012            "include_usage should be false when supports_usage_in_streaming=false"
2013        );
2014    }
2015
2016    #[test]
2017    fn compat_combined_overrides() {
2018        let provider = OpenAIProvider::new("custom-model").with_compat(Some(CompatConfig {
2019            system_role_name: Some("developer".to_string()),
2020            max_tokens_field: Some("max_completion_tokens".to_string()),
2021            supports_tools: Some(false),
2022            supports_usage_in_streaming: Some(false),
2023            ..Default::default()
2024        }));
2025        let context = context_with_tools();
2026        let options = default_stream_options();
2027        let req = provider.build_request(&context, &options);
2028        let value = serde_json::to_value(&req).expect("serialize");
2029        assert_eq!(value["messages"][0]["role"], "developer");
2030        assert!(value["max_tokens"].is_null());
2031        assert_eq!(value["max_completion_tokens"], 1024);
2032        assert!(value["tools"].is_null());
2033        assert_eq!(value["stream_options"]["include_usage"], false);
2034    }
2035
2036    #[test]
2037    fn compat_custom_headers_injected_into_stream_request() {
2038        let mut custom = HashMap::new();
2039        custom.insert("X-Custom-Tag".to_string(), "test-123".to_string());
2040        custom.insert("X-Provider-Region".to_string(), "us-east-1".to_string());
2041        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
2042        let provider = OpenAIProvider::new("gpt-4o")
2043            .with_base_url(base_url)
2044            .with_compat(Some(CompatConfig {
2045                custom_headers: Some(custom),
2046                ..Default::default()
2047            }));
2048
2049        let context = Context {
2050            system_prompt: None,
2051            messages: vec![Message::User(crate::model::UserMessage {
2052                content: UserContent::Text("ping".to_string()),
2053                timestamp: 0,
2054            })]
2055            .into(),
2056            tools: Vec::new().into(),
2057        };
2058        let options = StreamOptions {
2059            api_key: Some("test-key".to_string()),
2060            ..Default::default()
2061        };
2062
2063        let runtime = RuntimeBuilder::current_thread()
2064            .build()
2065            .expect("runtime build");
2066        runtime.block_on(async {
2067            let mut stream = provider.stream(&context, &options).await.expect("stream");
2068            while let Some(event) = stream.next().await {
2069                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
2070                    break;
2071                }
2072            }
2073        });
2074
2075        let captured = rx
2076            .recv_timeout(Duration::from_secs(2))
2077            .expect("captured request");
2078        assert_eq!(
2079            captured.headers.get("x-custom-tag").map(String::as_str),
2080            Some("test-123"),
2081            "custom header should be present in request"
2082        );
2083        assert_eq!(
2084            captured
2085                .headers
2086                .get("x-provider-region")
2087                .map(String::as_str),
2088            Some("us-east-1"),
2089            "custom header should be present in request"
2090        );
2091    }
2092
2093    // ========================================================================
2094    // Proptest — process_event() fuzz coverage (FUZZ-P1.3)
2095    // ========================================================================
2096
2097    mod proptest_process_event {
2098        use super::*;
2099        use proptest::prelude::*;
2100
2101        fn make_state()
2102        -> StreamState<impl Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin>
2103        {
2104            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2105            let sse = crate::sse::SseStream::new(Box::pin(empty));
2106            StreamState::new(sse, "gpt-test".into(), "openai".into(), "openai".into())
2107        }
2108
2109        fn small_string() -> impl Strategy<Value = String> {
2110            prop_oneof![Just(String::new()), "[a-zA-Z0-9_]{1,16}", "[ -~]{0,32}",]
2111        }
2112
2113        fn optional_string() -> impl Strategy<Value = Option<String>> {
2114            prop_oneof![Just(None), small_string().prop_map(Some),]
2115        }
2116
2117        fn token_count() -> impl Strategy<Value = u64> {
2118            prop_oneof![
2119                5 => 0u64..10_000u64,
2120                2 => Just(0u64),
2121                1 => Just(u64::MAX),
2122                1 => (u64::MAX - 100)..=u64::MAX,
2123            ]
2124        }
2125
2126        fn finish_reason() -> impl Strategy<Value = Option<String>> {
2127            prop_oneof![
2128                3 => Just(None),
2129                1 => Just(Some("stop".to_string())),
2130                1 => Just(Some("length".to_string())),
2131                1 => Just(Some("tool_calls".to_string())),
2132                1 => Just(Some("content_filter".to_string())),
2133                1 => small_string().prop_map(Some),
2134            ]
2135        }
2136
2137        fn tool_call_index() -> impl Strategy<Value = u32> {
2138            prop_oneof![
2139                5 => 0u32..3u32,
2140                1 => Just(u32::MAX),
2141                1 => 100u32..200u32,
2142            ]
2143        }
2144
2145        /// Generate valid `OpenAIStreamChunk` JSON.
2146        fn openai_chunk_json() -> impl Strategy<Value = String> {
2147            prop_oneof![
2148                // Text content delta
2149                3 => (small_string(), finish_reason()).prop_map(|(text, fr)| {
2150                    let mut choice = serde_json::json!({
2151                        "delta": {"content": text}
2152                    });
2153                    if let Some(reason) = fr {
2154                        choice["finish_reason"] = serde_json::Value::String(reason);
2155                    }
2156                    serde_json::json!({"choices": [choice]}).to_string()
2157                }),
2158                // Empty delta (initial or heartbeat)
2159                2 => Just(r#"{"choices":[{"delta":{}}]}"#.to_string()),
2160                // Finish-only delta
2161                2 => finish_reason().prop_filter("some reason", Option::is_some).prop_map(|fr| {
2162                    serde_json::json!({
2163                        "choices": [{"delta": {}, "finish_reason": fr.unwrap()}]
2164                    })
2165                    .to_string()
2166                }),
2167                // Tool call delta
2168                3 => (tool_call_index(), optional_string(), optional_string(), optional_string())
2169                    .prop_map(|(idx, id, name, args)| {
2170                        let mut tc = serde_json::json!({"index": idx});
2171                        if let Some(id) = id { tc["id"] = serde_json::Value::String(id); }
2172                        let mut func = serde_json::Map::new();
2173                        if let Some(n) = name { func.insert("name".into(), serde_json::Value::String(n)); }
2174                        if let Some(a) = args { func.insert("arguments".into(), serde_json::Value::String(a)); }
2175                        if !func.is_empty() { tc["function"] = serde_json::Value::Object(func); }
2176                        serde_json::json!({
2177                            "choices": [{"delta": {"tool_calls": [tc]}}]
2178                        })
2179                        .to_string()
2180                    }),
2181                // Usage-only chunk (no choices)
2182                2 => (token_count(), token_count(), token_count()).prop_map(|(prompt, compl, total)| {
2183                    serde_json::json!({
2184                        "choices": [],
2185                        "usage": {
2186                            "prompt_tokens": prompt,
2187                            "completion_tokens": compl,
2188                            "total_tokens": total
2189                        }
2190                    })
2191                    .to_string()
2192                }),
2193                // Error chunk
2194                1 => small_string().prop_map(|msg| {
2195                    serde_json::json!({
2196                        "choices": [],
2197                        "error": {"message": msg}
2198                    })
2199                    .to_string()
2200                }),
2201                // Empty choices
2202                1 => Just(r#"{"choices":[]}"#.to_string()),
2203            ]
2204        }
2205
2206        /// Chaos — arbitrary JSON strings.
2207        fn chaos_json() -> impl Strategy<Value = String> {
2208            prop_oneof![
2209                Just(String::new()),
2210                Just("{}".to_string()),
2211                Just("[]".to_string()),
2212                Just("null".to_string()),
2213                Just("{".to_string()),
2214                Just(r#"{"choices":"not_array"}"#.to_string()),
2215                Just(r#"{"choices":[{"delta":null}]}"#.to_string()),
2216                "[a-z_]{1,20}".prop_map(|t| format!(r#"{{"type":"{t}"}}"#)),
2217                "[ -~]{0,64}",
2218            ]
2219        }
2220
2221        proptest! {
2222            #![proptest_config(ProptestConfig {
2223                cases: 256,
2224                max_shrink_iters: 100,
2225                .. ProptestConfig::default()
2226            })]
2227
2228            #[test]
2229            fn process_event_valid_never_panics(data in openai_chunk_json()) {
2230                let mut state = make_state();
2231                let _ = state.process_event(&data);
2232            }
2233
2234            #[test]
2235            fn process_event_chaos_never_panics(data in chaos_json()) {
2236                let mut state = make_state();
2237                let _ = state.process_event(&data);
2238            }
2239
2240            #[test]
2241            fn process_event_sequence_never_panics(
2242                events in prop::collection::vec(openai_chunk_json(), 1..8)
2243            ) {
2244                let mut state = make_state();
2245                for event in &events {
2246                    let _ = state.process_event(event);
2247                }
2248            }
2249
2250            #[test]
2251            fn process_event_mixed_sequence_never_panics(
2252                events in prop::collection::vec(
2253                    prop_oneof![openai_chunk_json(), chaos_json()],
2254                    1..12
2255                )
2256            ) {
2257                let mut state = make_state();
2258                for event in &events {
2259                    let _ = state.process_event(event);
2260                }
2261            }
2262        }
2263    }
2264}
2265
2266// ============================================================================
2267// Fuzzing support
2268// ============================================================================
2269
2270#[cfg(feature = "fuzzing")]
2271pub mod fuzz {
2272    use super::*;
2273    use futures::stream;
2274    use std::pin::Pin;
2275
2276    type FuzzStream =
2277        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
2278
2279    /// Opaque wrapper around the OpenAI stream processor state.
2280    pub struct Processor(StreamState<FuzzStream>);
2281
2282    impl Default for Processor {
2283        fn default() -> Self {
2284            Self::new()
2285        }
2286    }
2287
2288    impl Processor {
2289        /// Create a fresh processor with default state.
2290        pub fn new() -> Self {
2291            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
2292            Self(StreamState::new(
2293                crate::sse::SseStream::new(Box::pin(empty)),
2294                "gpt-fuzz".into(),
2295                "openai".into(),
2296                "openai".into(),
2297            ))
2298        }
2299
2300        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
2301        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
2302            self.0.process_event(data)?;
2303            Ok(self.0.pending_events.drain(..).collect())
2304        }
2305    }
2306}