Skip to main content

entelix_core/
stream.rs

1//! `StreamAggregator` — accumulates streaming model deltas into a coherent
2//! [`ModelResponse`].
3//!
4//! Tool-call ordering is preserved: each `ToolUseStart` opens a fresh
5//! tool block, subsequent `ToolUseInputDelta`s append into that block
6//! until `ToolUseStop` closes it.
7//!
8//! ## Variant naming — semantic, not wire-aligned
9//!
10//! Variant names describe the *meaning* of each delta
11//! (`TextDelta`, `ThinkingDelta`, `ToolUseStart`) rather than mirror
12//! one vendor's SSE event names (`content_block_delta`,
13//! `message_start`). Per invariant 5 the IR never returns
14//! vendor-shaped JSON, and the same principle applies to the
15//! streaming surface: codecs translate their wire events into
16//! these variants so consumers writing against `StreamDelta` work
17//! across Anthropic, `OpenAI` (Chat Completions and Responses),
18//! Gemini, and Bedrock without renames. Renaming a variant to match
19//! one provider's wire format would couple the public API to that
20//! vendor's terminology and force a churn whenever they renumber an
21//! event type.
22
23use std::pin::Pin;
24use std::task::{Context, Poll};
25
26use futures::Stream;
27use futures::StreamExt;
28use futures::future::BoxFuture;
29use tokio::sync::oneshot;
30
31use crate::codecs::BoxDeltaStream;
32use crate::error::{Error, Result};
33use crate::ir::{
34    ContentPart, ModelResponse, ModelWarning, ProviderEchoSnapshot, StopReason, Usage,
35};
36use crate::rate_limit::RateLimitSnapshot;
37use crate::service::ModelStream;
38
39/// One chunk from a streaming model response.
40#[derive(Clone, Debug, PartialEq, Eq)]
41#[non_exhaustive]
42pub enum StreamDelta {
43    /// First message — vendor's response id and model identifier.
44    Start {
45        /// Vendor message id (echoed in the final `ModelResponse`).
46        id: String,
47        /// Resolved model identifier.
48        model: String,
49        /// Response-level vendor opaque round-trip tokens — OpenAI
50        /// Responses `Response.id` (so the next request can chain via
51        /// `previous_response_id` from `ModelRequest::continued_from`),
52        /// or anything else the codec wants to carry at response root
53        /// rather than on a single content part. The aggregator
54        /// surfaces these on [`ModelResponse::provider_echoes`] at
55        /// finalize time, mirroring the non-streaming decode path.
56        provider_echoes: Vec<ProviderEchoSnapshot>,
57    },
58    /// Append text to the in-progress text block. Consecutive `TextDelta`s
59    /// fold into a single `ContentPart::Text` in the output.
60    TextDelta {
61        /// Text fragment to append.
62        text: String,
63        /// Vendor opaque round-trip tokens this fragment carries
64        /// (Gemini 3.x attaches `thought_signature` to `text` parts on
65        /// reasoning turns). The aggregator extends the open-text
66        /// block's accumulated echoes — a single `ContentPart::Text`
67        /// finalises with the union of every delta's echoes.
68        provider_echoes: Vec<ProviderEchoSnapshot>,
69    },
70    /// Append text (or vendor opaque tokens) to the in-progress
71    /// thinking block. Consecutive `ThinkingDelta`s fold into a
72    /// single `ContentPart::Thinking` in the output. A delta carrying
73    /// only `provider_echoes` (empty `text`) attaches the round-trip
74    /// marker without growing the body — Anthropic emits the
75    /// signature on a discrete `signature_delta` SSE event with no
76    /// associated text.
77    ThinkingDelta {
78        /// Text fragment to append. Empty when the delta carries
79        /// only a `provider_echoes` update.
80        text: String,
81        /// Vendor opaque round-trip tokens (Anthropic `signature`,
82        /// Gemini `thought_signature`, OpenAI Responses
83        /// `encrypted_content`). Codecs pre-wrap the wire-shape blob
84        /// into [`ProviderEchoSnapshot`] before yielding the delta;
85        /// the aggregator stays codec-agnostic and just accumulates.
86        provider_echoes: Vec<ProviderEchoSnapshot>,
87    },
88    /// Begin a new tool-use block. Closes any open text block so the
89    /// output preserves the model's intended ordering.
90    ToolUseStart {
91        /// Stable tool-use id.
92        id: String,
93        /// Tool name to call.
94        name: String,
95        /// Vendor opaque round-trip tokens attached to this tool call
96        /// (Gemini 3.x `thought_signature` on `functionCall` parts —
97        /// missing on the next turn yields HTTP 400 on the first
98        /// `functionCall` of a step).
99        provider_echoes: Vec<ProviderEchoSnapshot>,
100    },
101    /// Append partial JSON to the open tool-use block's input buffer.
102    ToolUseInputDelta {
103        /// Raw JSON fragment — the aggregator concatenates and parses
104        /// once the block closes.
105        partial_json: String,
106    },
107    /// Close the current tool-use block. Returns `Err` if the buffered
108    /// JSON does not parse.
109    ToolUseStop,
110    /// Token usage update (last value wins).
111    Usage(Usage),
112    /// Provider rate-limit snapshot, typically emitted as the leading
113    /// chunk by `ChatModel::stream_deltas` before the first content
114    /// delta. Last value wins inside an aggregator.
115    RateLimit(RateLimitSnapshot),
116    /// Provider warning surfaced inline.
117    Warning(ModelWarning),
118    /// End of stream with stop reason.
119    Stop {
120        /// Reason the model halted.
121        stop_reason: StopReason,
122    },
123}
124
125/// Per-tool-block scratch space.
126struct PendingTool {
127    id: String,
128    name: String,
129    input_buffer: String,
130    provider_echoes: Vec<ProviderEchoSnapshot>,
131}
132
133/// Per-thinking-block scratch space.
134#[derive(Default)]
135struct PendingThinking {
136    text: String,
137    provider_echoes: Vec<ProviderEchoSnapshot>,
138}
139
140/// Per-text-block scratch space.
141#[derive(Default)]
142struct PendingText {
143    text: String,
144    provider_echoes: Vec<ProviderEchoSnapshot>,
145}
146
147/// Accumulator that turns a sequence of `StreamDelta`s into a
148/// `ModelResponse`.
149///
150/// Typical usage:
151/// ```ignore
152/// let mut agg = StreamAggregator::new();
153/// while let Some(delta) = stream.next().await {
154///     agg.push(delta?)?;
155/// }
156/// let response = agg.finalize()?;
157/// ```
158#[derive(Default)]
159pub struct StreamAggregator {
160    id: String,
161    model: String,
162    parts: Vec<ContentPart>,
163    /// Buffer for the currently-open text block. `None` when the next
164    /// `TextDelta` should start a fresh block.
165    open_text: Option<PendingText>,
166    /// Buffer for the currently-open thinking block. `None` when the
167    /// next `ThinkingDelta` should start a fresh block. The text and
168    /// tool buffers are mutually exclusive with the thinking buffer:
169    /// any non-thinking delta closes an open thinking block first
170    /// (intra-turn order is preserved).
171    open_thinking: Option<PendingThinking>,
172    pending_tool: Option<PendingTool>,
173    usage: Option<Usage>,
174    rate_limit: Option<RateLimitSnapshot>,
175    stop_reason: Option<StopReason>,
176    warnings: Vec<ModelWarning>,
177    /// Response-level vendor opaque round-trip tokens captured from
178    /// the streaming `Start` delta. Surfaced on
179    /// [`ModelResponse::provider_echoes`] at finalize so streaming
180    /// and non-streaming decode produce equivalent IR.
181    response_echoes: Vec<ProviderEchoSnapshot>,
182}
183
184impl StreamAggregator {
185    /// Empty aggregator.
186    pub fn new() -> Self {
187        Self::default()
188    }
189
190    /// Apply one delta. Returns `Err` on protocol violations
191    /// (`ToolUseInputDelta` outside a tool block, malformed JSON in
192    /// `ToolUseStop`, double `ToolUseStart`).
193    pub fn push(&mut self, delta: StreamDelta) -> Result<()> {
194        match delta {
195            StreamDelta::Start {
196                id,
197                model,
198                provider_echoes,
199            } => {
200                if !self.id.is_empty() || !self.model.is_empty() {
201                    return Err(Error::invalid_request(
202                        "StreamAggregator: duplicate Start delta",
203                    ));
204                }
205                self.id = id;
206                self.model = model;
207                self.response_echoes.extend(provider_echoes);
208            }
209            StreamDelta::TextDelta {
210                text,
211                provider_echoes,
212            } => {
213                if self.pending_tool.is_some() {
214                    return Err(Error::invalid_request(
215                        "StreamAggregator: TextDelta during open tool_use block",
216                    ));
217                }
218                self.flush_thinking();
219                let pending = self.open_text.get_or_insert_with(PendingText::default);
220                pending.text.push_str(&text);
221                pending.provider_echoes.extend(provider_echoes);
222            }
223            StreamDelta::ThinkingDelta {
224                text,
225                provider_echoes,
226            } => {
227                if self.pending_tool.is_some() {
228                    return Err(Error::invalid_request(
229                        "StreamAggregator: ThinkingDelta during open tool_use block",
230                    ));
231                }
232                self.flush_text();
233                let pending = self
234                    .open_thinking
235                    .get_or_insert_with(PendingThinking::default);
236                pending.text.push_str(&text);
237                pending.provider_echoes.extend(provider_echoes);
238            }
239            StreamDelta::ToolUseStart {
240                id,
241                name,
242                provider_echoes,
243            } => {
244                if self.pending_tool.is_some() {
245                    return Err(Error::invalid_request(
246                        "StreamAggregator: ToolUseStart while another tool block is open",
247                    ));
248                }
249                self.flush_text();
250                self.flush_thinking();
251                self.pending_tool = Some(PendingTool {
252                    id,
253                    name,
254                    input_buffer: String::new(),
255                    provider_echoes,
256                });
257            }
258            StreamDelta::ToolUseInputDelta { partial_json } => {
259                let pending = self.pending_tool.as_mut().ok_or_else(|| {
260                    Error::invalid_request(
261                        "StreamAggregator: ToolUseInputDelta with no open tool block",
262                    )
263                })?;
264                pending.input_buffer.push_str(&partial_json);
265            }
266            StreamDelta::ToolUseStop => self.close_tool_block()?,
267            StreamDelta::Usage(u) => self.usage = Some(u),
268            StreamDelta::RateLimit(r) => self.rate_limit = Some(r),
269            StreamDelta::Warning(w) => self.warnings.push(w),
270            StreamDelta::Stop { stop_reason } => {
271                // Refuse to overwrite a stop reason. Some providers
272                // misbehave and ship a follow-up Stop delta after a
273                // valid terminal one (rare, but real); silently
274                // accepting the second value would change the
275                // observed termination cause from `EndTurn` to
276                // `MaxTokens` — a meaningful semantic flip that
277                // operators would never see. Fail closed instead.
278                if self.stop_reason.is_some() {
279                    return Err(Error::invalid_request(
280                        "StreamAggregator: duplicate Stop delta — terminal state already set",
281                    ));
282                }
283                self.stop_reason = Some(stop_reason);
284            }
285        }
286        Ok(())
287    }
288
289    /// Convenience: returns true after a `Stop` delta has been pushed.
290    pub const fn is_finished(&self) -> bool {
291        self.stop_reason.is_some()
292    }
293
294    /// Drain into a final `ModelResponse`. Returns `Err` if a tool block
295    /// was left open or no `Stop` delta was seen.
296    pub fn finalize(mut self) -> Result<ModelResponse> {
297        if self.pending_tool.is_some() {
298            return Err(Error::invalid_request(
299                "StreamAggregator: stream ended with an open tool block",
300            ));
301        }
302        let stop_reason = self.stop_reason.take().ok_or_else(|| {
303            Error::invalid_request("StreamAggregator: stream ended without Stop delta")
304        })?;
305        self.flush_text();
306        self.flush_thinking();
307        // A streaming response that closes without ever emitting a
308        // `Usage` delta silently zeros out the cost meter — every
309        // downstream `gen_ai.usage.cost` becomes a phantom $0
310        // charge. Surface a `LossyEncode` warning so operators see
311        // the miss in observability instead of debugging a
312        // suspiciously-cheap month at billing time.
313        if self.usage.is_none() {
314            self.warnings.push(crate::ir::ModelWarning::LossyEncode {
315                field: "usage".to_owned(),
316                detail: "streaming response closed without Usage delta — cost will be zero"
317                    .to_owned(),
318            });
319        }
320        Ok(ModelResponse {
321            id: self.id,
322            model: self.model,
323            stop_reason,
324            content: self.parts,
325            usage: self.usage.unwrap_or_default(),
326            rate_limit: self.rate_limit,
327            warnings: self.warnings,
328            provider_echoes: self.response_echoes,
329        })
330    }
331
332    /// Close an open `tool_use` block — parses the buffered JSON
333    /// arguments and pushes the finalised `ContentPart::ToolUse` (with
334    /// any accumulated `provider_echoes`) onto `parts`. Returns
335    /// `Err(Error::invalid_request)` if there is no open tool block or
336    /// the buffered arguments fail to parse.
337    fn close_tool_block(&mut self) -> Result<()> {
338        let pending = self.pending_tool.take().ok_or_else(|| {
339            Error::invalid_request("StreamAggregator: ToolUseStop with no open tool block")
340        })?;
341        let input: serde_json::Value = if pending.input_buffer.is_empty() {
342            serde_json::json!({})
343        } else {
344            // Surface the tool name + id and the buffered payload so
345            // operators can see which tool's arguments arrived
346            // malformed. The bare serde_json::Error message is opaque
347            // ("expected value at line 1 column 7"); without context,
348            // a multi-tool agent run leaves the operator hunting
349            // through logs.
350            serde_json::from_str(&pending.input_buffer).map_err(|e| {
351                Error::invalid_request(format!(
352                    "StreamAggregator: ToolUse '{}' (id={}) arguments are not valid JSON: \
353                     {e}; buffered={:?}",
354                    pending.name,
355                    pending.id,
356                    truncate_for_diagnostic(&pending.input_buffer),
357                ))
358            })?
359        };
360        self.parts.push(ContentPart::ToolUse {
361            id: pending.id,
362            name: pending.name,
363            input,
364            provider_echoes: pending.provider_echoes,
365        });
366        Ok(())
367    }
368
369    /// Close the open text buffer, if any, into a `ContentPart::Text`.
370    fn flush_text(&mut self) {
371        if let Some(pending) = self.open_text.take()
372            && !(pending.text.is_empty() && pending.provider_echoes.is_empty())
373        {
374            self.parts.push(ContentPart::Text {
375                text: pending.text,
376                cache_control: None,
377                provider_echoes: pending.provider_echoes,
378            });
379        }
380    }
381
382    /// Close the open thinking buffer, if any, into a
383    /// `ContentPart::Thinking`.
384    fn flush_thinking(&mut self) {
385        if let Some(pending) = self.open_thinking.take()
386            && !(pending.text.is_empty() && pending.provider_echoes.is_empty())
387        {
388            self.parts.push(ContentPart::Thinking {
389                text: pending.text,
390                cache_control: None,
391                provider_echoes: pending.provider_echoes,
392            });
393        }
394    }
395}
396
397/// Cap a malformed-JSON payload before it rides into an error
398/// message. A streaming tool-use arguments buffer can be arbitrarily
399/// large under provider misbehavior; including the full buffer
400/// inflates structured logs and pollutes traces. 256 bytes is enough
401/// for an operator to see the rough shape and cheap to keep.
402/// Wrap a raw `BoxDeltaStream` in a [`ModelStream`] whose
403/// [`ModelStream::completion`] future resolves to the aggregated
404/// [`ModelResponse`] after the consumer drains the stream.
405///
406/// The aggregator runs as a stateful side-effect inside the
407/// returned stream — each delta the consumer reads is also pushed
408/// into a local `StreamAggregator`. When the consumer reads the
409/// terminal `Stop` (or the inner stream ends without one), the
410/// aggregator finalises and the `completion` future resolves. If
411/// the consumer drops the stream early, the aggregator is dropped
412/// without finalising and `completion` resolves to
413/// `Err(Error::Cancelled)` so observability layers gating on
414/// `completion.await.is_ok()` (cost emission) do not fire on
415/// abandoned streams (invariant 12).
416///
417/// Mid-stream `Err` propagates twofold: the consumer sees the
418/// `Err` on the next `next().await`, and `completion` resolves to
419/// the same error so wrapping layers see the failure path on the
420/// post-stream branch.
421pub fn tap_aggregator(inner: BoxDeltaStream<'static>) -> ModelStream {
422    let (tx, rx) = oneshot::channel::<Result<ModelResponse>>();
423    let tap = AggregatorTap {
424        inner,
425        agg: StreamAggregator::new(),
426        completion: Some(tx),
427        terminated: false,
428    };
429    ModelStream {
430        stream: Box::pin(tap),
431        completion: Box::pin(async move {
432            match rx.await {
433                Ok(result) => result,
434                // Sender dropped before sending — the wrapping
435                // stream was abandoned without reaching terminal
436                // Stop. Surface as Cancelled so layers gate on Ok.
437                Err(_) => Err(Error::Cancelled),
438            }
439        }) as BoxFuture<'static, Result<ModelResponse>>,
440    }
441}
442
443/// `Stream<Item = Result<StreamDelta>>` wrapper that taps each
444/// delta into a `StreamAggregator`. On terminal `Stop` (or stream
445/// EOF, or mid-stream `Err`), it sends the aggregator's final
446/// state through a `oneshot::Sender` so the paired
447/// [`ModelStream::completion`] future resolves with the
448/// aggregated response or the propagated error.
449struct AggregatorTap {
450    inner: BoxDeltaStream<'static>,
451    agg: StreamAggregator,
452    completion: Option<oneshot::Sender<Result<ModelResponse>>>,
453    terminated: bool,
454}
455
456impl AggregatorTap {
457    /// Send the aggregator's terminal state through the completion
458    /// channel. Idempotent — subsequent calls are no-ops, so a
459    /// stream that finalises on `Stop` and is then dropped does
460    /// not double-send.
461    fn finalize(&mut self, outcome: Result<ModelResponse>) {
462        if let Some(tx) = self.completion.take() {
463            // Receiver may have been dropped (operator abandoned
464            // `completion` future before consuming the stream); the
465            // send error is not actionable on this side.
466            let _ = tx.send(outcome);
467        }
468    }
469}
470
471impl Stream for AggregatorTap {
472    type Item = Result<StreamDelta>;
473
474    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
475        if self.terminated {
476            return Poll::Ready(None);
477        }
478        match self.inner.poll_next_unpin(cx) {
479            Poll::Pending => Poll::Pending,
480            Poll::Ready(None) => {
481                // Inner stream ended without terminal Stop —
482                // finalise produces `Err` (the aggregator's own
483                // protocol-violation message). `completion`
484                // resolves Err so wrapping layers see failure.
485                let agg = std::mem::take(&mut self.agg);
486                let outcome = agg.finalize();
487                self.finalize(outcome);
488                self.terminated = true;
489                Poll::Ready(None)
490            }
491            Poll::Ready(Some(Err(e))) => {
492                // Mid-stream error — clone the error for the
493                // completion channel (consumer sees the original
494                // Err on this branch).
495                let cloned = clone_error(&e);
496                self.finalize(Err(cloned));
497                self.terminated = true;
498                Poll::Ready(Some(Err(e)))
499            }
500            Poll::Ready(Some(Ok(delta))) => {
501                let is_stop = matches!(delta, StreamDelta::Stop { .. });
502                if let Err(e) = self.agg.push(delta.clone()) {
503                    // Aggregator rejected a protocol violation —
504                    // surface to the consumer (so they see why)
505                    // *and* through completion (so layers see
506                    // the failure path).
507                    let cloned = clone_error(&e);
508                    self.finalize(Err(cloned));
509                    self.terminated = true;
510                    return Poll::Ready(Some(Err(e)));
511                }
512                if is_stop {
513                    // Terminal Stop — finalise immediately so the
514                    // completion future resolves before the
515                    // consumer's next `.next()` call. Any further
516                    // poll returns `None`.
517                    let agg = std::mem::take(&mut self.agg);
518                    let outcome = agg.finalize();
519                    self.finalize(outcome);
520                    self.terminated = true;
521                }
522                Poll::Ready(Some(Ok(delta)))
523            }
524        }
525    }
526}
527
528impl Drop for AggregatorTap {
529    fn drop(&mut self) {
530        // Stream dropped without terminal Stop — completion
531        // resolves Err(Cancelled) so cost-emit layers gating on
532        // Ok branch do not fire on abandoned streams.
533        if self.completion.is_some() {
534            self.finalize(Err(Error::Cancelled));
535        }
536    }
537}
538
539/// Best-effort clone of an `Error` for the `completion` channel.
540/// `Error` is not `Clone` because `serde_json::Error` and the
541/// `Auth` variant carry non-Clone payloads, but the streaming-tap
542/// path needs to forward both the consumer-side `Err` and the
543/// completion-future `Err`. The reconstruction preserves the
544/// variant + message for observability purposes; the source
545/// chain on the consumer side stays intact (the original `Err` is
546/// what the consumer receives).
547fn clone_error(e: &Error) -> Error {
548    use crate::error::ProviderErrorKind;
549    match e {
550        Error::InvalidRequest(msg) => Error::invalid_request(msg.clone()),
551        Error::Config(msg) => Error::config(msg.clone()),
552        Error::Provider {
553            kind,
554            message,
555            retry_after,
556            ..
557        } => {
558            let cloned = match kind {
559                ProviderErrorKind::Network => Error::provider_network(message.clone()),
560                ProviderErrorKind::Tls => Error::provider_tls(message.clone()),
561                ProviderErrorKind::Dns => Error::provider_dns(message.clone()),
562                ProviderErrorKind::Http(status) => Error::provider_http(*status, message.clone()),
563            };
564            match retry_after {
565                Some(after) => cloned.with_retry_after(*after),
566                None => cloned,
567            }
568        }
569        Error::Auth(_) => Error::config("authentication failed (cloned for stream completion)"),
570        Error::Cancelled => Error::Cancelled,
571        Error::DeadlineExceeded => Error::DeadlineExceeded,
572        Error::Interrupted { kind, payload } => Error::Interrupted {
573            kind: kind.clone(),
574            payload: payload.clone(),
575        },
576        Error::Serde(_) => {
577            Error::invalid_request("output serialisation failed (cloned for stream completion)")
578        }
579        Error::UsageLimitExceeded(breach) => Error::UsageLimitExceeded(breach.clone()),
580        Error::ModelRetry { hint, attempt } => Error::ModelRetry {
581            hint: hint.clone(),
582            attempt: *attempt,
583        },
584    }
585}
586
587fn truncate_for_diagnostic(s: &str) -> String {
588    const BUDGET: usize = 256;
589    if s.len() <= BUDGET {
590        return s.to_owned();
591    }
592    let mut cut = BUDGET;
593    while cut > 0 && !s.is_char_boundary(cut) {
594        cut -= 1;
595    }
596    format!("{}…", &s[..cut])
597}