Skip to main content

claude_api/messages/
stream.rs

1//! Streaming event types and reconstruction.
2//!
3//! Both [`StreamEvent`] and [`ContentDelta`] are forward-compatible: unknown
4//! `type` tags deserialize into the `Other` arm with the raw JSON preserved
5//! byte-for-byte. Strict-on-known semantics: a known tag with a malformed
6//! body returns a deserialization error rather than silently falling through.
7//!
8//! [`EventStream`] is the typed wrapper around the SSE wire format; call
9//! [`EventStream::aggregate`] to reconstruct a [`Message`] from a full
10//! `message_start → ... → message_stop` sequence.
11
12use serde::{Deserialize, Serialize};
13
14use crate::error::ApiErrorPayload;
15use crate::forward_compat::dispatch_known_or_other;
16use crate::messages::content::ContentBlock;
17use crate::messages::response::Message;
18use crate::types::{StopReason, Usage};
19
20#[cfg(feature = "streaming")]
21use crate::error::{Error, Result, StreamError};
22#[cfg(feature = "streaming")]
23use crate::messages::content::KnownBlock;
24
25/// A single event from the Messages streaming endpoint.
26///
27/// Forward-compatible wrapper around [`KnownStreamEvent`]; unknown event types
28/// land in [`StreamEvent::Other`] preserving the raw JSON.
29//
30// Suppress `large_enum_variant`: boxing Known would break pattern-match
31// ergonomics. Worth revisiting in a v1.0 release that's free to break the
32// stream-event API.
33#[allow(clippy::large_enum_variant)]
34#[derive(Debug, Clone, PartialEq)]
35pub enum StreamEvent {
36    /// An event whose `type` is recognized by this SDK version.
37    Known(KnownStreamEvent),
38    /// An event whose `type` is not recognized; the raw JSON is preserved.
39    Other(serde_json::Value),
40}
41
42/// All streaming event types known to this SDK version.
43#[allow(clippy::large_enum_variant)]
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45#[serde(tag = "type", rename_all = "snake_case")]
46#[non_exhaustive]
47pub enum KnownStreamEvent {
48    /// Begins a new streamed message; carries the empty [`Message`] shell
49    /// that subsequent events will fill in.
50    MessageStart {
51        /// The opening message snapshot.
52        message: Message,
53    },
54    /// Begins a new content block within the message.
55    ContentBlockStart {
56        /// Index of the block within the message's content array.
57        index: u32,
58        /// Initial state of the block.
59        content_block: ContentBlock,
60    },
61    /// Incremental update to a content block.
62    ContentBlockDelta {
63        /// Index of the block being updated.
64        index: u32,
65        /// The delta payload.
66        delta: ContentDelta,
67    },
68    /// Marks a content block as complete.
69    ContentBlockStop {
70        /// Index of the block that finished.
71        index: u32,
72    },
73    /// Late-arriving updates to message-level fields, plus final usage.
74    MessageDelta {
75        /// Updated message-level fields.
76        delta: MessageDelta,
77        /// Cumulative usage at the point this delta was emitted.
78        usage: Usage,
79    },
80    /// Final event in a successful stream.
81    MessageStop,
82    /// Keep-alive ping; no payload.
83    Ping,
84    /// Server reported a fatal error mid-stream.
85    Error {
86        /// The error payload.
87        error: ApiErrorPayload,
88    },
89}
90
91/// `type` tags this SDK recognizes for streaming events.
92const KNOWN_EVENT_TAGS: &[&str] = &[
93    "message_start",
94    "content_block_start",
95    "content_block_delta",
96    "content_block_stop",
97    "message_delta",
98    "message_stop",
99    "ping",
100    "error",
101];
102
103impl Serialize for StreamEvent {
104    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
105        match self {
106            StreamEvent::Known(k) => k.serialize(s),
107            StreamEvent::Other(v) => v.serialize(s),
108        }
109    }
110}
111
112impl<'de> Deserialize<'de> for StreamEvent {
113    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
114        let raw = serde_json::Value::deserialize(d)?;
115        dispatch_known_or_other(
116            raw,
117            KNOWN_EVENT_TAGS,
118            StreamEvent::Known,
119            StreamEvent::Other,
120        )
121        .map_err(serde::de::Error::custom)
122    }
123}
124
125impl From<KnownStreamEvent> for StreamEvent {
126    fn from(k: KnownStreamEvent) -> Self {
127        StreamEvent::Known(k)
128    }
129}
130
131impl StreamEvent {
132    /// If this is a known event, return the inner [`KnownStreamEvent`].
133    pub fn known(&self) -> Option<&KnownStreamEvent> {
134        match self {
135            Self::Known(k) => Some(k),
136            Self::Other(_) => None,
137        }
138    }
139
140    /// If this is an unknown event, return the raw JSON.
141    pub fn other(&self) -> Option<&serde_json::Value> {
142        match self {
143            Self::Other(v) => Some(v),
144            Self::Known(_) => None,
145        }
146    }
147
148    /// Wire-level `type` tag for this event regardless of variant.
149    pub fn type_tag(&self) -> Option<&str> {
150        match self {
151            Self::Known(k) => Some(known_event_tag(k)),
152            Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
153        }
154    }
155}
156
157fn known_event_tag(k: &KnownStreamEvent) -> &'static str {
158    match k {
159        KnownStreamEvent::MessageStart { .. } => "message_start",
160        KnownStreamEvent::ContentBlockStart { .. } => "content_block_start",
161        KnownStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
162        KnownStreamEvent::ContentBlockStop { .. } => "content_block_stop",
163        KnownStreamEvent::MessageDelta { .. } => "message_delta",
164        KnownStreamEvent::MessageStop => "message_stop",
165        KnownStreamEvent::Ping => "ping",
166        KnownStreamEvent::Error { .. } => "error",
167    }
168}
169
170/// Late-arriving updates to message-level fields, emitted in
171/// [`KnownStreamEvent::MessageDelta`].
172#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
173#[non_exhaustive]
174pub struct MessageDelta {
175    /// Why the model stopped (if known at this point).
176    #[serde(default, skip_serializing_if = "Option::is_none")]
177    pub stop_reason: Option<StopReason>,
178    /// Stop sequence that triggered termination, if any.
179    #[serde(default, skip_serializing_if = "Option::is_none")]
180    pub stop_sequence: Option<String>,
181}
182
183/// One delta update inside a [`KnownStreamEvent::ContentBlockDelta`].
184///
185/// Forward-compatible wrapper around [`KnownContentDelta`].
186#[derive(Debug, Clone, PartialEq)]
187pub enum ContentDelta {
188    /// A delta whose `type` is recognized by this SDK version.
189    Known(KnownContentDelta),
190    /// A delta whose `type` is not recognized; the raw JSON is preserved.
191    Other(serde_json::Value),
192}
193
194/// All content-delta variants known to this SDK version.
195#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
196#[serde(tag = "type", rename_all = "snake_case")]
197#[non_exhaustive]
198pub enum KnownContentDelta {
199    /// Append text to a `text` block.
200    TextDelta {
201        /// Additional text.
202        text: String,
203    },
204    /// Append a partial-JSON fragment to a `tool_use`'s `input`.
205    InputJsonDelta {
206        /// Partial JSON fragment.
207        partial_json: String,
208    },
209    /// Append text to a `thinking` block.
210    ThinkingDelta {
211        /// Additional thinking text.
212        thinking: String,
213    },
214    /// Update the `signature` of a `thinking` block.
215    SignatureDelta {
216        /// Updated signature.
217        signature: String,
218    },
219    /// Append a citation to a `text` block.
220    CitationsDelta {
221        /// The citation payload (typed enum with forward-compat fallback).
222        citation: crate::messages::citation::Citation,
223    },
224}
225
226const KNOWN_DELTA_TAGS: &[&str] = &[
227    "text_delta",
228    "input_json_delta",
229    "thinking_delta",
230    "signature_delta",
231    "citations_delta",
232];
233
234impl Serialize for ContentDelta {
235    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
236        match self {
237            ContentDelta::Known(k) => k.serialize(s),
238            ContentDelta::Other(v) => v.serialize(s),
239        }
240    }
241}
242
243impl<'de> Deserialize<'de> for ContentDelta {
244    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
245        let raw = serde_json::Value::deserialize(d)?;
246        dispatch_known_or_other(
247            raw,
248            KNOWN_DELTA_TAGS,
249            ContentDelta::Known,
250            ContentDelta::Other,
251        )
252        .map_err(serde::de::Error::custom)
253    }
254}
255
256impl From<KnownContentDelta> for ContentDelta {
257    fn from(k: KnownContentDelta) -> Self {
258        ContentDelta::Known(k)
259    }
260}
261
262impl ContentDelta {
263    /// If this is a known delta, return the inner [`KnownContentDelta`].
264    pub fn known(&self) -> Option<&KnownContentDelta> {
265        match self {
266            Self::Known(k) => Some(k),
267            Self::Other(_) => None,
268        }
269    }
270
271    /// If this is an unknown delta, return the raw JSON.
272    pub fn other(&self) -> Option<&serde_json::Value> {
273        match self {
274            Self::Other(v) => Some(v),
275            Self::Known(_) => None,
276        }
277    }
278
279    /// Wire-level `type` tag for this delta regardless of variant.
280    pub fn type_tag(&self) -> Option<&str> {
281        match self {
282            Self::Known(k) => Some(match k {
283                KnownContentDelta::TextDelta { .. } => "text_delta",
284                KnownContentDelta::InputJsonDelta { .. } => "input_json_delta",
285                KnownContentDelta::ThinkingDelta { .. } => "thinking_delta",
286                KnownContentDelta::SignatureDelta { .. } => "signature_delta",
287                KnownContentDelta::CitationsDelta { .. } => "citations_delta",
288            }),
289            Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::error::ApiErrorKind;
298    use pretty_assertions::assert_eq;
299    use serde_json::json;
300
301    fn round_trip_event(event: &StreamEvent, expected: &serde_json::Value) {
302        let v = serde_json::to_value(event).expect("serialize");
303        assert_eq!(&v, expected, "wire form mismatch");
304        let parsed: StreamEvent = serde_json::from_value(v).expect("deserialize");
305        assert_eq!(&parsed, event, "round-trip mismatch");
306    }
307
308    fn round_trip_delta(delta: &ContentDelta, expected: &serde_json::Value) {
309        let v = serde_json::to_value(delta).expect("serialize");
310        assert_eq!(&v, expected, "wire form mismatch");
311        let parsed: ContentDelta = serde_json::from_value(v).expect("deserialize");
312        assert_eq!(&parsed, delta, "round-trip mismatch");
313    }
314
315    // ---- StreamEvent variants ----
316
317    #[test]
318    fn message_stop_round_trips() {
319        round_trip_event(
320            &StreamEvent::Known(KnownStreamEvent::MessageStop),
321            &json!({"type": "message_stop"}),
322        );
323    }
324
325    #[test]
326    fn ping_round_trips() {
327        round_trip_event(
328            &StreamEvent::Known(KnownStreamEvent::Ping),
329            &json!({"type": "ping"}),
330        );
331    }
332
333    #[test]
334    fn content_block_start_round_trips() {
335        let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
336            index: 0,
337            content_block: ContentBlock::text(""),
338        });
339        round_trip_event(
340            &ev,
341            &json!({
342                "type": "content_block_start",
343                "index": 0,
344                "content_block": {"type": "text", "text": ""}
345            }),
346        );
347    }
348
349    #[test]
350    fn content_block_delta_round_trips() {
351        let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
352            index: 1,
353            delta: ContentDelta::Known(KnownContentDelta::TextDelta {
354                text: "Hello".into(),
355            }),
356        });
357        round_trip_event(
358            &ev,
359            &json!({
360                "type": "content_block_delta",
361                "index": 1,
362                "delta": {"type": "text_delta", "text": "Hello"}
363            }),
364        );
365    }
366
367    #[test]
368    fn content_block_stop_round_trips() {
369        let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStop { index: 2 });
370        round_trip_event(&ev, &json!({"type": "content_block_stop", "index": 2}));
371    }
372
373    #[test]
374    fn message_delta_round_trips() {
375        let ev = StreamEvent::Known(KnownStreamEvent::MessageDelta {
376            delta: MessageDelta {
377                stop_reason: Some(StopReason::EndTurn),
378                stop_sequence: None,
379            },
380            usage: Usage {
381                input_tokens: 5,
382                output_tokens: 10,
383                ..Usage::default()
384            },
385        });
386        round_trip_event(
387            &ev,
388            &json!({
389                "type": "message_delta",
390                "delta": {"stop_reason": "end_turn"},
391                "usage": {"input_tokens": 5, "output_tokens": 10}
392            }),
393        );
394    }
395
396    #[test]
397    fn error_event_round_trips() {
398        let ev = StreamEvent::Known(KnownStreamEvent::Error {
399            error: ApiErrorPayload {
400                kind: ApiErrorKind::OverloadedError,
401                message: "try again".into(),
402            },
403        });
404        round_trip_event(
405            &ev,
406            &json!({
407                "type": "error",
408                "error": {"type": "overloaded_error", "message": "try again"}
409            }),
410        );
411    }
412
413    // ---- Forward-compat ----
414
415    #[test]
416    fn unknown_event_type_falls_back_to_other_preserving_json() {
417        let raw = json!({
418            "type": "future_event",
419            "payload": {"x": 1, "y": [2, 3]}
420        });
421        let ev: StreamEvent = serde_json::from_value(raw.clone()).expect("deserialize");
422        assert!(ev.other().is_some());
423        assert_eq!(ev.type_tag(), Some("future_event"));
424
425        let reserialized = serde_json::to_value(&ev).expect("serialize");
426        assert_eq!(reserialized, raw, "Other must round-trip byte-for-byte");
427    }
428
429    #[test]
430    fn malformed_known_event_is_an_error() {
431        // Known type, but `index` should be a u32, not a string.
432        let raw = json!({"type": "content_block_stop", "index": "nope"});
433        let result: Result<StreamEvent, _> = serde_json::from_value(raw);
434        assert!(
435            result.is_err(),
436            "malformed known event must error, not silently fall through to Other"
437        );
438    }
439
440    // ---- ContentDelta variants ----
441
442    #[test]
443    fn text_delta_round_trips() {
444        round_trip_delta(
445            &ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
446            &json!({"type": "text_delta", "text": "hi"}),
447        );
448    }
449
450    #[test]
451    fn input_json_delta_round_trips() {
452        round_trip_delta(
453            &ContentDelta::Known(KnownContentDelta::InputJsonDelta {
454                partial_json: r#"{"city":"P"#.into(),
455            }),
456            &json!({"type": "input_json_delta", "partial_json": "{\"city\":\"P"}),
457        );
458    }
459
460    #[test]
461    fn thinking_delta_round_trips() {
462        round_trip_delta(
463            &ContentDelta::Known(KnownContentDelta::ThinkingDelta {
464                thinking: " more thinking".into(),
465            }),
466            &json!({"type": "thinking_delta", "thinking": " more thinking"}),
467        );
468    }
469
470    #[test]
471    fn signature_delta_round_trips() {
472        round_trip_delta(
473            &ContentDelta::Known(KnownContentDelta::SignatureDelta {
474                signature: "sig123".into(),
475            }),
476            &json!({"type": "signature_delta", "signature": "sig123"}),
477        );
478    }
479
480    #[test]
481    fn citations_delta_round_trips() {
482        use crate::messages::citation::{Citation, KnownCitation};
483        round_trip_delta(
484            &ContentDelta::Known(KnownContentDelta::CitationsDelta {
485                citation: Citation::Known(KnownCitation::CharLocation {
486                    document_index: 0,
487                    document_title: Some("Doc".into()),
488                    cited_text: "hello".into(),
489                    start_char_index: 0,
490                    end_char_index: 5,
491                }),
492            }),
493            &json!({
494                "type": "citations_delta",
495                "citation": {
496                    "type": "char_location",
497                    "document_index": 0,
498                    "document_title": "Doc",
499                    "cited_text": "hello",
500                    "start_char_index": 0,
501                    "end_char_index": 5
502                }
503            }),
504        );
505    }
506
507    #[test]
508    fn unknown_delta_type_falls_back_to_other_preserving_json() {
509        let raw = json!({"type": "future_delta", "stuff": [1, 2]});
510        let d: ContentDelta = serde_json::from_value(raw.clone()).expect("deserialize");
511        assert!(d.other().is_some());
512        assert_eq!(d.type_tag(), Some("future_delta"));
513        let reserialized = serde_json::to_value(&d).expect("serialize");
514        assert_eq!(reserialized, raw);
515    }
516
517    #[test]
518    fn malformed_known_delta_is_an_error() {
519        let raw = json!({"type": "text_delta", "text": 42});
520        let result: Result<ContentDelta, _> = serde_json::from_value(raw);
521        assert!(result.is_err());
522    }
523
524    // ---- Golden sequence: a typical stream from start to stop ----
525
526    #[test]
527    fn golden_sequence_decodes_end_to_end() {
528        let events = vec![
529            json!({
530                "type": "message_start",
531                "message": {
532                    "id": "msg_X",
533                    "type": "message",
534                    "role": "assistant",
535                    "content": [],
536                    "model": "claude-sonnet-4-6",
537                    "usage": {"input_tokens": 10, "output_tokens": 0}
538                }
539            }),
540            json!({
541                "type": "content_block_start",
542                "index": 0,
543                "content_block": {"type": "text", "text": ""}
544            }),
545            json!({
546                "type": "content_block_delta",
547                "index": 0,
548                "delta": {"type": "text_delta", "text": "Hello"}
549            }),
550            json!({
551                "type": "content_block_delta",
552                "index": 0,
553                "delta": {"type": "text_delta", "text": " world"}
554            }),
555            json!({"type": "content_block_stop", "index": 0}),
556            json!({
557                "type": "message_delta",
558                "delta": {"stop_reason": "end_turn"},
559                "usage": {"input_tokens": 10, "output_tokens": 2}
560            }),
561            json!({"type": "message_stop"}),
562        ];
563
564        let parsed: Vec<StreamEvent> = events
565            .into_iter()
566            .map(|v| serde_json::from_value(v).expect("decode"))
567            .collect();
568
569        assert_eq!(parsed.len(), 7);
570        assert_eq!(parsed[0].type_tag(), Some("message_start"));
571        assert_eq!(parsed[6].type_tag(), Some("message_stop"));
572
573        // The two text_delta events should match.
574        match &parsed[2] {
575            StreamEvent::Known(KnownStreamEvent::ContentBlockDelta { delta, .. }) => match delta {
576                ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
577                    assert_eq!(text, "Hello");
578                }
579                _ => panic!("expected TextDelta"),
580            },
581            _ => panic!("expected ContentBlockDelta"),
582        }
583    }
584}
585
586// ---------------------------------------------------------------------------
587// EventStream + Aggregator (gated on the `streaming` feature)
588// ---------------------------------------------------------------------------
589
590/// Typed stream of [`StreamEvent`]s yielded from a streaming Messages request.
591///
592/// Implements [`futures_util::Stream`] so callers can iterate event-by-event,
593/// or call [`Self::aggregate`] to drive the stream to completion and
594/// reconstruct a full [`Message`].
595///
596/// Optional **callback hooks** can be attached via the `on_*` builder
597/// methods; they fire only during [`Self::aggregate`] (the raw `Stream`
598/// path is unaffected). Useful for token-by-token UI updates without
599/// pattern-matching `StreamEvent` yourself.
600///
601/// Mid-stream connection failures are not retried -- doing so would silently
602/// drop content. See [`crate::error::Error::is_retryable`].
603#[cfg(feature = "streaming")]
604#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
605pub struct EventStream {
606    inner: futures_util::stream::BoxStream<'static, Result<StreamEvent>>,
607    handlers: MessageStreamHandlers,
608}
609
610#[cfg(feature = "streaming")]
611impl EventStream {
612    /// Wrap a streaming HTTP response.
613    pub(crate) fn from_response(response: reqwest::Response) -> Self {
614        use futures_util::StreamExt;
615        Self {
616            inner: crate::sse::into_typed_stream::<StreamEvent>(response).boxed(),
617            handlers: MessageStreamHandlers::default(),
618        }
619    }
620
621    /// Test helper: build an `EventStream` from a pre-baked sequence of
622    /// `Result<StreamEvent>`s. Used to exercise callback wiring without a
623    /// real HTTP connection.
624    #[cfg(test)]
625    fn from_events(events: Vec<Result<StreamEvent>>) -> Self {
626        use futures_util::StreamExt;
627        Self {
628            inner: futures_util::stream::iter(events).boxed(),
629            handlers: MessageStreamHandlers::default(),
630        }
631    }
632
633    /// Attach a handler fired on each text-delta inside a `text` content block.
634    /// The closure receives only the new chunk; the running concatenation is
635    /// available via the final [`Message`] returned by [`Self::aggregate`].
636    #[must_use]
637    pub fn on_text_delta<F>(mut self, handler: F) -> Self
638    where
639        F: FnMut(&str) + Send + 'static,
640    {
641        self.handlers.text_delta = Some(Box::new(handler));
642        self
643    }
644
645    /// Attach a handler fired when a `tool_use` content block finishes
646    /// streaming (its `input` JSON is fully reconstructed). The closure
647    /// receives `(id, name, &input)`. Also fires for `server_tool_use`
648    /// blocks (e.g. web search invocations).
649    #[must_use]
650    pub fn on_tool_use_complete<F>(mut self, handler: F) -> Self
651    where
652        F: FnMut(&str, &str, &serde_json::Value) + Send + 'static,
653    {
654        self.handlers.tool_use_complete = Some(Box::new(handler));
655        self
656    }
657
658    /// Attach a handler fired on each delta inside a `thinking` block.
659    #[must_use]
660    pub fn on_thinking_delta<F>(mut self, handler: F) -> Self
661    where
662        F: FnMut(&str) + Send + 'static,
663    {
664        self.handlers.thinking_delta = Some(Box::new(handler));
665        self
666    }
667
668    /// Attach a handler fired once when the stream's final `message_stop`
669    /// event arrives. Receives the cumulative [`Usage`] from the message.
670    #[must_use]
671    pub fn on_message_stop<F>(mut self, handler: F) -> Self
672    where
673        F: FnMut(&Usage) + Send + 'static,
674    {
675        self.handlers.message_stop = Some(Box::new(handler));
676        self
677    }
678
679    /// Attach a handler fired when the server emits an `error` stream event
680    /// or when a stream-parse failure escapes mid-aggregation. The closure
681    /// runs before the error propagates back to the caller of
682    /// [`Self::aggregate`].
683    #[must_use]
684    pub fn on_error<F>(mut self, handler: F) -> Self
685    where
686        F: FnMut(&Error) + Send + 'static,
687    {
688        self.handlers.error = Some(Box::new(handler));
689        self
690    }
691
692    /// Drive the stream to completion and return the reconstructed [`Message`].
693    ///
694    /// Equivalent to using `messages.create(...)` non-streamed -- the same
695    /// final [`Message`] payload is produced. If callback hooks were
696    /// attached via the `on_*` builder methods, they fire as their
697    /// corresponding events are processed.
698    pub async fn aggregate(self) -> Result<Message> {
699        use futures_util::StreamExt;
700        let Self {
701            mut inner,
702            handlers,
703        } = self;
704        let mut agg = Aggregator::with_handlers(handlers);
705        while let Some(event) = inner.next().await {
706            match event {
707                Ok(ev) => match agg.handle(ev) {
708                    Ok(()) => {}
709                    Err(e) => {
710                        agg.fire_error(&e);
711                        return Err(e);
712                    }
713                },
714                Err(e) => {
715                    agg.fire_error(&e);
716                    return Err(e);
717                }
718            }
719        }
720        agg.finalize()
721    }
722}
723
724#[cfg(feature = "streaming")]
725type TextDeltaHandler = Box<dyn FnMut(&str) + Send>;
726#[cfg(feature = "streaming")]
727type ToolUseCompleteHandler = Box<dyn FnMut(&str, &str, &serde_json::Value) + Send>;
728#[cfg(feature = "streaming")]
729type ThinkingDeltaHandler = Box<dyn FnMut(&str) + Send>;
730#[cfg(feature = "streaming")]
731type MessageStopHandler = Box<dyn FnMut(&Usage) + Send>;
732#[cfg(feature = "streaming")]
733type ErrorHandler = Box<dyn FnMut(&Error) + Send>;
734
735/// Callback hooks fired during [`EventStream::aggregate`].
736#[cfg(feature = "streaming")]
737#[derive(Default)]
738struct MessageStreamHandlers {
739    text_delta: Option<TextDeltaHandler>,
740    tool_use_complete: Option<ToolUseCompleteHandler>,
741    thinking_delta: Option<ThinkingDeltaHandler>,
742    message_stop: Option<MessageStopHandler>,
743    error: Option<ErrorHandler>,
744}
745
746#[cfg(feature = "streaming")]
747impl std::fmt::Debug for MessageStreamHandlers {
748    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
749        f.debug_struct("MessageStreamHandlers")
750            .field("text_delta", &self.text_delta.as_ref().map(|_| "<fn>"))
751            .field(
752                "tool_use_complete",
753                &self.tool_use_complete.as_ref().map(|_| "<fn>"),
754            )
755            .field(
756                "thinking_delta",
757                &self.thinking_delta.as_ref().map(|_| "<fn>"),
758            )
759            .field("message_stop", &self.message_stop.as_ref().map(|_| "<fn>"))
760            .field("error", &self.error.as_ref().map(|_| "<fn>"))
761            .finish()
762    }
763}
764
765#[cfg(feature = "streaming")]
766impl futures_util::Stream for EventStream {
767    type Item = Result<StreamEvent>;
768
769    fn poll_next(
770        mut self: std::pin::Pin<&mut Self>,
771        cx: &mut std::task::Context<'_>,
772    ) -> std::task::Poll<Option<Self::Item>> {
773        self.inner.as_mut().poll_next(cx)
774    }
775}
776
777#[cfg(feature = "streaming")]
778impl std::fmt::Debug for EventStream {
779    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
780        f.debug_struct("EventStream").finish_non_exhaustive()
781    }
782}
783
784/// Reconstructs a [`Message`] from a sequence of [`StreamEvent`]s.
785///
786/// Pure data structure -- no I/O. Designed to be testable in isolation by
787/// feeding events directly via [`Self::handle`].
788#[cfg(feature = "streaming")]
789#[derive(Debug, Default)]
790pub struct Aggregator {
791    message: Option<Message>,
792    blocks: Vec<ContentBlock>,
793    /// Accumulated `partial_json` strings per block index, parsed at
794    /// `ContentBlockStop` and stored back on the corresponding `ToolUse`
795    /// or `ServerToolUse` block's `input`.
796    tool_input_buffers: std::collections::HashMap<u32, String>,
797    handlers: MessageStreamHandlers,
798}
799
800#[cfg(feature = "streaming")]
801impl Aggregator {
802    /// Build an Aggregator pre-populated with stream callback hooks.
803    /// Internal -- callers wire handlers through [`EventStream`].
804    fn with_handlers(handlers: MessageStreamHandlers) -> Self {
805        Self {
806            handlers,
807            ..Self::default()
808        }
809    }
810
811    /// Fire the `error` handler for an aborting error. Internal helper
812    /// for [`EventStream::aggregate`].
813    fn fire_error(&mut self, err: &Error) {
814        if let Some(h) = self.handlers.error.as_mut() {
815            h(err);
816        }
817    }
818
819    /// Apply one event to the aggregator's state.
820    pub fn handle(&mut self, event: StreamEvent) -> Result<()> {
821        match event {
822            StreamEvent::Known(known) => self.handle_known(known),
823            StreamEvent::Other(value) => {
824                tracing::debug!(?value, "claude-api: ignoring unknown stream event");
825                Ok(())
826            }
827        }
828    }
829
830    fn handle_known(&mut self, event: KnownStreamEvent) -> Result<()> {
831        match event {
832            KnownStreamEvent::MessageStart { message } => {
833                self.message = Some(message);
834            }
835            KnownStreamEvent::ContentBlockStart {
836                index,
837                content_block,
838            } => {
839                if index as usize != self.blocks.len() {
840                    return Err(Error::Stream(StreamError::Parse(format!(
841                        "out-of-order content_block_start: index {} but {} blocks already received",
842                        index,
843                        self.blocks.len()
844                    ))));
845                }
846                self.blocks.push(content_block);
847            }
848            KnownStreamEvent::ContentBlockDelta { index, delta } => {
849                self.apply_delta(index, delta);
850            }
851            KnownStreamEvent::ContentBlockStop { index } => {
852                if let Some(buf) = self.tool_input_buffers.remove(&index) {
853                    self.finalize_tool_input(index, &buf);
854                }
855                // Fire on_tool_use_complete for tool_use / server_tool_use blocks.
856                if let Some(handler) = self.handlers.tool_use_complete.as_mut()
857                    && let Some(ContentBlock::Known(
858                        KnownBlock::ToolUse { id, name, input }
859                        | KnownBlock::ServerToolUse { id, name, input },
860                    )) = self.blocks.get(index as usize)
861                {
862                    handler(id, name, input);
863                }
864            }
865            KnownStreamEvent::MessageDelta { delta, usage } => {
866                if let Some(msg) = self.message.as_mut() {
867                    if let Some(sr) = delta.stop_reason {
868                        msg.stop_reason = Some(sr);
869                    }
870                    if let Some(ss) = delta.stop_sequence {
871                        msg.stop_sequence = Some(ss);
872                    }
873                    msg.usage = usage;
874                }
875            }
876            KnownStreamEvent::MessageStop => {
877                if let Some(handler) = self.handlers.message_stop.as_mut()
878                    && let Some(msg) = self.message.as_ref()
879                {
880                    handler(&msg.usage);
881                }
882            }
883            KnownStreamEvent::Ping => {}
884            KnownStreamEvent::Error { error } => {
885                return Err(Error::Stream(StreamError::Server {
886                    kind: error.kind,
887                    message: error.message,
888                }));
889            }
890        }
891        Ok(())
892    }
893
894    fn apply_delta(&mut self, index: u32, delta: ContentDelta) {
895        let Some(block) = self.blocks.get_mut(index as usize) else {
896            tracing::warn!(index, "claude-api: delta for unknown block index, dropping");
897            return;
898        };
899        match delta {
900            ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
901                if let ContentBlock::Known(KnownBlock::Text { text: existing, .. }) = block {
902                    existing.push_str(&text);
903                }
904                if let Some(handler) = self.handlers.text_delta.as_mut() {
905                    handler(&text);
906                }
907            }
908            ContentDelta::Known(KnownContentDelta::InputJsonDelta { partial_json }) => {
909                self.tool_input_buffers
910                    .entry(index)
911                    .or_default()
912                    .push_str(&partial_json);
913            }
914            ContentDelta::Known(KnownContentDelta::ThinkingDelta { thinking }) => {
915                if let ContentBlock::Known(KnownBlock::Thinking {
916                    thinking: existing, ..
917                }) = block
918                {
919                    existing.push_str(&thinking);
920                }
921                if let Some(handler) = self.handlers.thinking_delta.as_mut() {
922                    handler(&thinking);
923                }
924            }
925            ContentDelta::Known(KnownContentDelta::SignatureDelta { signature }) => {
926                if let ContentBlock::Known(KnownBlock::Thinking { signature: sig, .. }) = block {
927                    *sig = signature;
928                }
929            }
930            ContentDelta::Known(KnownContentDelta::CitationsDelta { citation }) => {
931                if let ContentBlock::Known(KnownBlock::Text { citations, .. }) = block {
932                    citations.get_or_insert_with(Vec::new).push(citation);
933                }
934            }
935            ContentDelta::Other(value) => {
936                tracing::debug!(?value, "claude-api: ignoring unknown content delta");
937            }
938        }
939    }
940
941    fn finalize_tool_input(&mut self, index: u32, buffer: &str) {
942        let Some(block) = self.blocks.get_mut(index as usize) else {
943            return;
944        };
945        let parsed = if buffer.is_empty() {
946            // Nothing to parse; leave whatever the start event provided.
947            return;
948        } else {
949            serde_json::from_str::<serde_json::Value>(buffer).unwrap_or_else(|e| {
950                tracing::warn!(
951                    error = %e,
952                    "claude-api: tool_use input failed to parse; storing raw string"
953                );
954                serde_json::Value::String(buffer.to_owned())
955            })
956        };
957        match block {
958            ContentBlock::Known(
959                KnownBlock::ToolUse { input, .. } | KnownBlock::ServerToolUse { input, .. },
960            ) => {
961                *input = parsed;
962            }
963            _ => {
964                tracing::warn!(
965                    index,
966                    "claude-api: input_json_delta accumulated for non-tool-use block"
967                );
968            }
969        }
970    }
971
972    /// Finalize: combine the accumulated `MessageStart` shell with the
973    /// reconstructed content blocks.
974    pub fn finalize(mut self) -> Result<Message> {
975        let mut message = self.message.take().ok_or_else(|| {
976            Error::Stream(StreamError::Parse(
977                "stream ended without a message_start event".into(),
978            ))
979        })?;
980        message.content = self.blocks;
981        Ok(message)
982    }
983}
984
985#[cfg(all(test, feature = "streaming"))]
986mod aggregator_tests {
987    use super::*;
988    use crate::error::{ApiErrorKind, ApiErrorPayload};
989    use crate::types::{ModelId, Role};
990    use pretty_assertions::assert_eq;
991    use serde_json::json;
992
993    fn message_start_event() -> StreamEvent {
994        StreamEvent::Known(KnownStreamEvent::MessageStart {
995            message: serde_json::from_value(json!({
996                "id": "msg_x",
997                "type": "message",
998                "role": "assistant",
999                "content": [],
1000                "model": "claude-sonnet-4-6",
1001                "usage": {"input_tokens": 5, "output_tokens": 0}
1002            }))
1003            .unwrap(),
1004        })
1005    }
1006
1007    #[test]
1008    fn aggregator_reconstructs_text_message() {
1009        let mut agg = Aggregator::default();
1010        agg.handle(message_start_event()).unwrap();
1011        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1012            index: 0,
1013            content_block: ContentBlock::text(""),
1014        }))
1015        .unwrap();
1016        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1017            index: 0,
1018            delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1019                text: "Hello".into(),
1020            }),
1021        }))
1022        .unwrap();
1023        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1024            index: 0,
1025            delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1026                text: " world".into(),
1027            }),
1028        }))
1029        .unwrap();
1030        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1031            index: 0,
1032        }))
1033        .unwrap();
1034        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1035            delta: MessageDelta {
1036                stop_reason: Some(StopReason::EndTurn),
1037                stop_sequence: None,
1038            },
1039            usage: Usage {
1040                input_tokens: 5,
1041                output_tokens: 2,
1042                ..Usage::default()
1043            },
1044        }))
1045        .unwrap();
1046        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1047            .unwrap();
1048
1049        let msg = agg.finalize().unwrap();
1050        assert_eq!(msg.id, "msg_x");
1051        assert_eq!(msg.role, Role::Assistant);
1052        assert_eq!(msg.model, ModelId::SONNET_4_6);
1053        assert_eq!(msg.stop_reason, Some(StopReason::EndTurn));
1054        assert_eq!(msg.usage.output_tokens, 2);
1055        assert_eq!(msg.content.len(), 1);
1056        match &msg.content[0] {
1057            ContentBlock::Known(KnownBlock::Text { text, .. }) => {
1058                assert_eq!(text, "Hello world");
1059            }
1060            _ => panic!("expected text block"),
1061        }
1062    }
1063
1064    #[test]
1065    fn aggregator_reconstructs_tool_use_input_from_partial_json_deltas() {
1066        let mut agg = Aggregator::default();
1067        agg.handle(message_start_event()).unwrap();
1068        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1069            index: 0,
1070            content_block: ContentBlock::Known(KnownBlock::ToolUse {
1071                id: "toolu_1".into(),
1072                name: "get_weather".into(),
1073                input: json!({}),
1074            }),
1075        }))
1076        .unwrap();
1077        for chunk in ["{\"city\":", "\"Paris\"", ",\"unit\":\"C\"}"] {
1078            agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1079                index: 0,
1080                delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1081                    partial_json: chunk.into(),
1082                }),
1083            }))
1084            .unwrap();
1085        }
1086        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1087            index: 0,
1088        }))
1089        .unwrap();
1090        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1091            .unwrap();
1092
1093        let msg = agg.finalize().unwrap();
1094        match &msg.content[0] {
1095            ContentBlock::Known(KnownBlock::ToolUse { input, name, .. }) => {
1096                assert_eq!(name, "get_weather");
1097                assert_eq!(input, &json!({"city": "Paris", "unit": "C"}));
1098            }
1099            _ => panic!("expected ToolUse block"),
1100        }
1101    }
1102
1103    #[test]
1104    fn aggregator_reconstructs_thinking_block() {
1105        let mut agg = Aggregator::default();
1106        agg.handle(message_start_event()).unwrap();
1107        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1108            index: 0,
1109            content_block: ContentBlock::Known(KnownBlock::Thinking {
1110                thinking: String::new(),
1111                signature: String::new(),
1112            }),
1113        }))
1114        .unwrap();
1115        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1116            index: 0,
1117            delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1118                thinking: "let me ".into(),
1119            }),
1120        }))
1121        .unwrap();
1122        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1123            index: 0,
1124            delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1125                thinking: "think".into(),
1126            }),
1127        }))
1128        .unwrap();
1129        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1130            index: 0,
1131            delta: ContentDelta::Known(KnownContentDelta::SignatureDelta {
1132                signature: "sig_xyz".into(),
1133            }),
1134        }))
1135        .unwrap();
1136        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1137            index: 0,
1138        }))
1139        .unwrap();
1140        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1141            .unwrap();
1142
1143        let msg = agg.finalize().unwrap();
1144        match &msg.content[0] {
1145            ContentBlock::Known(KnownBlock::Thinking {
1146                thinking,
1147                signature,
1148            }) => {
1149                assert_eq!(thinking, "let me think");
1150                assert_eq!(signature, "sig_xyz");
1151            }
1152            _ => panic!("expected Thinking block"),
1153        }
1154    }
1155
1156    #[test]
1157    fn aggregator_unknown_event_is_ignored() {
1158        let mut agg = Aggregator::default();
1159        agg.handle(message_start_event()).unwrap();
1160        // Unknown event should not error.
1161        agg.handle(StreamEvent::Other(json!({"type": "future_event"})))
1162            .unwrap();
1163        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1164            .unwrap();
1165        let msg = agg.finalize().unwrap();
1166        assert!(msg.content.is_empty());
1167    }
1168
1169    #[test]
1170    fn aggregator_unknown_delta_is_ignored() {
1171        let mut agg = Aggregator::default();
1172        agg.handle(message_start_event()).unwrap();
1173        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1174            index: 0,
1175            content_block: ContentBlock::text(""),
1176        }))
1177        .unwrap();
1178        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1179            index: 0,
1180            delta: ContentDelta::Other(json!({"type": "future_delta"})),
1181        }))
1182        .unwrap();
1183        // Aggregator should not have crashed.
1184    }
1185
1186    #[test]
1187    fn aggregator_server_error_event_propagates() {
1188        let mut agg = Aggregator::default();
1189        agg.handle(message_start_event()).unwrap();
1190        let err = agg
1191            .handle(StreamEvent::Known(KnownStreamEvent::Error {
1192                error: ApiErrorPayload {
1193                    kind: ApiErrorKind::OverloadedError,
1194                    message: "boom".into(),
1195                },
1196            }))
1197            .unwrap_err();
1198        match err {
1199            Error::Stream(StreamError::Server { kind, message }) => {
1200                assert_eq!(kind, ApiErrorKind::OverloadedError);
1201                assert_eq!(message, "boom");
1202            }
1203            other => panic!("expected Stream::Server, got {other:?}"),
1204        }
1205    }
1206
1207    #[test]
1208    fn aggregator_out_of_order_block_start_errors() {
1209        let mut agg = Aggregator::default();
1210        agg.handle(message_start_event()).unwrap();
1211        // Skip index 0; start with index 1.
1212        let err = agg
1213            .handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1214                index: 1,
1215                content_block: ContentBlock::text(""),
1216            }))
1217            .unwrap_err();
1218        assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1219    }
1220
1221    #[test]
1222    fn aggregator_finalize_without_message_start_errors() {
1223        let agg = Aggregator::default();
1224        let err = agg.finalize().unwrap_err();
1225        assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1226    }
1227}
1228
1229#[cfg(all(test, feature = "streaming"))]
1230mod stream_callback_tests {
1231    use super::*;
1232    use crate::error::{ApiErrorKind, ApiErrorPayload};
1233    use pretty_assertions::assert_eq;
1234    use serde_json::json;
1235    use std::sync::{Arc, Mutex};
1236
1237    fn message_start_event() -> StreamEvent {
1238        StreamEvent::Known(KnownStreamEvent::MessageStart {
1239            message: serde_json::from_value(json!({
1240                "id": "msg_x",
1241                "type": "message",
1242                "role": "assistant",
1243                "content": [],
1244                "model": "claude-sonnet-4-6",
1245                "usage": {"input_tokens": 5, "output_tokens": 0}
1246            }))
1247            .unwrap(),
1248        })
1249    }
1250
1251    #[tokio::test]
1252    async fn on_text_delta_fires_for_each_text_chunk() {
1253        let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1254        let sink = Arc::clone(&captured);
1255        let events = vec![
1256            Ok(message_start_event()),
1257            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1258                index: 0,
1259                content_block: ContentBlock::text(""),
1260            })),
1261            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1262                index: 0,
1263                delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1264                    text: "Hello".into(),
1265                }),
1266            })),
1267            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1268                index: 0,
1269                delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1270                    text: " world".into(),
1271                }),
1272            })),
1273            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1274                index: 0,
1275            })),
1276            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1277        ];
1278
1279        let stream = EventStream::from_events(events).on_text_delta(move |chunk| {
1280            sink.lock().unwrap().push(chunk.to_string());
1281        });
1282        stream.aggregate().await.unwrap();
1283
1284        assert_eq!(*captured.lock().unwrap(), vec!["Hello", " world"]);
1285    }
1286
1287    #[tokio::test]
1288    async fn on_tool_use_complete_fires_with_parsed_input() {
1289        let captured: Arc<Mutex<Vec<(String, String, serde_json::Value)>>> =
1290            Arc::new(Mutex::new(Vec::new()));
1291        let sink = Arc::clone(&captured);
1292        let events = vec![
1293            Ok(message_start_event()),
1294            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1295                index: 0,
1296                content_block: ContentBlock::Known(KnownBlock::ToolUse {
1297                    id: "toolu_1".into(),
1298                    name: "get_weather".into(),
1299                    input: json!({}),
1300                }),
1301            })),
1302            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1303                index: 0,
1304                delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1305                    partial_json: "{\"city\":\"Paris\"}".into(),
1306                }),
1307            })),
1308            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1309                index: 0,
1310            })),
1311            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1312        ];
1313
1314        let stream =
1315            EventStream::from_events(events).on_tool_use_complete(move |id, name, input| {
1316                sink.lock()
1317                    .unwrap()
1318                    .push((id.to_string(), name.to_string(), input.clone()));
1319            });
1320        stream.aggregate().await.unwrap();
1321
1322        let captured = captured.lock().unwrap();
1323        assert_eq!(captured.len(), 1);
1324        assert_eq!(captured[0].0, "toolu_1");
1325        assert_eq!(captured[0].1, "get_weather");
1326        assert_eq!(captured[0].2, json!({"city": "Paris"}));
1327    }
1328
1329    #[tokio::test]
1330    async fn on_tool_use_complete_fires_for_server_tool_use_blocks() {
1331        let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1332        let sink = Arc::clone(&captured);
1333        let events = vec![
1334            Ok(message_start_event()),
1335            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1336                index: 0,
1337                content_block: ContentBlock::Known(KnownBlock::ServerToolUse {
1338                    id: "srvu_1".into(),
1339                    name: "web_search".into(),
1340                    input: json!({}),
1341                }),
1342            })),
1343            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1344                index: 0,
1345                delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1346                    partial_json: "{\"q\":\"rust\"}".into(),
1347                }),
1348            })),
1349            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1350                index: 0,
1351            })),
1352            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1353        ];
1354
1355        let stream = EventStream::from_events(events).on_tool_use_complete(move |id, _, _| {
1356            sink.lock().unwrap().push(id.to_string());
1357        });
1358        stream.aggregate().await.unwrap();
1359
1360        assert_eq!(*captured.lock().unwrap(), vec!["srvu_1"]);
1361    }
1362
1363    #[tokio::test]
1364    async fn on_thinking_delta_fires_for_each_thinking_chunk() {
1365        let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1366        let sink = Arc::clone(&captured);
1367        let events = vec![
1368            Ok(message_start_event()),
1369            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1370                index: 0,
1371                content_block: ContentBlock::Known(KnownBlock::Thinking {
1372                    thinking: String::new(),
1373                    signature: String::new(),
1374                }),
1375            })),
1376            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1377                index: 0,
1378                delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1379                    thinking: "let me ".into(),
1380                }),
1381            })),
1382            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1383                index: 0,
1384                delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1385                    thinking: "think".into(),
1386                }),
1387            })),
1388            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1389                index: 0,
1390            })),
1391            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1392        ];
1393
1394        let stream = EventStream::from_events(events).on_thinking_delta(move |chunk| {
1395            sink.lock().unwrap().push(chunk.to_string());
1396        });
1397        stream.aggregate().await.unwrap();
1398
1399        assert_eq!(*captured.lock().unwrap(), vec!["let me ", "think"]);
1400    }
1401
1402    #[tokio::test]
1403    async fn on_message_stop_fires_once_with_usage() {
1404        let captured: Arc<Mutex<Vec<Usage>>> = Arc::new(Mutex::new(Vec::new()));
1405        let sink = Arc::clone(&captured);
1406        let events = vec![
1407            Ok(message_start_event()),
1408            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1409                index: 0,
1410                content_block: ContentBlock::text(""),
1411            })),
1412            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1413                index: 0,
1414                delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1415            })),
1416            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1417                index: 0,
1418            })),
1419            Ok(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1420                delta: MessageDelta {
1421                    stop_reason: Some(StopReason::EndTurn),
1422                    stop_sequence: None,
1423                },
1424                usage: Usage {
1425                    input_tokens: 5,
1426                    output_tokens: 7,
1427                    ..Usage::default()
1428                },
1429            })),
1430            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1431        ];
1432
1433        let stream = EventStream::from_events(events).on_message_stop(move |usage| {
1434            sink.lock().unwrap().push(usage.clone());
1435        });
1436        stream.aggregate().await.unwrap();
1437
1438        let captured = captured.lock().unwrap();
1439        assert_eq!(captured.len(), 1);
1440        assert_eq!(captured[0].input_tokens, 5);
1441        assert_eq!(captured[0].output_tokens, 7);
1442    }
1443
1444    #[tokio::test]
1445    async fn on_error_fires_before_propagating_server_error() {
1446        let count = Arc::new(Mutex::new(0u32));
1447        let sink = Arc::clone(&count);
1448        let events = vec![
1449            Ok(message_start_event()),
1450            Ok(StreamEvent::Known(KnownStreamEvent::Error {
1451                error: ApiErrorPayload {
1452                    kind: ApiErrorKind::OverloadedError,
1453                    message: "boom".into(),
1454                },
1455            })),
1456        ];
1457
1458        let stream = EventStream::from_events(events).on_error(move |_| {
1459            *sink.lock().unwrap() += 1;
1460        });
1461        let err = stream.aggregate().await.unwrap_err();
1462        assert!(matches!(
1463            err,
1464            Error::Stream(StreamError::Server {
1465                kind: ApiErrorKind::OverloadedError,
1466                ..
1467            })
1468        ));
1469        assert_eq!(
1470            *count.lock().unwrap(),
1471            1,
1472            "handler should fire exactly once"
1473        );
1474    }
1475
1476    #[tokio::test]
1477    async fn on_error_fires_for_transport_error() {
1478        let count = Arc::new(Mutex::new(0u32));
1479        let sink = Arc::clone(&count);
1480        let events: Vec<Result<StreamEvent>> = vec![
1481            Ok(message_start_event()),
1482            Err(Error::Stream(StreamError::Parse("decode failed".into()))),
1483        ];
1484
1485        let stream = EventStream::from_events(events).on_error(move |_| {
1486            *sink.lock().unwrap() += 1;
1487        });
1488        let err = stream.aggregate().await.unwrap_err();
1489        assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1490        assert_eq!(*count.lock().unwrap(), 1);
1491    }
1492
1493    #[tokio::test]
1494    async fn raw_stream_iteration_does_not_fire_callbacks() {
1495        use futures_util::StreamExt;
1496        let count = Arc::new(Mutex::new(0u32));
1497        let sink = Arc::clone(&count);
1498        let events = vec![
1499            Ok(message_start_event()),
1500            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1501                index: 0,
1502                content_block: ContentBlock::text(""),
1503            })),
1504            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1505                index: 0,
1506                delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1507            })),
1508            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1509                index: 0,
1510            })),
1511            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1512        ];
1513
1514        let mut stream = EventStream::from_events(events).on_text_delta(move |_| {
1515            *sink.lock().unwrap() += 1;
1516        });
1517        while let Some(_ev) = stream.next().await {}
1518        // Callbacks only fire during aggregate(), not raw .next().
1519        assert_eq!(*count.lock().unwrap(), 0);
1520    }
1521}