Skip to main content

adk_anthropic/
accumulating_stream.rs

1//! Accumulates streaming events into a complete message while passing events through.
2
3use std::pin::Pin;
4
5use futures::Stream;
6use serde_json::Value;
7
8use crate::{
9    CacheControlEphemeral, Citation, ContentBlock, ContentBlockDelta, Error, Message,
10    MessageStreamEvent, ServerToolUseBlock, StopReason, TextBlock, TextCitation, ThinkingBlock,
11    ToolUseBlock,
12};
13
14/// A stream wrapper that accumulates `MessageStreamEvent`s into a complete `Message`.
15///
16/// This allows streaming tokens to the user while simultaneously building the final message
17/// without buffering. When the stream is fully drained, the accumulated message is sent via
18/// the oneshot channel returned by `new()`.
19pub struct AccumulatingStream {
20    inner: Pin<Box<dyn Stream<Item = Result<MessageStreamEvent, Error>> + Send>>,
21    message_tx: Option<tokio::sync::oneshot::Sender<Result<Message, Error>>>,
22    message: Option<Message>,
23    content_blocks: Vec<ContentBlockBuilder>,
24}
25
26impl AccumulatingStream {
27    /// Wraps a `MessageStreamEvent` stream to accumulate events into a `Message`.
28    ///
29    /// Returns the stream and a receiver that will contain the accumulated `Message` once the
30    /// stream is fully drained.
31    pub fn new<S>(stream: S) -> (Self, tokio::sync::oneshot::Receiver<Result<Message, Error>>)
32    where
33        S: Stream<Item = Result<MessageStreamEvent, Error>> + Send + 'static,
34    {
35        Self::new_with_message(stream, None)
36    }
37
38    /// Wraps a `MessageStreamEvent` stream and seeds accumulation with a fallback message.
39    pub fn new_with_message<S>(
40        stream: S,
41        message: impl Into<Option<Message>>,
42    ) -> (Self, tokio::sync::oneshot::Receiver<Result<Message, Error>>)
43    where
44        S: Stream<Item = Result<MessageStreamEvent, Error>> + Send + 'static,
45    {
46        let (tx, rx) = tokio::sync::oneshot::channel();
47        let this = Self {
48            inner: Box::pin(stream),
49            message_tx: Some(tx),
50            message: message.into(),
51            content_blocks: Vec::new(),
52        };
53        (this, rx)
54    }
55
56    fn accumulate_event(&mut self, event: &MessageStreamEvent) {
57        match event {
58            MessageStreamEvent::MessageStart(start) => {
59                self.message = Some(start.message.clone());
60            }
61            MessageStreamEvent::ContentBlockStart(start) => {
62                let idx = start.index;
63                while self.content_blocks.len() <= idx {
64                    self.content_blocks.push(ContentBlockBuilder::Empty);
65                }
66                self.content_blocks[idx] =
67                    ContentBlockBuilder::from_content_block(start.content_block.clone());
68            }
69            MessageStreamEvent::ContentBlockDelta(delta_event) => {
70                let idx = delta_event.index;
71                if idx < self.content_blocks.len() {
72                    self.content_blocks[idx].apply_delta(delta_event.delta.clone());
73                }
74            }
75            MessageStreamEvent::ContentBlockStop(_) => {}
76            MessageStreamEvent::MessageDelta(delta_event) => {
77                if let Some(ref mut msg) = self.message {
78                    if delta_event.delta.stop_reason.is_some() {
79                        msg.stop_reason = delta_event.delta.stop_reason;
80                    }
81                    if delta_event.delta.stop_sequence.is_some() {
82                        msg.stop_sequence = delta_event.delta.stop_sequence.clone();
83                    }
84                    if let Some(input_tokens) = delta_event.usage.input_tokens {
85                        msg.usage.input_tokens = input_tokens;
86                    }
87                    msg.usage.output_tokens = delta_event.usage.output_tokens;
88                    if let Some(cache) = delta_event.usage.cache_creation_input_tokens {
89                        msg.usage.cache_creation_input_tokens = Some(cache);
90                    }
91                    if let Some(cache_read) = delta_event.usage.cache_read_input_tokens {
92                        msg.usage.cache_read_input_tokens = Some(cache_read);
93                    }
94                    if let Some(server_tool) = delta_event.usage.server_tool_use {
95                        msg.usage.server_tool_use = Some(server_tool);
96                    }
97                }
98            }
99            MessageStreamEvent::MessageStop(_) => {}
100            MessageStreamEvent::Ping => {}
101            // New event types that don't affect message accumulation
102            MessageStreamEvent::ToolInputStart { .. } => {}
103            MessageStreamEvent::ToolInputDelta { .. } => {}
104            MessageStreamEvent::CompactionEvent(_) => {}
105            MessageStreamEvent::StreamError { .. } => {}
106        }
107    }
108
109    fn finalize(&mut self) -> Result<Message, Error> {
110        let mut msg = self
111            .message
112            .take()
113            .ok_or_else(|| Error::streaming("stream ended without a message start event", None))?;
114        let mut blocks = Vec::new();
115        for builder in std::mem::take(&mut self.content_blocks) {
116            if let Some(block) = builder.build(msg.stop_reason)? {
117                blocks.push(block);
118            }
119        }
120        msg.content = blocks;
121        Ok(msg)
122    }
123
124    /// Finalizes the currently accumulated message without draining the stream.
125    pub fn finalize_partial(&mut self) -> Result<Message, Error> {
126        self.message_tx.take();
127        self.finalize()
128    }
129}
130
131impl Stream for AccumulatingStream {
132    type Item = Result<MessageStreamEvent, Error>;
133
134    fn poll_next(
135        mut self: Pin<&mut Self>,
136        cx: &mut std::task::Context<'_>,
137    ) -> std::task::Poll<Option<Self::Item>> {
138        match self.inner.as_mut().poll_next(cx) {
139            std::task::Poll::Ready(Some(Ok(event))) => {
140                self.accumulate_event(&event);
141                std::task::Poll::Ready(Some(Ok(event)))
142            }
143            std::task::Poll::Ready(Some(Err(e))) => std::task::Poll::Ready(Some(Err(e))),
144            std::task::Poll::Ready(None) => {
145                if let Some(tx) = self.message_tx.take() {
146                    let _ = tx.send(self.finalize());
147                }
148                std::task::Poll::Ready(None)
149            }
150            std::task::Poll::Pending => std::task::Poll::Pending,
151        }
152    }
153}
154
155enum ContentBlockBuilder {
156    Empty,
157    Text {
158        text: String,
159        citations: Option<Vec<TextCitation>>,
160        cache_control: Option<CacheControlEphemeral>,
161    },
162    ToolUse {
163        id: String,
164        name: String,
165        input_json: String,
166        input_value: Option<Value>,
167        saw_delta: bool,
168        cache_control: Option<CacheControlEphemeral>,
169    },
170    ServerToolUse {
171        id: String,
172        name: String,
173        input: Value,
174        cache_control: Option<CacheControlEphemeral>,
175    },
176    Thinking {
177        thinking: String,
178        signature: String,
179    },
180    Complete(ContentBlock),
181}
182
183impl ContentBlockBuilder {
184    fn from_content_block(block: ContentBlock) -> Self {
185        match block {
186            ContentBlock::Text(text_block) => ContentBlockBuilder::Text {
187                text: text_block.text,
188                citations: text_block.citations,
189                cache_control: text_block.cache_control,
190            },
191            ContentBlock::ToolUse(tool_use) => ContentBlockBuilder::ToolUse {
192                id: tool_use.id,
193                name: tool_use.name,
194                input_json: String::new(),
195                input_value: Some(tool_use.input),
196                saw_delta: false,
197                cache_control: tool_use.cache_control,
198            },
199            ContentBlock::ServerToolUse(server_tool_use) => ContentBlockBuilder::ServerToolUse {
200                id: server_tool_use.id,
201                name: server_tool_use.name,
202                input: server_tool_use.input,
203                cache_control: server_tool_use.cache_control,
204            },
205            ContentBlock::Thinking(thinking) => ContentBlockBuilder::Thinking {
206                thinking: thinking.thinking,
207                signature: thinking.signature,
208            },
209            other => ContentBlockBuilder::Complete(other),
210        }
211    }
212
213    fn apply_delta(&mut self, delta: ContentBlockDelta) {
214        match (self, delta) {
215            (ContentBlockBuilder::Text { text, .. }, ContentBlockDelta::TextDelta(text_delta)) => {
216                text.push_str(&text_delta.text);
217            }
218            (
219                ContentBlockBuilder::Text { citations, .. },
220                ContentBlockDelta::CitationsDelta(citations_delta),
221            ) => {
222                let citation = match citations_delta.citation {
223                    Citation::CharLocation(loc) => TextCitation::CharLocation(loc),
224                    Citation::PageLocation(loc) => TextCitation::PageLocation(loc),
225                    Citation::ContentBlockLocation(loc) => TextCitation::ContentBlockLocation(loc),
226                    Citation::WebSearchResultLocation(loc) => {
227                        TextCitation::WebSearchResultLocation(loc)
228                    }
229                };
230                citations.get_or_insert_with(Vec::new).push(citation);
231            }
232            (
233                ContentBlockBuilder::ToolUse { input_json, saw_delta, .. },
234                ContentBlockDelta::InputJsonDelta(json_delta),
235            ) => {
236                *saw_delta = true;
237                input_json.push_str(&json_delta.partial_json);
238            }
239            (
240                ContentBlockBuilder::Thinking { thinking, .. },
241                ContentBlockDelta::ThinkingDelta(thinking_delta),
242            ) => {
243                thinking.push_str(&thinking_delta.thinking);
244            }
245            (
246                ContentBlockBuilder::Thinking { signature, .. },
247                ContentBlockDelta::SignatureDelta(sig_delta),
248            ) => {
249                signature.push_str(&sig_delta.signature);
250            }
251            _ => {}
252        }
253    }
254
255    fn build(self, stop_reason: Option<StopReason>) -> Result<Option<ContentBlock>, Error> {
256        match self {
257            ContentBlockBuilder::Empty => Ok(None),
258            ContentBlockBuilder::Text { text, citations, cache_control } => {
259                Ok(Some(ContentBlock::Text(TextBlock { text, citations, cache_control })))
260            }
261            ContentBlockBuilder::ToolUse {
262                id,
263                name,
264                input_json,
265                input_value,
266                saw_delta,
267                cache_control,
268            } => {
269                let input = if saw_delta {
270                    if input_json.trim().is_empty() {
271                        Value::Object(serde_json::Map::new())
272                    } else {
273                        match serde_json::from_str::<Value>(&input_json) {
274                            Ok(value) => value,
275                            Err(_err) => {
276                                if stop_reason == Some(StopReason::MaxTokens) {
277                                    return Ok(None);
278                                }
279                                Value::String(input_json)
280                            }
281                        }
282                    }
283                } else if let Some(input) = input_value {
284                    input
285                } else if input_json.trim().is_empty() {
286                    Value::Object(serde_json::Map::new())
287                } else {
288                    match serde_json::from_str::<Value>(&input_json) {
289                        Ok(value) => value,
290                        Err(_err) => {
291                            if stop_reason == Some(StopReason::MaxTokens) {
292                                return Ok(None);
293                            }
294                            Value::String(input_json)
295                        }
296                    }
297                };
298                Ok(Some(ContentBlock::ToolUse(ToolUseBlock { id, name, input, cache_control })))
299            }
300            ContentBlockBuilder::ServerToolUse { id, name, input, cache_control } => {
301                Ok(Some(ContentBlock::ServerToolUse(ServerToolUseBlock {
302                    id,
303                    name,
304                    input,
305                    cache_control,
306                })))
307            }
308            ContentBlockBuilder::Thinking { thinking, signature } => {
309                Ok(Some(ContentBlock::Thinking(ThinkingBlock { thinking, signature })))
310            }
311            ContentBlockBuilder::Complete(block) => Ok(Some(block)),
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use crate::{
320        ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputJsonDelta,
321        KnownModel, MessageDelta, MessageDeltaEvent, MessageDeltaUsage, MessageStartEvent, Model,
322        TextDelta, Usage,
323    };
324    use futures::stream;
325
326    /// Verifies that cache tokens from message_start are preserved through streaming.
327    #[tokio::test]
328    async fn cache_tokens_from_message_start_preserved() {
329        // Build a message_start event with cache tokens in the usage
330        let usage_with_cache = Usage::new(100, 0)
331            .with_cache_creation_input_tokens(50)
332            .with_cache_read_input_tokens(25);
333
334        let start_message = Message::new(
335            "msg_test".to_string(),
336            Vec::new(),
337            Model::Known(KnownModel::ClaudeSonnet46),
338            usage_with_cache,
339        );
340        let start_event = MessageStreamEvent::MessageStart(MessageStartEvent::new(start_message));
341
342        // Build content_block_start for text
343        let text_block = ContentBlock::Text(TextBlock::new(String::new()));
344        let content_start =
345            MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent::new(text_block, 0));
346
347        // Build content_block_delta with text
348        let text_delta = TextDelta::new("Hello".to_string());
349        let content_delta = MessageStreamEvent::ContentBlockDelta(
350            crate::ContentBlockDeltaEvent::new(ContentBlockDelta::TextDelta(text_delta), 0),
351        );
352
353        // Build message_delta with final output tokens (no cache tokens - they were in start)
354        let delta_usage = MessageDeltaUsage::new(10);
355        let message_delta = MessageDelta::new().with_stop_reason(StopReason::EndTurn);
356        let delta_event =
357            MessageStreamEvent::MessageDelta(MessageDeltaEvent::new(message_delta, delta_usage));
358
359        // Create the stream
360        let events = vec![Ok(start_event), Ok(content_start), Ok(content_delta), Ok(delta_event)];
361        let event_stream = stream::iter(events);
362
363        let (mut acc_stream, rx) = AccumulatingStream::new(event_stream);
364
365        // Drain the stream
366        use futures::StreamExt;
367        while acc_stream.next().await.is_some() {}
368
369        // Get the accumulated message
370        let message = rx.await.expect("channel closed").expect("accumulation failed");
371
372        // Verify cache tokens were preserved from message_start
373        // DEBUG: Print what we got
374        println!("cache_creation_input_tokens: {:?}", message.usage.cache_creation_input_tokens);
375        println!("cache_read_input_tokens: {:?}", message.usage.cache_read_input_tokens);
376
377        assert_eq!(
378            message.usage.cache_creation_input_tokens,
379            Some(50),
380            "cache_creation_input_tokens should be preserved from message_start"
381        );
382        assert_eq!(
383            message.usage.cache_read_input_tokens,
384            Some(25),
385            "cache_read_input_tokens should be preserved from message_start"
386        );
387        assert_eq!(message.usage.output_tokens, 10, "output_tokens should be from message_delta");
388    }
389
390    /// Verifies that tool use with empty input JSON becomes an empty object, not null.
391    #[tokio::test]
392    async fn empty_tool_input_becomes_empty_object() {
393        let usage = Usage::new(100, 0);
394        let start_message = Message::new(
395            "msg_test".to_string(),
396            Vec::new(),
397            Model::Known(KnownModel::ClaudeSonnet46),
398            usage,
399        );
400        let start_event = MessageStreamEvent::MessageStart(MessageStartEvent::new(start_message));
401
402        // Build content_block_start for tool_use with initial empty input
403        let tool_use_block =
404            ContentBlock::ToolUse(ToolUseBlock::new("tool_123", "get_document", Value::Null));
405        let content_start =
406            MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent::new(tool_use_block, 0));
407
408        // Build content_block_delta with empty JSON (simulating no input parameters)
409        let json_delta = InputJsonDelta::new(String::new());
410        let content_delta = MessageStreamEvent::ContentBlockDelta(ContentBlockDeltaEvent::new(
411            ContentBlockDelta::InputJsonDelta(json_delta),
412            0,
413        ));
414
415        // Build content_block_stop
416        let content_stop = MessageStreamEvent::ContentBlockStop(ContentBlockStopEvent::new(0));
417
418        // Build message_delta
419        let delta_usage = MessageDeltaUsage::new(10);
420        let message_delta = MessageDelta::new().with_stop_reason(StopReason::ToolUse);
421        let delta_event =
422            MessageStreamEvent::MessageDelta(MessageDeltaEvent::new(message_delta, delta_usage));
423
424        let events = vec![
425            Ok(start_event),
426            Ok(content_start),
427            Ok(content_delta),
428            Ok(content_stop),
429            Ok(delta_event),
430        ];
431        let event_stream = stream::iter(events);
432
433        let (mut acc_stream, rx) = AccumulatingStream::new(event_stream);
434
435        use futures::StreamExt;
436        while acc_stream.next().await.is_some() {}
437
438        let message = rx.await.expect("channel closed").expect("accumulation failed");
439
440        assert_eq!(message.content.len(), 1, "Should have one content block");
441        let tool_use = message.content[0].as_tool_use().expect("Expected ToolUseBlock");
442
443        // Empty input should be an empty object, not null
444        assert!(
445            tool_use.input.is_object(),
446            "Empty tool input should be an object, not null. Got: {:?}",
447            tool_use.input
448        );
449        assert!(
450            tool_use.input.as_object().expect("input should be object").is_empty(),
451            "Empty tool input should be an empty object"
452        );
453        println!("tool_use.input: {:?}", tool_use.input);
454    }
455
456    /// Verifies that tool use with no delta events uses initial input_value.
457    #[tokio::test]
458    async fn tool_input_without_delta_uses_initial_value() {
459        let usage = Usage::new(100, 0);
460        let start_message = Message::new(
461            "msg_test".to_string(),
462            Vec::new(),
463            Model::Known(KnownModel::ClaudeSonnet46),
464            usage,
465        );
466        let start_event = MessageStreamEvent::MessageStart(MessageStartEvent::new(start_message));
467
468        // Build content_block_start for tool_use with an actual input value
469        let input = serde_json::json!({"key": "value"});
470        let tool_use_block =
471            ContentBlock::ToolUse(ToolUseBlock::new("tool_123", "get_document", input.clone()));
472        let content_start =
473            MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent::new(tool_use_block, 0));
474
475        // No delta events - the input should come from the initial value
476
477        let content_stop = MessageStreamEvent::ContentBlockStop(ContentBlockStopEvent::new(0));
478
479        let delta_usage = MessageDeltaUsage::new(10);
480        let message_delta = MessageDelta::new().with_stop_reason(StopReason::ToolUse);
481        let delta_event =
482            MessageStreamEvent::MessageDelta(MessageDeltaEvent::new(message_delta, delta_usage));
483
484        let events = vec![Ok(start_event), Ok(content_start), Ok(content_stop), Ok(delta_event)];
485        let event_stream = stream::iter(events);
486
487        let (mut acc_stream, rx) = AccumulatingStream::new(event_stream);
488
489        use futures::StreamExt;
490        while acc_stream.next().await.is_some() {}
491
492        let message = rx.await.expect("channel closed").expect("accumulation failed");
493
494        let tool_use = message.content[0].as_tool_use().expect("Expected ToolUseBlock");
495
496        assert_eq!(tool_use.input, input, "Tool input should match initial value");
497        println!("tool_use.input: {:?}", tool_use.input);
498    }
499}