1use serde::{Deserialize, Serialize};
48
49use crate::error::ApiErrorPayload;
50use crate::forward_compat::dispatch_known_or_other;
51use crate::messages::content::ContentBlock;
52use crate::messages::response::Message;
53use crate::types::{StopReason, Usage};
54
55#[cfg(feature = "streaming")]
56use crate::error::{Error, Result, StreamError};
57#[cfg(feature = "streaming")]
58use crate::messages::content::KnownBlock;
59
60#[allow(clippy::large_enum_variant)]
69#[derive(Debug, Clone, PartialEq)]
70pub enum StreamEvent {
71 Known(KnownStreamEvent),
73 Other(serde_json::Value),
75}
76
77#[allow(clippy::large_enum_variant)]
79#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80#[serde(tag = "type", rename_all = "snake_case")]
81#[non_exhaustive]
82pub enum KnownStreamEvent {
83 MessageStart {
86 message: Message,
88 },
89 ContentBlockStart {
91 index: u32,
93 content_block: ContentBlock,
95 },
96 ContentBlockDelta {
98 index: u32,
100 delta: ContentDelta,
102 },
103 ContentBlockStop {
105 index: u32,
107 },
108 MessageDelta {
110 delta: MessageDelta,
112 usage: Usage,
114 },
115 MessageStop,
117 Ping,
119 Error {
121 error: ApiErrorPayload,
123 },
124}
125
126const KNOWN_EVENT_TAGS: &[&str] = &[
128 "message_start",
129 "content_block_start",
130 "content_block_delta",
131 "content_block_stop",
132 "message_delta",
133 "message_stop",
134 "ping",
135 "error",
136];
137
138impl Serialize for StreamEvent {
139 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
140 match self {
141 StreamEvent::Known(k) => k.serialize(s),
142 StreamEvent::Other(v) => v.serialize(s),
143 }
144 }
145}
146
147impl<'de> Deserialize<'de> for StreamEvent {
148 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
149 let raw = serde_json::Value::deserialize(d)?;
150 dispatch_known_or_other(
151 raw,
152 KNOWN_EVENT_TAGS,
153 StreamEvent::Known,
154 StreamEvent::Other,
155 )
156 .map_err(serde::de::Error::custom)
157 }
158}
159
160impl From<KnownStreamEvent> for StreamEvent {
161 fn from(k: KnownStreamEvent) -> Self {
162 StreamEvent::Known(k)
163 }
164}
165
166impl StreamEvent {
167 pub fn known(&self) -> Option<&KnownStreamEvent> {
169 match self {
170 Self::Known(k) => Some(k),
171 Self::Other(_) => None,
172 }
173 }
174
175 pub fn other(&self) -> Option<&serde_json::Value> {
177 match self {
178 Self::Other(v) => Some(v),
179 Self::Known(_) => None,
180 }
181 }
182
183 pub fn type_tag(&self) -> Option<&str> {
185 match self {
186 Self::Known(k) => Some(known_event_tag(k)),
187 Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
188 }
189 }
190}
191
192fn known_event_tag(k: &KnownStreamEvent) -> &'static str {
193 match k {
194 KnownStreamEvent::MessageStart { .. } => "message_start",
195 KnownStreamEvent::ContentBlockStart { .. } => "content_block_start",
196 KnownStreamEvent::ContentBlockDelta { .. } => "content_block_delta",
197 KnownStreamEvent::ContentBlockStop { .. } => "content_block_stop",
198 KnownStreamEvent::MessageDelta { .. } => "message_delta",
199 KnownStreamEvent::MessageStop => "message_stop",
200 KnownStreamEvent::Ping => "ping",
201 KnownStreamEvent::Error { .. } => "error",
202 }
203}
204
205#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
208#[non_exhaustive]
209pub struct MessageDelta {
210 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub stop_reason: Option<StopReason>,
213 #[serde(default, skip_serializing_if = "Option::is_none")]
215 pub stop_sequence: Option<String>,
216}
217
218#[derive(Debug, Clone, PartialEq)]
222pub enum ContentDelta {
223 Known(KnownContentDelta),
225 Other(serde_json::Value),
227}
228
229#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
231#[serde(tag = "type", rename_all = "snake_case")]
232#[non_exhaustive]
233pub enum KnownContentDelta {
234 TextDelta {
236 text: String,
238 },
239 InputJsonDelta {
241 partial_json: String,
243 },
244 ThinkingDelta {
246 thinking: String,
248 },
249 SignatureDelta {
251 signature: String,
253 },
254 CitationsDelta {
256 citation: crate::messages::citation::Citation,
258 },
259}
260
261const KNOWN_DELTA_TAGS: &[&str] = &[
262 "text_delta",
263 "input_json_delta",
264 "thinking_delta",
265 "signature_delta",
266 "citations_delta",
267];
268
269impl Serialize for ContentDelta {
270 fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
271 match self {
272 ContentDelta::Known(k) => k.serialize(s),
273 ContentDelta::Other(v) => v.serialize(s),
274 }
275 }
276}
277
278impl<'de> Deserialize<'de> for ContentDelta {
279 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
280 let raw = serde_json::Value::deserialize(d)?;
281 dispatch_known_or_other(
282 raw,
283 KNOWN_DELTA_TAGS,
284 ContentDelta::Known,
285 ContentDelta::Other,
286 )
287 .map_err(serde::de::Error::custom)
288 }
289}
290
291impl From<KnownContentDelta> for ContentDelta {
292 fn from(k: KnownContentDelta) -> Self {
293 ContentDelta::Known(k)
294 }
295}
296
297impl ContentDelta {
298 pub fn known(&self) -> Option<&KnownContentDelta> {
300 match self {
301 Self::Known(k) => Some(k),
302 Self::Other(_) => None,
303 }
304 }
305
306 pub fn other(&self) -> Option<&serde_json::Value> {
308 match self {
309 Self::Other(v) => Some(v),
310 Self::Known(_) => None,
311 }
312 }
313
314 pub fn type_tag(&self) -> Option<&str> {
316 match self {
317 Self::Known(k) => Some(match k {
318 KnownContentDelta::TextDelta { .. } => "text_delta",
319 KnownContentDelta::InputJsonDelta { .. } => "input_json_delta",
320 KnownContentDelta::ThinkingDelta { .. } => "thinking_delta",
321 KnownContentDelta::SignatureDelta { .. } => "signature_delta",
322 KnownContentDelta::CitationsDelta { .. } => "citations_delta",
323 }),
324 Self::Other(v) => v.get("type").and_then(serde_json::Value::as_str),
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use crate::error::ApiErrorKind;
333 use pretty_assertions::assert_eq;
334 use serde_json::json;
335
336 fn round_trip_event(event: &StreamEvent, expected: &serde_json::Value) {
337 let v = serde_json::to_value(event).expect("serialize");
338 assert_eq!(&v, expected, "wire form mismatch");
339 let parsed: StreamEvent = serde_json::from_value(v).expect("deserialize");
340 assert_eq!(&parsed, event, "round-trip mismatch");
341 }
342
343 fn round_trip_delta(delta: &ContentDelta, expected: &serde_json::Value) {
344 let v = serde_json::to_value(delta).expect("serialize");
345 assert_eq!(&v, expected, "wire form mismatch");
346 let parsed: ContentDelta = serde_json::from_value(v).expect("deserialize");
347 assert_eq!(&parsed, delta, "round-trip mismatch");
348 }
349
350 #[test]
353 fn message_stop_round_trips() {
354 round_trip_event(
355 &StreamEvent::Known(KnownStreamEvent::MessageStop),
356 &json!({"type": "message_stop"}),
357 );
358 }
359
360 #[test]
361 fn ping_round_trips() {
362 round_trip_event(
363 &StreamEvent::Known(KnownStreamEvent::Ping),
364 &json!({"type": "ping"}),
365 );
366 }
367
368 #[test]
369 fn content_block_start_round_trips() {
370 let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
371 index: 0,
372 content_block: ContentBlock::text(""),
373 });
374 round_trip_event(
375 &ev,
376 &json!({
377 "type": "content_block_start",
378 "index": 0,
379 "content_block": {"type": "text", "text": ""}
380 }),
381 );
382 }
383
384 #[test]
385 fn content_block_delta_round_trips() {
386 let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
387 index: 1,
388 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
389 text: "Hello".into(),
390 }),
391 });
392 round_trip_event(
393 &ev,
394 &json!({
395 "type": "content_block_delta",
396 "index": 1,
397 "delta": {"type": "text_delta", "text": "Hello"}
398 }),
399 );
400 }
401
402 #[test]
403 fn content_block_stop_round_trips() {
404 let ev = StreamEvent::Known(KnownStreamEvent::ContentBlockStop { index: 2 });
405 round_trip_event(&ev, &json!({"type": "content_block_stop", "index": 2}));
406 }
407
408 #[test]
409 fn message_delta_round_trips() {
410 let ev = StreamEvent::Known(KnownStreamEvent::MessageDelta {
411 delta: MessageDelta {
412 stop_reason: Some(StopReason::EndTurn),
413 stop_sequence: None,
414 },
415 usage: Usage {
416 input_tokens: 5,
417 output_tokens: 10,
418 ..Usage::default()
419 },
420 });
421 round_trip_event(
422 &ev,
423 &json!({
424 "type": "message_delta",
425 "delta": {"stop_reason": "end_turn"},
426 "usage": {"input_tokens": 5, "output_tokens": 10}
427 }),
428 );
429 }
430
431 #[test]
432 fn error_event_round_trips() {
433 let ev = StreamEvent::Known(KnownStreamEvent::Error {
434 error: ApiErrorPayload {
435 kind: ApiErrorKind::OverloadedError,
436 message: "try again".into(),
437 },
438 });
439 round_trip_event(
440 &ev,
441 &json!({
442 "type": "error",
443 "error": {"type": "overloaded_error", "message": "try again"}
444 }),
445 );
446 }
447
448 #[test]
451 fn unknown_event_type_falls_back_to_other_preserving_json() {
452 let raw = json!({
453 "type": "future_event",
454 "payload": {"x": 1, "y": [2, 3]}
455 });
456 let ev: StreamEvent = serde_json::from_value(raw.clone()).expect("deserialize");
457 assert!(ev.other().is_some());
458 assert_eq!(ev.type_tag(), Some("future_event"));
459
460 let reserialized = serde_json::to_value(&ev).expect("serialize");
461 assert_eq!(reserialized, raw, "Other must round-trip byte-for-byte");
462 }
463
464 #[test]
465 fn malformed_known_event_is_an_error() {
466 let raw = json!({"type": "content_block_stop", "index": "nope"});
468 let result: Result<StreamEvent, _> = serde_json::from_value(raw);
469 assert!(
470 result.is_err(),
471 "malformed known event must error, not silently fall through to Other"
472 );
473 }
474
475 #[test]
478 fn text_delta_round_trips() {
479 round_trip_delta(
480 &ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
481 &json!({"type": "text_delta", "text": "hi"}),
482 );
483 }
484
485 #[test]
486 fn input_json_delta_round_trips() {
487 round_trip_delta(
488 &ContentDelta::Known(KnownContentDelta::InputJsonDelta {
489 partial_json: r#"{"city":"P"#.into(),
490 }),
491 &json!({"type": "input_json_delta", "partial_json": "{\"city\":\"P"}),
492 );
493 }
494
495 #[test]
496 fn thinking_delta_round_trips() {
497 round_trip_delta(
498 &ContentDelta::Known(KnownContentDelta::ThinkingDelta {
499 thinking: " more thinking".into(),
500 }),
501 &json!({"type": "thinking_delta", "thinking": " more thinking"}),
502 );
503 }
504
505 #[test]
506 fn signature_delta_round_trips() {
507 round_trip_delta(
508 &ContentDelta::Known(KnownContentDelta::SignatureDelta {
509 signature: "sig123".into(),
510 }),
511 &json!({"type": "signature_delta", "signature": "sig123"}),
512 );
513 }
514
515 #[test]
516 fn citations_delta_round_trips() {
517 use crate::messages::citation::{Citation, KnownCitation};
518 round_trip_delta(
519 &ContentDelta::Known(KnownContentDelta::CitationsDelta {
520 citation: Citation::Known(KnownCitation::CharLocation {
521 document_index: 0,
522 document_title: Some("Doc".into()),
523 cited_text: "hello".into(),
524 start_char_index: 0,
525 end_char_index: 5,
526 }),
527 }),
528 &json!({
529 "type": "citations_delta",
530 "citation": {
531 "type": "char_location",
532 "document_index": 0,
533 "document_title": "Doc",
534 "cited_text": "hello",
535 "start_char_index": 0,
536 "end_char_index": 5
537 }
538 }),
539 );
540 }
541
542 #[test]
543 fn unknown_delta_type_falls_back_to_other_preserving_json() {
544 let raw = json!({"type": "future_delta", "stuff": [1, 2]});
545 let d: ContentDelta = serde_json::from_value(raw.clone()).expect("deserialize");
546 assert!(d.other().is_some());
547 assert_eq!(d.type_tag(), Some("future_delta"));
548 let reserialized = serde_json::to_value(&d).expect("serialize");
549 assert_eq!(reserialized, raw);
550 }
551
552 #[test]
553 fn malformed_known_delta_is_an_error() {
554 let raw = json!({"type": "text_delta", "text": 42});
555 let result: Result<ContentDelta, _> = serde_json::from_value(raw);
556 assert!(result.is_err());
557 }
558
559 #[test]
562 fn golden_sequence_decodes_end_to_end() {
563 let events = vec![
564 json!({
565 "type": "message_start",
566 "message": {
567 "id": "msg_X",
568 "type": "message",
569 "role": "assistant",
570 "content": [],
571 "model": "claude-sonnet-4-6",
572 "usage": {"input_tokens": 10, "output_tokens": 0}
573 }
574 }),
575 json!({
576 "type": "content_block_start",
577 "index": 0,
578 "content_block": {"type": "text", "text": ""}
579 }),
580 json!({
581 "type": "content_block_delta",
582 "index": 0,
583 "delta": {"type": "text_delta", "text": "Hello"}
584 }),
585 json!({
586 "type": "content_block_delta",
587 "index": 0,
588 "delta": {"type": "text_delta", "text": " world"}
589 }),
590 json!({"type": "content_block_stop", "index": 0}),
591 json!({
592 "type": "message_delta",
593 "delta": {"stop_reason": "end_turn"},
594 "usage": {"input_tokens": 10, "output_tokens": 2}
595 }),
596 json!({"type": "message_stop"}),
597 ];
598
599 let parsed: Vec<StreamEvent> = events
600 .into_iter()
601 .map(|v| serde_json::from_value(v).expect("decode"))
602 .collect();
603
604 assert_eq!(parsed.len(), 7);
605 assert_eq!(parsed[0].type_tag(), Some("message_start"));
606 assert_eq!(parsed[6].type_tag(), Some("message_stop"));
607
608 match &parsed[2] {
610 StreamEvent::Known(KnownStreamEvent::ContentBlockDelta { delta, .. }) => match delta {
611 ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
612 assert_eq!(text, "Hello");
613 }
614 _ => panic!("expected TextDelta"),
615 },
616 _ => panic!("expected ContentBlockDelta"),
617 }
618 }
619}
620
621#[cfg(feature = "streaming")]
639#[cfg_attr(docsrs, doc(cfg(feature = "streaming")))]
640pub struct EventStream {
641 inner: futures_util::stream::BoxStream<'static, Result<StreamEvent>>,
642 handlers: MessageStreamHandlers,
643}
644
645#[cfg(feature = "streaming")]
646impl EventStream {
647 pub(crate) fn from_response(response: reqwest::Response) -> Self {
649 use futures_util::StreamExt;
650 Self {
651 inner: crate::sse::into_typed_stream::<StreamEvent>(response).boxed(),
652 handlers: MessageStreamHandlers::default(),
653 }
654 }
655
656 #[cfg(test)]
660 fn from_events(events: Vec<Result<StreamEvent>>) -> Self {
661 use futures_util::StreamExt;
662 Self {
663 inner: futures_util::stream::iter(events).boxed(),
664 handlers: MessageStreamHandlers::default(),
665 }
666 }
667
668 #[must_use]
672 pub fn on_text_delta<F>(mut self, handler: F) -> Self
673 where
674 F: FnMut(&str) + Send + 'static,
675 {
676 self.handlers.text_delta = Some(Box::new(handler));
677 self
678 }
679
680 #[must_use]
685 pub fn on_tool_use_complete<F>(mut self, handler: F) -> Self
686 where
687 F: FnMut(&str, &str, &serde_json::Value) + Send + 'static,
688 {
689 self.handlers.tool_use_complete = Some(Box::new(handler));
690 self
691 }
692
693 #[must_use]
695 pub fn on_thinking_delta<F>(mut self, handler: F) -> Self
696 where
697 F: FnMut(&str) + Send + 'static,
698 {
699 self.handlers.thinking_delta = Some(Box::new(handler));
700 self
701 }
702
703 #[must_use]
706 pub fn on_message_stop<F>(mut self, handler: F) -> Self
707 where
708 F: FnMut(&Usage) + Send + 'static,
709 {
710 self.handlers.message_stop = Some(Box::new(handler));
711 self
712 }
713
714 #[must_use]
719 pub fn on_error<F>(mut self, handler: F) -> Self
720 where
721 F: FnMut(&Error) + Send + 'static,
722 {
723 self.handlers.error = Some(Box::new(handler));
724 self
725 }
726
727 pub async fn aggregate(self) -> Result<Message> {
734 use futures_util::StreamExt;
735 let Self {
736 mut inner,
737 handlers,
738 } = self;
739 let mut agg = Aggregator::with_handlers(handlers);
740 while let Some(event) = inner.next().await {
741 match event {
742 Ok(ev) => match agg.handle(ev) {
743 Ok(()) => {}
744 Err(e) => {
745 agg.fire_error(&e);
746 return Err(e);
747 }
748 },
749 Err(e) => {
750 agg.fire_error(&e);
751 return Err(e);
752 }
753 }
754 }
755 agg.finalize()
756 }
757}
758
759#[cfg(feature = "streaming")]
760type TextDeltaHandler = Box<dyn FnMut(&str) + Send>;
761#[cfg(feature = "streaming")]
762type ToolUseCompleteHandler = Box<dyn FnMut(&str, &str, &serde_json::Value) + Send>;
763#[cfg(feature = "streaming")]
764type ThinkingDeltaHandler = Box<dyn FnMut(&str) + Send>;
765#[cfg(feature = "streaming")]
766type MessageStopHandler = Box<dyn FnMut(&Usage) + Send>;
767#[cfg(feature = "streaming")]
768type ErrorHandler = Box<dyn FnMut(&Error) + Send>;
769
770#[cfg(feature = "streaming")]
772#[derive(Default)]
773struct MessageStreamHandlers {
774 text_delta: Option<TextDeltaHandler>,
775 tool_use_complete: Option<ToolUseCompleteHandler>,
776 thinking_delta: Option<ThinkingDeltaHandler>,
777 message_stop: Option<MessageStopHandler>,
778 error: Option<ErrorHandler>,
779}
780
781#[cfg(feature = "streaming")]
782impl std::fmt::Debug for MessageStreamHandlers {
783 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
784 f.debug_struct("MessageStreamHandlers")
785 .field("text_delta", &self.text_delta.as_ref().map(|_| "<fn>"))
786 .field(
787 "tool_use_complete",
788 &self.tool_use_complete.as_ref().map(|_| "<fn>"),
789 )
790 .field(
791 "thinking_delta",
792 &self.thinking_delta.as_ref().map(|_| "<fn>"),
793 )
794 .field("message_stop", &self.message_stop.as_ref().map(|_| "<fn>"))
795 .field("error", &self.error.as_ref().map(|_| "<fn>"))
796 .finish()
797 }
798}
799
800#[cfg(feature = "streaming")]
801impl futures_util::Stream for EventStream {
802 type Item = Result<StreamEvent>;
803
804 fn poll_next(
805 mut self: std::pin::Pin<&mut Self>,
806 cx: &mut std::task::Context<'_>,
807 ) -> std::task::Poll<Option<Self::Item>> {
808 self.inner.as_mut().poll_next(cx)
809 }
810}
811
812#[cfg(feature = "streaming")]
813impl std::fmt::Debug for EventStream {
814 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
815 f.debug_struct("EventStream").finish_non_exhaustive()
816 }
817}
818
819#[cfg(feature = "streaming")]
824#[derive(Debug, Default)]
825pub struct Aggregator {
826 message: Option<Message>,
827 blocks: Vec<ContentBlock>,
828 tool_input_buffers: std::collections::HashMap<u32, String>,
832 handlers: MessageStreamHandlers,
833}
834
835#[cfg(feature = "streaming")]
836impl Aggregator {
837 fn with_handlers(handlers: MessageStreamHandlers) -> Self {
840 Self {
841 handlers,
842 ..Self::default()
843 }
844 }
845
846 fn fire_error(&mut self, err: &Error) {
849 if let Some(h) = self.handlers.error.as_mut() {
850 h(err);
851 }
852 }
853
854 pub fn handle(&mut self, event: StreamEvent) -> Result<()> {
856 match event {
857 StreamEvent::Known(known) => self.handle_known(known),
858 StreamEvent::Other(value) => {
859 tracing::debug!(?value, "claude-api: ignoring unknown stream event");
860 Ok(())
861 }
862 }
863 }
864
865 fn handle_known(&mut self, event: KnownStreamEvent) -> Result<()> {
866 match event {
867 KnownStreamEvent::MessageStart { message } => {
868 self.message = Some(message);
869 }
870 KnownStreamEvent::ContentBlockStart {
871 index,
872 content_block,
873 } => {
874 if index as usize != self.blocks.len() {
875 return Err(Error::Stream(StreamError::Parse(format!(
876 "out-of-order content_block_start: index {} but {} blocks already received",
877 index,
878 self.blocks.len()
879 ))));
880 }
881 self.blocks.push(content_block);
882 }
883 KnownStreamEvent::ContentBlockDelta { index, delta } => {
884 self.apply_delta(index, delta);
885 }
886 KnownStreamEvent::ContentBlockStop { index } => {
887 if let Some(buf) = self.tool_input_buffers.remove(&index) {
888 self.finalize_tool_input(index, &buf);
889 }
890 if let Some(handler) = self.handlers.tool_use_complete.as_mut()
892 && let Some(ContentBlock::Known(
893 KnownBlock::ToolUse { id, name, input }
894 | KnownBlock::ServerToolUse { id, name, input },
895 )) = self.blocks.get(index as usize)
896 {
897 handler(id, name, input);
898 }
899 }
900 KnownStreamEvent::MessageDelta { delta, usage } => {
901 if let Some(msg) = self.message.as_mut() {
902 if let Some(sr) = delta.stop_reason {
903 msg.stop_reason = Some(sr);
904 }
905 if let Some(ss) = delta.stop_sequence {
906 msg.stop_sequence = Some(ss);
907 }
908 msg.usage = usage;
909 }
910 }
911 KnownStreamEvent::MessageStop => {
912 if let Some(handler) = self.handlers.message_stop.as_mut()
913 && let Some(msg) = self.message.as_ref()
914 {
915 handler(&msg.usage);
916 }
917 }
918 KnownStreamEvent::Ping => {}
919 KnownStreamEvent::Error { error } => {
920 return Err(Error::Stream(StreamError::Server {
921 kind: error.kind,
922 message: error.message,
923 }));
924 }
925 }
926 Ok(())
927 }
928
929 fn apply_delta(&mut self, index: u32, delta: ContentDelta) {
930 let Some(block) = self.blocks.get_mut(index as usize) else {
931 tracing::warn!(index, "claude-api: delta for unknown block index, dropping");
932 return;
933 };
934 match delta {
935 ContentDelta::Known(KnownContentDelta::TextDelta { text }) => {
936 if let ContentBlock::Known(KnownBlock::Text { text: existing, .. }) = block {
937 existing.push_str(&text);
938 }
939 if let Some(handler) = self.handlers.text_delta.as_mut() {
940 handler(&text);
941 }
942 }
943 ContentDelta::Known(KnownContentDelta::InputJsonDelta { partial_json }) => {
944 self.tool_input_buffers
945 .entry(index)
946 .or_default()
947 .push_str(&partial_json);
948 }
949 ContentDelta::Known(KnownContentDelta::ThinkingDelta { thinking }) => {
950 if let ContentBlock::Known(KnownBlock::Thinking {
951 thinking: existing, ..
952 }) = block
953 {
954 existing.push_str(&thinking);
955 }
956 if let Some(handler) = self.handlers.thinking_delta.as_mut() {
957 handler(&thinking);
958 }
959 }
960 ContentDelta::Known(KnownContentDelta::SignatureDelta { signature }) => {
961 if let ContentBlock::Known(KnownBlock::Thinking { signature: sig, .. }) = block {
962 *sig = signature;
963 }
964 }
965 ContentDelta::Known(KnownContentDelta::CitationsDelta { citation }) => {
966 if let ContentBlock::Known(KnownBlock::Text { citations, .. }) = block {
967 citations.get_or_insert_with(Vec::new).push(citation);
968 }
969 }
970 ContentDelta::Other(value) => {
971 tracing::debug!(?value, "claude-api: ignoring unknown content delta");
972 }
973 }
974 }
975
976 fn finalize_tool_input(&mut self, index: u32, buffer: &str) {
977 let Some(block) = self.blocks.get_mut(index as usize) else {
978 return;
979 };
980 let parsed = if buffer.is_empty() {
981 return;
983 } else {
984 serde_json::from_str::<serde_json::Value>(buffer).unwrap_or_else(|e| {
985 tracing::warn!(
986 error = %e,
987 "claude-api: tool_use input failed to parse; storing raw string"
988 );
989 serde_json::Value::String(buffer.to_owned())
990 })
991 };
992 match block {
993 ContentBlock::Known(
994 KnownBlock::ToolUse { input, .. } | KnownBlock::ServerToolUse { input, .. },
995 ) => {
996 *input = parsed;
997 }
998 _ => {
999 tracing::warn!(
1000 index,
1001 "claude-api: input_json_delta accumulated for non-tool-use block"
1002 );
1003 }
1004 }
1005 }
1006
1007 pub fn finalize(mut self) -> Result<Message> {
1010 let mut message = self.message.take().ok_or_else(|| {
1011 Error::Stream(StreamError::Parse(
1012 "stream ended without a message_start event".into(),
1013 ))
1014 })?;
1015 message.content = self.blocks;
1016 Ok(message)
1017 }
1018}
1019
1020#[cfg(all(test, feature = "streaming"))]
1021mod aggregator_tests {
1022 use super::*;
1023 use crate::error::{ApiErrorKind, ApiErrorPayload};
1024 use crate::types::{ModelId, Role};
1025 use pretty_assertions::assert_eq;
1026 use serde_json::json;
1027
1028 fn message_start_event() -> StreamEvent {
1029 StreamEvent::Known(KnownStreamEvent::MessageStart {
1030 message: serde_json::from_value(json!({
1031 "id": "msg_x",
1032 "type": "message",
1033 "role": "assistant",
1034 "content": [],
1035 "model": "claude-sonnet-4-6",
1036 "usage": {"input_tokens": 5, "output_tokens": 0}
1037 }))
1038 .unwrap(),
1039 })
1040 }
1041
1042 #[test]
1043 fn aggregator_reconstructs_text_message() {
1044 let mut agg = Aggregator::default();
1045 agg.handle(message_start_event()).unwrap();
1046 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1047 index: 0,
1048 content_block: ContentBlock::text(""),
1049 }))
1050 .unwrap();
1051 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1052 index: 0,
1053 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1054 text: "Hello".into(),
1055 }),
1056 }))
1057 .unwrap();
1058 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1059 index: 0,
1060 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1061 text: " world".into(),
1062 }),
1063 }))
1064 .unwrap();
1065 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1066 index: 0,
1067 }))
1068 .unwrap();
1069 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1070 delta: MessageDelta {
1071 stop_reason: Some(StopReason::EndTurn),
1072 stop_sequence: None,
1073 },
1074 usage: Usage {
1075 input_tokens: 5,
1076 output_tokens: 2,
1077 ..Usage::default()
1078 },
1079 }))
1080 .unwrap();
1081 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1082 .unwrap();
1083
1084 let msg = agg.finalize().unwrap();
1085 assert_eq!(msg.id, "msg_x");
1086 assert_eq!(msg.role, Role::Assistant);
1087 assert_eq!(msg.model, ModelId::SONNET_4_6);
1088 assert_eq!(msg.stop_reason, Some(StopReason::EndTurn));
1089 assert_eq!(msg.usage.output_tokens, 2);
1090 assert_eq!(msg.content.len(), 1);
1091 match &msg.content[0] {
1092 ContentBlock::Known(KnownBlock::Text { text, .. }) => {
1093 assert_eq!(text, "Hello world");
1094 }
1095 _ => panic!("expected text block"),
1096 }
1097 }
1098
1099 #[test]
1100 fn aggregator_reconstructs_tool_use_input_from_partial_json_deltas() {
1101 let mut agg = Aggregator::default();
1102 agg.handle(message_start_event()).unwrap();
1103 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1104 index: 0,
1105 content_block: ContentBlock::Known(KnownBlock::ToolUse {
1106 id: "toolu_1".into(),
1107 name: "get_weather".into(),
1108 input: json!({}),
1109 }),
1110 }))
1111 .unwrap();
1112 for chunk in ["{\"city\":", "\"Paris\"", ",\"unit\":\"C\"}"] {
1113 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1114 index: 0,
1115 delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1116 partial_json: chunk.into(),
1117 }),
1118 }))
1119 .unwrap();
1120 }
1121 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1122 index: 0,
1123 }))
1124 .unwrap();
1125 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1126 .unwrap();
1127
1128 let msg = agg.finalize().unwrap();
1129 match &msg.content[0] {
1130 ContentBlock::Known(KnownBlock::ToolUse { input, name, .. }) => {
1131 assert_eq!(name, "get_weather");
1132 assert_eq!(input, &json!({"city": "Paris", "unit": "C"}));
1133 }
1134 _ => panic!("expected ToolUse block"),
1135 }
1136 }
1137
1138 #[test]
1139 fn aggregator_reconstructs_thinking_block() {
1140 let mut agg = Aggregator::default();
1141 agg.handle(message_start_event()).unwrap();
1142 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1143 index: 0,
1144 content_block: ContentBlock::Known(KnownBlock::Thinking {
1145 thinking: String::new(),
1146 signature: String::new(),
1147 }),
1148 }))
1149 .unwrap();
1150 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1151 index: 0,
1152 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1153 thinking: "let me ".into(),
1154 }),
1155 }))
1156 .unwrap();
1157 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1158 index: 0,
1159 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1160 thinking: "think".into(),
1161 }),
1162 }))
1163 .unwrap();
1164 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1165 index: 0,
1166 delta: ContentDelta::Known(KnownContentDelta::SignatureDelta {
1167 signature: "sig_xyz".into(),
1168 }),
1169 }))
1170 .unwrap();
1171 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1172 index: 0,
1173 }))
1174 .unwrap();
1175 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1176 .unwrap();
1177
1178 let msg = agg.finalize().unwrap();
1179 match &msg.content[0] {
1180 ContentBlock::Known(KnownBlock::Thinking {
1181 thinking,
1182 signature,
1183 }) => {
1184 assert_eq!(thinking, "let me think");
1185 assert_eq!(signature, "sig_xyz");
1186 }
1187 _ => panic!("expected Thinking block"),
1188 }
1189 }
1190
1191 #[test]
1192 fn aggregator_unknown_event_is_ignored() {
1193 let mut agg = Aggregator::default();
1194 agg.handle(message_start_event()).unwrap();
1195 agg.handle(StreamEvent::Other(json!({"type": "future_event"})))
1197 .unwrap();
1198 agg.handle(StreamEvent::Known(KnownStreamEvent::MessageStop))
1199 .unwrap();
1200 let msg = agg.finalize().unwrap();
1201 assert!(msg.content.is_empty());
1202 }
1203
1204 #[test]
1205 fn aggregator_unknown_delta_is_ignored() {
1206 let mut agg = Aggregator::default();
1207 agg.handle(message_start_event()).unwrap();
1208 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1209 index: 0,
1210 content_block: ContentBlock::text(""),
1211 }))
1212 .unwrap();
1213 agg.handle(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1214 index: 0,
1215 delta: ContentDelta::Other(json!({"type": "future_delta"})),
1216 }))
1217 .unwrap();
1218 }
1220
1221 #[test]
1222 fn aggregator_server_error_event_propagates() {
1223 let mut agg = Aggregator::default();
1224 agg.handle(message_start_event()).unwrap();
1225 let err = agg
1226 .handle(StreamEvent::Known(KnownStreamEvent::Error {
1227 error: ApiErrorPayload {
1228 kind: ApiErrorKind::OverloadedError,
1229 message: "boom".into(),
1230 },
1231 }))
1232 .unwrap_err();
1233 match err {
1234 Error::Stream(StreamError::Server { kind, message }) => {
1235 assert_eq!(kind, ApiErrorKind::OverloadedError);
1236 assert_eq!(message, "boom");
1237 }
1238 other => panic!("expected Stream::Server, got {other:?}"),
1239 }
1240 }
1241
1242 #[test]
1243 fn aggregator_out_of_order_block_start_errors() {
1244 let mut agg = Aggregator::default();
1245 agg.handle(message_start_event()).unwrap();
1246 let err = agg
1248 .handle(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1249 index: 1,
1250 content_block: ContentBlock::text(""),
1251 }))
1252 .unwrap_err();
1253 assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1254 }
1255
1256 #[test]
1257 fn aggregator_finalize_without_message_start_errors() {
1258 let agg = Aggregator::default();
1259 let err = agg.finalize().unwrap_err();
1260 assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1261 }
1262}
1263
1264#[cfg(all(test, feature = "streaming"))]
1265mod stream_callback_tests {
1266 use super::*;
1267 use crate::error::{ApiErrorKind, ApiErrorPayload};
1268 use pretty_assertions::assert_eq;
1269 use serde_json::json;
1270 use std::sync::{Arc, Mutex};
1271
1272 fn message_start_event() -> StreamEvent {
1273 StreamEvent::Known(KnownStreamEvent::MessageStart {
1274 message: serde_json::from_value(json!({
1275 "id": "msg_x",
1276 "type": "message",
1277 "role": "assistant",
1278 "content": [],
1279 "model": "claude-sonnet-4-6",
1280 "usage": {"input_tokens": 5, "output_tokens": 0}
1281 }))
1282 .unwrap(),
1283 })
1284 }
1285
1286 #[tokio::test]
1287 async fn on_text_delta_fires_for_each_text_chunk() {
1288 let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1289 let sink = Arc::clone(&captured);
1290 let events = vec![
1291 Ok(message_start_event()),
1292 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1293 index: 0,
1294 content_block: ContentBlock::text(""),
1295 })),
1296 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1297 index: 0,
1298 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1299 text: "Hello".into(),
1300 }),
1301 })),
1302 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1303 index: 0,
1304 delta: ContentDelta::Known(KnownContentDelta::TextDelta {
1305 text: " world".into(),
1306 }),
1307 })),
1308 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1309 index: 0,
1310 })),
1311 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1312 ];
1313
1314 let stream = EventStream::from_events(events).on_text_delta(move |chunk| {
1315 sink.lock().unwrap().push(chunk.to_string());
1316 });
1317 stream.aggregate().await.unwrap();
1318
1319 assert_eq!(*captured.lock().unwrap(), vec!["Hello", " world"]);
1320 }
1321
1322 #[tokio::test]
1323 async fn on_tool_use_complete_fires_with_parsed_input() {
1324 let captured: Arc<Mutex<Vec<(String, String, serde_json::Value)>>> =
1325 Arc::new(Mutex::new(Vec::new()));
1326 let sink = Arc::clone(&captured);
1327 let events = vec![
1328 Ok(message_start_event()),
1329 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1330 index: 0,
1331 content_block: ContentBlock::Known(KnownBlock::ToolUse {
1332 id: "toolu_1".into(),
1333 name: "get_weather".into(),
1334 input: json!({}),
1335 }),
1336 })),
1337 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1338 index: 0,
1339 delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1340 partial_json: "{\"city\":\"Paris\"}".into(),
1341 }),
1342 })),
1343 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1344 index: 0,
1345 })),
1346 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1347 ];
1348
1349 let stream =
1350 EventStream::from_events(events).on_tool_use_complete(move |id, name, input| {
1351 sink.lock()
1352 .unwrap()
1353 .push((id.to_string(), name.to_string(), input.clone()));
1354 });
1355 stream.aggregate().await.unwrap();
1356
1357 let captured = captured.lock().unwrap();
1358 assert_eq!(captured.len(), 1);
1359 assert_eq!(captured[0].0, "toolu_1");
1360 assert_eq!(captured[0].1, "get_weather");
1361 assert_eq!(captured[0].2, json!({"city": "Paris"}));
1362 }
1363
1364 #[tokio::test]
1365 async fn on_tool_use_complete_fires_for_server_tool_use_blocks() {
1366 let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1367 let sink = Arc::clone(&captured);
1368 let events = vec![
1369 Ok(message_start_event()),
1370 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1371 index: 0,
1372 content_block: ContentBlock::Known(KnownBlock::ServerToolUse {
1373 id: "srvu_1".into(),
1374 name: "web_search".into(),
1375 input: json!({}),
1376 }),
1377 })),
1378 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1379 index: 0,
1380 delta: ContentDelta::Known(KnownContentDelta::InputJsonDelta {
1381 partial_json: "{\"q\":\"rust\"}".into(),
1382 }),
1383 })),
1384 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1385 index: 0,
1386 })),
1387 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1388 ];
1389
1390 let stream = EventStream::from_events(events).on_tool_use_complete(move |id, _, _| {
1391 sink.lock().unwrap().push(id.to_string());
1392 });
1393 stream.aggregate().await.unwrap();
1394
1395 assert_eq!(*captured.lock().unwrap(), vec!["srvu_1"]);
1396 }
1397
1398 #[tokio::test]
1399 async fn on_thinking_delta_fires_for_each_thinking_chunk() {
1400 let captured: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1401 let sink = Arc::clone(&captured);
1402 let events = vec![
1403 Ok(message_start_event()),
1404 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1405 index: 0,
1406 content_block: ContentBlock::Known(KnownBlock::Thinking {
1407 thinking: String::new(),
1408 signature: String::new(),
1409 }),
1410 })),
1411 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1412 index: 0,
1413 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1414 thinking: "let me ".into(),
1415 }),
1416 })),
1417 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1418 index: 0,
1419 delta: ContentDelta::Known(KnownContentDelta::ThinkingDelta {
1420 thinking: "think".into(),
1421 }),
1422 })),
1423 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1424 index: 0,
1425 })),
1426 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1427 ];
1428
1429 let stream = EventStream::from_events(events).on_thinking_delta(move |chunk| {
1430 sink.lock().unwrap().push(chunk.to_string());
1431 });
1432 stream.aggregate().await.unwrap();
1433
1434 assert_eq!(*captured.lock().unwrap(), vec!["let me ", "think"]);
1435 }
1436
1437 #[tokio::test]
1438 async fn on_message_stop_fires_once_with_usage() {
1439 let captured: Arc<Mutex<Vec<Usage>>> = Arc::new(Mutex::new(Vec::new()));
1440 let sink = Arc::clone(&captured);
1441 let events = vec![
1442 Ok(message_start_event()),
1443 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1444 index: 0,
1445 content_block: ContentBlock::text(""),
1446 })),
1447 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1448 index: 0,
1449 delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1450 })),
1451 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1452 index: 0,
1453 })),
1454 Ok(StreamEvent::Known(KnownStreamEvent::MessageDelta {
1455 delta: MessageDelta {
1456 stop_reason: Some(StopReason::EndTurn),
1457 stop_sequence: None,
1458 },
1459 usage: Usage {
1460 input_tokens: 5,
1461 output_tokens: 7,
1462 ..Usage::default()
1463 },
1464 })),
1465 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1466 ];
1467
1468 let stream = EventStream::from_events(events).on_message_stop(move |usage| {
1469 sink.lock().unwrap().push(usage.clone());
1470 });
1471 stream.aggregate().await.unwrap();
1472
1473 let captured = captured.lock().unwrap();
1474 assert_eq!(captured.len(), 1);
1475 assert_eq!(captured[0].input_tokens, 5);
1476 assert_eq!(captured[0].output_tokens, 7);
1477 }
1478
1479 #[tokio::test]
1480 async fn on_error_fires_before_propagating_server_error() {
1481 let count = Arc::new(Mutex::new(0u32));
1482 let sink = Arc::clone(&count);
1483 let events = vec![
1484 Ok(message_start_event()),
1485 Ok(StreamEvent::Known(KnownStreamEvent::Error {
1486 error: ApiErrorPayload {
1487 kind: ApiErrorKind::OverloadedError,
1488 message: "boom".into(),
1489 },
1490 })),
1491 ];
1492
1493 let stream = EventStream::from_events(events).on_error(move |_| {
1494 *sink.lock().unwrap() += 1;
1495 });
1496 let err = stream.aggregate().await.unwrap_err();
1497 assert!(matches!(
1498 err,
1499 Error::Stream(StreamError::Server {
1500 kind: ApiErrorKind::OverloadedError,
1501 ..
1502 })
1503 ));
1504 assert_eq!(
1505 *count.lock().unwrap(),
1506 1,
1507 "handler should fire exactly once"
1508 );
1509 }
1510
1511 #[tokio::test]
1512 async fn on_error_fires_for_transport_error() {
1513 let count = Arc::new(Mutex::new(0u32));
1514 let sink = Arc::clone(&count);
1515 let events: Vec<Result<StreamEvent>> = vec![
1516 Ok(message_start_event()),
1517 Err(Error::Stream(StreamError::Parse("decode failed".into()))),
1518 ];
1519
1520 let stream = EventStream::from_events(events).on_error(move |_| {
1521 *sink.lock().unwrap() += 1;
1522 });
1523 let err = stream.aggregate().await.unwrap_err();
1524 assert!(matches!(err, Error::Stream(StreamError::Parse(_))));
1525 assert_eq!(*count.lock().unwrap(), 1);
1526 }
1527
1528 #[tokio::test]
1529 async fn raw_stream_iteration_does_not_fire_callbacks() {
1530 use futures_util::StreamExt;
1531 let count = Arc::new(Mutex::new(0u32));
1532 let sink = Arc::clone(&count);
1533 let events = vec![
1534 Ok(message_start_event()),
1535 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStart {
1536 index: 0,
1537 content_block: ContentBlock::text(""),
1538 })),
1539 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockDelta {
1540 index: 0,
1541 delta: ContentDelta::Known(KnownContentDelta::TextDelta { text: "hi".into() }),
1542 })),
1543 Ok(StreamEvent::Known(KnownStreamEvent::ContentBlockStop {
1544 index: 0,
1545 })),
1546 Ok(StreamEvent::Known(KnownStreamEvent::MessageStop)),
1547 ];
1548
1549 let mut stream = EventStream::from_events(events).on_text_delta(move |_| {
1550 *sink.lock().unwrap() += 1;
1551 });
1552 while let Some(_ev) = stream.next().await {}
1553 assert_eq!(*count.lock().unwrap(), 0);
1555 }
1556}