Skip to main content

defect_llm/protocol/
openai_chat.rs

1//! OpenAI Chat Completions protocol encoding and decoding.
2//!
3//! Encodes [`defect_core::llm::CompletionRequest`] into the wire format
4//! [`crate::wire::openai::components::CreateChatCompletionRequest`],
5//! and decodes an SSE [`Sse`] stream of [`CreateChatCompletionStreamResponse`] into a
6//! [`defect_core::llm::ProviderChunk`] stream.
7//!
8//! OpenAI Chat Completions API protocol mapping.
9//!
10//! [`Sse`]: ::sse_stream::Sse
11//! [`CreateChatCompletionStreamResponse`]:
12//!     crate::wire::openai::components::CreateChatCompletionStreamResponse
13
14use std::collections::HashMap;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use defect_core::error::BoxError;
19use defect_core::llm::{
20    CompletionRequest, ImageData, Message, MessageContent, ProviderChunk, ProviderError,
21    ProviderErrorKind, ReasoningEffort, Role, StopReason, ThinkingConfig, ThinkingEcho, ToolChoice,
22    ToolResultBody, ToolResultContent, Usage,
23};
24use defect_core::tool::ToolSchema;
25use futures::Stream;
26use sse_stream::Sse;
27use toac::body::codec::sse::SseEventStream;
28use tokio_util::sync::CancellationToken;
29use tracing::warn;
30
31use crate::wire::openai::components as wire;
32
33// encode
34
35const PROMPT_CACHE_KEY_PREFIX: &str = "defect:chat:v1:";
36const PROMPT_CACHE_KEY_OFFSET_BASIS: u64 = 0xcbf2_9ce4_8422_2325;
37const PROMPT_CACHE_KEY_PRIME: u64 = 0x0000_0001_0000_01b3;
38
39type UsageParser = fn(Option<&serde_json::Value>, &wire::CompletionUsage) -> Usage;
40
41/// Encodes a [`CompletionRequest`] into a wire request body.
42///
43/// Key mapping decisions:
44///
45/// - Forces `stream = true` + `stream_options.include_usage = true`:
46///   the protocol layer only runs the SSE branch, and **must** let the upstream
47///   send a trailing usage chunk, otherwise token billing is unavailable.
48/// - Promotes `system` to `messages[0]` as a system message — OpenAI has no
49///   top-level `system` field (unlike Anthropic).
50/// - A single [`Message`] may be split into multiple wire messages in OpenAI
51///   format: if a user message mixes [`MessageContent::ToolResult`], it needs
52///   a separate tool message (OpenAI uses `role: tool` + `tool_call_id` for
53///   tool results, which cannot be mixed with user text in the same message).
54/// - [`MessageContent::ToolUse`] in assistant messages maps to the `tool_calls`
55///   field rather than content blocks. `args` is serialized via `serde_json::to_string`
56///   (the OpenAI protocol requires `function.arguments` to be stringified JSON).
57/// - `top_k` is absent in the OpenAI protocol; the provider layer handles this.
58/// - `max_tokens`: the OpenAI dialect deprecates `max_tokens` in favor of
59///   `max_completion_tokens`. The DeepSeek-compatible dialect still uses
60///   `max_tokens` to align with its OpenAI-compatible endpoint and opencode
61///   request format. Neither sets a default like Anthropic — the model decides
62///   when not specified.
63pub fn encode_request(req: &CompletionRequest) -> wire::CreateChatCompletionRequest {
64    encode_request_with_echo(req, ThinkingEcho::Forbidden)
65}
66
67/// Same shape as [`encode_request`], but explicitly accepts a thinking-echo policy.
68///
69/// `echo_mode` is read by the provider layer from [`defect_core::llm::Capabilities`]
70/// and passed in: when `Required`, the [`MessageContent::Thinking`] text on the
71/// assistant message is written to the non-standard `reasoning_content` field on the
72/// wire; when `Forbidden` (including unconfigured), it is never written.
73pub fn encode_request_with_echo(
74    req: &CompletionRequest,
75    echo_mode: ThinkingEcho,
76) -> wire::CreateChatCompletionRequest {
77    encode_request_full(req, echo_mode, None)
78}
79
80/// Same shape as [`encode_request_with_echo`], but allows the provider layer to forcibly
81/// override the `reasoning_effort` field. When `effort_override` is `Some(_)`, the value
82/// of `SamplingParams::thinking` is ignored and the override is written directly to the
83/// wire; when `None`, the old behavior (thinking enabled → medium) is preserved.
84pub fn encode_request_full(
85    req: &CompletionRequest,
86    echo_mode: ThinkingEcho,
87    effort_override: Option<ReasoningEffort>,
88) -> wire::CreateChatCompletionRequest {
89    encode_request_with_dialect(req, echo_mode, effort_override, ChatDialect::OpenAi)
90}
91
92/// OpenAI Chat-compatible request dialect.
93///
94/// Even though OpenAI and compatible providers share the same JSON schema, there are
95/// still minor semantic differences in a few fields.
96#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
97pub enum ChatDialect {
98    #[default]
99    OpenAi,
100    DeepSeek,
101}
102
103/// Same shape as [`encode_request_full`], but allows the provider to specify a compatible
104/// vendor dialect.
105pub fn encode_request_with_dialect(
106    req: &CompletionRequest,
107    echo_mode: ThinkingEcho,
108    effort_override: Option<ReasoningEffort>,
109    dialect: ChatDialect,
110) -> wire::CreateChatCompletionRequest {
111    let mut messages = Vec::with_capacity(req.messages.len() + 1);
112    if let Some(sys) = req.system.as_ref() {
113        messages.push(encode_system_message(sys));
114    }
115    for m in &req.messages {
116        encode_message_into(m, echo_mode, dialect, &mut messages);
117    }
118
119    let max_tokens = req.sampling.max_tokens.map(i64::from);
120    #[allow(deprecated)]
121    wire::CreateChatCompletionRequest {
122        // ---- fields we use ----
123        messages,
124        model: wire::ModelIdsShared::ModelIdsSharedVariant0(req.model.clone()),
125        stream: Some(true),
126        stream_options: Some(wire::ChatCompletionStreamOptions::ChatCompletionStreamOptionsVariant0(
127            wire::ChatCompletionStreamOptionsVariant0 {
128                include_usage: Some(true),
129                include_obfuscation: None,
130            },
131        )),
132        max_completion_tokens: match dialect {
133            ChatDialect::OpenAi => max_tokens,
134            ChatDialect::DeepSeek => None,
135        },
136        temperature: req.sampling.temperature.map(|t| {
137            wire::CreateChatCompletionRequestTemperature::CreateChatCompletionRequestTemperatureVariant0(
138                f64::from(t),
139            )
140        }),
141        top_p: req.sampling.top_p.map(|t| {
142            wire::CreateChatCompletionRequestTopP::CreateChatCompletionRequestTopPVariant0(
143                f64::from(t),
144            )
145        }),
146        stop: if req.sampling.stop_sequences.is_empty() {
147            None
148        } else {
149            Some(wire::StopConfiguration::StopConfigurationVariant1(
150                req.sampling.stop_sequences.clone(),
151            ))
152        },
153        // Priority: per-session `sampling.reasoning_effort` (ACP thought-level,
154        // switchable at runtime) > provider-level `effort_override` (fixed in config) >
155        // derived from `thinking`. The first two directly materialize the level; the last
156        // can only map to medium.
157        reasoning_effort: req
158            .sampling
159            .reasoning_effort
160            .or(effort_override)
161            .map(encode_reasoning_effort)
162            .or_else(|| encode_thinking(req.sampling.thinking)),
163        tools: if req.tools.is_empty() {
164            None
165        } else {
166            Some(req.tools.iter().map(encode_tool).collect())
167        },
168        tool_choice: encode_tool_choice(&req.tool_choice),
169        // Unused fields: explicitly set to None for easy grepping later
170        metadata: None,
171        top_logprobs: None,
172        user: None,
173        safety_identifier: None,
174        prompt_cache_key: match dialect {
175            ChatDialect::OpenAi => Some(build_prompt_cache_key(req, echo_mode)),
176            ChatDialect::DeepSeek => None,
177        },
178        service_tier: None,
179        prompt_cache_retention: None,
180        modalities: None,
181        verbosity: None,
182        frequency_penalty: None,
183        presence_penalty: None,
184        web_search_options: None,
185        response_format: None,
186        audio: None,
187        store: None,
188        logit_bias: None,
189        logprobs: None,
190        max_tokens: match dialect {
191            ChatDialect::OpenAi => None,
192            ChatDialect::DeepSeek => max_tokens,
193        },
194        n: None,
195        prediction: None,
196        seed: None,
197        parallel_tool_calls: None,
198        function_call: None,
199        functions: None,
200    }
201}
202
203fn build_prompt_cache_key(req: &CompletionRequest, echo_mode: ThinkingEcho) -> String {
204    let mut hasher = PromptCacheKeyHasher::new();
205    hasher.write_str(&req.model);
206    if let Some(system) = req.system.as_deref() {
207        hasher.write_str(system);
208    }
209    hasher.write_str(prompt_cache_echo_mode(echo_mode));
210    hasher.write_str(prompt_cache_tool_choice(&req.tool_choice));
211    hasher.write_json(&req.tools);
212    format!("{PROMPT_CACHE_KEY_PREFIX}{:016x}", hasher.finish())
213}
214
215fn prompt_cache_echo_mode(mode: ThinkingEcho) -> &'static str {
216    match mode {
217        ThinkingEcho::Forbidden => "forbidden",
218        ThinkingEcho::Required => "required",
219        ThinkingEcho::Optional => "optional",
220    }
221}
222
223fn prompt_cache_tool_choice(choice: &ToolChoice) -> &str {
224    match choice {
225        ToolChoice::Auto => "auto",
226        ToolChoice::Required => "required",
227        ToolChoice::Named { name } => name.as_str(),
228        ToolChoice::None => "none",
229    }
230}
231
232struct PromptCacheKeyHasher {
233    state: u64,
234}
235
236impl PromptCacheKeyHasher {
237    fn new() -> Self {
238        Self {
239            state: PROMPT_CACHE_KEY_OFFSET_BASIS,
240        }
241    }
242
243    fn write_json<T>(&mut self, value: &T)
244    where
245        T: serde::Serialize,
246    {
247        let Ok(encoded) = serde_json::to_vec(value) else {
248            return;
249        };
250        self.write_bytes(&encoded);
251    }
252
253    fn write_str(&mut self, value: &str) {
254        self.write_bytes(value.as_bytes());
255    }
256
257    fn write_bytes(&mut self, bytes: &[u8]) {
258        for byte in bytes {
259            self.state ^= u64::from(*byte);
260            self.state = self.state.wrapping_mul(PROMPT_CACHE_KEY_PRIME);
261        }
262        self.state ^= u64::from(b'\n');
263        self.state = self.state.wrapping_mul(PROMPT_CACHE_KEY_PRIME);
264    }
265
266    fn finish(self) -> u64 {
267        self.state
268    }
269}
270
271fn encode_system_message(text: &str) -> wire::ChatCompletionRequestMessage {
272    wire::ChatCompletionRequestMessage::ChatCompletionRequestSystemMessage(
273        wire::ChatCompletionRequestSystemMessage {
274            content: wire::ChatCompletionRequestSystemMessageContent::ChatCompletionRequestSystemMessageContentVariant0(
275                text.to_owned(),
276            ),
277            role: wire::ChatCompletionRequestSystemMessageRole::System,
278            name: None,
279        },
280    )
281}
282
283/// A single [`Message`] may fan out into multiple wire messages:
284/// - Each [`MessageContent::ToolResult`] embedded in a user message becomes a separate
285///   tool message.
286/// - Each [`MessageContent::ToolUse`] in an assistant message is lifted to the top-level
287///   `tool_calls` field instead of being part of the content.
288fn encode_message_into(
289    m: &Message,
290    echo_mode: ThinkingEcho,
291    dialect: ChatDialect,
292    out: &mut Vec<wire::ChatCompletionRequestMessage>,
293) {
294    match m.role {
295        Role::User => encode_user_message_into(m, out),
296        Role::Assistant => encode_assistant_message_into(m, echo_mode, dialect, out),
297    }
298}
299
300fn encode_user_message_into(m: &Message, out: &mut Vec<wire::ChatCompletionRequestMessage>) {
301    let mut user_parts: Vec<wire::ChatCompletionRequestUserMessageContentPart> = Vec::new();
302    let mut tool_results: Vec<(String, String)> = Vec::new(); // (tool_use_id, text)
303
304    for c in m.content.iter() {
305        match c {
306            MessageContent::Text { text } => {
307                user_parts.push(
308                    wire::ChatCompletionRequestUserMessageContentPart::ChatCompletionRequestMessageContentPartText(
309                        wire::ChatCompletionRequestMessageContentPartText {
310                            r#type: wire::ChatCompletionRequestMessageContentPartTextType::Text,
311                            text: text.clone(),
312                        },
313                    ),
314                );
315            }
316            MessageContent::Image { mime, data } => {
317                user_parts.push(image_part(mime, data));
318            }
319            MessageContent::ToolResult {
320                tool_use_id,
321                output,
322                is_error: _,
323            } => {
324                // OpenAI's tool message has no `is_error` field; we use a prefix to
325                // signal the error state so the model can read it from the content.
326                // `is_error` is primarily for Anthropic; here we preserve its semantics
327                // but in a different form.
328                //
329                // OpenAI's tool message only accepts text—images from multimodal results
330                // cannot be placed inside a tool message. Strategy: extract image blocks
331                // and push them into `user_parts` (the user message immediately following
332                // the tool message), leaving only text plus a placeholder hint in the
333                // tool message so the model knows the images are in the next message.
334                let text = match output {
335                    ToolResultBody::Text { text } => text.clone(),
336                    ToolResultBody::Json { value } => value.to_string(),
337                    ToolResultBody::Content { blocks } => {
338                        let mut text = String::new();
339                        let mut image_count = 0usize;
340                        for block in blocks {
341                            match block {
342                                ToolResultContent::Text { text: t } => {
343                                    if !text.is_empty() {
344                                        text.push('\n');
345                                    }
346                                    text.push_str(t);
347                                }
348                                ToolResultContent::Image { mime, data } => {
349                                    image_count += 1;
350                                    user_parts.push(image_part(mime, data));
351                                }
352                            }
353                        }
354                        if image_count > 0 {
355                            if !text.is_empty() {
356                                text.push('\n');
357                            }
358                            text.push_str(&format!(
359                                "[{image_count} image(s) from this tool result follow in the next user message]"
360                            ));
361                        }
362                        text
363                    }
364                };
365                tool_results.push((tool_use_id.clone(), text));
366            }
367            // Fallback for `non_exhaustive`: keep the slot but leave the content empty.
368            _ => {
369                user_parts.push(
370                    wire::ChatCompletionRequestUserMessageContentPart::ChatCompletionRequestMessageContentPartText(
371                        wire::ChatCompletionRequestMessageContentPartText {
372                            r#type: wire::ChatCompletionRequestMessageContentPartTextType::Text,
373                            text: String::new(),
374                        },
375                    ),
376                );
377            }
378        }
379    }
380
381    // OpenAI / LiteLLM require that an assistant message with tool_calls must be
382    // immediately followed by the corresponding tool messages; a subsequent user message
383    // cannot be inserted in between.
384    for (tool_use_id, text) in tool_results {
385        out.push(wire::ChatCompletionRequestMessage::ChatCompletionRequestToolMessage(
386            wire::ChatCompletionRequestToolMessage {
387                role: wire::ChatCompletionRequestToolMessageRole::Tool,
388                content: wire::ChatCompletionRequestToolMessageContent::ChatCompletionRequestToolMessageContentVariant0(
389                    text,
390                ),
391                tool_call_id: tool_use_id,
392            },
393        ));
394    }
395    if !user_parts.is_empty() {
396        out.push(wire::ChatCompletionRequestMessage::ChatCompletionRequestUserMessage(
397            wire::ChatCompletionRequestUserMessage {
398                content: wire::ChatCompletionRequestUserMessageContent::ChatCompletionRequestUserMessageContentVariant1(
399                    user_parts,
400                ),
401                role: wire::ChatCompletionRequestUserMessageRole::User,
402                name: None,
403            },
404        ));
405    }
406}
407
408fn encode_assistant_message_into(
409    m: &Message,
410    echo_mode: ThinkingEcho,
411    dialect: ChatDialect,
412    out: &mut Vec<wire::ChatCompletionRequestMessage>,
413) {
414    const EMPTY_ASSISTANT_CONTENT: &str = "";
415
416    let mut text_parts: Vec<String> = Vec::new();
417    let mut tool_calls: Vec<wire::ChatCompletionMessageToolCallsItem> = Vec::new();
418    let mut reasoning_text = String::new();
419
420    for c in m.content.iter() {
421        match c {
422            MessageContent::Text { text } => text_parts.push(text.clone()),
423            MessageContent::Thinking { text, .. } => {
424                // The `signature` field is irrelevant on the OpenAI path (neither
425                // DeepSeek nor OpenAI uses it); only the text is taken.
426                reasoning_text.push_str(text);
427            }
428            MessageContent::ToolUse { id, name, args } => {
429                tool_calls.push(
430                    wire::ChatCompletionMessageToolCallsItem::ChatCompletionMessageToolCall(
431                        wire::ChatCompletionMessageToolCall {
432                            id: id.clone(),
433                            r#type: wire::ChatCompletionMessageToolCallType::Function,
434                            function: wire::ChatCompletionMessageToolCallFunction {
435                                name: name.clone(),
436                                arguments: serde_json::to_string(args).unwrap_or_default(),
437                            },
438                        },
439                    ),
440                );
441            }
442            // ToolResult/Image should not appear in the assistant role; the
443            // non_exhaustive fallback also reaches here. Ignore and do not send over the
444            // wire.
445            _ => {}
446        }
447    }
448
449    let reasoning_content = match dialect {
450        ChatDialect::DeepSeek => Some(reasoning_text),
451        ChatDialect::OpenAi => match (echo_mode, reasoning_text.is_empty()) {
452            (ThinkingEcho::Required, false) => Some(reasoning_text),
453            // Treat `Optional` the same as `Required`: replaying is safer when the server
454            // tolerates extra thinking fields (DeepSeek-v4-pro docs list it as `must`;
455            // other `Optional` vendors do not error on extra fields either).
456            (ThinkingEcho::Optional, false) => Some(reasoning_text),
457            _ => None,
458        },
459    };
460    let content = if text_parts.is_empty() {
461        if tool_calls.is_empty() && reasoning_content.is_some() {
462            // DeepSeek v4 series validates that assistant messages have at least
463            // `content` or `tool_calls`; replaying a thinking-only history requires
464            // adding an empty `content`.
465            Some(wire::ChatCompletionRequestAssistantMessageContent::ChatCompletionRequestAssistantMessageContentVariant0(
466                wire::ChatCompletionRequestAssistantMessageContentVariant0::ChatCompletionRequestAssistantMessageContentVariant0Variant0(
467                    EMPTY_ASSISTANT_CONTENT.to_owned(),
468                ),
469            ))
470        } else {
471            None
472        }
473    } else {
474        Some(wire::ChatCompletionRequestAssistantMessageContent::ChatCompletionRequestAssistantMessageContentVariant0(
475            wire::ChatCompletionRequestAssistantMessageContentVariant0::ChatCompletionRequestAssistantMessageContentVariant0Variant0(
476                text_parts.join(""),
477            ),
478        ))
479    };
480
481    #[allow(deprecated)]
482    out.push(
483        wire::ChatCompletionRequestMessage::ChatCompletionRequestAssistantMessage(
484            wire::ChatCompletionRequestAssistantMessage {
485                content,
486                refusal: None,
487                role: wire::ChatCompletionRequestAssistantMessageRole::Assistant,
488                name: None,
489                audio: None,
490                tool_calls: if tool_calls.is_empty() {
491                    None
492                } else {
493                    Some(tool_calls)
494                },
495                function_call: None,
496                reasoning_content,
497            },
498        ),
499    );
500}
501
502/// Build an OpenAI user-message image part. This is shared with the image block extracted
503/// from a multimodal `tool_result` via `MessageContent::Image`.
504fn image_part(mime: &str, data: &ImageData) -> wire::ChatCompletionRequestUserMessageContentPart {
505    wire::ChatCompletionRequestUserMessageContentPart::ChatCompletionRequestMessageContentPartImage(
506        wire::ChatCompletionRequestMessageContentPartImage {
507            r#type: wire::ChatCompletionRequestMessageContentPartImageType::ImageUrl,
508            image_url: wire::ChatCompletionRequestMessageContentPartImageImageUrl {
509                url: image_url_string(mime, data),
510                detail: None,
511            },
512        },
513    )
514}
515
516fn image_url_string(mime: &str, data: &ImageData) -> String {
517    match data {
518        ImageData::Url { url } => url.clone(),
519        ImageData::Base64 { encoded } => format!("data:{mime};base64,{encoded}"),
520    }
521}
522
523fn encode_thinking(t: ThinkingConfig) -> Option<wire::ReasoningEffort> {
524    match t {
525        ThinkingConfig::Disabled => None,
526        // OpenAI's thinking does not accept `budget_tokens` (unlike Anthropic); it only
527        // supports effort levels. The budget value is discarded and uniformly mapped to
528        // `medium`.
529        ThinkingConfig::Enabled { .. } => Some(wire::ReasoningEffort::ReasoningEffortVariant0(
530            wire::ReasoningEffortVariant0::Medium,
531        )),
532    }
533}
534
535fn encode_reasoning_effort(effort: ReasoningEffort) -> wire::ReasoningEffort {
536    use ReasoningEffort as E;
537    use wire::ReasoningEffortVariant0 as V;
538    let v = match effort {
539        E::None => V::None,
540        E::Minimal => V::Minimal,
541        E::Low => V::Low,
542        E::Medium => V::Medium,
543        E::High => V::High,
544        E::Xhigh => V::Xhigh,
545    };
546    wire::ReasoningEffort::ReasoningEffortVariant0(v)
547}
548
549fn encode_tool_choice(c: &ToolChoice) -> Option<wire::ChatCompletionToolChoiceOption> {
550    match c {
551        ToolChoice::Auto => Some(
552            wire::ChatCompletionToolChoiceOption::ChatCompletionToolChoiceOptionVariant0(
553                wire::ChatCompletionToolChoiceOptionVariant0::Auto,
554            ),
555        ),
556        ToolChoice::Required => Some(
557            wire::ChatCompletionToolChoiceOption::ChatCompletionToolChoiceOptionVariant0(
558                wire::ChatCompletionToolChoiceOptionVariant0::Required,
559            ),
560        ),
561        ToolChoice::None => Some(
562            wire::ChatCompletionToolChoiceOption::ChatCompletionToolChoiceOptionVariant0(
563                wire::ChatCompletionToolChoiceOptionVariant0::None,
564            ),
565        ),
566        ToolChoice::Named { name } => Some(
567            wire::ChatCompletionToolChoiceOption::ChatCompletionNamedToolChoice(
568                wire::ChatCompletionNamedToolChoice {
569                    r#type: wire::ChatCompletionNamedToolChoiceType::Function,
570                    function: wire::ChatCompletionNamedToolChoiceFunction { name: name.clone() },
571                },
572            ),
573        ),
574    }
575}
576
577fn encode_tool(t: &ToolSchema) -> wire::CreateChatCompletionRequestTools {
578    wire::CreateChatCompletionRequestTools::ChatCompletionTool(wire::ChatCompletionTool {
579        r#type: wire::ChatCompletionToolType::Function,
580        function: wire::FunctionObject {
581            name: t.name.clone(),
582            description: if t.description.is_empty() {
583                None
584            } else {
585                Some(t.description.clone())
586            },
587            parameters: Some(json_value_to_parameters(&t.input_schema)),
588            strict: None,
589        },
590    })
591}
592
593fn json_value_to_parameters(v: &serde_json::Value) -> wire::FunctionParameters {
594    v.as_object()
595        .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
596        .unwrap_or_default()
597}
598
599// ---------- decode -------------------------------------------------------
600
601/// Internal state of the decoding state machine.
602#[derive(Debug, Default)]
603struct DecoderState {
604    /// MessageStart has been emitted.
605    started: bool,
606    /// `Stop` has been emitted. After `Stop`, only `Usage` is allowed.
607    stopped: bool,
608    /// Whether the `data: [DONE]` marker has been seen.
609    done: bool,
610    /// Received a fatal error (parsing failure; retry cannot continue).
611    fatal: bool,
612    /// Maps `delta.tool_calls[].index` → `tool_call_id`. OpenAI streaming chunks
613    /// associate tool calls by index: the first frame carries `id` + `name`, while
614    /// subsequent `args` frames only have the index. This table maps the index back to
615    /// the `id` in `ProviderChunk`.
616    tool_calls: HashMap<i64, ToolCallState>,
617    /// Order in which tool_calls were received (used to emit ToolUseEnd in arrival order
618    /// on Stop).
619    tool_call_order: Vec<i64>,
620}
621
622#[derive(Debug, Clone)]
623struct ToolCallState {
624    id: String,
625    /// Whether the `ToolUseEnd` has already been sent.
626    closed: bool,
627}
628
629/// SSE stream → `ProviderChunk` stream. The return value implements [`Stream`]; dropping
630/// it cancels the stream.
631///
632/// After `cancel` is triggered, the stream silently terminates, consistent with the LLM
633/// trait contract.
634pub fn decode_stream(
635    sse: SseEventStream,
636    cancel: CancellationToken,
637) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send {
638    decode_stream_with_usage_parser(sse, cancel, usage_from_wire)
639}
640
641/// Same shape as [`decode_stream`], but generic over the input `Stream` type for easier
642/// testing — feed it directly with `futures::stream::iter` without going through the toac
643/// transport.
644pub fn decode_stream_generic<S, E>(
645    sse: S,
646    cancel: CancellationToken,
647) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send
648where
649    S: Stream<Item = Result<Sse, E>> + Send + 'static,
650    E: std::error::Error + Send + Sync + 'static,
651{
652    decode_stream_generic_with_usage_parser(sse, cancel, usage_from_wire)
653}
654
655/// Same shape as [`decode_stream`], but allows vendor-specific overrides of the usage
656/// parsing logic.
657pub(crate) fn decode_stream_with_usage_parser(
658    sse: SseEventStream,
659    cancel: CancellationToken,
660    usage_parser: UsageParser,
661) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send {
662    decode_stream_generic_with_usage_parser(sse, cancel, usage_parser)
663}
664
665fn decode_stream_generic_with_usage_parser<S, E>(
666    sse: S,
667    cancel: CancellationToken,
668    usage_parser: UsageParser,
669) -> impl Stream<Item = Result<ProviderChunk, ProviderError>> + Send
670where
671    S: Stream<Item = Result<Sse, E>> + Send + 'static,
672    E: std::error::Error + Send + Sync + 'static,
673{
674    OpenAiSseDecoder {
675        inner: sse,
676        cancel,
677        state: DecoderState::default(),
678        pending: Vec::new(),
679        finished: false,
680        usage_parser,
681        _err: std::marker::PhantomData::<E>,
682    }
683}
684
685struct OpenAiSseDecoder<S, E> {
686    inner: S,
687    cancel: CancellationToken,
688    state: DecoderState,
689    /// A single SSE frame may produce multiple chunks (a finish_reason frame is usually
690    /// followed by ToolUseEnd*N + Stop). Store them in `pending` first, then `poll_next`
691    /// pops them one by one.
692    pending: Vec<Result<ProviderChunk, ProviderError>>,
693    finished: bool,
694    usage_parser: UsageParser,
695    _err: std::marker::PhantomData<E>,
696}
697
698impl<S, E> Stream for OpenAiSseDecoder<S, E>
699where
700    S: Stream<Item = Result<Sse, E>>,
701    E: std::error::Error + Send + Sync + 'static,
702{
703    type Item = Result<ProviderChunk, ProviderError>;
704
705    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
706        // SAFETY: Standard pin-projection through a single field. We never move `inner`
707        // out, and `_err` is a zero-sized `PhantomData`.
708        let this = unsafe { self.get_unchecked_mut() };
709        loop {
710            if let Some(item) = this.pending.pop() {
711                return Poll::Ready(Some(item));
712            }
713            if this.finished {
714                return Poll::Ready(None);
715            }
716            if this.cancel.is_cancelled() {
717                this.finished = true;
718                return Poll::Ready(None);
719            }
720
721            // SAFETY: pin-projection through a single field.
722            let inner = unsafe { Pin::new_unchecked(&mut this.inner) };
723            match inner.poll_next(cx) {
724                Poll::Pending => return Poll::Pending,
725                Poll::Ready(None) => {
726                    this.finished = true;
727                    // Neither [DONE] nor stop chunk — ProtocolViolation.
728                    if !this.state.done
729                        && !this.state.stopped
730                        && this.state.started
731                        && !this.state.fatal
732                    {
733                        return Poll::Ready(Some(Err(ProviderError::new(
734                            ProviderErrorKind::ProtocolViolation {
735                                hint: "stream ended without finish_reason or [DONE]".into(),
736                            },
737                        ))));
738                    }
739                    return Poll::Ready(None);
740                }
741                Poll::Ready(Some(Err(e))) => {
742                    this.finished = true;
743                    return Poll::Ready(Some(Err(ProviderError::new(
744                        ProviderErrorKind::Transport(BoxError::new(e)),
745                    ))));
746                }
747                Poll::Ready(Some(Ok(sse))) => {
748                    process_sse(&mut this.state, sse, &mut this.pending, this.usage_parser);
749                    if this.state.done || this.state.fatal {
750                        this.finished = true;
751                    }
752                }
753            }
754        }
755    }
756}
757
758fn process_sse(
759    state: &mut DecoderState,
760    sse: Sse,
761    out: &mut Vec<Result<ProviderChunk, ProviderError>>,
762    usage_parser: UsageParser,
763) {
764    let data = match sse.data {
765        Some(d) => d,
766        None => return,
767    };
768    let trimmed = data.trim();
769    // OpenAI stream terminator. Drop all subsequent data frames after receiving it (there
770    // won't be any in practice).
771    if trimmed == "[DONE]" {
772        state.done = true;
773        return;
774    }
775
776    // First parse as a raw `Value` to extract DeepSeek's proprietary
777    // `delta.reasoning_content` — the wire OAS lacks this field, so it would be lost
778    // after structured parsing.
779    let raw: serde_json::Value = match serde_json::from_str(trimmed) {
780        Ok(v) => v,
781        Err(e) => {
782            out.push(Err(ProviderError::new(ProviderErrorKind::Malformed(
783                BoxError::new(e),
784            ))));
785            return;
786        }
787    };
788
789    let parsed: Result<wire::CreateChatCompletionStreamResponse, _> =
790        serde_json::from_value(raw.clone());
791    let evt = match parsed {
792        Ok(e) => e,
793        Err(e) => {
794            out.push(Err(ProviderError::new(ProviderErrorKind::Malformed(
795                BoxError::new(e),
796            ))));
797            return;
798        }
799    };
800
801    handle_chunk(state, &raw, evt, out, usage_parser);
802}
803
804fn handle_chunk(
805    state: &mut DecoderState,
806    raw: &serde_json::Value,
807    evt: wire::CreateChatCompletionStreamResponse,
808    out: &mut Vec<Result<ProviderChunk, ProviderError>>,
809    usage_parser: UsageParser,
810) {
811    // `poll_next` uses `pop()`, so to emit in chronological order we must push in
812    // reverse.
813    let mut buf: Vec<Result<ProviderChunk, ProviderError>> = Vec::new();
814
815    // The first chunk seen implies a `MessageStart`. Unlike Anthropic, OpenAI does not
816    // have a dedicated `message_start` event; every chunk carries `id`/`model`, so the
817    // first frame is the start.
818    if !state.started {
819        state.started = true;
820        buf.push(Ok(ProviderChunk::MessageStart {
821            id: evt.id.clone(),
822            model: evt.model.clone(),
823        }));
824    }
825
826    // choices are typically length 1 (`n=1`); the final usage chunk is an empty array.
827    for (choice_idx, choice) in evt.choices.iter().enumerate() {
828        // Extract the raw delta for `reasoning_content`.
829        let raw_delta = raw
830            .get("choices")
831            .and_then(|v| v.as_array())
832            .and_then(|a| a.get(choice_idx))
833            .and_then(|c| c.get("delta"));
834
835        let delta = &choice.delta;
836
837        // DeepSeek `reasoning_content` is not present in the wire OAS, so it is taken
838        // from the raw delta.
839        if let Some(rc) = raw_delta
840            .and_then(|d| d.get("reasoning_content"))
841            .and_then(|v| v.as_str())
842            && !rc.is_empty()
843        {
844            buf.push(Ok(ProviderChunk::ThinkingDelta {
845                text: rc.to_owned(),
846            }));
847        }
848
849        // Text delta.
850        if let Some(
851            wire::ChatCompletionStreamResponseDeltaContent::ChatCompletionStreamResponseDeltaContentVariant0(
852                s,
853            ),
854        ) = &delta.content
855            && !s.is_empty()
856        {
857            buf.push(Ok(ProviderChunk::TextDelta { text: s.clone() }));
858        }
859
860        // For each tool call index, the first chunk containing `id` and `name` triggers a
861        // `ToolUseStart`; subsequent chunks with `arguments` become `ToolUseArgsDelta`.
862        if let Some(calls) = &delta.tool_calls {
863            for tc in calls {
864                handle_tool_call_chunk(state, tc, &mut buf);
865            }
866        }
867
868        // OpenAI uses `delta.refusal` to signal a safety refusal. We treat it as a
869        // `TextDelta` (with a distinguishable prefix), and later, when
870        // `finish_reason=content_filter`, we propagate the semantics upward via `Stop`.
871        if let Some(
872            wire::ChatCompletionStreamResponseDeltaRefusal::ChatCompletionStreamResponseDeltaRefusalVariant0(
873                s,
874            ),
875        ) = &delta.refusal
876            && !s.is_empty()
877        {
878            buf.push(Ok(ProviderChunk::TextDelta { text: s.clone() }));
879        }
880
881        // `finish_reason` is required in the OAS (no `Option`); most chunks in the stream
882        // are `Stop` (a "non-stop" reason). However, in OpenAI's actual wire format,
883        // non-terminal chunks have `finish_reason: null`, and what the codegen produces
884        // depends on the OAS. We take a **conservative approach**: only treat a chunk as
885        // terminal when we see any of `tool_calls` / `length` / `content_filter` /
886        // `function_call` and no more data follows; `stop` is also terminal. Simplified
887        // strategy: the last chunk of every non-empty `choices` always carries a terminal
888        // `finish_reason`, so emit immediately upon receipt.
889        //
890        // Note: when the wire schema fails to deserialize `finish_reason: null`, it falls
891        // into the `Malformed` branch above, and the state machine never reaches here.
892        // `finish_reason` is `null` on intermediate chunks (the OAS has been patched to
893        // `Option`); only terminal chunks have a value. When hit, close `tool_calls` and
894        // emit `Stop`.
895        if !state.stopped
896            && let Some(fr) = choice.finish_reason
897        {
898            let order = state.tool_call_order.clone();
899            for idx in order {
900                if let Some(tc) = state.tool_calls.get_mut(&idx)
901                    && !tc.closed
902                {
903                    tc.closed = true;
904                    buf.push(Ok(ProviderChunk::ToolUseEnd { id: tc.id.clone() }));
905                }
906            }
907            state.stopped = true;
908            buf.push(Ok(ProviderChunk::Stop {
909                reason: stop_reason_from_wire(fr),
910            }));
911        }
912    }
913
914    // Final usage chunk: choices are empty, usage is present.
915    if let Some(usage) = &evt.usage {
916        buf.push(Ok(ProviderChunk::Usage(usage_parser(
917            raw.get("usage"),
918            usage,
919        ))));
920    }
921
922    buf.reverse();
923    out.extend(buf);
924}
925
926fn handle_tool_call_chunk(
927    state: &mut DecoderState,
928    tc: &wire::ChatCompletionMessageToolCallChunk,
929    out: &mut Vec<Result<ProviderChunk, ProviderError>>,
930) {
931    let idx = tc.index;
932    let entry_existed = state.tool_calls.contains_key(&idx);
933
934    // First frame: must carry `id` (OpenAI docs specify that the first `tool_calls` chunk
935    // carries the full `id` and `function.name`; subsequent chunks carry only
936    // `arguments`).
937    if !entry_existed {
938        let Some(id) = tc.id.clone() else {
939            // No id and no prior state, so the chunk cannot be associated — treat as a
940            // protocol violation, but it is not fatal because the next frame may carry
941            // the id.
942            warn!(index = idx, "tool_calls chunk missing id on first frame");
943            return;
944        };
945        let name = tc
946            .function
947            .as_ref()
948            .and_then(|f| f.name.clone())
949            .unwrap_or_default();
950        state.tool_calls.insert(
951            idx,
952            ToolCallState {
953                id: id.clone(),
954                closed: false,
955            },
956        );
957        state.tool_call_order.push(idx);
958        out.push(Ok(ProviderChunk::ToolUseStart { id, name }));
959    }
960
961    if let Some(func) = &tc.function
962        && let Some(args) = &func.arguments
963        && !args.is_empty()
964        && let Some(tool) = state.tool_calls.get(&idx)
965    {
966        out.push(Ok(ProviderChunk::ToolUseArgsDelta {
967            id: tool.id.clone(),
968            fragment: args.clone(),
969        }));
970    }
971}
972
973fn stop_reason_from_wire(
974    r: wire::CreateChatCompletionStreamResponseChoicesFinishReason,
975) -> StopReason {
976    use wire::CreateChatCompletionStreamResponseChoicesFinishReason as W;
977    match r {
978        W::Stop => StopReason::EndTurn,
979        W::Length => StopReason::MaxTokens,
980        W::ToolCalls | W::FunctionCall => StopReason::ToolUse,
981        W::ContentFilter => StopReason::Refusal,
982    }
983}
984
985fn usage_from_wire(_raw_usage: Option<&serde_json::Value>, u: &wire::CompletionUsage) -> Usage {
986    Usage {
987        input_tokens: u64::try_from(u.prompt_tokens).ok(),
988        output_tokens: u64::try_from(u.completion_tokens).ok(),
989        cache_read_input_tokens: u
990            .prompt_tokens_details
991            .as_ref()
992            .and_then(|d| d.cached_tokens)
993            .and_then(|v| u64::try_from(v).ok()),
994        // OpenAI does not report cache creation tokens; `cached_tokens` only indicates
995        // the number of input tokens that hit the cache.
996        cache_creation_input_tokens: None,
997    }
998}
999
1000#[cfg(test)]
1001mod tests;