Skip to main content

claude_api/messages/
stream.rs

1//! Streaming event types and reconstruction.
2//!
3//! Both [`StreamEvent`] and [`ContentDelta`] are forward-compatible: unknown
4//! `type` tags deserialize into the `Other` arm with the raw JSON preserved
5//! byte-for-byte. Strict-on-known semantics: a known tag with a malformed
6//! body returns a deserialization error rather than silently falling through.
7//!
8//! [`EventStream`] is the typed wrapper around the SSE wire format. Two
9//! consumption patterns:
10//!
11//! - **Event-by-event** via [`futures_util::StreamExt::next`] -- print text
12//!   deltas as they arrive (see `examples/streaming.rs`).
13//! - **Aggregate** via [`EventStream::aggregate`] -- wait for the full
14//!   `message_start ... message_stop` sequence and get back a [`Message`],
15//!   identical to what [`Messages::create`](crate::messages::Messages::create)
16//!   returns.
17//!
18//! # Print text deltas as they arrive
19//!
20//! ```no_run
21//! use claude_api::{Client, messages::{CreateMessageRequest, KnownContentDelta,
22//!     KnownStreamEvent, ContentDelta, StreamEvent}, types::ModelId};
23//! use futures_util::StreamExt;
24//! # async fn run() -> Result<(), claude_api::Error> {
25//! let client = Client::new(std::env::var("ANTHROPIC_API_KEY").unwrap());
26//! let request = CreateMessageRequest::builder()
27//!     .model(ModelId::SONNET_4_6)
28//!     .max_tokens(256)
29//!     .user("Tell me a fact.")
30//!     .build()?;
31//! let mut stream = client.messages().create_stream(request).await?;
32//! while let Some(event) = stream.next().await {
33//!     if let StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
34//!         delta: ContentDelta::Known(KnownContentDelta::TextDelta { text }),
35//!         ..
36//!     }) = event?
37//!     {
38//!         print!("{text}");
39//!     }
40//! }
41//! # Ok(())
42//! # }
43//! ```
44//!
45//! Gated on the `streaming` feature.
46
47use serde::{Deserialize, Serialize};
48
49use crate::error::ApiErrorPayload;
50use crate::forward_compat::dispatch_known_or_other;
51use crate::messages::content::ContentBlock;
52use crate::messages::response::Message;
53use crate::types::{StopReason, Usage};
54
55#[cfg(feature = "streaming")]
56use crate::error::{Error, Result, StreamError};
57#[cfg(feature = "streaming")]
58use crate::messages::content::KnownBlock;
59
60/// A single event from the Messages streaming endpoint.
61///
62/// Forward-compatible wrapper around [`KnownStreamEvent`]; unknown event types
63/// land in [`StreamEvent::Other`] preserving the raw JSON.
64//
65// Suppress `large_enum_variant`: boxing Known would break pattern-match
66// ergonomics. Worth revisiting in a v1.0 release that's free to break the
67// stream-event API.
68#[allow(clippy::large_enum_variant)]
69#[derive(Debug, Clone, PartialEq)]
70pub enum StreamEvent {
71    /// An event whose `type` is recognized by this SDK version.
72    Known(KnownStreamEvent),
73    /// An event whose `type` is not recognized; the raw JSON is preserved.
74    Other(serde_json::Value),
75}
76
77/// All streaming event types known to this SDK version.
78#[allow(clippy::large_enum_variant)]
79#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80#[serde(tag = "type", rename_all = "snake_case")]
81#[non_exhaustive]
82pub enum KnownStreamEvent {
83    /// Begins a new streamed message; carries the empty [`Message`] shell
84    /// that subsequent events will fill in.
85    MessageStart {
86        /// The opening message snapshot.
87        message: Message,
88    },
89    /// Begins a new content block within the message.
90    ContentBlockStart {
91        /// Index of the block within the message's content array.
92        index: u32,
93        /// Initial state of the block.
94        content_block: ContentBlock,
95    },
96    /// Incremental update to a content block.
97    ContentBlockDelta {
98        /// Index of the block being updated.
99        index: u32,
100        /// The delta payload.
101        delta: ContentDelta,
102    },
103    /// Marks a content block as complete.
104    ContentBlockStop {
105        /// Index of the block that finished.
106        index: u32,
107    },
108    /// Late-arriving updates to message-level fields, plus final usage.
109    MessageDelta {
110        /// Updated message-level fields.
111        delta: MessageDelta,
112        /// Cumulative usage at the point this delta was emitted.
113        usage: Usage,
114    },
115    /// Final event in a successful stream.
116    MessageStop,
117    /// Keep-alive ping; no payload.
118    Ping,
119    /// Server reported a fatal error mid-stream.
120    Error {
121        /// The error payload.
122        error: ApiErrorPayload,
123    },
124}
125
126/// `type` tags this SDK recognizes for streaming events.
127const KNOWN_EVENT_TAGS: &[&str] = &[
128    "message_start",
129    "content_block_start",
130    "content_block_delta",
131    "content_block_stop",
132    "message_delta",
133    "message_stop",
134    "ping",
135    "error",
136];
137
138impl Serialize for StreamEvent {
139    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
140        match self {
141            StreamEvent::Known(k) => k.serialize(s),
142            StreamEvent::Other(v) => v.serialize(s),
143        }
144    }
145}
146
147impl<'de> Deserialize<'de> for StreamEvent {
148    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
149        let raw = serde_json::Value::deserialize(d)?;
150        dispatch_known_or_other(
151            raw,
152            KNOWN_EVENT_TAGS,
153            StreamEvent::Known,
154            StreamEvent::Other,
155        )
156        .map_err(serde::de::Error::custom)
157    }
158}
159
160impl From<KnownStreamEvent> for StreamEvent {
161    fn from(k: KnownStreamEvent) -> Self {
162        StreamEvent::Known(k)
163    }
164}
165
166impl StreamEvent {
167    /// If this is a known event, return the inner [`KnownStreamEvent`].
168    pub fn known(&self) -> Option<&KnownStreamEvent> {
169        match self {
170            Self::Known(k) => Some(k),
171            Self::Other(_) => None,
172        }
173    }
174
175    /// If this is an unknown event, return the raw JSON.
176    pub fn other(&self) -> Option<&serde_json::Value> {
177        match self {
178            Self::Other(v) => Some(v),
179            Self::Known(_) => None,
180        }
181    }
182
183    /// Wire-level `type` tag for this event regardless of variant.
184    pub fn type_tag(&self) -> Option<&str> {
185        match self {
186            Self::Known(k) => Some(known_event_tag(k)),
187            Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
188        }
189    }
190}
191
192fn known_event_tag(k: &KnownStreamEvent) -> &'static str {
193    match k {
194        KnownStreamEvent::MessageStart { .. } => "message_start",
195        KnownStreamEvent::ContentBlockStart { .. } => "content_block_start",
196        KnownStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
197        KnownStreamEvent::ContentBlockStop { .. } => "content_block_stop",
198        KnownStreamEvent::MessageDelta { .. } => "message_delta",
199        KnownStreamEvent::MessageStop => "message_stop",
200        KnownStreamEvent::Ping => "ping",
201        KnownStreamEvent::Error { .. } => "error",
202    }
203}
204
205/// Late-arriving updates to message-level fields, emitted in
206/// [`KnownStreamEvent::MessageDelta`].
207#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
208#[non_exhaustive]
209pub struct MessageDelta {
210    /// Why the model stopped (if known at this point).
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    pub stop_reason: Option<StopReason>,
213    /// Stop sequence that triggered termination, if any.
214    #[serde(default, skip_serializing_if = "Option::is_none")]
215    pub stop_sequence: Option<String>,
216}
217
218/// One delta update inside a [`KnownStreamEvent::ContentBlockDelta`].
219///
220/// Forward-compatible wrapper around [`KnownContentDelta`].
221#[derive(Debug, Clone, PartialEq)]
222pub enum ContentDelta {
223    /// A delta whose `type` is recognized by this SDK version.
224    Known(KnownContentDelta),
225    /// A delta whose `type` is not recognized; the raw JSON is preserved.
226    Other(serde_json::Value),
227}
228
229/// All content-delta variants known to this SDK version.
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231#[serde(tag = "type", rename_all = "snake_case")]
232#[non_exhaustive]
233pub enum KnownContentDelta {
234    /// Append text to a `text` block.
235    TextDelta {
236        /// Additional text.
237        text: String,
238    },
239    /// Append a partial-JSON fragment to a `tool_use`'s `input`.
240    InputJsonDelta {
241        /// Partial JSON fragment.
242        partial_json: String,
243    },
244    /// Append text to a `thinking` block.
245    ThinkingDelta {
246        /// Additional thinking text.
247        thinking: String,
248    },
249    /// Update the `signature` of a `thinking` block.
250    SignatureDelta {
251        /// Updated signature.
252        signature: String,
253    },
254    /// Append a citation to a `text` block.
255    CitationsDelta {
256        /// The citation payload (typed enum with forward-compat fallback).
257        citation: crate::messages::citation::Citation,
258    },
259}
260
261const KNOWN_DELTA_TAGS: &[&str] = &[
262    "text_delta",
263    "input_json_delta",
264    "thinking_delta",
265    "signature_delta",
266    "citations_delta",
267];
268
269impl Serialize for ContentDelta {
270    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
271        match self {
272            ContentDelta::Known(k) => k.serialize(s),
273            ContentDelta::Other(v) => v.serialize(s),
274        }
275    }
276}
277
278impl<'de> Deserialize<'de> for ContentDelta {
279    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
280        let raw = serde_json::Value::deserialize(d)?;
281        dispatch_known_or_other(
282            raw,
283            KNOWN_DELTA_TAGS,
284            ContentDelta::Known,
285            ContentDelta::Other,
286        )
287        .map_err(serde::de::Error::custom)
288    }
289}
290
291impl From<KnownContentDelta> for ContentDelta {
292    fn from(k: KnownContentDelta) -> Self {
293        ContentDelta::Known(k)
294    }
295}
296
297impl ContentDelta {
298    /// If this is a known delta, return the inner [`KnownContentDelta`].
299    pub fn known(&self) -> Option<&KnownContentDelta> {
300        match self {
301            Self::Known(k) => Some(k),
302            Self::Other(_) => None,
303        }
304    }
305
306    /// If this is an unknown delta, return the raw JSON.
307    pub fn other(&self) -> Option<&serde_json::Value> {
308        match self {
309            Self::Other(v) => Some(v),
310            Self::Known(_) => None,
311        }
312    }
313
314    /// Wire-level `type` tag for this delta regardless of variant.
315    pub fn type_tag(&self) -> Option<&str> {
316        match self {
317            Self::Known(k) => Some(match k {
318                KnownContentDelta::TextDelta { .. } => "text_delta",
319                KnownContentDelta::InputJsonDelta { .. } => "input_json_delta",
320                KnownContentDelta::ThinkingDelta { .. } => "thinking_delta",
321                KnownContentDelta::SignatureDelta { .. } => "signature_delta",
322                KnownContentDelta::CitationsDelta { .. } => "citations_delta",
323            }),
324            Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use crate::error::ApiErrorKind;
333    use pretty_assertions::assert_eq;
334    use serde_json::json;
335
336    fn round_trip_event(event: &StreamEvent, expected: &serde_json::Value) {
337        let v = serde_json::to_value(event).expect("serialize");
338        assert_eq!(&v, expected, "wire form mismatch");
339        let parsed: StreamEvent = serde_json::from_value(v).expect("deserialize");
340        assert_eq!(&parsed, event, "round-trip mismatch");
341    }
342
343    fn round_trip_delta(delta: &ContentDelta, expected: &serde_json::Value) {
344        let v = serde_json::to_value(delta).expect("serialize");
345        assert_eq!(&v, expected, "wire form mismatch");
346        let parsed: ContentDelta = serde_json::from_value(v).expect("deserialize");
347        assert_eq!(&parsed, delta, "round-trip mismatch");
348    }
349
350    // ---- StreamEvent variants ----
351
352    #[test]
353    fn message_stop_round_trips() {
354        round_trip_event(
355            &StreamEvent::Known(KnownStreamEvent::MessageStop),
356            &json!({"type": "message_stop"}),
357        );
358    }
359
360    #[test]
361    fn ping_round_trips() {
362        round_trip_event(
363            &StreamEvent::Known(KnownStreamEvent::Ping),
364            &json!({"type": "ping"}),
365        );
366    }
367
368    #[test]
369    fn content_block_start_round_trips() {
370        let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
371            index: 0,
372            content_block: ContentBlock::text(""),
373        });
374        round_trip_event(
375            &ev,
376            &json!({
377                "type": "content_block_start",
378                "index": 0,
379                "content_block": {"type": "text", "text": ""}
380            }),
381        );
382    }
383
384    #[test]
385    fn content_block_delta_round_trips() {
386        let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
387            index: 1,
388            delta: ContentDelta::Known(KnownContentDelta::TextDelta {
389                text: "Hello".into(),
390            }),
391        });
392        round_trip_event(
393            &ev,
394            &json!({
395                "type": "content_block_delta",
396                "index": 1,
397                "delta": {"type": "text_delta", "text": "Hello"}
398            }),
399        );
400    }
401
402    #[test]
403    fn content_block_stop_round_trips() {
404        let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStop { index: 2 });
405        round_trip_event(&ev, &json!({"type": "content_block_stop", "index": 2}));
406    }
407
408    #[test]
409    fn message_delta_round_trips() {
410        let ev = StreamEvent::Known(KnownStreamEvent::MessageDelta {
411            delta: MessageDelta {
412                stop_reason: Some(StopReason::EndTurn),
413                stop_sequence: None,
414            },
415            usage: Usage {
416                input_tokens: 5,
417                output_tokens: 10,
418                ..Usage::default()
419            },
420        });
421        round_trip_event(
422            &ev,
423            &json!({
424                "type": "message_delta",
425                "delta": {"stop_reason": "end_turn"},
426                "usage": {"input_tokens": 5, "output_tokens": 10}
427            }),
428        );
429    }
430
431    #[test]
432    fn error_event_round_trips() {
433        let ev = StreamEvent::Known(KnownStreamEvent::Error {
434            error: ApiErrorPayload {
435                kind: ApiErrorKind::OverloadedError,
436                message: "try again".into(),
437            },
438        });
439        round_trip_event(
440            &ev,
441            &json!({
442                "type": "error",
443                "error": {"type": "overloaded_error", "message": "try again"}
444            }),
445        );
446    }
447
448    // ---- Forward-compat ----
449
450    #[test]
451    fn unknown_event_type_falls_back_to_other_preserving_json() {
452        let raw = json!({
453            "type": "future_event",
454            "payload": {"x": 1, "y": [2, 3]}
455        });
456        let ev: StreamEvent = serde_json::from_value(raw.clone()).expect("deserialize");
457        assert!(ev.other().is_some());
458        assert_eq!(ev.type_tag(), Some("future_event"));
459
460        let reserialized = serde_json::to_value(&ev).expect("serialize");
461        assert_eq!(reserialized, raw, "Other must round-trip byte-for-byte");
462    }
463
464    #[test]
465    fn malformed_known_event_is_an_error() {
466        // Known type, but `index` should be a u32, not a string.
467        let raw = json!({"type": "content_block_stop", "index": "nope"});
468        let result: Result<StreamEvent, _> = serde_json::from_value(raw);
469        assert!(
470            result.is_err(),
471            "malformed known event must error, not silently fall through to Other"
472        );
473    }
474
475    // ---- ContentDelta variants ----
476
477    #[test]
478    fn text_delta_round_trips() {
479        round_trip_delta(
480            &ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
481            &json!({"type": "text_delta", "text": "hi"}),
482        );
483    }
484
485    #[test]
486    fn input_json_delta_round_trips() {
487        round_trip_delta(
488            &ContentDelta::Known(KnownContentDelta::InputJsonDelta {
489                partial_json: r#"{"city":"P"#.into(),
490            }),
491            &json!({"type": "input_json_delta", "partial_json": "{\"city\":\"P"}),
492        );
493    }
494
495    #[test]
496    fn thinking_delta_round_trips() {
497        round_trip_delta(
498            &ContentDelta::Known(KnownContentDelta::ThinkingDelta {
499                thinking: " more thinking".into(),
500            }),
501            &json!({"type": "thinking_delta", "thinking": " more thinking"}),
502        );
503    }
504
505    #[test]
506    fn signature_delta_round_trips() {
507        round_trip_delta(
508            &ContentDelta::Known(KnownContentDelta::SignatureDelta {
509                signature: "sig123".into(),
510            }),
511            &json!({"type": "signature_delta", "signature": "sig123"}),
512        );
513    }
514
515    #[test]
516    fn citations_delta_round_trips() {
517        use crate::messages::citation::{Citation, KnownCitation};
518        round_trip_delta(
519            &ContentDelta::Known(KnownContentDelta::CitationsDelta {
520                citation: Citation::Known(KnownCitation::CharLocation {
521                    document_index: 0,
522                    document_title: Some("Doc".into()),
523                    cited_text: "hello".into(),
524                    start_char_index: 0,
525                    end_char_index: 5,
526                }),
527            }),
528            &json!({
529                "type": "citations_delta",
530                "citation": {
531                    "type": "char_location",
532                    "document_index": 0,
533                    "document_title": "Doc",
534                    "cited_text": "hello",
535                    "start_char_index": 0,
536                    "end_char_index": 5
537                }
538            }),
539        );
540    }
541
542    #[test]
543    fn unknown_delta_type_falls_back_to_other_preserving_json() {
544        let raw = json!({"type": "future_delta", "stuff": [1, 2]});
545        let d: ContentDelta = serde_json::from_value(raw.clone()).expect("deserialize");
546        assert!(d.other().is_some());
547        assert_eq!(d.type_tag(), Some("future_delta"));
548        let reserialized = serde_json::to_value(&d).expect("serialize");
549        assert_eq!(reserialized, raw);
550    }
551
552    #[test]
553    fn malformed_known_delta_is_an_error() {
554        let raw = json!({"type": "text_delta", "text": 42});
555        let result: Result<ContentDelta, _> = serde_json::from_value(raw);
556        assert!(result.is_err());
557    }
558
559    // ---- Golden sequence: a typical stream from start to stop ----
560
561    #[test]
562    fn golden_sequence_decodes_end_to_end() {
563        let events = vec![
564            json!({
565                "type": "message_start",
566                "message": {
567                    "id": "msg_X",
568                    "type": "message",
569                    "role": "assistant",
570                    "content": [],
571                    "model": "claude-sonnet-4-6",
572                    "usage": {"input_tokens": 10, "output_tokens": 0}
573                }
574            }),
575            json!({
576                "type": "content_block_start",
577                "index": 0,
578                "content_block": {"type": "text", "text": ""}
579            }),
580            json!({
581                "type": "content_block_delta",
582                "index": 0,
583                "delta": {"type": "text_delta", "text": "Hello"}
584            }),
585            json!({
586                "type": "content_block_delta",
587                "index": 0,
588                "delta": {"type": "text_delta", "text": " world"}
589            }),
590            json!({"type": "content_block_stop", "index": 0}),
591            json!({
592                "type": "message_delta",
593                "delta": {"stop_reason": "end_turn"},
594                "usage": {"input_tokens": 10, "output_tokens": 2}
595            }),
596            json!({"type": "message_stop"}),
597        ];
598
599        let parsed: Vec<StreamEvent> = events
600            .into_iter()
601            .map(|v| serde_json::from_value(v).expect("decode"))
602            .collect();
603
604        assert_eq!(parsed.len(), 7);
605        assert_eq!(parsed[0].type_tag(), Some("message_start"));
606        assert_eq!(parsed[6].type_tag(), Some("message_stop"));
607
608        // The two text_delta events should match.
609        match &parsed[2] {
610            StreamEvent::Known(KnownStreamEvent::ContentBlockDelta { delta, .. }) => match delta {
611                ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
612                    assert_eq!(text, "Hello");
613                }
614                _ => panic!("expected TextDelta"),
615            },
616            _ => panic!("expected ContentBlockDelta"),
617        }
618    }
619}
620
621// ---------------------------------------------------------------------------
622// EventStream + Aggregator (gated on the `streaming` feature)
623// ---------------------------------------------------------------------------
624
625/// Typed stream of [`StreamEvent`]s yielded from a streaming Messages request.
626///
627/// Implements [`futures_util::Stream`] so callers can iterate event-by-event,
628/// or call [`Self::aggregate`] to drive the stream to completion and
629/// reconstruct a full [`Message`].
630///
631/// Optional **callback hooks** can be attached via the `on_*` builder
632/// methods; they fire only during [`Self::aggregate`] (the raw `Stream`
633/// path is unaffected). Useful for token-by-token UI updates without
634/// pattern-matching `StreamEvent` yourself.
635///
636/// Mid-stream connection failures are not retried -- doing so would silently
637/// drop content. See [`crate::error::Error::is_retryable`].
638#[cfg(feature = "streaming")]
639#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
640pub struct EventStream {
641    inner: futures_util::stream::BoxStream<'static, Result<StreamEvent>>,
642    handlers: MessageStreamHandlers,
643}
644
645#[cfg(feature = "streaming")]
646impl EventStream {
647    /// Wrap a streaming HTTP response.
648    pub(crate) fn from_response(response: reqwest::Response) -> Self {
649        use futures_util::StreamExt;
650        Self {
651            inner: crate::sse::into_typed_stream::<StreamEvent>(response).boxed(),
652            handlers: MessageStreamHandlers::default(),
653        }
654    }
655
656    /// Test helper: build an `EventStream` from a pre-baked sequence of
657    /// `Result<StreamEvent>`s. Used to exercise callback wiring without a
658    /// real HTTP connection.
659    #[cfg(test)]
660    fn from_events(events: Vec<Result<StreamEvent>>) -> Self {
661        use futures_util::StreamExt;
662        Self {
663            inner: futures_util::stream::iter(events).boxed(),
664            handlers: MessageStreamHandlers::default(),
665        }
666    }
667
668    /// Attach a handler fired on each text-delta inside a `text` content block.
669    /// The closure receives only the new chunk; the running concatenation is
670    /// available via the final [`Message`] returned by [`Self::aggregate`].
671    #[must_use]
672    pub fn on_text_delta<F>(mut self, handler: F) -> Self
673    where
674        F: FnMut(&str) + Send + 'static,
675    {
676        self.handlers.text_delta = Some(Box::new(handler));
677        self
678    }
679
680    /// Attach a handler fired when a `tool_use` content block finishes
681    /// streaming (its `input` JSON is fully reconstructed). The closure
682    /// receives `(id, name, &input)`. Also fires for `server_tool_use`
683    /// blocks (e.g. web search invocations).
684    #[must_use]
685    pub fn on_tool_use_complete<F>(mut self, handler: F) -> Self
686    where
687        F: FnMut(&str, &str, &serde_json::Value) + Send + 'static,
688    {
689        self.handlers.tool_use_complete = Some(Box::new(handler));
690        self
691    }
692
693    /// Attach a handler fired on each delta inside a `thinking` block.
694    #[must_use]
695    pub fn on_thinking_delta<F>(mut self, handler: F) -> Self
696    where
697        F: FnMut(&str) + Send + 'static,
698    {
699        self.handlers.thinking_delta = Some(Box::new(handler));
700        self
701    }
702
703    /// Attach a handler fired once when the stream's final `message_stop`
704    /// event arrives. Receives the cumulative [`Usage`] from the message.
705    #[must_use]
706    pub fn on_message_stop<F>(mut self, handler: F) -> Self
707    where
708        F: FnMut(&Usage) + Send + 'static,
709    {
710        self.handlers.message_stop = Some(Box::new(handler));
711        self
712    }
713
714    /// Attach a handler fired when the server emits an `error` stream event
715    /// or when a stream-parse failure escapes mid-aggregation. The closure
716    /// runs before the error propagates back to the caller of
717    /// [`Self::aggregate`].
718    #[must_use]
719    pub fn on_error<F>(mut self, handler: F) -> Self
720    where
721        F: FnMut(&Error) + Send + 'static,
722    {
723        self.handlers.error = Some(Box::new(handler));
724        self
725    }
726
727    /// Drive the stream to completion and return the reconstructed [`Message`].
728    ///
729    /// Equivalent to using `messages.create(...)` non-streamed -- the same
730    /// final [`Message`] payload is produced. If callback hooks were
731    /// attached via the `on_*` builder methods, they fire as their
732    /// corresponding events are processed.
733    pub async fn aggregate(self) -> Result<Message> {
734        use futures_util::StreamExt;
735        let Self {
736            mut inner,
737            handlers,
738        } = self;
739        let mut agg = Aggregator::with_handlers(handlers);
740        while let Some(event) = inner.next().await {
741            match event {
742                Ok(ev) => match agg.handle(ev) {
743                    Ok(()) => {}
744                    Err(e) => {
745                        agg.fire_error(&e);
746                        return Err(e);
747                    }
748                },
749                Err(e) => {
750                    agg.fire_error(&e);
751                    return Err(e);
752                }
753            }
754        }
755        agg.finalize()
756    }
757}
758
759#[cfg(feature = "streaming")]
760type TextDeltaHandler = Box<dyn FnMut(&str) + Send>;
761#[cfg(feature = "streaming")]
762type ToolUseCompleteHandler = Box<dyn FnMut(&str, &str, &serde_json::Value) + Send>;
763#[cfg(feature = "streaming")]
764type ThinkingDeltaHandler = Box<dyn FnMut(&str) + Send>;
765#[cfg(feature = "streaming")]
766type MessageStopHandler = Box<dyn FnMut(&Usage) + Send>;
767#[cfg(feature = "streaming")]
768type ErrorHandler = Box<dyn FnMut(&Error) + Send>;
769
770/// Callback hooks fired during [`EventStream::aggregate`].
771#[cfg(feature = "streaming")]
772#[derive(Default)]
773struct MessageStreamHandlers {
774    text_delta: Option<TextDeltaHandler>,
775    tool_use_complete: Option<ToolUseCompleteHandler>,
776    thinking_delta: Option<ThinkingDeltaHandler>,
777    message_stop: Option<MessageStopHandler>,
778    error: Option<ErrorHandler>,
779}
780
781#[cfg(feature = "streaming")]
782impl std::fmt::Debug for MessageStreamHandlers {
783    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
784        f.debug_struct("MessageStreamHandlers")
785            .field("text_delta", &self.text_delta.as_ref().map(|_| "<fn>"))
786            .field(
787                "tool_use_complete",
788                &self.tool_use_complete.as_ref().map(|_| "<fn>"),
789            )
790            .field(
791                "thinking_delta",
792                &self.thinking_delta.as_ref().map(|_| "<fn>"),
793            )
794            .field("message_stop", &self.message_stop.as_ref().map(|_| "<fn>"))
795            .field("error", &self.error.as_ref().map(|_| "<fn>"))
796            .finish()
797    }
798}
799
800#[cfg(feature = "streaming")]
801impl futures_util::Stream for EventStream {
802    type Item = Result<StreamEvent>;
803
804    fn poll_next(
805        mut self: std::pin::Pin<&mut Self>,
806        cx: &mut std::task::Context<'_>,
807    ) -> std::task::Poll<Option<Self::Item>> {
808        self.inner.as_mut().poll_next(cx)
809    }
810}
811
812#[cfg(feature = "streaming")]
813impl std::fmt::Debug for EventStream {
814    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815        f.debug_struct("EventStream").finish_non_exhaustive()
816    }
817}
818
819/// Reconstructs a [`Message`] from a sequence of [`StreamEvent`]s.
820///
821/// Pure data structure -- no I/O. Designed to be testable in isolation by
822/// feeding events directly via [`Self::handle`].
823#[cfg(feature = "streaming")]
824#[derive(Debug, Default)]
825pub struct Aggregator {
826    message: Option<Message>,
827    blocks: Vec<ContentBlock>,
828    /// Accumulated `partial_json` strings per block index, parsed at
829    /// `ContentBlockStop` and stored back on the corresponding `ToolUse`
830    /// or `ServerToolUse` block's `input`.
831    tool_input_buffers: std::collections::HashMap<u32, String>,
832    handlers: MessageStreamHandlers,
833}
834
835#[cfg(feature = "streaming")]
836impl Aggregator {
837    /// Build an Aggregator pre-populated with stream callback hooks.
838    /// Internal -- callers wire handlers through [`EventStream`].
839    fn with_handlers(handlers: MessageStreamHandlers) -> Self {
840        Self {
841            handlers,
842            ..Self::default()
843        }
844    }
845
846    /// Fire the `error` handler for an aborting error. Internal helper
847    /// for [`EventStream::aggregate`].
848    fn fire_error(&mut self, err: &Error) {
849        if let Some(h) = self.handlers.error.as_mut() {
850            h(err);
851        }
852    }
853
854    /// Apply one event to the aggregator's state.
855    pub fn handle(&mut self, event: StreamEvent) -> Result<()> {
856        match event {
857            StreamEvent::Known(known) => self.handle_known(known),
858            StreamEvent::Other(value) => {
859                tracing::debug!(?value, "claude-api: ignoring unknown stream event");
860                Ok(())
861            }
862        }
863    }
864
865    fn handle_known(&mut self, event: KnownStreamEvent) -> Result<()> {
866        match event {
867            KnownStreamEvent::MessageStart { message } => {
868                self.message = Some(message);
869            }
870            KnownStreamEvent::ContentBlockStart {
871                index,
872                content_block,
873            } => {
874                if index as usize != self.blocks.len() {
875                    return Err(Error::Stream(StreamError::Parse(format!(
876                        "out-of-order content_block_start: index {} but {} blocks already received",
877                        index,
878                        self.blocks.len()
879                    ))));
880                }
881                self.blocks.push(content_block);
882            }
883            KnownStreamEvent::ContentBlockDelta { index, delta } => {
884                self.apply_delta(index, delta);
885            }
886            KnownStreamEvent::ContentBlockStop { index } => {
887                if let Some(buf) = self.tool_input_buffers.remove(&index) {
888                    self.finalize_tool_input(index, &buf);
889                }
890                // Fire on_tool_use_complete for tool_use / server_tool_use blocks.
891                if let Some(handler) = self.handlers.tool_use_complete.as_mut()
892                    && let Some(ContentBlock::Known(
893                        KnownBlock::ToolUse { id, name, input }
894                        | KnownBlock::ServerToolUse { id, name, input },
895                    )) = self.blocks.get(index as usize)
896                {
897                    handler(id, name, input);
898                }
899            }
900            KnownStreamEvent::MessageDelta { delta, usage } => {
901                if let Some(msg) = self.message.as_mut() {
902                    if let Some(sr) = delta.stop_reason {
903                        msg.stop_reason = Some(sr);
904                    }
905                    if let Some(ss) = delta.stop_sequence {
906                        msg.stop_sequence = Some(ss);
907                    }
908                    msg.usage = usage;
909                }
910            }
911            KnownStreamEvent::MessageStop => {
912                if let Some(handler) = self.handlers.message_stop.as_mut()
913                    && let Some(msg) = self.message.as_ref()
914                {
915                    handler(&msg.usage);
916                }
917            }
918            KnownStreamEvent::Ping => {}
919            KnownStreamEvent::Error { error } => {
920                return Err(Error::Stream(StreamError::Server {
921                    kind: error.kind,
922                    message: error.message,
923                }));
924            }
925        }
926        Ok(())
927    }
928
929    fn apply_delta(&mut self, index: u32, delta: ContentDelta) {
930        let Some(block) = self.blocks.get_mut(index as usize) else {
931            tracing::warn!(index, "claude-api: delta for unknown block index, dropping");
932            return;
933        };
934        match delta {
935            ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
936                if let ContentBlock::Known(KnownBlock::Text { text: existing, .. }) = block {
937                    existing.push_str(&text);
938                }
939                if let Some(handler) = self.handlers.text_delta.as_mut() {
940                    handler(&text);
941                }
942            }
943            ContentDelta::Known(KnownContentDelta::InputJsonDelta { partial_json }) => {
944                self.tool_input_buffers
945                    .entry(index)
946                    .or_default()
947                    .push_str(&partial_json);
948            }
949            ContentDelta::Known(KnownContentDelta::ThinkingDelta { thinking }) => {
950                if let ContentBlock::Known(KnownBlock::Thinking {
951                    thinking: existing, ..
952                }) = block
953                {
954                    existing.push_str(&thinking);
955                }
956                if let Some(handler) = self.handlers.thinking_delta.as_mut() {
957                    handler(&thinking);
958                }
959            }
960            ContentDelta::Known(KnownContentDelta::SignatureDelta { signature }) => {
961                if let ContentBlock::Known(KnownBlock::Thinking { signature: sig, .. }) = block {
962                    *sig = signature;
963                }
964            }
965            ContentDelta::Known(KnownContentDelta::CitationsDelta { citation }) => {
966                if let ContentBlock::Known(KnownBlock::Text { citations, .. }) = block {
967                    citations.get_or_insert_with(Vec::new).push(citation);
968                }
969            }
970            ContentDelta::Other(value) => {
971                tracing::debug!(?value, "claude-api: ignoring unknown content delta");
972            }
973        }
974    }
975
976    fn finalize_tool_input(&mut self, index: u32, buffer: &str) {
977        let Some(block) = self.blocks.get_mut(index as usize) else {
978            return;
979        };
980        let parsed = if buffer.is_empty() {
981            // Nothing to parse; leave whatever the start event provided.
982            return;
983        } else {
984            serde_json::from_str::<serde_json::Value>(buffer).unwrap_or_else(|e| {
985                tracing::warn!(
986                    error = %e,
987                    "claude-api: tool_use input failed to parse; storing raw string"
988                );
989                serde_json::Value::String(buffer.to_owned())
990            })
991        };
992        match block {
993            ContentBlock::Known(
994                KnownBlock::ToolUse { input, .. } | KnownBlock::ServerToolUse { input, .. },
995            ) => {
996                *input = parsed;
997            }
998            _ => {
999                tracing::warn!(
1000                    index,
1001                    "claude-api: input_json_delta accumulated for non-tool-use block"
1002                );
1003            }
1004        }
1005    }
1006
1007    /// Finalize: combine the accumulated `MessageStart` shell with the
1008    /// reconstructed content blocks.
1009    pub fn finalize(mut self) -> Result<Message> {
1010        let mut message = self.message.take().ok_or_else(|| {
1011            Error::Stream(StreamError::Parse(
1012                "stream ended without a message_start event".into(),
1013            ))
1014        })?;
1015        message.content = self.blocks;
1016        Ok(message)
1017    }
1018}
1019
1020#[cfg(all(test, feature = "streaming"))]
1021mod aggregator_tests {
1022    use super::*;
1023    use crate::error::{ApiErrorKind, ApiErrorPayload};
1024    use crate::types::{ModelId, Role};
1025    use pretty_assertions::assert_eq;
1026    use serde_json::json;
1027
1028    fn message_start_event() -> StreamEvent {
1029        StreamEvent::Known(KnownStreamEvent::MessageStart {
1030            message: serde_json::from_value(json!({
1031                "id": "msg_x",
1032                "type": "message",
1033                "role": "assistant",
1034                "content": [],
1035                "model": "claude-sonnet-4-6",
1036                "usage": {"input_tokens": 5, "output_tokens": 0}
1037            }))
1038            .unwrap(),
1039        })
1040    }
1041
1042    #[test]
1043    fn aggregator_reconstructs_text_message() {
1044        let mut agg = Aggregator::default();
1045        agg.handle(message_start_event()).unwrap();
1046        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1047            index: 0,
1048            content_block: ContentBlock::text(""),
1049        }))
1050        .unwrap();
1051        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1052            index: 0,
1053            delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1054                text: "Hello".into(),
1055            }),
1056        }))
1057        .unwrap();
1058        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1059            index: 0,
1060            delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1061                text: " world".into(),
1062            }),
1063        }))
1064        .unwrap();
1065        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1066            index: 0,
1067        }))
1068        .unwrap();
1069        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1070            delta: MessageDelta {
1071                stop_reason: Some(StopReason::EndTurn),
1072                stop_sequence: None,
1073            },
1074            usage: Usage {
1075                input_tokens: 5,
1076                output_tokens: 2,
1077                ..Usage::default()
1078            },
1079        }))
1080        .unwrap();
1081        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1082            .unwrap();
1083
1084        let msg = agg.finalize().unwrap();
1085        assert_eq!(msg.id, "msg_x");
1086        assert_eq!(msg.role, Role::Assistant);
1087        assert_eq!(msg.model, ModelId::SONNET_4_6);
1088        assert_eq!(msg.stop_reason, Some(StopReason::EndTurn));
1089        assert_eq!(msg.usage.output_tokens, 2);
1090        assert_eq!(msg.content.len(), 1);
1091        match &msg.content[0] {
1092            ContentBlock::Known(KnownBlock::Text { text, .. }) => {
1093                assert_eq!(text, "Hello world");
1094            }
1095            _ => panic!("expected text block"),
1096        }
1097    }
1098
1099    #[test]
1100    fn aggregator_reconstructs_tool_use_input_from_partial_json_deltas() {
1101        let mut agg = Aggregator::default();
1102        agg.handle(message_start_event()).unwrap();
1103        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1104            index: 0,
1105            content_block: ContentBlock::Known(KnownBlock::ToolUse {
1106                id: "toolu_1".into(),
1107                name: "get_weather".into(),
1108                input: json!({}),
1109            }),
1110        }))
1111        .unwrap();
1112        for chunk in ["{\"city\":", "\"Paris\"", ",\"unit\":\"C\"}"] {
1113            agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1114                index: 0,
1115                delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1116                    partial_json: chunk.into(),
1117                }),
1118            }))
1119            .unwrap();
1120        }
1121        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1122            index: 0,
1123        }))
1124        .unwrap();
1125        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1126            .unwrap();
1127
1128        let msg = agg.finalize().unwrap();
1129        match &msg.content[0] {
1130            ContentBlock::Known(KnownBlock::ToolUse { input, name, .. }) => {
1131                assert_eq!(name, "get_weather");
1132                assert_eq!(input, &json!({"city": "Paris", "unit": "C"}));
1133            }
1134            _ => panic!("expected ToolUse block"),
1135        }
1136    }
1137
1138    #[test]
1139    fn aggregator_reconstructs_thinking_block() {
1140        let mut agg = Aggregator::default();
1141        agg.handle(message_start_event()).unwrap();
1142        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1143            index: 0,
1144            content_block: ContentBlock::Known(KnownBlock::Thinking {
1145                thinking: String::new(),
1146                signature: String::new(),
1147            }),
1148        }))
1149        .unwrap();
1150        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1151            index: 0,
1152            delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1153                thinking: "let me ".into(),
1154            }),
1155        }))
1156        .unwrap();
1157        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1158            index: 0,
1159            delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1160                thinking: "think".into(),
1161            }),
1162        }))
1163        .unwrap();
1164        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1165            index: 0,
1166            delta: ContentDelta::Known(KnownContentDelta::SignatureDelta {
1167                signature: "sig_xyz".into(),
1168            }),
1169        }))
1170        .unwrap();
1171        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1172            index: 0,
1173        }))
1174        .unwrap();
1175        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1176            .unwrap();
1177
1178        let msg = agg.finalize().unwrap();
1179        match &msg.content[0] {
1180            ContentBlock::Known(KnownBlock::Thinking {
1181                thinking,
1182                signature,
1183            }) => {
1184                assert_eq!(thinking, "let me think");
1185                assert_eq!(signature, "sig_xyz");
1186            }
1187            _ => panic!("expected Thinking block"),
1188        }
1189    }
1190
1191    #[test]
1192    fn aggregator_unknown_event_is_ignored() {
1193        let mut agg = Aggregator::default();
1194        agg.handle(message_start_event()).unwrap();
1195        // Unknown event should not error.
1196        agg.handle(StreamEvent::Other(json!({"type": "future_event"})))
1197            .unwrap();
1198        agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1199            .unwrap();
1200        let msg = agg.finalize().unwrap();
1201        assert!(msg.content.is_empty());
1202    }
1203
1204    #[test]
1205    fn aggregator_unknown_delta_is_ignored() {
1206        let mut agg = Aggregator::default();
1207        agg.handle(message_start_event()).unwrap();
1208        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1209            index: 0,
1210            content_block: ContentBlock::text(""),
1211        }))
1212        .unwrap();
1213        agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1214            index: 0,
1215            delta: ContentDelta::Other(json!({"type": "future_delta"})),
1216        }))
1217        .unwrap();
1218        // Aggregator should not have crashed.
1219    }
1220
1221    #[test]
1222    fn aggregator_server_error_event_propagates() {
1223        let mut agg = Aggregator::default();
1224        agg.handle(message_start_event()).unwrap();
1225        let err = agg
1226            .handle(StreamEvent::Known(KnownStreamEvent::Error {
1227                error: ApiErrorPayload {
1228                    kind: ApiErrorKind::OverloadedError,
1229                    message: "boom".into(),
1230                },
1231            }))
1232            .unwrap_err();
1233        match err {
1234            Error::Stream(StreamError::Server { kind, message }) => {
1235                assert_eq!(kind, ApiErrorKind::OverloadedError);
1236                assert_eq!(message, "boom");
1237            }
1238            other => panic!("expected Stream::Server, got {other:?}"),
1239        }
1240    }
1241
1242    #[test]
1243    fn aggregator_out_of_order_block_start_errors() {
1244        let mut agg = Aggregator::default();
1245        agg.handle(message_start_event()).unwrap();
1246        // Skip index 0; start with index 1.
1247        let err = agg
1248            .handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1249                index: 1,
1250                content_block: ContentBlock::text(""),
1251            }))
1252            .unwrap_err();
1253        assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1254    }
1255
1256    #[test]
1257    fn aggregator_finalize_without_message_start_errors() {
1258        let agg = Aggregator::default();
1259        let err = agg.finalize().unwrap_err();
1260        assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1261    }
1262}
1263
1264#[cfg(all(test, feature = "streaming"))]
1265mod stream_callback_tests {
1266    use super::*;
1267    use crate::error::{ApiErrorKind, ApiErrorPayload};
1268    use pretty_assertions::assert_eq;
1269    use serde_json::json;
1270    use std::sync::{Arc, Mutex};
1271
1272    fn message_start_event() -> StreamEvent {
1273        StreamEvent::Known(KnownStreamEvent::MessageStart {
1274            message: serde_json::from_value(json!({
1275                "id": "msg_x",
1276                "type": "message",
1277                "role": "assistant",
1278                "content": [],
1279                "model": "claude-sonnet-4-6",
1280                "usage": {"input_tokens": 5, "output_tokens": 0}
1281            }))
1282            .unwrap(),
1283        })
1284    }
1285
1286    #[tokio::test]
1287    async fn on_text_delta_fires_for_each_text_chunk() {
1288        let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1289        let sink = Arc::clone(&captured);
1290        let events = vec![
1291            Ok(message_start_event()),
1292            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1293                index: 0,
1294                content_block: ContentBlock::text(""),
1295            })),
1296            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1297                index: 0,
1298                delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1299                    text: "Hello".into(),
1300                }),
1301            })),
1302            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1303                index: 0,
1304                delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1305                    text: " world".into(),
1306                }),
1307            })),
1308            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1309                index: 0,
1310            })),
1311            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1312        ];
1313
1314        let stream = EventStream::from_events(events).on_text_delta(move |chunk| {
1315            sink.lock().unwrap().push(chunk.to_string());
1316        });
1317        stream.aggregate().await.unwrap();
1318
1319        assert_eq!(*captured.lock().unwrap(), vec!["Hello", " world"]);
1320    }
1321
1322    #[tokio::test]
1323    async fn on_tool_use_complete_fires_with_parsed_input() {
1324        let captured: Arc<Mutex<Vec<(String, String, serde_json::Value)>>> =
1325            Arc::new(Mutex::new(Vec::new()));
1326        let sink = Arc::clone(&captured);
1327        let events = vec![
1328            Ok(message_start_event()),
1329            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1330                index: 0,
1331                content_block: ContentBlock::Known(KnownBlock::ToolUse {
1332                    id: "toolu_1".into(),
1333                    name: "get_weather".into(),
1334                    input: json!({}),
1335                }),
1336            })),
1337            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1338                index: 0,
1339                delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1340                    partial_json: "{\"city\":\"Paris\"}".into(),
1341                }),
1342            })),
1343            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1344                index: 0,
1345            })),
1346            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1347        ];
1348
1349        let stream =
1350            EventStream::from_events(events).on_tool_use_complete(move |id, name, input| {
1351                sink.lock()
1352                    .unwrap()
1353                    .push((id.to_string(), name.to_string(), input.clone()));
1354            });
1355        stream.aggregate().await.unwrap();
1356
1357        let captured = captured.lock().unwrap();
1358        assert_eq!(captured.len(), 1);
1359        assert_eq!(captured[0].0, "toolu_1");
1360        assert_eq!(captured[0].1, "get_weather");
1361        assert_eq!(captured[0].2, json!({"city": "Paris"}));
1362    }
1363
1364    #[tokio::test]
1365    async fn on_tool_use_complete_fires_for_server_tool_use_blocks() {
1366        let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1367        let sink = Arc::clone(&captured);
1368        let events = vec![
1369            Ok(message_start_event()),
1370            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1371                index: 0,
1372                content_block: ContentBlock::Known(KnownBlock::ServerToolUse {
1373                    id: "srvu_1".into(),
1374                    name: "web_search".into(),
1375                    input: json!({}),
1376                }),
1377            })),
1378            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1379                index: 0,
1380                delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1381                    partial_json: "{\"q\":\"rust\"}".into(),
1382                }),
1383            })),
1384            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1385                index: 0,
1386            })),
1387            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1388        ];
1389
1390        let stream = EventStream::from_events(events).on_tool_use_complete(move |id, _, _| {
1391            sink.lock().unwrap().push(id.to_string());
1392        });
1393        stream.aggregate().await.unwrap();
1394
1395        assert_eq!(*captured.lock().unwrap(), vec!["srvu_1"]);
1396    }
1397
1398    #[tokio::test]
1399    async fn on_thinking_delta_fires_for_each_thinking_chunk() {
1400        let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1401        let sink = Arc::clone(&captured);
1402        let events = vec![
1403            Ok(message_start_event()),
1404            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1405                index: 0,
1406                content_block: ContentBlock::Known(KnownBlock::Thinking {
1407                    thinking: String::new(),
1408                    signature: String::new(),
1409                }),
1410            })),
1411            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1412                index: 0,
1413                delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1414                    thinking: "let me ".into(),
1415                }),
1416            })),
1417            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1418                index: 0,
1419                delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1420                    thinking: "think".into(),
1421                }),
1422            })),
1423            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1424                index: 0,
1425            })),
1426            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1427        ];
1428
1429        let stream = EventStream::from_events(events).on_thinking_delta(move |chunk| {
1430            sink.lock().unwrap().push(chunk.to_string());
1431        });
1432        stream.aggregate().await.unwrap();
1433
1434        assert_eq!(*captured.lock().unwrap(), vec!["let me ", "think"]);
1435    }
1436
1437    #[tokio::test]
1438    async fn on_message_stop_fires_once_with_usage() {
1439        let captured: Arc<Mutex<Vec<Usage>>> = Arc::new(Mutex::new(Vec::new()));
1440        let sink = Arc::clone(&captured);
1441        let events = vec![
1442            Ok(message_start_event()),
1443            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1444                index: 0,
1445                content_block: ContentBlock::text(""),
1446            })),
1447            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1448                index: 0,
1449                delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1450            })),
1451            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1452                index: 0,
1453            })),
1454            Ok(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1455                delta: MessageDelta {
1456                    stop_reason: Some(StopReason::EndTurn),
1457                    stop_sequence: None,
1458                },
1459                usage: Usage {
1460                    input_tokens: 5,
1461                    output_tokens: 7,
1462                    ..Usage::default()
1463                },
1464            })),
1465            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1466        ];
1467
1468        let stream = EventStream::from_events(events).on_message_stop(move |usage| {
1469            sink.lock().unwrap().push(usage.clone());
1470        });
1471        stream.aggregate().await.unwrap();
1472
1473        let captured = captured.lock().unwrap();
1474        assert_eq!(captured.len(), 1);
1475        assert_eq!(captured[0].input_tokens, 5);
1476        assert_eq!(captured[0].output_tokens, 7);
1477    }
1478
1479    #[tokio::test]
1480    async fn on_error_fires_before_propagating_server_error() {
1481        let count = Arc::new(Mutex::new(0u32));
1482        let sink = Arc::clone(&count);
1483        let events = vec![
1484            Ok(message_start_event()),
1485            Ok(StreamEvent::Known(KnownStreamEvent::Error {
1486                error: ApiErrorPayload {
1487                    kind: ApiErrorKind::OverloadedError,
1488                    message: "boom".into(),
1489                },
1490            })),
1491        ];
1492
1493        let stream = EventStream::from_events(events).on_error(move |_| {
1494            *sink.lock().unwrap() += 1;
1495        });
1496        let err = stream.aggregate().await.unwrap_err();
1497        assert!(matches!(
1498            err,
1499            Error::Stream(StreamError::Server {
1500                kind: ApiErrorKind::OverloadedError,
1501                ..
1502            })
1503        ));
1504        assert_eq!(
1505            *count.lock().unwrap(),
1506            1,
1507            "handler should fire exactly once"
1508        );
1509    }
1510
1511    #[tokio::test]
1512    async fn on_error_fires_for_transport_error() {
1513        let count = Arc::new(Mutex::new(0u32));
1514        let sink = Arc::clone(&count);
1515        let events: Vec<Result<StreamEvent>> = vec![
1516            Ok(message_start_event()),
1517            Err(Error::Stream(StreamError::Parse("decode failed".into()))),
1518        ];
1519
1520        let stream = EventStream::from_events(events).on_error(move |_| {
1521            *sink.lock().unwrap() += 1;
1522        });
1523        let err = stream.aggregate().await.unwrap_err();
1524        assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1525        assert_eq!(*count.lock().unwrap(), 1);
1526    }
1527
1528    #[tokio::test]
1529    async fn raw_stream_iteration_does_not_fire_callbacks() {
1530        use futures_util::StreamExt;
1531        let count = Arc::new(Mutex::new(0u32));
1532        let sink = Arc::clone(&count);
1533        let events = vec![
1534            Ok(message_start_event()),
1535            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1536                index: 0,
1537                content_block: ContentBlock::text(""),
1538            })),
1539            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1540                index: 0,
1541                delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1542            })),
1543            Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1544                index: 0,
1545            })),
1546            Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1547        ];
1548
1549        let mut stream = EventStream::from_events(events).on_text_delta(move |_| {
1550            *sink.lock().unwrap() += 1;
1551        });
1552        while let Some(_ev) = stream.next().await {}
1553        // Callbacks only fire during aggregate(), not raw .next().
1554        assert_eq!(*count.lock().unwrap(), 0);
1555    }
1556}