Skip to main content

oxi_ai/providers/
openai.rs

1//! OpenAI-compatible provider implementation
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use futures::{Stream, StreamExt};
6use reqwest::Client;
7use serde::Deserialize;
8use serde_json::Value as JsonValue;
9use std::pin::Pin;
10
11use super::openai_responses_shared::parse_streaming_json;
12use super::shared_client;
13use crate::{
14    error::ProviderError, Api, AssistantMessage, ContentBlock, Context, Model, Provider,
15    ProviderEvent, StopReason, StreamOptions, TextContent, ThinkingContent, Usage,
16};
17
18/// Detect whether a model targets the ZAI provider.
19fn is_zai(model: &Model) -> bool {
20    model.provider.eq_ignore_ascii_case("zai") || model.base_url.contains("api.z.ai")
21}
22
23/// OpenAI-compatible provider
24#[derive(Clone)]
25pub struct OpenAiProvider {
26    client: &'static Client,
27    api_key: Option<String>,
28    base_url: Option<String>,
29}
30
31impl OpenAiProvider {
32    /// Create a new OpenAI provider without an API key.
33    ///
34    /// API keys are resolved at request time via auth.json or StreamOptions.
35    /// Use `with_api_key()` for explicit key injection.
36    pub fn new() -> Self {
37        Self {
38            client: shared_client(),
39            api_key: None,
40            base_url: None,
41        }
42    }
43
44    /// Create with explicit API key (public API for external consumers)
45    pub fn with_api_key(api_key: impl Into<String>) -> Self {
46        Self {
47            client: shared_client(),
48            api_key: Some(api_key.into()),
49            base_url: None,
50        }
51    }
52
53    /// Create with a custom base URL (API key resolved from auth storage).
54    ///
55    /// Used for built-in OpenAI-compatible providers like ZAI.
56    pub fn with_base_url(base_url: &str) -> Self {
57        Self {
58            client: shared_client(),
59            api_key: None,
60            base_url: Some(base_url.to_string()),
61        }
62    }
63
64    /// Create with a custom base URL and optional API key.
65    ///
66    /// Used for registering custom OpenAI-compatible providers (Minimax, ZAI, etc.).
67    pub fn with_base_url_and_key(base_url: &str, api_key: Option<String>) -> Self {
68        Self {
69            client: shared_client(),
70            api_key,
71            base_url: Some(base_url.to_string()),
72        }
73    }
74}
75
76impl Default for OpenAiProvider {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82#[async_trait]
83impl Provider for OpenAiProvider {
84    async fn stream(
85        &self,
86        model: &Model,
87        context: &Context,
88        options: Option<StreamOptions>,
89    ) -> Result<Pin<Box<dyn Stream<Item = ProviderEvent> + Send>>, ProviderError> {
90        let options = options.unwrap_or_default();
91
92        // Build the request
93        let effective_base_url = self.base_url.as_deref().unwrap_or(&model.base_url);
94        let url = format!("{}/chat/completions", effective_base_url);
95
96        // Get API key
97        let api_key = options
98            .api_key
99            .as_ref()
100            .or(self.api_key.as_ref())
101            .ok_or_else(|| ProviderError::MissingApiKey)?;
102
103        // Build messages
104        let messages = build_messages(context)?;
105
106        // Build request body
107        let mut body = serde_json::json!({
108            "model": model.id,
109            "messages": messages,
110            "stream": true,
111            "stream_options": { "include_usage": true },
112        });
113
114        // Add optional parameters
115        if let Some(temp) = options.temperature {
116            body["temperature"] = serde_json::json!(temp);
117        }
118
119        if let Some(max) = options.max_tokens {
120            body["max_tokens"] = serde_json::json!(max);
121        }
122
123        // Add tools if present
124        if !context.tools.is_empty() {
125            body["tools"] = build_tools(&context.tools)?;
126        }
127
128        // ── ZAI-specific parameters ──────────────────────────────────
129        // Mirror pi's detectCompat: when provider is ZAI (or base_url contains
130        // api.z.ai), send enable_thinking and tool_stream.
131        if is_zai(model) {
132            if model.reasoning {
133                body["enable_thinking"] = serde_json::json!(true);
134            }
135            if !context.tools.is_empty() {
136                body["tool_stream"] = serde_json::json!(true);
137            }
138        }
139
140        tracing::info!(
141            "Sending request to {} model={} body_len={} enable_thinking={} tool_stream={}",
142            url,
143            model.id,
144            body.to_string().len(),
145            body.get("enable_thinking").is_some(),
146            body.get("tool_stream").is_some()
147        );
148        tracing::debug!("Request body: {}", body.to_string());
149
150        // Build headers
151        let mut headers = reqwest::header::HeaderMap::new();
152        headers.insert(
153            reqwest::header::AUTHORIZATION,
154            format!("Bearer {}", api_key)
155                .parse()
156                .expect("valid bearer header"),
157        );
158        headers.insert(
159            reqwest::header::CONTENT_TYPE,
160            "application/json".parse().expect("valid header value"),
161        );
162
163        for (k, v) in &options.headers {
164            if let (Ok(name), Ok(value)) = (
165                k.parse::<reqwest::header::HeaderName>(),
166                v.parse::<reqwest::header::HeaderValue>(),
167            ) {
168                headers.insert(name, value);
169            }
170        }
171
172        // Make request
173        let response = self
174            .client
175            .post(&url)
176            .headers(headers)
177            .json(&body)
178            .send()
179            .await
180            .map_err(ProviderError::RequestFailed)?;
181
182        if !response.status().is_success() {
183            let status = response.status();
184            let body: String = response.text().await.unwrap_or_default();
185            return Err(ProviderError::HttpError(status.as_u16(), body));
186        }
187
188        // Create event stream
189        let provider_name = model.provider.clone();
190        let model_id = model.id.clone();
191
192        // Emit Start event once at the beginning of the stream (matches pi's behavior)
193        let start_event = ProviderEvent::Start {
194            partial: AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
195        };
196
197        // Stateful stream parser that accumulates tool calls across chunks.
198        // OpenAI sends tool calls as multiple deltas (id, name, arguments fragments)
199        // that must be reassembled before emitting ToolCallEnd.
200        //
201        // State:
202        //   pending_bytes     – incomplete UTF-8 bytes from the previous HTTP chunk
203        //   pending_tc_index  – accumulated tool calls keyed by streaming index
204        //   pending_tc_id     – secondary lookup by tool-call ID (ZAI et al. may
205        //                       omit the index on continuation deltas)
206        //   thinking_started  – whether ThinkingStart has been emitted
207        let stream = response
208            .bytes_stream()
209            .scan(
210                (
211                    Vec::new(),
212                    std::collections::HashMap::<usize, (String, String, String)>::new(),
213                    std::collections::HashMap::<String, usize>::new(), // id → index
214                    false,
215                    AssistantMessage::new(Api::OpenAiCompletions, &provider_name, &model_id),
216                ),
217                move |(
218                    pending_bytes,
219                    pending_tc,
220                    tc_id_to_index,
221                    thinking_started,
222                    accumulated_output,
223                ),
224                      chunk: Result<Bytes, reqwest::Error>| {
225                    let events = match chunk {
226                        Ok(bytes) => {
227                            // Prepend any incomplete bytes from previous chunk
228                            let mut combined =
229                                Vec::with_capacity(pending_bytes.len() + bytes.len());
230                            combined.extend_from_slice(pending_bytes);
231                            combined.extend_from_slice(&bytes);
232
233                            // Split into complete lines (ending with \n) and trailing incomplete data.
234                            // This prevents JSON parse failures from partial SSE lines
235                            // that were split across HTTP chunks.
236                            let (text, trailing) = split_complete_lines(&combined);
237                            *pending_bytes = trailing;
238
239                            tracing::debug!(
240                                "parse_sse_events input: {} bytes, {} lines",
241                                text.len(),
242                                text.lines().count()
243                            );
244                            let raw_events = parse_sse_events(
245                                &text,
246                                &provider_name,
247                                &model_id,
248                                accumulated_output,
249                            );
250                            tracing::debug!("parse_sse_events output: {} events", raw_events.len());
251
252                            // Post-process: accumulate tool call deltas, inject ThinkingStart once
253                            let mut processed = Vec::new();
254                            for event in raw_events {
255                                match &event {
256                                    ProviderEvent::ThinkingDelta { content_index, .. } => {
257                                        // Inject ThinkingStart before the first ThinkingDelta
258                                        if !*thinking_started {
259                                            *thinking_started = true;
260                                            processed.push(ProviderEvent::ThinkingStart {
261                                                content_index: *content_index,
262                                                partial: AssistantMessage::new(
263                                                    Api::OpenAiCompletions,
264                                                    &provider_name,
265                                                    &model_id,
266                                                ),
267                                            });
268                                        }
269                                        processed.push(event);
270                                    }
271                                    ProviderEvent::ToolCallStart {
272                                        content_index,
273                                        tool_call_id,
274                                        tool_name,
275                                        ..
276                                    } => {
277                                        let entry =
278                                            pending_tc.entry(*content_index).or_insert_with(|| {
279                                                (String::new(), String::new(), String::new())
280                                            });
281                                        if let Some(ref id) = tool_call_id {
282                                            if !id.is_empty() {
283                                                entry.0 = id.clone();
284                                                tc_id_to_index.insert(id.clone(), *content_index);
285                                            }
286                                        }
287                                        if let Some(ref name) = tool_name {
288                                            if !name.is_empty() {
289                                                entry.1 = name.clone();
290                                            }
291                                        }
292                                        processed.push(event);
293                                    }
294                                    ProviderEvent::ToolCallDelta {
295                                        content_index,
296                                        delta,
297                                        ..
298                                    } => {
299                                        // Dual-map lookup: prefer index, fall back to ID
300                                        let idx = if pending_tc.contains_key(content_index) {
301                                            *content_index
302                                        } else {
303                                            // Scan id→index map for a match
304                                            tc_id_to_index
305                                                .values()
306                                                .copied()
307                                                .find(|i| *i == *content_index)
308                                                .unwrap_or(*content_index)
309                                        };
310                                        let entry = pending_tc.entry(idx).or_insert_with(|| {
311                                            (String::new(), String::new(), String::new())
312                                        });
313                                        tracing::debug!(
314                                            "[TC-DELTA] idx={}, delta_len={}, accumulated_len={}",
315                                            idx,
316                                            delta.len(),
317                                            entry.2.len() + delta.len()
318                                        );
319                                        entry.2.push_str(delta);
320                                        processed.push(event);
321                                    }
322                                    ProviderEvent::ToolCallEnd { .. } => {
323                                        // Already a ToolCallEnd from parse_sse_events
324                                        processed.push(event);
325                                    }
326                                    ProviderEvent::Done { reason, .. } => {
327                                        // Before Done, emit ToolCallEnd for all accumulated tool calls
328                                        if matches!(reason, StopReason::ToolUse) {
329                                            let mut indices: Vec<usize> =
330                                                pending_tc.keys().copied().collect();
331                                            indices.sort();
332                                            for idx in indices {
333                                                let (id, name, arguments) = &pending_tc[&idx];
334                                                tracing::debug!(
335                                                    "[TC-END] idx={}, id={}, name={}, args_len={}",
336                                                    idx,
337                                                    id.len(),
338                                                    name.len(),
339                                                    arguments.len()
340                                                );
341                                                let args_value = parse_streaming_json(arguments);
342                                                processed.push(ProviderEvent::ToolCallEnd {
343                                                    content_index: idx,
344                                                    tool_call: crate::ToolCall {
345                                                        content_type:
346                                                            crate::messages::ToolCallType::ToolCall,
347                                                        id: id.clone(),
348                                                        name: name.clone(),
349                                                        arguments: args_value,
350                                                        thought_signature: None,
351                                                    },
352                                                    partial: AssistantMessage::new(
353                                                        Api::OpenAiCompletions,
354                                                        &provider_name,
355                                                        &model_id,
356                                                    ),
357                                                });
358                                            }
359                                        }
360                                        // Clear pending_tc for the next stream/turn.
361                                        // Without this, tool call arguments from the previous
362                                        // turn leak into the next turn's accumulation.
363                                        pending_tc.clear();
364                                        tc_id_to_index.clear();
365                                        processed.push(event);
366                                    }
367                                    _ => {
368                                        processed.push(event);
369                                    }
370                                }
371                            }
372                            processed
373                        }
374                        Err(e) => {
375                            vec![ProviderEvent::Error {
376                                reason: StopReason::Error,
377                                error: create_error_message(
378                                    &e.to_string(),
379                                    &provider_name,
380                                    &model_id,
381                                ),
382                            }]
383                        }
384                    };
385                    // Return Some to continue, wrap events in an iterator
386                    async move { Some(futures::stream::iter(events)) }
387                },
388            )
389            .flatten();
390
391        // Prepend Start event to the stream
392        let stream_with_start = futures::stream::once(async move { start_event }).chain(stream);
393        Ok(Box::pin(stream_with_start))
394    }
395
396    fn name(&self) -> &str {
397        "openai"
398    }
399}
400
401/// Build messages array from context
402fn build_messages(context: &Context) -> Result<Vec<JsonValue>, ProviderError> {
403    let mut messages = Vec::new();
404
405    // System prompt
406    if let Some(ref prompt) = context.system_prompt {
407        messages.push(serde_json::json!({
408            "role": "system",
409            "content": prompt,
410        }));
411    }
412
413    // Conversation messages
414    for msg in &context.messages {
415        match msg {
416            crate::Message::User(u) => {
417                let content: String = match &u.content {
418                    crate::MessageContent::Text(s) => s.clone(),
419                    crate::MessageContent::Blocks(blocks) => blocks_to_content(blocks)?.to_string(),
420                };
421                messages.push(serde_json::json!({
422                    "role": "user",
423                    "content": content,
424                }));
425            }
426            crate::Message::Assistant(a) => {
427                // OpenAI format: separate content (text) and tool_calls
428                let mut text_parts = Vec::new();
429                let mut tool_calls = Vec::new();
430                for block in &a.content {
431                    match block {
432                        ContentBlock::Text(t) => {
433                            text_parts.push(t.text.clone());
434                        }
435                        ContentBlock::Thinking(_) => {
436                            // Skip thinking blocks in message history
437                        }
438                        ContentBlock::ToolCall(tc) => {
439                            tool_calls.push(serde_json::json!({
440                                "id": tc.id,
441                                "type": "function",
442                                "function": {
443                                    "name": tc.name,
444                                    "arguments": tc.arguments.to_string(),
445                                },
446                            }));
447                        }
448                        ContentBlock::Image(_) | ContentBlock::Unknown(_) => {}
449                    }
450                }
451                let mut msg = serde_json::json!({
452                    "role": "assistant",
453                    "content": text_parts.join(""),
454                });
455                if !tool_calls.is_empty() {
456                    msg["tool_calls"] = serde_json::json!(tool_calls);
457                }
458                messages.push(msg);
459            }
460            crate::Message::ToolResult(t) => {
461                let result_text: String = t
462                    .content
463                    .iter()
464                    .filter_map(|b| b.as_text())
465                    .collect::<Vec<_>>()
466                    .join("");
467                messages.push(serde_json::json!({
468                    "role": "tool",
469                    "tool_call_id": t.tool_call_id,
470                    "content": result_text,
471                }));
472            }
473        }
474    }
475
476    Ok(messages)
477}
478
479/// Convert content blocks to a string representation
480fn blocks_to_content(blocks: &[ContentBlock]) -> Result<JsonValue, ProviderError> {
481    if blocks.len() == 1 {
482        if let Some(text) = blocks[0].as_text() {
483            return Ok(JsonValue::String(text.to_string()));
484        }
485    }
486
487    let items: Result<Vec<_>, _> = blocks
488        .iter()
489        .map(|block| match block {
490            ContentBlock::Text(t) => Ok(serde_json::json!({
491                "type": "text",
492                "text": t.text,
493            })),
494            ContentBlock::ToolCall(tc) => Ok(serde_json::json!({
495                "type": "function",
496                "id": tc.id,
497                "function": {
498                    "name": tc.name,
499                    "arguments": tc.arguments.to_string(),
500                },
501            })),
502            ContentBlock::Thinking(th) => Ok(serde_json::json!({
503                "type": "thinking",
504                "thinking": th.thinking,
505            })),
506            ContentBlock::Image(img) => Ok(serde_json::json!({
507                "type": "image_url",
508                "image_url": {
509                    "url": format!("data:{};base64,{}", img.mime_type, img.data),
510                },
511            })),
512            ContentBlock::Unknown(_) => Err(ProviderError::InvalidResponse(
513                "Unknown content block type".into(),
514            )),
515        })
516        .collect();
517
518    Ok(serde_json::json!(items?))
519}
520
521/// Build tools array
522fn build_tools(tools: &[crate::Tool]) -> Result<JsonValue, ProviderError> {
523    let items: Vec<_> = tools
524        .iter()
525        .map(|tool| {
526            serde_json::json!({
527                "type": "function",
528                "function": {
529                    "name": tool.name,
530                    "description": tool.description,
531                    "parameters": tool.parameters,
532                },
533            })
534        })
535        .collect();
536
537    Ok(serde_json::json!(items))
538}
539
540/// Extract the longest valid UTF-8 prefix from a byte slice.
541///
542/// Returns the valid string and the trailing bytes that form an incomplete UTF-8
543/// sequence. These trailing bytes should be prepended to the next chunk to
544/// ensure no characters are lost at HTTP chunk boundaries.
545fn find_valid_utf8_prefix(bytes: &[u8]) -> (String, Vec<u8>) {
546    match std::str::from_utf8(bytes) {
547        Ok(s) => (s.to_string(), Vec::new()),
548        Err(e) => {
549            let valid = &bytes[..e.valid_up_to()];
550            let trailing = bytes[e.valid_up_to()..].to_vec();
551            (String::from_utf8_lossy(valid).to_string(), trailing)
552        }
553    }
554}
555
556/// Split bytes into complete lines (ending with \n) and trailing incomplete data.
557/// This ensures `parse_sse_events` only receives complete SSE `data:` lines,
558/// preventing JSON parse failures from lines split across HTTP chunks.
559pub fn split_complete_lines(bytes: &[u8]) -> (String, Vec<u8>) {
560    // Find the last newline — everything up to and including it is complete.
561    match bytes.iter().rposition(|&b| b == b'\n') {
562        Some(last_nl) => {
563            let split_at = last_nl + 1;
564            let complete = match std::str::from_utf8(&bytes[..split_at]) {
565                Ok(s) => s.to_string(),
566                Err(_) => {
567                    let (s, _) = find_valid_utf8_prefix(&bytes[..split_at]);
568                    s
569                }
570            };
571            let trailing = bytes[split_at..].to_vec();
572            (complete, trailing)
573        }
574        None => {
575            // No newline at all — the entire buffer is incomplete.
576            // Check if it's valid UTF-8; if not, save as pending.
577            (String::new(), bytes.to_vec())
578        }
579    }
580}
581
582/// Parse SSE event stream from a byte buffer.
583///
584/// Optimizations over a naïve implementation:
585/// - **Fast-line splitting** – iterates over `\n` boundaries via `split`
586///   instead of allocating an intermediate `String` per line.
587/// - **Early `DONE` exit** – breaks immediately when `data: [DONE]` is
588///   encountered.
589/// - **Pre-allocated events** – reserves capacity based on data-line count.
590/// - **Accumulated usage** – tracks usage separately, only cloning into
591///   the Done message at stream end, not on every chunk.
592fn parse_sse_events(
593    text: &str,
594    _provider: &str,
595    _model_id: &str,
596    output: &mut AssistantMessage,
597) -> Vec<ProviderEvent> {
598    let mut events = Vec::new();
599
600    // Pre-estimate capacity: one event per data line is a reasonable upper bound.
601    let estimated_events = text.split('\n').filter(|l| l.starts_with("data: ")).count();
602    events.reserve(estimated_events);
603
604    let mut accumulated_usage = Usage::default();
605
606    for line in text.split('\n') {
607        let line = line.trim_end_matches('\r');
608        if line.is_empty() {
609            continue;
610        }
611
612        // Fast rejection for non-data lines (comments, event tags, etc.)
613        if !line.starts_with("data: ") {
614            continue;
615        }
616
617        let data = &line[6..]; // skip "data: "
618
619        // Early exit on stream end
620        if data == "[DONE]" {
621            break;
622        }
623
624        if data.is_empty() {
625            continue;
626        }
627
628        let chunk = match serde_json::from_str::<SSEChunk>(data) {
629            Ok(c) => c,
630            Err(_) => continue,
631        };
632
633        for choice in &chunk.choices {
634            if let Some(delta) = &choice.delta {
635                if let Some(content) = &delta.content {
636                    // pi-mono: append to the output's text block
637                    let last_text_idx = output
638                        .content
639                        .iter()
640                        .rposition(|b| matches!(b, ContentBlock::Text(_)));
641                    if let Some(idx) = last_text_idx {
642                        if let ContentBlock::Text(t) = &mut output.content[idx] {
643                            t.text.push_str(content);
644                        }
645                    } else {
646                        output
647                            .content
648                            .push(ContentBlock::Text(TextContent::new(content.clone())));
649                    }
650                    events.push(ProviderEvent::TextDelta {
651                        content_index: choice.index,
652                        delta: content.clone(),
653                        partial: output.clone(),
654                    });
655                }
656
657                // Handle GLM's reasoning_content field (thinking/thought chain)
658                if let Some(ref reasoning) = delta.reasoning_content {
659                    if !reasoning.is_empty() {
660                        // pi-mono: append to the output's thinking block
661                        let last_think_idx = output
662                            .content
663                            .iter()
664                            .rposition(|b| matches!(b, ContentBlock::Thinking(_)));
665                        if let Some(idx) = last_think_idx {
666                            if let ContentBlock::Thinking(t) = &mut output.content[idx] {
667                                t.thinking.push_str(reasoning);
668                            }
669                        } else {
670                            output
671                                .content
672                                .push(ContentBlock::Thinking(ThinkingContent::new(
673                                    reasoning.clone(),
674                                )));
675                        }
676                        events.push(ProviderEvent::ThinkingDelta {
677                            content_index: choice.index,
678                            delta: reasoning.clone(),
679                            partial: output.clone(),
680                        });
681                    }
682                }
683
684                if let Some(tool_calls) = &delta.tool_calls {
685                    for tc in tool_calls {
686                        let tc_index = tc.index.unwrap_or(choice.index);
687
688                        // Emit ToolCallStart when id or name is present (first delta)
689                        if tc.id.is_some()
690                            || tc.function.as_ref().and_then(|f| f.name.as_ref()).is_some()
691                        {
692                            events.push(ProviderEvent::ToolCallStart {
693                                content_index: tc_index,
694                                tool_call_id: tc.id.clone(),
695                                tool_name: tc.function.as_ref().and_then(|f| f.name.clone()),
696                                partial: output.clone(),
697                            });
698                        }
699
700                        // Emit ToolCallDelta for arguments
701                        if let Some(func) = &tc.function {
702                            events.push(ProviderEvent::ToolCallDelta {
703                                content_index: tc_index,
704                                delta: func.arguments.clone().unwrap_or_default(),
705                                partial: output.clone(),
706                            });
707                        }
708                    }
709                }
710            }
711
712            if choice.finish_reason.is_some() {
713                let reason = match choice.finish_reason.as_deref() {
714                    Some("stop") | Some("end") => StopReason::Stop,
715                    Some("length") => StopReason::Length,
716                    Some("tool_calls") | Some("function_call") => StopReason::ToolUse,
717                    Some("content_filter") => StopReason::Error,
718                    Some(unknown) => {
719                        tracing::warn!("Unknown finish_reason: '{}', treating as Error", unknown);
720                        StopReason::Error
721                    }
722                    None => StopReason::Stop,
723                };
724                tracing::info!("finish_reason={:?} → {:?}", choice.finish_reason, reason);
725
726                let mut done_msg = output.clone();
727                done_msg.stop_reason = reason;
728                done_msg.usage = accumulated_usage.clone();
729                events.push(ProviderEvent::Done {
730                    reason,
731                    message: done_msg,
732                });
733            }
734        }
735
736        // Accumulate usage from the chunk (if present).
737        if let Some(chunk_usage) = chunk.usage {
738            accumulated_usage.input = chunk_usage.prompt_tokens;
739            accumulated_usage.output = chunk_usage.completion_tokens;
740            accumulated_usage.cache_read = chunk_usage
741                .prompt_tokens_details
742                .as_ref()
743                .map(|d| d.cached_tokens)
744                .unwrap_or(0);
745            accumulated_usage.total_tokens = chunk_usage.total_tokens;
746        }
747    }
748
749    events
750}
751
752/// Create error assistant message
753fn create_error_message(msg: &str, provider: &str, model_id: &str) -> AssistantMessage {
754    let mut message = AssistantMessage::new(Api::OpenAiCompletions, provider, model_id);
755    message.stop_reason = StopReason::Error;
756    message.error_message = Some(msg.to_string());
757    message
758}
759
760// SSE chunk structure
761#[derive(Debug, Deserialize)]
762// serde deserialization structs
763struct SSEChunk {
764    _id: Option<String>,
765    #[serde(rename = "model")]
766    _model: Option<String>,
767    choices: Vec<Choice>,
768    usage: Option<UsageInfo>,
769}
770
771#[derive(Debug, Deserialize)]
772// serde deserialization structs
773struct Choice {
774    index: usize,
775    delta: Option<Delta>,
776    finish_reason: Option<String>,
777}
778
779#[derive(Debug, Deserialize)]
780struct Delta {
781    content: Option<String>,
782    reasoning_content: Option<String>,
783    tool_calls: Option<Vec<ToolCallDelta>>,
784}
785
786#[derive(Debug, Deserialize)]
787// serde deserialization structs
788struct ToolCallDelta {
789    index: Option<usize>,
790    id: Option<String>,
791    #[serde(rename = "type")]
792    _type_: Option<String>,
793    function: Option<FunctionDelta>,
794}
795
796#[derive(Debug, Deserialize)]
797// serde deserialization structs
798struct FunctionDelta {
799    name: Option<String>,
800    arguments: Option<String>,
801}
802
803#[derive(Debug, Deserialize, Clone)]
804struct UsageInfo {
805    prompt_tokens: usize,
806    completion_tokens: usize,
807    total_tokens: usize,
808    #[serde(rename = "prompt_tokens_details")]
809    prompt_tokens_details: Option<PromptTokensDetails>,
810}
811
812#[derive(Debug, Deserialize, Clone)]
813struct PromptTokensDetails {
814    #[serde(rename = "cached_tokens")]
815    cached_tokens: usize,
816}
817
818#[cfg(test)]
819mod tests {
820    use super::*;
821
822    const PROVIDER: &str = "openai";
823    const MODEL: &str = "gpt-4o";
824
825    fn parse_sse(sse: &str) -> Vec<ProviderEvent> {
826        let mut output = AssistantMessage::new(Api::OpenAiCompletions, PROVIDER, MODEL);
827        parse_sse_events(sse, PROVIDER, MODEL, &mut output)
828    }
829
830    // ── SSE event parsing ──────────────────────────────────────────────
831
832    #[test]
833    fn parse_single_text_event() {
834        let sse = "data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\"}}]}\n\n";
835        let events = parse_sse(sse);
836        assert_eq!(events.len(), 1);
837        match &events[0] {
838            ProviderEvent::TextDelta {
839                delta,
840                content_index,
841                ..
842            } => {
843                assert_eq!(delta, "Hello");
844                assert_eq!(*content_index, 0);
845            }
846            other => panic!("expected TextDelta, got {other:?}"),
847        }
848    }
849
850    #[test]
851    fn parse_multiple_text_events() {
852        let sse = concat!(
853            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hel\"}}]}\n",
854            "\n",
855            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"lo!\"}}]}\n",
856            "\n"
857        );
858        let events = parse_sse(sse);
859        assert_eq!(events.len(), 2);
860        let texts: Vec<&str> = events
861            .iter()
862            .filter_map(|e| match e {
863                ProviderEvent::TextDelta { delta, .. } => Some(delta.as_str()),
864                _ => None,
865            })
866            .collect();
867        assert_eq!(texts, vec!["Hel", "lo!"]);
868    }
869
870    #[test]
871    fn parse_done_terminator() {
872        let sse = concat!(
873            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n",
874            "\n",
875            "data: [DONE]\n",
876            "\n",
877            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"NEVER\"}}]}\n"
878        );
879        let events = parse_sse(sse);
880        // Should stop at [DONE]; the final data line is never parsed
881        assert_eq!(events.len(), 1);
882        match &events[0] {
883            ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "X"),
884            other => panic!("expected TextDelta, got {other:?}"),
885        }
886    }
887
888    // ── Content extraction ─────────────────────────────────────────────
889
890    #[test]
891    fn parse_finish_reason_stop() {
892        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n\n";
893        let events = parse_sse(sse);
894        assert_eq!(events.len(), 1);
895        match &events[0] {
896            ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Stop)),
897            other => panic!("expected Done, got {other:?}"),
898        }
899    }
900
901    #[test]
902    fn parse_finish_reason_length() {
903        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"length\"}]}\n\n";
904        let events = parse_sse(sse);
905        match &events[0] {
906            ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::Length)),
907            other => panic!("expected Done with Length, got {other:?}"),
908        }
909    }
910
911    #[test]
912    fn parse_finish_reason_tool_calls() {
913        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n\n";
914        let events = parse_sse(sse);
915        match &events[0] {
916            ProviderEvent::Done { reason, .. } => assert!(matches!(reason, StopReason::ToolUse)),
917            other => panic!("expected Done with ToolUse, got {other:?}"),
918        }
919    }
920
921    // ── Tool call delta accumulation ───────────────────────────────────
922
923    #[test]
924    fn parse_tool_call_deltas() {
925        let sse = concat!(
926            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"\"}}]}}]}\n",
927            "\n",
928            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\\\"city\\\":\\\"SF\\\"}\"}}]}}]}\n",
929            "\n"
930        );
931        let events = parse_sse(sse);
932        // First chunk: ToolCallStart (id+name present) + ToolCallDelta (function present)
933        // Second chunk: ToolCallDelta only
934        assert_eq!(events.len(), 3);
935        let starts: Vec<&str> = events
936            .iter()
937            .filter_map(|e| match e {
938                ProviderEvent::ToolCallStart { tool_name, .. } => tool_name.as_deref(),
939                _ => None,
940            })
941            .collect();
942        assert_eq!(starts, vec!["get_weather"]);
943        let deltas: Vec<&str> = events
944            .iter()
945            .filter_map(|e| match e {
946                ProviderEvent::ToolCallDelta { delta, .. } => Some(delta.as_str()),
947                _ => None,
948            })
949            .collect();
950        assert_eq!(deltas, vec!["", "{\"city\":\"SF\"}"]);
951    }
952
953    #[test]
954    fn parse_tool_call_with_no_arguments_field() {
955        // function field present but arguments is null → emits ToolCallStart + ToolCallDelta
956        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"name\":\"run\"}}]}}]}\n\n";
957        let events = parse_sse(sse);
958        assert_eq!(events.len(), 2);
959        match &events[0] {
960            ProviderEvent::ToolCallStart { tool_name, .. } => {
961                assert_eq!(tool_name.as_deref(), Some("run"));
962            }
963            other => panic!("expected ToolCallStart, got {other:?}"),
964        }
965        match &events[1] {
966            ProviderEvent::ToolCallDelta { delta, .. } => assert_eq!(delta, ""),
967            other => panic!("expected ToolCallDelta, got {other:?}"),
968        }
969    }
970
971    // ── Usage accumulation ─────────────────────────────────────────────
972
973    #[test]
974    fn parse_usage_in_chunk() {
975        // Usage is accumulated from earlier chunks; the Done event captures
976        // usage that was accumulated *before* the finish_reason chunk.
977        let sse = concat!(
978            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"hi\"}}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":8,\"total_tokens\":18,\"prompt_tokens_details\":{\"cached_tokens\":3}}}\n",
979            "\n",
980            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
981        );
982        let events = parse_sse(sse);
983        // TextDelta + Done
984        assert_eq!(events.len(), 2);
985        match &events[1] {
986            ProviderEvent::Done { message, .. } => {
987                assert_eq!(message.usage.input, 10);
988                assert_eq!(message.usage.output, 8);
989                assert_eq!(message.usage.total_tokens, 18);
990                assert_eq!(message.usage.cache_read, 3);
991            }
992            other => panic!("expected Done, got {other:?}"),
993        }
994    }
995
996    #[test]
997    fn parse_usage_without_cache_details() {
998        // Usage from an earlier chunk; Done event on a separate chunk without usage.
999        let sse = concat!(
1000            "data: {\"id\":\"c\",\"choices\":[],\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":2,\"total_tokens\":7}}\n",
1001            "\n",
1002            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"stop\"}]}\n"
1003        );
1004        let events = parse_sse(sse);
1005        match &events[0] {
1006            ProviderEvent::Done { message, .. } => {
1007                assert_eq!(message.usage.input, 5);
1008                assert_eq!(message.usage.output, 2);
1009                assert_eq!(message.usage.cache_read, 0);
1010            }
1011            other => panic!("expected Done, got {other:?}"),
1012        }
1013    }
1014
1015    // ── Empty / malformed handling ─────────────────────────────────────
1016
1017    #[test]
1018    fn parse_empty_input() {
1019        let events = parse_sse("");
1020        assert!(events.is_empty());
1021    }
1022
1023    #[test]
1024    fn parse_only_empty_lines() {
1025        let events = parse_sse("\n\n\n");
1026        assert!(events.is_empty());
1027    }
1028
1029    #[test]
1030    fn parse_malformed_json_after_data() {
1031        let sse = "data: {not json at all}\ndata: also bad\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"ok\"}}]}\n";
1032        let events = parse_sse(sse);
1033        // Malformed lines are skipped, only the valid one emits
1034        assert_eq!(events.len(), 1);
1035        match &events[0] {
1036            ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "ok"),
1037            other => panic!("expected TextDelta, got {other:?}"),
1038        }
1039    }
1040
1041    #[test]
1042    fn parse_empty_data_line() {
1043        let sse = "data: \ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"X\"}}]}\n";
1044        let events = parse_sse(sse);
1045        assert_eq!(events.len(), 1);
1046    }
1047
1048    #[test]
1049    fn parse_non_data_lines_ignored() {
1050        let sse = "event: ping\nid: 42\nretry: 5000\ndata: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Y\"}}]}\n";
1051        let events = parse_sse(sse);
1052        assert_eq!(events.len(), 1);
1053    }
1054
1055    #[test]
1056    fn parse_carriage_return_line_endings() {
1057        let sse = "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"CR\"}}]}\r\n\r\n";
1058        let events = parse_sse(sse);
1059        assert_eq!(events.len(), 1);
1060        match &events[0] {
1061            ProviderEvent::TextDelta { delta, .. } => assert_eq!(delta, "CR"),
1062            other => panic!("expected TextDelta, got {other:?}"),
1063        }
1064    }
1065
1066    // ── Mixed content + tool + done ────────────────────────────────────
1067
1068    #[test]
1069    fn parse_full_stream_with_text_tool_and_done() {
1070        let sse = concat!(
1071            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Let me\"}}]}\n",
1072            "\n",
1073            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" check\"}}]}\n",
1074            "\n",
1075            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\\\"rust\\\"}\"}}]}}]}\n",
1076            "\n",
1077            "data: {\"id\":\"c\",\"choices\":[{\"index\":0,\"delta\":null,\"finish_reason\":\"tool_calls\"}]}\n",
1078            "\n",
1079            "data: [DONE]\n"
1080        );
1081        let events = parse_sse(sse);
1082        assert_eq!(events.len(), 5); // 2 TextDelta + ToolCallStart + ToolCallDelta + Done
1083
1084        let mut text_count = 0;
1085        let mut tc_start_count = 0;
1086        let mut tc_delta_count = 0;
1087        let mut done_count = 0;
1088        for e in &events {
1089            match e {
1090                ProviderEvent::TextDelta { .. } => text_count += 1,
1091                ProviderEvent::ToolCallStart { .. } => tc_start_count += 1,
1092                ProviderEvent::ToolCallDelta { .. } => tc_delta_count += 1,
1093                ProviderEvent::Done { reason, .. } => {
1094                    done_count += 1;
1095                    assert!(matches!(reason, StopReason::ToolUse));
1096                }
1097                other => panic!("unexpected event: {other:?}"),
1098            }
1099        }
1100        assert_eq!(text_count, 2);
1101        assert_eq!(tc_start_count, 1);
1102        assert_eq!(tc_delta_count, 1);
1103        assert_eq!(done_count, 1);
1104    }
1105}