1use serde::{Deserialize, Serialize};
13
14use crate::error::ApiErrorPayload;
15use crate::forward_compat::dispatch_known_or_other;
16use crate::messages::content::ContentBlock;
17use crate::messages::response::Message;
18use crate::types::{StopReason, Usage};
19
20#[cfg(feature = "streaming")]
21use crate::error::{Error, Result, StreamError};
22#[cfg(feature = "streaming")]
23use crate::messages::content::KnownBlock;
24
25#[allow(clippy::large_enum_variant)]
34#[derive(Debug, Clone, PartialEq)]
35pub enum StreamEvent {
36 Known(KnownStreamEvent),
38 Other(serde_json::Value),
40}
41
42#[allow(clippy::large_enum_variant)]
44#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45#[serde(tag = "type", rename_all = "snake_case")]
46#[non_exhaustive]
47pub enum KnownStreamEvent {
48 MessageStart {
51 message: Message,
53 },
54 ContentBlockStart {
56 index: u32,
58 content_block: ContentBlock,
60 },
61 ContentBlockDelta {
63 index: u32,
65 delta: ContentDelta,
67 },
68 ContentBlockStop {
70 index: u32,
72 },
73 MessageDelta {
75 delta: MessageDelta,
77 usage: Usage,
79 },
80 MessageStop,
82 Ping,
84 Error {
86 error: ApiErrorPayload,
88 },
89}
90
91const KNOWN_EVENT_TAGS: &[&str] = &[
93 "message_start",
94 "content_block_start",
95 "content_block_delta",
96 "content_block_stop",
97 "message_delta",
98 "message_stop",
99 "ping",
100 "error",
101];
102
103impl Serialize for StreamEvent {
104 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
105 match self {
106 StreamEvent::Known(k) => k.serialize(s),
107 StreamEvent::Other(v) => v.serialize(s),
108 }
109 }
110}
111
112impl<'de> Deserialize<'de> for StreamEvent {
113 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
114 let raw = serde_json::Value::deserialize(d)?;
115 dispatch_known_or_other(
116 raw,
117 KNOWN_EVENT_TAGS,
118 StreamEvent::Known,
119 StreamEvent::Other,
120 )
121 .map_err(serde::de::Error::custom)
122 }
123}
124
125impl From<KnownStreamEvent> for StreamEvent {
126 fn from(k: KnownStreamEvent) -> Self {
127 StreamEvent::Known(k)
128 }
129}
130
131impl StreamEvent {
132 pub fn known(&self) -> Option<&KnownStreamEvent> {
134 match self {
135 Self::Known(k) => Some(k),
136 Self::Other(_) => None,
137 }
138 }
139
140 pub fn other(&self) -> Option<&serde_json::Value> {
142 match self {
143 Self::Other(v) => Some(v),
144 Self::Known(_) => None,
145 }
146 }
147
148 pub fn type_tag(&self) -> Option<&str> {
150 match self {
151 Self::Known(k) => Some(known_event_tag(k)),
152 Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
153 }
154 }
155}
156
157fn known_event_tag(k: &KnownStreamEvent) -> &'static str {
158 match k {
159 KnownStreamEvent::MessageStart { .. } => "message_start",
160 KnownStreamEvent::ContentBlockStart { .. } => "content_block_start",
161 KnownStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
162 KnownStreamEvent::ContentBlockStop { .. } => "content_block_stop",
163 KnownStreamEvent::MessageDelta { .. } => "message_delta",
164 KnownStreamEvent::MessageStop => "message_stop",
165 KnownStreamEvent::Ping => "ping",
166 KnownStreamEvent::Error { .. } => "error",
167 }
168}
169
170#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
173#[non_exhaustive]
174pub struct MessageDelta {
175 #[serde(default, skip_serializing_if = "Option::is_none")]
177 pub stop_reason: Option<StopReason>,
178 #[serde(default, skip_serializing_if = "Option::is_none")]
180 pub stop_sequence: Option<String>,
181}
182
183#[derive(Debug, Clone, PartialEq)]
187pub enum ContentDelta {
188 Known(KnownContentDelta),
190 Other(serde_json::Value),
192}
193
194#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
196#[serde(tag = "type", rename_all = "snake_case")]
197#[non_exhaustive]
198pub enum KnownContentDelta {
199 TextDelta {
201 text: String,
203 },
204 InputJsonDelta {
206 partial_json: String,
208 },
209 ThinkingDelta {
211 thinking: String,
213 },
214 SignatureDelta {
216 signature: String,
218 },
219 CitationsDelta {
221 citation: crate::messages::citation::Citation,
223 },
224}
225
226const KNOWN_DELTA_TAGS: &[&str] = &[
227 "text_delta",
228 "input_json_delta",
229 "thinking_delta",
230 "signature_delta",
231 "citations_delta",
232];
233
234impl Serialize for ContentDelta {
235 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
236 match self {
237 ContentDelta::Known(k) => k.serialize(s),
238 ContentDelta::Other(v) => v.serialize(s),
239 }
240 }
241}
242
243impl<'de> Deserialize<'de> for ContentDelta {
244 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
245 let raw = serde_json::Value::deserialize(d)?;
246 dispatch_known_or_other(
247 raw,
248 KNOWN_DELTA_TAGS,
249 ContentDelta::Known,
250 ContentDelta::Other,
251 )
252 .map_err(serde::de::Error::custom)
253 }
254}
255
256impl From<KnownContentDelta> for ContentDelta {
257 fn from(k: KnownContentDelta) -> Self {
258 ContentDelta::Known(k)
259 }
260}
261
262impl ContentDelta {
263 pub fn known(&self) -> Option<&KnownContentDelta> {
265 match self {
266 Self::Known(k) => Some(k),
267 Self::Other(_) => None,
268 }
269 }
270
271 pub fn other(&self) -> Option<&serde_json::Value> {
273 match self {
274 Self::Other(v) => Some(v),
275 Self::Known(_) => None,
276 }
277 }
278
279 pub fn type_tag(&self) -> Option<&str> {
281 match self {
282 Self::Known(k) => Some(match k {
283 KnownContentDelta::TextDelta { .. } => "text_delta",
284 KnownContentDelta::InputJsonDelta { .. } => "input_json_delta",
285 KnownContentDelta::ThinkingDelta { .. } => "thinking_delta",
286 KnownContentDelta::SignatureDelta { .. } => "signature_delta",
287 KnownContentDelta::CitationsDelta { .. } => "citations_delta",
288 }),
289 Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
290 }
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::error::ApiErrorKind;
298 use pretty_assertions::assert_eq;
299 use serde_json::json;
300
301 fn round_trip_event(event: &StreamEvent, expected: &serde_json::Value) {
302 let v = serde_json::to_value(event).expect("serialize");
303 assert_eq!(&v, expected, "wire form mismatch");
304 let parsed: StreamEvent = serde_json::from_value(v).expect("deserialize");
305 assert_eq!(&parsed, event, "round-trip mismatch");
306 }
307
308 fn round_trip_delta(delta: &ContentDelta, expected: &serde_json::Value) {
309 let v = serde_json::to_value(delta).expect("serialize");
310 assert_eq!(&v, expected, "wire form mismatch");
311 let parsed: ContentDelta = serde_json::from_value(v).expect("deserialize");
312 assert_eq!(&parsed, delta, "round-trip mismatch");
313 }
314
315 #[test]
318 fn message_stop_round_trips() {
319 round_trip_event(
320 &StreamEvent::Known(KnownStreamEvent::MessageStop),
321 &json!({"type": "message_stop"}),
322 );
323 }
324
325 #[test]
326 fn ping_round_trips() {
327 round_trip_event(
328 &StreamEvent::Known(KnownStreamEvent::Ping),
329 &json!({"type": "ping"}),
330 );
331 }
332
333 #[test]
334 fn content_block_start_round_trips() {
335 let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
336 index: 0,
337 content_block: ContentBlock::text(""),
338 });
339 round_trip_event(
340 &ev,
341 &json!({
342 "type": "content_block_start",
343 "index": 0,
344 "content_block": {"type": "text", "text": ""}
345 }),
346 );
347 }
348
349 #[test]
350 fn content_block_delta_round_trips() {
351 let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
352 index: 1,
353 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
354 text: "Hello".into(),
355 }),
356 });
357 round_trip_event(
358 &ev,
359 &json!({
360 "type": "content_block_delta",
361 "index": 1,
362 "delta": {"type": "text_delta", "text": "Hello"}
363 }),
364 );
365 }
366
367 #[test]
368 fn content_block_stop_round_trips() {
369 let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStop { index: 2 });
370 round_trip_event(&ev, &json!({"type": "content_block_stop", "index": 2}));
371 }
372
373 #[test]
374 fn message_delta_round_trips() {
375 let ev = StreamEvent::Known(KnownStreamEvent::MessageDelta {
376 delta: MessageDelta {
377 stop_reason: Some(StopReason::EndTurn),
378 stop_sequence: None,
379 },
380 usage: Usage {
381 input_tokens: 5,
382 output_tokens: 10,
383 ..Usage::default()
384 },
385 });
386 round_trip_event(
387 &ev,
388 &json!({
389 "type": "message_delta",
390 "delta": {"stop_reason": "end_turn"},
391 "usage": {"input_tokens": 5, "output_tokens": 10}
392 }),
393 );
394 }
395
396 #[test]
397 fn error_event_round_trips() {
398 let ev = StreamEvent::Known(KnownStreamEvent::Error {
399 error: ApiErrorPayload {
400 kind: ApiErrorKind::OverloadedError,
401 message: "try again".into(),
402 },
403 });
404 round_trip_event(
405 &ev,
406 &json!({
407 "type": "error",
408 "error": {"type": "overloaded_error", "message": "try again"}
409 }),
410 );
411 }
412
413 #[test]
416 fn unknown_event_type_falls_back_to_other_preserving_json() {
417 let raw = json!({
418 "type": "future_event",
419 "payload": {"x": 1, "y": [2, 3]}
420 });
421 let ev: StreamEvent = serde_json::from_value(raw.clone()).expect("deserialize");
422 assert!(ev.other().is_some());
423 assert_eq!(ev.type_tag(), Some("future_event"));
424
425 let reserialized = serde_json::to_value(&ev).expect("serialize");
426 assert_eq!(reserialized, raw, "Other must round-trip byte-for-byte");
427 }
428
429 #[test]
430 fn malformed_known_event_is_an_error() {
431 let raw = json!({"type": "content_block_stop", "index": "nope"});
433 let result: Result<StreamEvent, _> = serde_json::from_value(raw);
434 assert!(
435 result.is_err(),
436 "malformed known event must error, not silently fall through to Other"
437 );
438 }
439
440 #[test]
443 fn text_delta_round_trips() {
444 round_trip_delta(
445 &ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
446 &json!({"type": "text_delta", "text": "hi"}),
447 );
448 }
449
450 #[test]
451 fn input_json_delta_round_trips() {
452 round_trip_delta(
453 &ContentDelta::Known(KnownContentDelta::InputJsonDelta {
454 partial_json: r#"{"city":"P"#.into(),
455 }),
456 &json!({"type": "input_json_delta", "partial_json": "{\"city\":\"P"}),
457 );
458 }
459
460 #[test]
461 fn thinking_delta_round_trips() {
462 round_trip_delta(
463 &ContentDelta::Known(KnownContentDelta::ThinkingDelta {
464 thinking: " more thinking".into(),
465 }),
466 &json!({"type": "thinking_delta", "thinking": " more thinking"}),
467 );
468 }
469
470 #[test]
471 fn signature_delta_round_trips() {
472 round_trip_delta(
473 &ContentDelta::Known(KnownContentDelta::SignatureDelta {
474 signature: "sig123".into(),
475 }),
476 &json!({"type": "signature_delta", "signature": "sig123"}),
477 );
478 }
479
480 #[test]
481 fn citations_delta_round_trips() {
482 use crate::messages::citation::{Citation, KnownCitation};
483 round_trip_delta(
484 &ContentDelta::Known(KnownContentDelta::CitationsDelta {
485 citation: Citation::Known(KnownCitation::CharLocation {
486 document_index: 0,
487 document_title: Some("Doc".into()),
488 cited_text: "hello".into(),
489 start_char_index: 0,
490 end_char_index: 5,
491 }),
492 }),
493 &json!({
494 "type": "citations_delta",
495 "citation": {
496 "type": "char_location",
497 "document_index": 0,
498 "document_title": "Doc",
499 "cited_text": "hello",
500 "start_char_index": 0,
501 "end_char_index": 5
502 }
503 }),
504 );
505 }
506
507 #[test]
508 fn unknown_delta_type_falls_back_to_other_preserving_json() {
509 let raw = json!({"type": "future_delta", "stuff": [1, 2]});
510 let d: ContentDelta = serde_json::from_value(raw.clone()).expect("deserialize");
511 assert!(d.other().is_some());
512 assert_eq!(d.type_tag(), Some("future_delta"));
513 let reserialized = serde_json::to_value(&d).expect("serialize");
514 assert_eq!(reserialized, raw);
515 }
516
517 #[test]
518 fn malformed_known_delta_is_an_error() {
519 let raw = json!({"type": "text_delta", "text": 42});
520 let result: Result<ContentDelta, _> = serde_json::from_value(raw);
521 assert!(result.is_err());
522 }
523
524 #[test]
527 fn golden_sequence_decodes_end_to_end() {
528 let events = vec![
529 json!({
530 "type": "message_start",
531 "message": {
532 "id": "msg_X",
533 "type": "message",
534 "role": "assistant",
535 "content": [],
536 "model": "claude-sonnet-4-6",
537 "usage": {"input_tokens": 10, "output_tokens": 0}
538 }
539 }),
540 json!({
541 "type": "content_block_start",
542 "index": 0,
543 "content_block": {"type": "text", "text": ""}
544 }),
545 json!({
546 "type": "content_block_delta",
547 "index": 0,
548 "delta": {"type": "text_delta", "text": "Hello"}
549 }),
550 json!({
551 "type": "content_block_delta",
552 "index": 0,
553 "delta": {"type": "text_delta", "text": " world"}
554 }),
555 json!({"type": "content_block_stop", "index": 0}),
556 json!({
557 "type": "message_delta",
558 "delta": {"stop_reason": "end_turn"},
559 "usage": {"input_tokens": 10, "output_tokens": 2}
560 }),
561 json!({"type": "message_stop"}),
562 ];
563
564 let parsed: Vec<StreamEvent> = events
565 .into_iter()
566 .map(|v| serde_json::from_value(v).expect("decode"))
567 .collect();
568
569 assert_eq!(parsed.len(), 7);
570 assert_eq!(parsed[0].type_tag(), Some("message_start"));
571 assert_eq!(parsed[6].type_tag(), Some("message_stop"));
572
573 match &parsed[2] {
575 StreamEvent::Known(KnownStreamEvent::ContentBlockDelta { delta, .. }) => match delta {
576 ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
577 assert_eq!(text, "Hello");
578 }
579 _ => panic!("expected TextDelta"),
580 },
581 _ => panic!("expected ContentBlockDelta"),
582 }
583 }
584}
585
586#[cfg(feature = "streaming")]
604#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
605pub struct EventStream {
606 inner: futures_util::stream::BoxStream<'static, Result<StreamEvent>>,
607 handlers: MessageStreamHandlers,
608}
609
610#[cfg(feature = "streaming")]
611impl EventStream {
612 pub(crate) fn from_response(response: reqwest::Response) -> Self {
614 use futures_util::StreamExt;
615 Self {
616 inner: crate::sse::into_typed_stream::<StreamEvent>(response).boxed(),
617 handlers: MessageStreamHandlers::default(),
618 }
619 }
620
621 #[cfg(test)]
625 fn from_events(events: Vec<Result<StreamEvent>>) -> Self {
626 use futures_util::StreamExt;
627 Self {
628 inner: futures_util::stream::iter(events).boxed(),
629 handlers: MessageStreamHandlers::default(),
630 }
631 }
632
633 #[must_use]
637 pub fn on_text_delta<F>(mut self, handler: F) -> Self
638 where
639 F: FnMut(&str) + Send + 'static,
640 {
641 self.handlers.text_delta = Some(Box::new(handler));
642 self
643 }
644
645 #[must_use]
650 pub fn on_tool_use_complete<F>(mut self, handler: F) -> Self
651 where
652 F: FnMut(&str, &str, &serde_json::Value) + Send + 'static,
653 {
654 self.handlers.tool_use_complete = Some(Box::new(handler));
655 self
656 }
657
658 #[must_use]
660 pub fn on_thinking_delta<F>(mut self, handler: F) -> Self
661 where
662 F: FnMut(&str) + Send + 'static,
663 {
664 self.handlers.thinking_delta = Some(Box::new(handler));
665 self
666 }
667
668 #[must_use]
671 pub fn on_message_stop<F>(mut self, handler: F) -> Self
672 where
673 F: FnMut(&Usage) + Send + 'static,
674 {
675 self.handlers.message_stop = Some(Box::new(handler));
676 self
677 }
678
679 #[must_use]
684 pub fn on_error<F>(mut self, handler: F) -> Self
685 where
686 F: FnMut(&Error) + Send + 'static,
687 {
688 self.handlers.error = Some(Box::new(handler));
689 self
690 }
691
692 pub async fn aggregate(self) -> Result<Message> {
699 use futures_util::StreamExt;
700 let Self {
701 mut inner,
702 handlers,
703 } = self;
704 let mut agg = Aggregator::with_handlers(handlers);
705 while let Some(event) = inner.next().await {
706 match event {
707 Ok(ev) => match agg.handle(ev) {
708 Ok(()) => {}
709 Err(e) => {
710 agg.fire_error(&e);
711 return Err(e);
712 }
713 },
714 Err(e) => {
715 agg.fire_error(&e);
716 return Err(e);
717 }
718 }
719 }
720 agg.finalize()
721 }
722}
723
724#[cfg(feature = "streaming")]
725type TextDeltaHandler = Box<dyn FnMut(&str) + Send>;
726#[cfg(feature = "streaming")]
727type ToolUseCompleteHandler = Box<dyn FnMut(&str, &str, &serde_json::Value) + Send>;
728#[cfg(feature = "streaming")]
729type ThinkingDeltaHandler = Box<dyn FnMut(&str) + Send>;
730#[cfg(feature = "streaming")]
731type MessageStopHandler = Box<dyn FnMut(&Usage) + Send>;
732#[cfg(feature = "streaming")]
733type ErrorHandler = Box<dyn FnMut(&Error) + Send>;
734
735#[cfg(feature = "streaming")]
737#[derive(Default)]
738struct MessageStreamHandlers {
739 text_delta: Option<TextDeltaHandler>,
740 tool_use_complete: Option<ToolUseCompleteHandler>,
741 thinking_delta: Option<ThinkingDeltaHandler>,
742 message_stop: Option<MessageStopHandler>,
743 error: Option<ErrorHandler>,
744}
745
746#[cfg(feature = "streaming")]
747impl std::fmt::Debug for MessageStreamHandlers {
748 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
749 f.debug_struct("MessageStreamHandlers")
750 .field("text_delta", &self.text_delta.as_ref().map(|_| "<fn>"))
751 .field(
752 "tool_use_complete",
753 &self.tool_use_complete.as_ref().map(|_| "<fn>"),
754 )
755 .field(
756 "thinking_delta",
757 &self.thinking_delta.as_ref().map(|_| "<fn>"),
758 )
759 .field("message_stop", &self.message_stop.as_ref().map(|_| "<fn>"))
760 .field("error", &self.error.as_ref().map(|_| "<fn>"))
761 .finish()
762 }
763}
764
765#[cfg(feature = "streaming")]
766impl futures_util::Stream for EventStream {
767 type Item = Result<StreamEvent>;
768
769 fn poll_next(
770 mut self: std::pin::Pin<&mut Self>,
771 cx: &mut std::task::Context<'_>,
772 ) -> std::task::Poll<Option<Self::Item>> {
773 self.inner.as_mut().poll_next(cx)
774 }
775}
776
777#[cfg(feature = "streaming")]
778impl std::fmt::Debug for EventStream {
779 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
780 f.debug_struct("EventStream").finish_non_exhaustive()
781 }
782}
783
784#[cfg(feature = "streaming")]
789#[derive(Debug, Default)]
790pub struct Aggregator {
791 message: Option<Message>,
792 blocks: Vec<ContentBlock>,
793 tool_input_buffers: std::collections::HashMap<u32, String>,
797 handlers: MessageStreamHandlers,
798}
799
800#[cfg(feature = "streaming")]
801impl Aggregator {
802 fn with_handlers(handlers: MessageStreamHandlers) -> Self {
805 Self {
806 handlers,
807 ..Self::default()
808 }
809 }
810
811 fn fire_error(&mut self, err: &Error) {
814 if let Some(h) = self.handlers.error.as_mut() {
815 h(err);
816 }
817 }
818
819 pub fn handle(&mut self, event: StreamEvent) -> Result<()> {
821 match event {
822 StreamEvent::Known(known) => self.handle_known(known),
823 StreamEvent::Other(value) => {
824 tracing::debug!(?value, "claude-api: ignoring unknown stream event");
825 Ok(())
826 }
827 }
828 }
829
830 fn handle_known(&mut self, event: KnownStreamEvent) -> Result<()> {
831 match event {
832 KnownStreamEvent::MessageStart { message } => {
833 self.message = Some(message);
834 }
835 KnownStreamEvent::ContentBlockStart {
836 index,
837 content_block,
838 } => {
839 if index as usize != self.blocks.len() {
840 return Err(Error::Stream(StreamError::Parse(format!(
841 "out-of-order content_block_start: index {} but {} blocks already received",
842 index,
843 self.blocks.len()
844 ))));
845 }
846 self.blocks.push(content_block);
847 }
848 KnownStreamEvent::ContentBlockDelta { index, delta } => {
849 self.apply_delta(index, delta);
850 }
851 KnownStreamEvent::ContentBlockStop { index } => {
852 if let Some(buf) = self.tool_input_buffers.remove(&index) {
853 self.finalize_tool_input(index, &buf);
854 }
855 if let Some(handler) = self.handlers.tool_use_complete.as_mut()
857 && let Some(ContentBlock::Known(
858 KnownBlock::ToolUse { id, name, input }
859 | KnownBlock::ServerToolUse { id, name, input },
860 )) = self.blocks.get(index as usize)
861 {
862 handler(id, name, input);
863 }
864 }
865 KnownStreamEvent::MessageDelta { delta, usage } => {
866 if let Some(msg) = self.message.as_mut() {
867 if let Some(sr) = delta.stop_reason {
868 msg.stop_reason = Some(sr);
869 }
870 if let Some(ss) = delta.stop_sequence {
871 msg.stop_sequence = Some(ss);
872 }
873 msg.usage = usage;
874 }
875 }
876 KnownStreamEvent::MessageStop => {
877 if let Some(handler) = self.handlers.message_stop.as_mut()
878 && let Some(msg) = self.message.as_ref()
879 {
880 handler(&msg.usage);
881 }
882 }
883 KnownStreamEvent::Ping => {}
884 KnownStreamEvent::Error { error } => {
885 return Err(Error::Stream(StreamError::Server {
886 kind: error.kind,
887 message: error.message,
888 }));
889 }
890 }
891 Ok(())
892 }
893
894 fn apply_delta(&mut self, index: u32, delta: ContentDelta) {
895 let Some(block) = self.blocks.get_mut(index as usize) else {
896 tracing::warn!(index, "claude-api: delta for unknown block index, dropping");
897 return;
898 };
899 match delta {
900 ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
901 if let ContentBlock::Known(KnownBlock::Text { text: existing, .. }) = block {
902 existing.push_str(&text);
903 }
904 if let Some(handler) = self.handlers.text_delta.as_mut() {
905 handler(&text);
906 }
907 }
908 ContentDelta::Known(KnownContentDelta::InputJsonDelta { partial_json }) => {
909 self.tool_input_buffers
910 .entry(index)
911 .or_default()
912 .push_str(&partial_json);
913 }
914 ContentDelta::Known(KnownContentDelta::ThinkingDelta { thinking }) => {
915 if let ContentBlock::Known(KnownBlock::Thinking {
916 thinking: existing, ..
917 }) = block
918 {
919 existing.push_str(&thinking);
920 }
921 if let Some(handler) = self.handlers.thinking_delta.as_mut() {
922 handler(&thinking);
923 }
924 }
925 ContentDelta::Known(KnownContentDelta::SignatureDelta { signature }) => {
926 if let ContentBlock::Known(KnownBlock::Thinking { signature: sig, .. }) = block {
927 *sig = signature;
928 }
929 }
930 ContentDelta::Known(KnownContentDelta::CitationsDelta { citation }) => {
931 if let ContentBlock::Known(KnownBlock::Text { citations, .. }) = block {
932 citations.get_or_insert_with(Vec::new).push(citation);
933 }
934 }
935 ContentDelta::Other(value) => {
936 tracing::debug!(?value, "claude-api: ignoring unknown content delta");
937 }
938 }
939 }
940
941 fn finalize_tool_input(&mut self, index: u32, buffer: &str) {
942 let Some(block) = self.blocks.get_mut(index as usize) else {
943 return;
944 };
945 let parsed = if buffer.is_empty() {
946 return;
948 } else {
949 serde_json::from_str::<serde_json::Value>(buffer).unwrap_or_else(|e| {
950 tracing::warn!(
951 error = %e,
952 "claude-api: tool_use input failed to parse; storing raw string"
953 );
954 serde_json::Value::String(buffer.to_owned())
955 })
956 };
957 match block {
958 ContentBlock::Known(
959 KnownBlock::ToolUse { input, .. } | KnownBlock::ServerToolUse { input, .. },
960 ) => {
961 *input = parsed;
962 }
963 _ => {
964 tracing::warn!(
965 index,
966 "claude-api: input_json_delta accumulated for non-tool-use block"
967 );
968 }
969 }
970 }
971
972 pub fn finalize(mut self) -> Result<Message> {
975 let mut message = self.message.take().ok_or_else(|| {
976 Error::Stream(StreamError::Parse(
977 "stream ended without a message_start event".into(),
978 ))
979 })?;
980 message.content = self.blocks;
981 Ok(message)
982 }
983}
984
985#[cfg(all(test, feature = "streaming"))]
986mod aggregator_tests {
987 use super::*;
988 use crate::error::{ApiErrorKind, ApiErrorPayload};
989 use crate::types::{ModelId, Role};
990 use pretty_assertions::assert_eq;
991 use serde_json::json;
992
993 fn message_start_event() -> StreamEvent {
994 StreamEvent::Known(KnownStreamEvent::MessageStart {
995 message: serde_json::from_value(json!({
996 "id": "msg_x",
997 "type": "message",
998 "role": "assistant",
999 "content": [],
1000 "model": "claude-sonnet-4-6",
1001 "usage": {"input_tokens": 5, "output_tokens": 0}
1002 }))
1003 .unwrap(),
1004 })
1005 }
1006
1007 #[test]
1008 fn aggregator_reconstructs_text_message() {
1009 let mut agg = Aggregator::default();
1010 agg.handle(message_start_event()).unwrap();
1011 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1012 index: 0,
1013 content_block: ContentBlock::text(""),
1014 }))
1015 .unwrap();
1016 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1017 index: 0,
1018 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1019 text: "Hello".into(),
1020 }),
1021 }))
1022 .unwrap();
1023 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1024 index: 0,
1025 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1026 text: " world".into(),
1027 }),
1028 }))
1029 .unwrap();
1030 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1031 index: 0,
1032 }))
1033 .unwrap();
1034 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1035 delta: MessageDelta {
1036 stop_reason: Some(StopReason::EndTurn),
1037 stop_sequence: None,
1038 },
1039 usage: Usage {
1040 input_tokens: 5,
1041 output_tokens: 2,
1042 ..Usage::default()
1043 },
1044 }))
1045 .unwrap();
1046 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1047 .unwrap();
1048
1049 let msg = agg.finalize().unwrap();
1050 assert_eq!(msg.id, "msg_x");
1051 assert_eq!(msg.role, Role::Assistant);
1052 assert_eq!(msg.model, ModelId::SONNET_4_6);
1053 assert_eq!(msg.stop_reason, Some(StopReason::EndTurn));
1054 assert_eq!(msg.usage.output_tokens, 2);
1055 assert_eq!(msg.content.len(), 1);
1056 match &msg.content[0] {
1057 ContentBlock::Known(KnownBlock::Text { text, .. }) => {
1058 assert_eq!(text, "Hello world");
1059 }
1060 _ => panic!("expected text block"),
1061 }
1062 }
1063
1064 #[test]
1065 fn aggregator_reconstructs_tool_use_input_from_partial_json_deltas() {
1066 let mut agg = Aggregator::default();
1067 agg.handle(message_start_event()).unwrap();
1068 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1069 index: 0,
1070 content_block: ContentBlock::Known(KnownBlock::ToolUse {
1071 id: "toolu_1".into(),
1072 name: "get_weather".into(),
1073 input: json!({}),
1074 }),
1075 }))
1076 .unwrap();
1077 for chunk in ["{\"city\":", "\"Paris\"", ",\"unit\":\"C\"}"] {
1078 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1079 index: 0,
1080 delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1081 partial_json: chunk.into(),
1082 }),
1083 }))
1084 .unwrap();
1085 }
1086 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1087 index: 0,
1088 }))
1089 .unwrap();
1090 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1091 .unwrap();
1092
1093 let msg = agg.finalize().unwrap();
1094 match &msg.content[0] {
1095 ContentBlock::Known(KnownBlock::ToolUse { input, name, .. }) => {
1096 assert_eq!(name, "get_weather");
1097 assert_eq!(input, &json!({"city": "Paris", "unit": "C"}));
1098 }
1099 _ => panic!("expected ToolUse block"),
1100 }
1101 }
1102
1103 #[test]
1104 fn aggregator_reconstructs_thinking_block() {
1105 let mut agg = Aggregator::default();
1106 agg.handle(message_start_event()).unwrap();
1107 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1108 index: 0,
1109 content_block: ContentBlock::Known(KnownBlock::Thinking {
1110 thinking: String::new(),
1111 signature: String::new(),
1112 }),
1113 }))
1114 .unwrap();
1115 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1116 index: 0,
1117 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1118 thinking: "let me ".into(),
1119 }),
1120 }))
1121 .unwrap();
1122 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1123 index: 0,
1124 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1125 thinking: "think".into(),
1126 }),
1127 }))
1128 .unwrap();
1129 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1130 index: 0,
1131 delta: ContentDelta::Known(KnownContentDelta::SignatureDelta {
1132 signature: "sig_xyz".into(),
1133 }),
1134 }))
1135 .unwrap();
1136 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1137 index: 0,
1138 }))
1139 .unwrap();
1140 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1141 .unwrap();
1142
1143 let msg = agg.finalize().unwrap();
1144 match &msg.content[0] {
1145 ContentBlock::Known(KnownBlock::Thinking {
1146 thinking,
1147 signature,
1148 }) => {
1149 assert_eq!(thinking, "let me think");
1150 assert_eq!(signature, "sig_xyz");
1151 }
1152 _ => panic!("expected Thinking block"),
1153 }
1154 }
1155
1156 #[test]
1157 fn aggregator_unknown_event_is_ignored() {
1158 let mut agg = Aggregator::default();
1159 agg.handle(message_start_event()).unwrap();
1160 agg.handle(StreamEvent::Other(json!({"type": "future_event"})))
1162 .unwrap();
1163 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1164 .unwrap();
1165 let msg = agg.finalize().unwrap();
1166 assert!(msg.content.is_empty());
1167 }
1168
1169 #[test]
1170 fn aggregator_unknown_delta_is_ignored() {
1171 let mut agg = Aggregator::default();
1172 agg.handle(message_start_event()).unwrap();
1173 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1174 index: 0,
1175 content_block: ContentBlock::text(""),
1176 }))
1177 .unwrap();
1178 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1179 index: 0,
1180 delta: ContentDelta::Other(json!({"type": "future_delta"})),
1181 }))
1182 .unwrap();
1183 }
1185
1186 #[test]
1187 fn aggregator_server_error_event_propagates() {
1188 let mut agg = Aggregator::default();
1189 agg.handle(message_start_event()).unwrap();
1190 let err = agg
1191 .handle(StreamEvent::Known(KnownStreamEvent::Error {
1192 error: ApiErrorPayload {
1193 kind: ApiErrorKind::OverloadedError,
1194 message: "boom".into(),
1195 },
1196 }))
1197 .unwrap_err();
1198 match err {
1199 Error::Stream(StreamError::Server { kind, message }) => {
1200 assert_eq!(kind, ApiErrorKind::OverloadedError);
1201 assert_eq!(message, "boom");
1202 }
1203 other => panic!("expected Stream::Server, got {other:?}"),
1204 }
1205 }
1206
1207 #[test]
1208 fn aggregator_out_of_order_block_start_errors() {
1209 let mut agg = Aggregator::default();
1210 agg.handle(message_start_event()).unwrap();
1211 let err = agg
1213 .handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1214 index: 1,
1215 content_block: ContentBlock::text(""),
1216 }))
1217 .unwrap_err();
1218 assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1219 }
1220
1221 #[test]
1222 fn aggregator_finalize_without_message_start_errors() {
1223 let agg = Aggregator::default();
1224 let err = agg.finalize().unwrap_err();
1225 assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1226 }
1227}
1228
1229#[cfg(all(test, feature = "streaming"))]
1230mod stream_callback_tests {
1231 use super::*;
1232 use crate::error::{ApiErrorKind, ApiErrorPayload};
1233 use pretty_assertions::assert_eq;
1234 use serde_json::json;
1235 use std::sync::{Arc, Mutex};
1236
1237 fn message_start_event() -> StreamEvent {
1238 StreamEvent::Known(KnownStreamEvent::MessageStart {
1239 message: serde_json::from_value(json!({
1240 "id": "msg_x",
1241 "type": "message",
1242 "role": "assistant",
1243 "content": [],
1244 "model": "claude-sonnet-4-6",
1245 "usage": {"input_tokens": 5, "output_tokens": 0}
1246 }))
1247 .unwrap(),
1248 })
1249 }
1250
1251 #[tokio::test]
1252 async fn on_text_delta_fires_for_each_text_chunk() {
1253 let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1254 let sink = Arc::clone(&captured);
1255 let events = vec![
1256 Ok(message_start_event()),
1257 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1258 index: 0,
1259 content_block: ContentBlock::text(""),
1260 })),
1261 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1262 index: 0,
1263 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1264 text: "Hello".into(),
1265 }),
1266 })),
1267 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1268 index: 0,
1269 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1270 text: " world".into(),
1271 }),
1272 })),
1273 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1274 index: 0,
1275 })),
1276 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1277 ];
1278
1279 let stream = EventStream::from_events(events).on_text_delta(move |chunk| {
1280 sink.lock().unwrap().push(chunk.to_string());
1281 });
1282 stream.aggregate().await.unwrap();
1283
1284 assert_eq!(*captured.lock().unwrap(), vec!["Hello", " world"]);
1285 }
1286
1287 #[tokio::test]
1288 async fn on_tool_use_complete_fires_with_parsed_input() {
1289 let captured: Arc<Mutex<Vec<(String, String, serde_json::Value)>>> =
1290 Arc::new(Mutex::new(Vec::new()));
1291 let sink = Arc::clone(&captured);
1292 let events = vec![
1293 Ok(message_start_event()),
1294 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1295 index: 0,
1296 content_block: ContentBlock::Known(KnownBlock::ToolUse {
1297 id: "toolu_1".into(),
1298 name: "get_weather".into(),
1299 input: json!({}),
1300 }),
1301 })),
1302 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1303 index: 0,
1304 delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1305 partial_json: "{\"city\":\"Paris\"}".into(),
1306 }),
1307 })),
1308 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1309 index: 0,
1310 })),
1311 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1312 ];
1313
1314 let stream =
1315 EventStream::from_events(events).on_tool_use_complete(move |id, name, input| {
1316 sink.lock()
1317 .unwrap()
1318 .push((id.to_string(), name.to_string(), input.clone()));
1319 });
1320 stream.aggregate().await.unwrap();
1321
1322 let captured = captured.lock().unwrap();
1323 assert_eq!(captured.len(), 1);
1324 assert_eq!(captured[0].0, "toolu_1");
1325 assert_eq!(captured[0].1, "get_weather");
1326 assert_eq!(captured[0].2, json!({"city": "Paris"}));
1327 }
1328
1329 #[tokio::test]
1330 async fn on_tool_use_complete_fires_for_server_tool_use_blocks() {
1331 let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1332 let sink = Arc::clone(&captured);
1333 let events = vec![
1334 Ok(message_start_event()),
1335 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1336 index: 0,
1337 content_block: ContentBlock::Known(KnownBlock::ServerToolUse {
1338 id: "srvu_1".into(),
1339 name: "web_search".into(),
1340 input: json!({}),
1341 }),
1342 })),
1343 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1344 index: 0,
1345 delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1346 partial_json: "{\"q\":\"rust\"}".into(),
1347 }),
1348 })),
1349 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1350 index: 0,
1351 })),
1352 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1353 ];
1354
1355 let stream = EventStream::from_events(events).on_tool_use_complete(move |id, _, _| {
1356 sink.lock().unwrap().push(id.to_string());
1357 });
1358 stream.aggregate().await.unwrap();
1359
1360 assert_eq!(*captured.lock().unwrap(), vec!["srvu_1"]);
1361 }
1362
1363 #[tokio::test]
1364 async fn on_thinking_delta_fires_for_each_thinking_chunk() {
1365 let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1366 let sink = Arc::clone(&captured);
1367 let events = vec![
1368 Ok(message_start_event()),
1369 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1370 index: 0,
1371 content_block: ContentBlock::Known(KnownBlock::Thinking {
1372 thinking: String::new(),
1373 signature: String::new(),
1374 }),
1375 })),
1376 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1377 index: 0,
1378 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1379 thinking: "let me ".into(),
1380 }),
1381 })),
1382 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1383 index: 0,
1384 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1385 thinking: "think".into(),
1386 }),
1387 })),
1388 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1389 index: 0,
1390 })),
1391 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1392 ];
1393
1394 let stream = EventStream::from_events(events).on_thinking_delta(move |chunk| {
1395 sink.lock().unwrap().push(chunk.to_string());
1396 });
1397 stream.aggregate().await.unwrap();
1398
1399 assert_eq!(*captured.lock().unwrap(), vec!["let me ", "think"]);
1400 }
1401
1402 #[tokio::test]
1403 async fn on_message_stop_fires_once_with_usage() {
1404 let captured: Arc<Mutex<Vec<Usage>>> = Arc::new(Mutex::new(Vec::new()));
1405 let sink = Arc::clone(&captured);
1406 let events = vec![
1407 Ok(message_start_event()),
1408 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1409 index: 0,
1410 content_block: ContentBlock::text(""),
1411 })),
1412 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1413 index: 0,
1414 delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1415 })),
1416 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1417 index: 0,
1418 })),
1419 Ok(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1420 delta: MessageDelta {
1421 stop_reason: Some(StopReason::EndTurn),
1422 stop_sequence: None,
1423 },
1424 usage: Usage {
1425 input_tokens: 5,
1426 output_tokens: 7,
1427 ..Usage::default()
1428 },
1429 })),
1430 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1431 ];
1432
1433 let stream = EventStream::from_events(events).on_message_stop(move |usage| {
1434 sink.lock().unwrap().push(usage.clone());
1435 });
1436 stream.aggregate().await.unwrap();
1437
1438 let captured = captured.lock().unwrap();
1439 assert_eq!(captured.len(), 1);
1440 assert_eq!(captured[0].input_tokens, 5);
1441 assert_eq!(captured[0].output_tokens, 7);
1442 }
1443
1444 #[tokio::test]
1445 async fn on_error_fires_before_propagating_server_error() {
1446 let count = Arc::new(Mutex::new(0u32));
1447 let sink = Arc::clone(&count);
1448 let events = vec![
1449 Ok(message_start_event()),
1450 Ok(StreamEvent::Known(KnownStreamEvent::Error {
1451 error: ApiErrorPayload {
1452 kind: ApiErrorKind::OverloadedError,
1453 message: "boom".into(),
1454 },
1455 })),
1456 ];
1457
1458 let stream = EventStream::from_events(events).on_error(move |_| {
1459 *sink.lock().unwrap() += 1;
1460 });
1461 let err = stream.aggregate().await.unwrap_err();
1462 assert!(matches!(
1463 err,
1464 Error::Stream(StreamError::Server {
1465 kind: ApiErrorKind::OverloadedError,
1466 ..
1467 })
1468 ));
1469 assert_eq!(
1470 *count.lock().unwrap(),
1471 1,
1472 "handler should fire exactly once"
1473 );
1474 }
1475
1476 #[tokio::test]
1477 async fn on_error_fires_for_transport_error() {
1478 let count = Arc::new(Mutex::new(0u32));
1479 let sink = Arc::clone(&count);
1480 let events: Vec<Result<StreamEvent>> = vec![
1481 Ok(message_start_event()),
1482 Err(Error::Stream(StreamError::Parse("decode failed".into()))),
1483 ];
1484
1485 let stream = EventStream::from_events(events).on_error(move |_| {
1486 *sink.lock().unwrap() += 1;
1487 });
1488 let err = stream.aggregate().await.unwrap_err();
1489 assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1490 assert_eq!(*count.lock().unwrap(), 1);
1491 }
1492
1493 #[tokio::test]
1494 async fn raw_stream_iteration_does_not_fire_callbacks() {
1495 use futures_util::StreamExt;
1496 let count = Arc::new(Mutex::new(0u32));
1497 let sink = Arc::clone(&count);
1498 let events = vec![
1499 Ok(message_start_event()),
1500 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1501 index: 0,
1502 content_block: ContentBlock::text(""),
1503 })),
1504 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1505 index: 0,
1506 delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1507 })),
1508 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1509 index: 0,
1510 })),
1511 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1512 ];
1513
1514 let mut stream = EventStream::from_events(events).on_text_delta(move |_| {
1515 *sink.lock().unwrap() += 1;
1516 });
1517 while let Some(_ev) = stream.next().await {}
1518 assert_eq!(*count.lock().unwrap(), 0);
1520 }
1521}