Skip to main content

oxi_ai/providers/
openai.rs

1//! OpenAI-compatible provider implementation
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use futures::{Stream, StreamExt};
6use reqwest::Client;
7use serde::Deserialize;
8use serde_json::Value as JsonValue;
9use std::pin::Pin;
10
11use super::openai_responses_shared::parse_streaming_json;
12use super::shared_client;
13use crate::{
14    error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
15    ProviderEvent, StopReason, StreamOptions, TextContent, ThinkingContent, Usage,
16};
17
18/// Detect whether a model targets the ZAI provider.
19fn is_zai(model: &Model) -> bool {
20    model.provider.eq_ignore_ascii_case("zai") || model.base_url.contains("api.z.ai")
21}
22
23/// OpenAI-compatible provider
24#[derive(Clone)]
25pub struct OpenAiProvider {
26    client: &'static Client,
27    api_key: Option<String>,
28    base_url: Option<String>,
29}
30
31impl OpenAiProvider {
32    /// Create a new OpenAI provider without an API key.
33    ///
34    /// API keys are resolved at request time via auth.json or StreamOptions.
35    /// Use `with_api_key()` for explicit key injection.
36    pub fn new() -> Self {
37        Self {
38            client: shared_client(),
39            api_key: None,
40            base_url: None,
41        }
42    }
43
44    /// Create with explicit API key (public API for external consumers)
45    pub fn with_api_key(api_key: impl Into<String>) -> Self {
46        Self {
47            client: shared_client(),
48            api_key: Some(api_key.into()),
49            base_url: None,
50        }
51    }
52
53    /// Create with a custom base URL (API key resolved from auth storage).
54    ///
55    /// Used for built-in OpenAI-compatible providers like ZAI.
56    pub fn with_base_url(base_url: &str) -> Self {
57        Self {
58            client: shared_client(),
59            api_key: None,
60            base_url: Some(base_url.to_string()),
61        }
62    }
63
64    /// Create with a custom base URL and optional API key.
65    ///
66    /// Used for registering custom OpenAI-compatible providers (Minimax, ZAI, etc.).
67    pub fn with_base_url_and_key(base_url: &str, api_key: Option<String>) -> Self {
68        Self {
69            client: shared_client(),
70            api_key,
71            base_url: Some(base_url.to_string()),
72        }
73    }
74}
75
76impl Default for OpenAiProvider {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82#[async_trait]
83impl Provider for OpenAiProvider {
84    async fn stream(
85        &self,
86        model: &Model,
87        context: &Context,
88        options: Option<StreamOptions>,
89    ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
90        let options = options.unwrap_or_default();
91
92        // Build the request
93        let effective_base_url = self.base_url.as_deref().unwrap_or(&model.base_url);
94        let url = format!("{}/chat/completions", effective_base_url);
95
96        // Get API key
97        let api_key = options
98            .api_key
99            .as_ref()
100            .or(self.api_key.as_ref())
101            .ok_or_else(|| ProviderError::MissingApiKey)?;
102
103        // Build messages
104        let messages = build_messages(context)?;
105
106        // Build request body
107        let mut body = serde_json::json!({
108            "model": model.id,
109            "messages": messages,
110            "stream": true,
111            "stream_options": { "include_usage": true },
112        });
113
114        // Add optional parameters
115        if let Some(temp) = options.temperature {
116            body["temperature"] = serde_json::json!(temp);
117        }
118
119        if let Some(max) = options.max_tokens {
120            body["max_tokens"] = serde_json::json!(max);
121        }
122
123        // Add tools if present
124        if !context.tools.is_empty() {
125            body["tools"] = build_tools(&context.tools)?;
126        }
127
128        // ── ZAI-specific parameters ──────────────────────────────────
129        // Mirror pi's detectCompat: when provider is ZAI (or base_url contains
130        // api.z.ai), send enable_thinking and tool_stream.
131        if is_zai(model) {
132            if model.reasoning {
133                body["enable_thinking"] = serde_json::json!(true);
134            }
135            if !context.tools.is_empty() {
136                body["tool_stream"] = serde_json::json!(true);
137            }
138        }
139
140        tracing::info!(
141            "Sending request to {} model={} body_len={} enable_thinking={} tool_stream={}",
142            url,
143            model.id,
144            body.to_string().len(),
145            body.get("enable_thinking").is_some(),
146            body.get("tool_stream").is_some()
147        );
148        tracing::debug!("Request body: {}", body.to_string());
149
150        // Build headers
151        let mut headers = reqwest::header::HeaderMap::new();
152        headers.insert(
153            reqwest::header::AUTHORIZATION,
154            format!("Bearer {}", api_key)
155                .parse()
156                .expect("valid bearer header"),
157        );
158        headers.insert(
159            reqwest::header::CONTENT_TYPE,
160            "application/json".parse().expect("valid header value"),
161        );
162
163        for (k, v) in &options.headers {
164            if let (Ok(name), Ok(value)) = (
165                k.parse::<reqwest::header::HeaderName>(),
166                v.parse::<reqwest::header::HeaderValue>(),
167            ) {
168                headers.insert(name, value);
169            }
170        }
171
172        // Make request
173        let response = self
174            .client
175            .post(&url)
176            .headers(headers)
177            .json(&body)
178            .send()
179            .await
180            .map_err(ProviderError::RequestFailed)?;
181
182        if !response.status().is_success() {
183            let status = response.status();
184            let body: String = response.text().await.unwrap_or_default();
185            return Err(ProviderError::HttpError(status.as_u16(), body));
186        }
187
188        // Create event stream
189        let provider_name = model.provider.clone();
190        let model_id = model.id.clone();
191
192        // Emit Start event once at the beginning of the stream (matches pi's behavior)
193        let start_event = ProviderEvent::Start {
194            partial: AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
195        };
196
197        // Stateful stream parser that accumulates tool calls across chunks.
198        // OpenAI sends tool calls as multiple deltas (id, name, arguments fragments)
199        // that must be reassembled before emitting ToolCallEnd.
200        //
201        // State:
202        //   pending_bytes     – incomplete UTF-8 bytes from the previous HTTP chunk
203        //   pending_tc_index  – accumulated tool calls keyed by streaming index
204        //   pending_tc_id     – secondary lookup by tool-call ID (ZAI et al. may
205        //                       omit the index on continuation deltas)
206        //   thinking_started  – whether ThinkingStart has been emitted
207        let stream = response
208            .bytes_stream()
209            .scan(
210                (
211                    Vec::new(),
212                    std::collections::HashMap::<usize, (String, String, String)>::new(),
213                    std::collections::HashMap::<String, usize>::new(), // id → index
214                    false,
215                    AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
216                ),
217                move |(
218                    pending_bytes,
219                    pending_tc,
220                    tc_id_to_index,
221                    thinking_started,
222                    accumulated_output,
223                ),
224                      chunk: Result<Bytes, reqwest::Error>| {
225                    let events = match chunk {
226                        Ok(bytes) => {
227                            // Prepend any incomplete bytes from previous chunk
228                            let mut combined =
229                                Vec::with_capacity(pending_bytes.len() + bytes.len());
230                            combined.extend_from_slice(pending_bytes);
231                            combined.extend_from_slice(&bytes);
232
233                            // Split into complete lines (ending with \n) and trailing incomplete data.
234                            // This prevents JSON parse failures from partial SSE lines
235                            // that were split across HTTP chunks.
236                            let (text, trailing) = split_complete_lines(&combined);
237                            *pending_bytes = trailing;
238
239                            tracing::debug!(
240                                "parse_sse_events input: {} bytes, {} lines",
241                                text.len(),
242                                text.lines().count()
243                            );
244                            let raw_events = parse_sse_events(
245                                &text,
246                                &provider_name,
247                                &model_id,
248                                accumulated_output,
249                            );
250                            tracing::debug!("parse_sse_events output: {} events", raw_events.len());
251
252                            // Post-process: accumulate tool call deltas, inject ThinkingStart once
253                            let mut processed = Vec::new();
254                            for event in raw_events {
255                                match &event {
256                                    ProviderEvent::ThinkingDelta { content_index, .. } => {
257                                        // Inject ThinkingStart before the first ThinkingDelta
258                                        if !*thinking_started {
259                                            *thinking_started = true;
260                                            processed.push(ProviderEvent::ThinkingStart {
261                                                content_index: *content_index,
262                                                partial: AssistantMessage::new(
263                                                    Api::OpenAiCompletions,
264                                                    &provider_name,
265                                                    &model_id,
266                                                ),
267                                            });
268                                        }
269                                        processed.push(event);
270                                    }
271                                    ProviderEvent::ToolCallStart {
272                                        content_index,
273                                        tool_call_id,
274                                        tool_name,
275                                        ..
276                                    } => {
277                                        let entry =
278                                            pending_tc.entry(*content_index).or_insert_with(|| {
279                                                (String::new(), String::new(), String::new())
280                                            });
281                                        if let Some(ref id) = tool_call_id {
282                                            if !id.is_empty() {
283                                                entry.0 = id.clone();
284                                                tc_id_to_index.insert(id.clone(), *content_index);
285                                            }
286                                        }
287                                        if let Some(ref name) = tool_name {
288                                            if !name.is_empty() {
289                                                entry.1 = name.clone();
290                                            }
291                                        }
292                                        processed.push(event);
293                                    }
294                                    ProviderEvent::ToolCallDelta {
295                                        content_index,
296                                        delta,
297                                        ..
298                                    } => {
299                                        // Dual-map lookup: prefer index, fall back to ID
300                                        let idx = if pending_tc.contains_key(content_index) {
301                                            *content_index
302                                        } else {
303                                            // Scan id→index map for a match
304                                            tc_id_to_index
305                                                .values()
306                                                .copied()
307                                                .find(|i| *i == *content_index)
308                                                .unwrap_or(*content_index)
309                                        };
310                                        let entry = pending_tc.entry(idx).or_insert_with(|| {
311                                            (String::new(), String::new(), String::new())
312                                        });
313                                        tracing::debug!(
314                                            "[TC-DELTA] idx={}, delta_len={}, accumulated_len={}",
315                                            idx,
316                                            delta.len(),
317                                            entry.2.len() + delta.len()
318                                        );
319                                        entry.2.push_str(delta);
320                                        processed.push(event);
321                                    }
322                                    ProviderEvent::ToolCallEnd { .. } => {
323                                        // Already a ToolCallEnd from parse_sse_events
324                                        processed.push(event);
325                                    }
326                                    ProviderEvent::Done { reason, .. } => {
327                                        // Before Done, emit ToolCallEnd for all accumulated tool calls
328                                        if matches!(reason, StopReason::ToolUse) {
329                                            let mut indices: Vec<usize> =
330                                                pending_tc.keys().copied().collect();
331                                            indices.sort();
332                                            for idx in indices {
333                                                let (id, name, arguments) = &pending_tc[&idx];
334                                                tracing::debug!(
335                                                    "[TC-END] idx={}, id={}, name={}, args_len={}",
336                                                    idx,
337                                                    id.len(),
338                                                    name.len(),
339                                                    arguments.len()
340                                                );
341                                                let args_value = parse_streaming_json(arguments);
342                                                processed.push(ProviderEvent::ToolCallEnd {
343                                                    content_index: idx,
344                                                    tool_call: crate::ToolCall {
345                                                        content_type:
346                                                            crate::messages::ToolCallType::ToolCall,
347                                                        id: id.clone(),
348                                                        name: name.clone(),
349                                                        arguments: args_value,
350                                                        thought_signature: None,
351                                                    },
352                                                    partial: AssistantMessage::new(
353                                                        Api::OpenAiCompletions,
354                                                        &provider_name,
355                                                        &model_id,
356                                                    ),
357                                                });
358                                            }
359                                        }
360                                        // Clear pending_tc for the next stream/turn.
361                                        // Without this, tool call arguments from the previous
362                                        // turn leak into the next turn's accumulation.
363                                        pending_tc.clear();
364                                        tc_id_to_index.clear();
365                                        processed.push(event);
366                                    }
367                                    _ => {
368                                        processed.push(event);
369                                    }
370                                }
371                            }
372                            processed
373                        }
374                        Err(e) => {
375                            vec![ProviderEvent::Error {
376                                reason: StopReason::Error,
377                                error: create_error_message(
378                                    &e.to_string(),
379                                    &provider_name,
380                                    &model_id,
381                                ),
382                            }]
383                        }
384                    };
385                    // Return Some to continue, wrap events in an iterator
386                    async move { Some(futures::stream::iter(events)) }
387                },
388            )
389            .flatten();
390
391        // Prepend Start event to the stream
392        let stream_with_start = futures::stream::once(async move { start_event }).chain(stream);
393        Ok(Box::pin(stream_with_start))
394    }
395
396    fn name(&self) -> &str {
397        "openai"
398    }
399}
400
401/// Build messages array from context
402fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
403    let mut messages = Vec::new();
404
405    // System prompt
406    if let Some(ref prompt) = context.system_prompt {
407        messages.push(serde_json::json!({
408            "role": "system",
409            "content": prompt,
410        }));
411    }
412
413    // Conversation messages
414    for msg in &context.messages {
415        match msg {
416            crate::Message::User(u) => {
417                let content: String = match &u.content {
418                    crate::MessageContent::Text(s) => s.clone(),
419                    crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
420                };
421                messages.push(serde_json::json!({
422                    "role": "user",
423                    "content": content,
424                }));
425            }
426            crate::Message::Assistant(a) => {
427                // OpenAI format: separate content (text) and tool_calls
428                let mut text_parts = Vec::new();
429                let mut tool_calls = Vec::new();
430                for block in &a.content {
431                    match block {
432                        ContentBlock::Text(t) => {
433                            text_parts.push(t.text.clone());
434                        }
435                        ContentBlock::Thinking(_) => {
436                            // Skip thinking blocks in message history
437                        }
438                        ContentBlock::ToolCall(tc) => {
439                            tool_calls.push(serde_json::json!({
440                                "id": tc.id,
441                                "type": "function",
442                                "function": {
443                                    "name": tc.name,
444                                    "arguments": tc.arguments.to_string(),
445                                },
446                            }));
447                        }
448                        ContentBlock::Image(_) | ContentBlock::Unknown(_) => {}
449                    }
450                }
451                let mut msg = serde_json::json!({
452                    "role": "assistant",
453                    "content": text_parts.join(""),
454                });
455                if !tool_calls.is_empty() {
456                    msg["tool_calls"] = serde_json::json!(tool_calls);
457                }
458                messages.push(msg);
459            }
460            crate::Message::ToolResult(t) => {
461                let result_text: String = t
462                    .content
463                    .iter()
464                    .filter_map(|b| b.as_text())
465                    .collect::<Vec<_>>()
466                    .join("");
467                messages.push(serde_json::json!({
468                    "role": "tool",
469                    "tool_call_id": t.tool_call_id,
470                    "content": result_text,
471                }));
472            }
473        }
474    }
475
476    Ok(messages)
477}
478
479/// Convert content blocks to a string representation
480fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
481    if blocks.len() == 1 {
482        if let Some(text) = blocks[0].as_text() {
483            return Ok(JsonValue::String(text.to_string()));
484        }
485    }
486
487    let items: Result<Vec<_>, _> = blocks
488        .iter()
489        .map(|block| match block {
490            ContentBlock::Text(t) => Ok(serde_json::json!({
491                "type": "text",
492                "text": t.text,
493            })),
494            ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
495                "type": "function",
496                "id": tc.id,
497                "function": {
498                    "name": tc.name,
499                    "arguments": tc.arguments.to_string(),
500                },
501            })),
502            ContentBlock::Thinking(th) => Ok(serde_json::json!({
503                "type": "thinking",
504                "thinking": th.thinking,
505            })),
506            ContentBlock::Image(img) => Ok(serde_json::json!({
507                "type": "image_url",
508                "image_url": {
509                    "url": format!("data:{};base64,{}", img.mime_type, img.data),
510                },
511            })),
512            ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
513                "Unknown content block type".into(),
514            )),
515        })
516        .collect();
517
518    Ok(serde_json::json!(items?))
519}
520
521/// Build tools array
522fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
523    let items: Vec<_> = tools
524        .iter()
525        .map(|tool| {
526            serde_json::json!({
527                "type": "function",
528                "function": {
529                    "name": tool.name,
530                    "description": tool.description,
531                    "parameters": tool.parameters,
532                },
533            })
534        })
535        .collect();
536
537    Ok(serde_json::json!(items))
538}
539
540/// Extract the longest valid UTF-8 prefix from a byte slice.
541///
542/// Returns the valid string and the trailing bytes that form an incomplete UTF-8
543/// sequence. These trailing bytes should be prepended to the next chunk to
544/// ensure no characters are lost at HTTP chunk boundaries.
545fn find_valid_utf8_prefix(bytes: &[u8]) -> (String, Vec<u8>) {
546    match std::str::from_utf8(bytes) {
547        Ok(s) => (s.to_string(), Vec::new()),
548        Err(e) => {
549            let valid = &bytes[..e.valid_up_to()];
550            let trailing = bytes[e.valid_up_to()..].to_vec();
551            (String::from_utf8_lossy(valid).to_string(), trailing)
552        }
553    }
554}
555
556/// Split bytes into complete lines (ending with \n) and trailing incomplete data.
557/// This ensures `parse_sse_events` only receives complete SSE `data:` lines,
558/// preventing JSON parse failures from lines split across HTTP chunks.
559pub fn split_complete_lines(bytes: &[u8]) -> (String, Vec<u8>) {
560    // Find the last newline — everything up to and including it is complete.
561    match bytes.iter().rposition(|&b| b == b'\n') {
562        Some(last_nl) => {
563            let split_at = last_nl + 1;
564            let complete = match std::str::from_utf8(&bytes[..split_at]) {
565                Ok(s) => s.to_string(),
566                Err(_) => {
567                    let (s, _) = find_valid_utf8_prefix(&bytes[..split_at]);
568                    s
569                }
570            };
571            let trailing = bytes[split_at..].to_vec();
572            (complete, trailing)
573        }
574        None => {
575            // No newline at all — the entire buffer is incomplete.
576            // Check if it's valid UTF-8; if not, save as pending.
577            (String::new(), bytes.to_vec())
578        }
579    }
580}
581
582/// Parse SSE event stream from a byte buffer.
583///
584/// Optimizations over a naïve implementation:
585/// - **Fast-line splitting** – iterates over `\n` boundaries via `split`
586///   instead of allocating an intermediate `String` per line.
587/// - **Early `DONE` exit** – breaks immediately when `data: [DONE]` is
588///   encountered.
589/// - **Pre-allocated events** – reserves capacity based on data-line count.
590/// - **Accumulated usage** – tracks usage separately, only cloning into
591///   the Done message at stream end, not on every chunk.
592fn parse_sse_events(
593    text: &str,
594    _provider: &str,
595    _model_id: &str,
596    output: &mut AssistantMessage,
597) -> Vec<ProviderEvent> {
598    let mut events = Vec::new();
599
600    // Pre-estimate capacity: one event per data line is a reasonable upper bound.
601    let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
602    events.reserve(estimated_events);
603
604    let mut accumulated_usage = Usage::default();
605
606    for line in text.split('\n') {
607        let line = line.trim_end_matches('\r');
608        if line.is_empty() {
609            continue;
610        }
611
612        // Fast rejection for non-data lines (comments, event tags, etc.)
613        if !line.starts_with("data: ") {
614            continue;
615        }
616
617        let data = &line[6..]; // skip "data: "
618
619        // Early exit on stream end
620        if data == "[DONE]" {
621            break;
622        }
623
624        if data.is_empty() {
625            continue;
626        }
627
628        let chunk = match serde_json::from_str::<SSEChunk>(data) {
629            Ok(c) => c,
630            Err(_) => continue,
631        };
632
633        // ── Accumulate usage BEFORE processing choices ────────────────
634        // OpenAI with include_usage sends usage in a final chunk with
635        // empty choices. By accumulating before the choice loop, the
636        // Done event (triggered by finish_reason in an earlier chunk)
637        // and any subsequent rendering sees the latest usage.
638        if let Some(chunk_usage) = &chunk.usage {
639            accumulated_usage.input = chunk_usage.prompt_tokens.max(accumulated_usage.input);
640            accumulated_usage.output = chunk_usage.completion_tokens.max(accumulated_usage.output);
641            accumulated_usage.cache_read = chunk_usage
642                .prompt_tokens_details
643                .as_ref()
644                .map(|d| d.cached_tokens)
645                .unwrap_or(0)
646                .max(accumulated_usage.cache_read);
647            accumulated_usage.total_tokens =
648                chunk_usage.total_tokens.max(accumulated_usage.total_tokens);
649        }
650
651        for choice in &chunk.choices {
652            if let Some(delta) = &choice.delta {
653                if let Some(content) = &delta.content {
654                    // pi-mono: append to the output's text block
655                    let last_text_idx = output
656                        .content
657                        .iter()
658                        .rposition(|b| matches!(b, ContentBlock::Text(_)));
659                    if let Some(idx) = last_text_idx {
660                        if let ContentBlock::Text(t) = &mut output.content[idx] {
661                            t.text.push_str(content);
662                        }
663                    } else {
664                        output
665                            .content
666                            .push(ContentBlock::Text(TextContent::new(content.clone())));
667                    }
668                    events.push(ProviderEvent::TextDelta {
669                        content_index: choice.index,
670                        delta: content.clone(),
671                        partial: output.clone(),
672                    });
673                }
674
675                // Handle GLM's reasoning_content field (thinking/thought chain)
676                if let Some(ref reasoning) = delta.reasoning_content {
677                    if !reasoning.is_empty() {
678                        // pi-mono: append to the output's thinking block
679                        let last_think_idx = output
680                            .content
681                            .iter()
682                            .rposition(|b| matches!(b, ContentBlock::Thinking(_)));
683                        if let Some(idx) = last_think_idx {
684                            if let ContentBlock::Thinking(t) = &mut output.content[idx] {
685                                t.thinking.push_str(reasoning);
686                            }
687                        } else {
688                            output
689                                .content
690                                .push(ContentBlock::Thinking(ThinkingContent::new(
691                                    reasoning.clone(),
692                                )));
693                        }
694                        events.push(ProviderEvent::ThinkingDelta {
695                            content_index: choice.index,
696                            delta: reasoning.clone(),
697                            partial: output.clone(),
698                        });
699                    }
700                }
701
702                if let Some(tool_calls) = &delta.tool_calls {
703                    for tc in tool_calls {
704                        let tc_index = tc.index.unwrap_or(choice.index);
705
706                        // Emit ToolCallStart when id or name is present (first delta)
707                        if tc.id.is_some()
708                            || tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()
709                        {
710                            events.push(ProviderEvent::ToolCallStart {
711                                content_index: tc_index,
712                                tool_call_id: tc.id.clone(),
713                                tool_name: tc.function.as_ref().and_then(|f| f.name.clone()),
714                                partial: output.clone(),
715                            });
716                        }
717
718                        // Emit ToolCallDelta for arguments
719                        if let Some(func) = &tc.function {
720                            events.push(ProviderEvent::ToolCallDelta {
721                                content_index: tc_index,
722                                delta: func.arguments.clone().unwrap_or_default(),
723                                partial: output.clone(),
724                            });
725                        }
726                    }
727                }
728            }
729
730            if choice.finish_reason.is_some() {
731                let reason = match choice.finish_reason.as_deref() {
732                    Some("stop") | Some("end") => StopReason::Stop,
733                    Some("length") => StopReason::Length,
734                    Some("tool_calls") | Some("function_call") => StopReason::ToolUse,
735                    Some("content_filter") => StopReason::Error,
736                    Some(unknown) => {
737                        tracing::warn!("Unknown finish_reason: '{}', treating as Error", unknown);
738                        StopReason::Error
739                    }
740                    None => StopReason::Stop,
741                };
742                tracing::info!("finish_reason={:?} → {:?}", choice.finish_reason, reason);
743
744                let mut done_msg = output.clone();
745                done_msg.stop_reason = reason;
746                done_msg.usage = accumulated_usage.clone();
747                events.push(ProviderEvent::Done {
748                    reason,
749                    message: done_msg,
750                });
751            }
752        }
753    }
754
755    events
756}
757
758/// Create error assistant message
759fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
760    let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
761    message.stop_reason = StopReason::Error;
762    message.error_message = Some(msg.to_string());
763    message
764}
765
766// SSE chunk structure
767#[derive(Debug, Deserialize)]
768// serde deserialization structs
769struct SSEChunk {
770    _id: Option<String>,
771    #[serde(rename = "model")]
772    _model: Option<String>,
773    choices: Vec<Choice>,
774    usage: Option<UsageInfo>,
775}
776
777#[derive(Debug, Deserialize)]
778// serde deserialization structs
779struct Choice {
780    index: usize,
781    delta: Option<Delta>,
782    finish_reason: Option<String>,
783}
784
785#[derive(Debug, Deserialize)]
786struct Delta {
787    content: Option<String>,
788    reasoning_content: Option<String>,
789    tool_calls: Option<Vec<ToolCallDelta>>,
790}
791
792#[derive(Debug, Deserialize)]
793// serde deserialization structs
794struct ToolCallDelta {
795    index: Option<usize>,
796    id: Option<String>,
797    #[serde(rename = "type")]
798    _type_: Option<String>,
799    function: Option<FunctionDelta>,
800}
801
802#[derive(Debug, Deserialize)]
803// serde deserialization structs
804struct FunctionDelta {
805    name: Option<String>,
806    arguments: Option<String>,
807}
808
809#[derive(Debug, Deserialize, Clone)]
810struct UsageInfo {
811    prompt_tokens: usize,
812    completion_tokens: usize,
813    total_tokens: usize,
814    #[serde(rename = "prompt_tokens_details")]
815    prompt_tokens_details: Option<PromptTokensDetails>,
816}
817
818#[derive(Debug, Deserialize, Clone)]
819struct PromptTokensDetails {
820    #[serde(rename = "cached_tokens")]
821    cached_tokens: usize,
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827
828    const PROVIDER: &str = "openai";
829    const MODEL: &str = "gpt-4o";
830
831    fn parse_sse(sse: &str) -> Vec<ProviderEvent> {
832        let mut output = AssistantMessage::new(Api::OpenAiCompletions, PROVIDER, MODEL);
833        parse_sse_events(sse, PROVIDER, MODEL, &mut output)
834    }
835
836    // ── SSE event parsing ──────────────────────────────────────────────
837
838    #[test]
839    fn parse_single_text_event() {
840        let sse = "data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n";
841        let events = parse_sse(sse);
842        assert_eq!(events.len(), 1);
843        match &events[0] {
844            ProviderEvent::TextDelta {
845                delta,
846                content_index,
847                ..
848            } => {
849                assert_eq!(delta, "Hello");
850                assert_eq!(*content_index, 0);
851            }
852            other => panic!("expected TextDelta, got {other:?}"),
853        }
854    }
855
856    #[test]
857    fn parse_multiple_text_events() {
858        let sse = concat!(
859            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hel\"}}]}\n",
860            "\n",
861            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"lo!\"}}]}\n",
862            "\n"
863        );
864        let events = parse_sse(sse);
865        assert_eq!(events.len(), 2);
866        let texts: Vec<&str> = events
867            .iter()
868            .filter_map(|e| match e {
869                ProviderEvent::TextDelta { delta, .. } => Some(delta.as_str()),
870                _ => None,
871            })
872            .collect();
873        assert_eq!(texts, vec!["Hel", "lo!"]);
874    }
875
876    #[test]
877    fn parse_done_terminator() {
878        let sse = concat!(
879            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n",
880            "\n",
881            "data: [DONE]\n",
882            "\n",
883            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"NEVER\"}}]}\n"
884        );
885        let events = parse_sse(sse);
886        // Should stop at [DONE]; the final data line is never parsed
887        assert_eq!(events.len(), 1);
888        match &events[0] {
889            ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "X"),
890            other => panic!("expected TextDelta, got {other:?}"),
891        }
892    }
893
894    // ── Content extraction ─────────────────────────────────────────────
895
896    #[test]
897    fn parse_finish_reason_stop() {
898        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n\n";
899        let events = parse_sse(sse);
900        assert_eq!(events.len(), 1);
901        match &events[0] {
902            ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Stop)),
903            other => panic!("expected Done, got {other:?}"),
904        }
905    }
906
907    #[test]
908    fn parse_finish_reason_length() {
909        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"length\"}]}\n\n";
910        let events = parse_sse(sse);
911        match &events[0] {
912            ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Length)),
913            other => panic!("expected Done with Length, got {other:?}"),
914        }
915    }
916
917    #[test]
918    fn parse_finish_reason_tool_calls() {
919        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n\n";
920        let events = parse_sse(sse);
921        match &events[0] {
922            ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::ToolUse)),
923            other => panic!("expected Done with ToolUse, got {other:?}"),
924        }
925    }
926
927    // ── Tool call delta accumulation ───────────────────────────────────
928
929    #[test]
930    fn parse_tool_call_deltas() {
931        let sse = concat!(
932            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]}}]}\n",
933            "\n",
934            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"city\\\":\\\"SF\\\"}\"}}]}}]}\n",
935            "\n"
936        );
937        let events = parse_sse(sse);
938        // First chunk: ToolCallStart (id+name present) + ToolCallDelta (function present)
939        // Second chunk: ToolCallDelta only
940        assert_eq!(events.len(), 3);
941        let starts: Vec<&str> = events
942            .iter()
943            .filter_map(|e| match e {
944                ProviderEvent::ToolCallStart { tool_name, .. } => tool_name.as_deref(),
945                _ => None,
946            })
947            .collect();
948        assert_eq!(starts, vec!["get_weather"]);
949        let deltas: Vec<&str> = events
950            .iter()
951            .filter_map(|e| match e {
952                ProviderEvent::ToolCallDelta { delta, .. } => Some(delta.as_str()),
953                _ => None,
954            })
955            .collect();
956        assert_eq!(deltas, vec!["", "{\"city\":\"SF\"}"]);
957    }
958
959    #[test]
960    fn parse_tool_call_with_no_arguments_field() {
961        // function field present but arguments is null → emits ToolCallStart + ToolCallDelta
962        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"name\":\"run\"}}]}}]}\n\n";
963        let events = parse_sse(sse);
964        assert_eq!(events.len(), 2);
965        match &events[0] {
966            ProviderEvent::ToolCallStart { tool_name, .. } => {
967                assert_eq!(tool_name.as_deref(), Some("run"));
968            }
969            other => panic!("expected ToolCallStart, got {other:?}"),
970        }
971        match &events[1] {
972            ProviderEvent::ToolCallDelta { delta, .. } => assert_eq!(delta, ""),
973            other => panic!("expected ToolCallDelta, got {other:?}"),
974        }
975    }
976
977    // ── Usage accumulation ─────────────────────────────────────────────
978
979    #[test]
980    fn parse_usage_in_chunk() {
981        // Usage is accumulated from earlier chunks; the Done event captures
982        // usage that was accumulated *before* the finish_reason chunk.
983        let sse = concat!(
984            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":8,\"total_tokens\":18,\"prompt_tokens_details\":{\"cached_tokens\":3}}}\n",
985            "\n",
986            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
987        );
988        let events = parse_sse(sse);
989        // TextDelta + Done
990        assert_eq!(events.len(), 2);
991        match &events[1] {
992            ProviderEvent::Done { message, .. } => {
993                assert_eq!(message.usage.input, 10);
994                assert_eq!(message.usage.output, 8);
995                assert_eq!(message.usage.total_tokens, 18);
996                assert_eq!(message.usage.cache_read, 3);
997            }
998            other => panic!("expected Done, got {other:?}"),
999        }
1000    }
1001
1002    #[test]
1003    fn parse_usage_without_cache_details() {
1004        // Usage from an earlier chunk; Done event on a separate chunk without usage.
1005        let sse = concat!(
1006            "data: {\"id\":\"c\",\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n",
1007            "\n",
1008            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
1009        );
1010        let events = parse_sse(sse);
1011        match &events[0] {
1012            ProviderEvent::Done { message, .. } => {
1013                assert_eq!(message.usage.input, 5);
1014                assert_eq!(message.usage.output, 2);
1015                assert_eq!(message.usage.cache_read, 0);
1016            }
1017            other => panic!("expected Done, got {other:?}"),
1018        }
1019    }
1020
1021    // ── Empty / malformed handling ─────────────────────────────────────
1022
1023    #[test]
1024    fn parse_empty_input() {
1025        let events = parse_sse("");
1026        assert!(events.is_empty());
1027    }
1028
1029    #[test]
1030    fn parse_only_empty_lines() {
1031        let events = parse_sse("\n\n\n");
1032        assert!(events.is_empty());
1033    }
1034
1035    #[test]
1036    fn parse_malformed_json_after_data() {
1037        let sse = "data: {not json at all}\ndata: also bad\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n";
1038        let events = parse_sse(sse);
1039        // Malformed lines are skipped, only the valid one emits
1040        assert_eq!(events.len(), 1);
1041        match &events[0] {
1042            ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "ok"),
1043            other => panic!("expected TextDelta, got {other:?}"),
1044        }
1045    }
1046
1047    #[test]
1048    fn parse_empty_data_line() {
1049        let sse = "data: \ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n";
1050        let events = parse_sse(sse);
1051        assert_eq!(events.len(), 1);
1052    }
1053
1054    #[test]
1055    fn parse_non_data_lines_ignored() {
1056        let sse = "event: ping\nid: 42\nretry: 5000\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Y\"}}]}\n";
1057        let events = parse_sse(sse);
1058        assert_eq!(events.len(), 1);
1059    }
1060
1061    #[test]
1062    fn parse_carriage_return_line_endings() {
1063        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"CR\"}}]}\r\n\r\n";
1064        let events = parse_sse(sse);
1065        assert_eq!(events.len(), 1);
1066        match &events[0] {
1067            ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "CR"),
1068            other => panic!("expected TextDelta, got {other:?}"),
1069        }
1070    }
1071
1072    // ── Mixed content + tool + done ────────────────────────────────────
1073
1074    #[test]
1075    fn parse_full_stream_with_text_tool_and_done() {
1076        let sse = concat!(
1077            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Let me\"}}]}\n",
1078            "\n",
1079            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" check\"}}]}\n",
1080            "\n",
1081            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"rust\\\"}\"}}]}}]}\n",
1082            "\n",
1083            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n",
1084            "\n",
1085            "data: [DONE]\n"
1086        );
1087        let events = parse_sse(sse);
1088        assert_eq!(events.len(), 5); // 2 TextDelta + ToolCallStart + ToolCallDelta + Done
1089
1090        let mut text_count = 0;
1091        let mut tc_start_count = 0;
1092        let mut tc_delta_count = 0;
1093        let mut done_count = 0;
1094        for e in &events {
1095            match e {
1096                ProviderEvent::TextDelta { .. } => text_count += 1,
1097                ProviderEvent::ToolCallStart { .. } => tc_start_count += 1,
1098                ProviderEvent::ToolCallDelta { .. } => tc_delta_count += 1,
1099                ProviderEvent::Done { reason, .. } => {
1100                    done_count += 1;
1101                    assert!(matches!(reason, StopReason::ToolUse));
1102                }
1103                other => panic!("unexpected event: {other:?}"),
1104            }
1105        }
1106        assert_eq!(text_count, 2);
1107        assert_eq!(tc_start_count, 1);
1108        assert_eq!(tc_delta_count, 1);
1109        assert_eq!(done_count, 1);
1110    }
1111}