1use std::pin::Pin;
4
5use futures::Stream;
6use serde_json::Value;
7
8use crate::{
9 CacheControlEphemeral, Citation, ContentBlock, ContentBlockDelta, Error, Message,
10 MessageStreamEvent, ServerToolUseBlock, StopReason, TextBlock, TextCitation, ThinkingBlock,
11 ToolUseBlock,
12};
13
14pub struct AccumulatingStream {
20 inner: Pin<Box<dyn Stream<Item = Result<MessageStreamEvent, Error>> + Send>>,
21 message_tx: Option<tokio::sync::oneshot::Sender<Result<Message, Error>>>,
22 message: Option<Message>,
23 content_blocks: Vec<ContentBlockBuilder>,
24}
25
26impl AccumulatingStream {
27 pub fn new<S>(stream: S) -> (Self, tokio::sync::oneshot::Receiver<Result<Message, Error>>)
32 where
33 S: Stream<Item = Result<MessageStreamEvent, Error>> + Send + 'static,
34 {
35 Self::new_with_message(stream, None)
36 }
37
38 pub fn new_with_message<S>(
40 stream: S,
41 message: impl Into<Option<Message>>,
42 ) -> (Self, tokio::sync::oneshot::Receiver<Result<Message, Error>>)
43 where
44 S: Stream<Item = Result<MessageStreamEvent, Error>> + Send + 'static,
45 {
46 let (tx, rx) = tokio::sync::oneshot::channel();
47 let this = Self {
48 inner: Box::pin(stream),
49 message_tx: Some(tx),
50 message: message.into(),
51 content_blocks: Vec::new(),
52 };
53 (this, rx)
54 }
55
56 fn accumulate_event(&mut self, event: &MessageStreamEvent) {
57 match event {
58 MessageStreamEvent::MessageStart(start) => {
59 self.message = Some(start.message.clone());
60 }
61 MessageStreamEvent::ContentBlockStart(start) => {
62 let idx = start.index;
63 while self.content_blocks.len() <= idx {
64 self.content_blocks.push(ContentBlockBuilder::Empty);
65 }
66 self.content_blocks[idx] =
67 ContentBlockBuilder::from_content_block(start.content_block.clone());
68 }
69 MessageStreamEvent::ContentBlockDelta(delta_event) => {
70 let idx = delta_event.index;
71 if idx < self.content_blocks.len() {
72 self.content_blocks[idx].apply_delta(delta_event.delta.clone());
73 }
74 }
75 MessageStreamEvent::ContentBlockStop(_) => {}
76 MessageStreamEvent::MessageDelta(delta_event) => {
77 if let Some(ref mut msg) = self.message {
78 if delta_event.delta.stop_reason.is_some() {
79 msg.stop_reason = delta_event.delta.stop_reason;
80 }
81 if delta_event.delta.stop_sequence.is_some() {
82 msg.stop_sequence = delta_event.delta.stop_sequence.clone();
83 }
84 if let Some(input_tokens) = delta_event.usage.input_tokens {
85 msg.usage.input_tokens = input_tokens;
86 }
87 msg.usage.output_tokens = delta_event.usage.output_tokens;
88 if let Some(cache) = delta_event.usage.cache_creation_input_tokens {
89 msg.usage.cache_creation_input_tokens = Some(cache);
90 }
91 if let Some(cache_read) = delta_event.usage.cache_read_input_tokens {
92 msg.usage.cache_read_input_tokens = Some(cache_read);
93 }
94 if let Some(server_tool) = delta_event.usage.server_tool_use {
95 msg.usage.server_tool_use = Some(server_tool);
96 }
97 }
98 }
99 MessageStreamEvent::MessageStop(_) => {}
100 MessageStreamEvent::Ping => {}
101 MessageStreamEvent::ToolInputStart { .. } => {}
103 MessageStreamEvent::ToolInputDelta { .. } => {}
104 MessageStreamEvent::CompactionEvent(_) => {}
105 MessageStreamEvent::StreamError { .. } => {}
106 }
107 }
108
109 fn finalize(&mut self) -> Result<Message, Error> {
110 let mut msg = self
111 .message
112 .take()
113 .ok_or_else(|| Error::streaming("stream ended without a message start event", None))?;
114 let mut blocks = Vec::new();
115 for builder in std::mem::take(&mut self.content_blocks) {
116 if let Some(block) = builder.build(msg.stop_reason)? {
117 blocks.push(block);
118 }
119 }
120 msg.content = blocks;
121 Ok(msg)
122 }
123
124 pub fn finalize_partial(&mut self) -> Result<Message, Error> {
126 self.message_tx.take();
127 self.finalize()
128 }
129}
130
131impl Stream for AccumulatingStream {
132 type Item = Result<MessageStreamEvent, Error>;
133
134 fn poll_next(
135 mut self: Pin<&mut Self>,
136 cx: &mut std::task::Context<'_>,
137 ) -> std::task::Poll<Option<Self::Item>> {
138 match self.inner.as_mut().poll_next(cx) {
139 std::task::Poll::Ready(Some(Ok(event))) => {
140 self.accumulate_event(&event);
141 std::task::Poll::Ready(Some(Ok(event)))
142 }
143 std::task::Poll::Ready(Some(Err(e))) => std::task::Poll::Ready(Some(Err(e))),
144 std::task::Poll::Ready(None) => {
145 if let Some(tx) = self.message_tx.take() {
146 let _ = tx.send(self.finalize());
147 }
148 std::task::Poll::Ready(None)
149 }
150 std::task::Poll::Pending => std::task::Poll::Pending,
151 }
152 }
153}
154
155enum ContentBlockBuilder {
156 Empty,
157 Text {
158 text: String,
159 citations: Option<Vec<TextCitation>>,
160 cache_control: Option<CacheControlEphemeral>,
161 },
162 ToolUse {
163 id: String,
164 name: String,
165 input_json: String,
166 input_value: Option<Value>,
167 saw_delta: bool,
168 cache_control: Option<CacheControlEphemeral>,
169 },
170 ServerToolUse {
171 id: String,
172 name: String,
173 input: Value,
174 cache_control: Option<CacheControlEphemeral>,
175 },
176 Thinking {
177 thinking: String,
178 signature: String,
179 },
180 Complete(ContentBlock),
181}
182
183impl ContentBlockBuilder {
184 fn from_content_block(block: ContentBlock) -> Self {
185 match block {
186 ContentBlock::Text(text_block) => ContentBlockBuilder::Text {
187 text: text_block.text,
188 citations: text_block.citations,
189 cache_control: text_block.cache_control,
190 },
191 ContentBlock::ToolUse(tool_use) => ContentBlockBuilder::ToolUse {
192 id: tool_use.id,
193 name: tool_use.name,
194 input_json: String::new(),
195 input_value: Some(tool_use.input),
196 saw_delta: false,
197 cache_control: tool_use.cache_control,
198 },
199 ContentBlock::ServerToolUse(server_tool_use) => ContentBlockBuilder::ServerToolUse {
200 id: server_tool_use.id,
201 name: server_tool_use.name,
202 input: server_tool_use.input,
203 cache_control: server_tool_use.cache_control,
204 },
205 ContentBlock::Thinking(thinking) => ContentBlockBuilder::Thinking {
206 thinking: thinking.thinking,
207 signature: thinking.signature,
208 },
209 other => ContentBlockBuilder::Complete(other),
210 }
211 }
212
213 fn apply_delta(&mut self, delta: ContentBlockDelta) {
214 match (self, delta) {
215 (ContentBlockBuilder::Text { text, .. }, ContentBlockDelta::TextDelta(text_delta)) => {
216 text.push_str(&text_delta.text);
217 }
218 (
219 ContentBlockBuilder::Text { citations, .. },
220 ContentBlockDelta::CitationsDelta(citations_delta),
221 ) => {
222 let citation = match citations_delta.citation {
223 Citation::CharLocation(loc) => TextCitation::CharLocation(loc),
224 Citation::PageLocation(loc) => TextCitation::PageLocation(loc),
225 Citation::ContentBlockLocation(loc) => TextCitation::ContentBlockLocation(loc),
226 Citation::WebSearchResultLocation(loc) => {
227 TextCitation::WebSearchResultLocation(loc)
228 }
229 };
230 citations.get_or_insert_with(Vec::new).push(citation);
231 }
232 (
233 ContentBlockBuilder::ToolUse { input_json, saw_delta, .. },
234 ContentBlockDelta::InputJsonDelta(json_delta),
235 ) => {
236 *saw_delta = true;
237 input_json.push_str(&json_delta.partial_json);
238 }
239 (
240 ContentBlockBuilder::Thinking { thinking, .. },
241 ContentBlockDelta::ThinkingDelta(thinking_delta),
242 ) => {
243 thinking.push_str(&thinking_delta.thinking);
244 }
245 (
246 ContentBlockBuilder::Thinking { signature, .. },
247 ContentBlockDelta::SignatureDelta(sig_delta),
248 ) => {
249 signature.push_str(&sig_delta.signature);
250 }
251 _ => {}
252 }
253 }
254
255 fn build(self, stop_reason: Option<StopReason>) -> Result<Option<ContentBlock>, Error> {
256 match self {
257 ContentBlockBuilder::Empty => Ok(None),
258 ContentBlockBuilder::Text { text, citations, cache_control } => {
259 Ok(Some(ContentBlock::Text(TextBlock { text, citations, cache_control })))
260 }
261 ContentBlockBuilder::ToolUse {
262 id,
263 name,
264 input_json,
265 input_value,
266 saw_delta,
267 cache_control,
268 } => {
269 let input = if saw_delta {
270 if input_json.trim().is_empty() {
271 Value::Object(serde_json::Map::new())
272 } else {
273 match serde_json::from_str::<Value>(&input_json) {
274 Ok(value) => value,
275 Err(_err) => {
276 if stop_reason == Some(StopReason::MaxTokens) {
277 return Ok(None);
278 }
279 Value::String(input_json)
280 }
281 }
282 }
283 } else if let Some(input) = input_value {
284 input
285 } else if input_json.trim().is_empty() {
286 Value::Object(serde_json::Map::new())
287 } else {
288 match serde_json::from_str::<Value>(&input_json) {
289 Ok(value) => value,
290 Err(_err) => {
291 if stop_reason == Some(StopReason::MaxTokens) {
292 return Ok(None);
293 }
294 Value::String(input_json)
295 }
296 }
297 };
298 Ok(Some(ContentBlock::ToolUse(ToolUseBlock { id, name, input, cache_control })))
299 }
300 ContentBlockBuilder::ServerToolUse { id, name, input, cache_control } => {
301 Ok(Some(ContentBlock::ServerToolUse(ServerToolUseBlock {
302 id,
303 name,
304 input,
305 cache_control,
306 })))
307 }
308 ContentBlockBuilder::Thinking { thinking, signature } => {
309 Ok(Some(ContentBlock::Thinking(ThinkingBlock { thinking, signature })))
310 }
311 ContentBlockBuilder::Complete(block) => Ok(Some(block)),
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use crate::{
320 ContentBlockDeltaEvent, ContentBlockStartEvent, ContentBlockStopEvent, InputJsonDelta,
321 KnownModel, MessageDelta, MessageDeltaEvent, MessageDeltaUsage, MessageStartEvent, Model,
322 TextDelta, Usage,
323 };
324 use futures::stream;
325
326 #[tokio::test]
328 async fn cache_tokens_from_message_start_preserved() {
329 let usage_with_cache = Usage::new(100, 0)
331 .with_cache_creation_input_tokens(50)
332 .with_cache_read_input_tokens(25);
333
334 let start_message = Message::new(
335 "msg_test".to_string(),
336 Vec::new(),
337 Model::Known(KnownModel::ClaudeSonnet46),
338 usage_with_cache,
339 );
340 let start_event = MessageStreamEvent::MessageStart(MessageStartEvent::new(start_message));
341
342 let text_block = ContentBlock::Text(TextBlock::new(String::new()));
344 let content_start =
345 MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent::new(text_block, 0));
346
347 let text_delta = TextDelta::new("Hello".to_string());
349 let content_delta = MessageStreamEvent::ContentBlockDelta(
350 crate::ContentBlockDeltaEvent::new(ContentBlockDelta::TextDelta(text_delta), 0),
351 );
352
353 let delta_usage = MessageDeltaUsage::new(10);
355 let message_delta = MessageDelta::new().with_stop_reason(StopReason::EndTurn);
356 let delta_event =
357 MessageStreamEvent::MessageDelta(MessageDeltaEvent::new(message_delta, delta_usage));
358
359 let events = vec![Ok(start_event), Ok(content_start), Ok(content_delta), Ok(delta_event)];
361 let event_stream = stream::iter(events);
362
363 let (mut acc_stream, rx) = AccumulatingStream::new(event_stream);
364
365 use futures::StreamExt;
367 while acc_stream.next().await.is_some() {}
368
369 let message = rx.await.expect("channel closed").expect("accumulation failed");
371
372 println!("cache_creation_input_tokens: {:?}", message.usage.cache_creation_input_tokens);
375 println!("cache_read_input_tokens: {:?}", message.usage.cache_read_input_tokens);
376
377 assert_eq!(
378 message.usage.cache_creation_input_tokens,
379 Some(50),
380 "cache_creation_input_tokens should be preserved from message_start"
381 );
382 assert_eq!(
383 message.usage.cache_read_input_tokens,
384 Some(25),
385 "cache_read_input_tokens should be preserved from message_start"
386 );
387 assert_eq!(message.usage.output_tokens, 10, "output_tokens should be from message_delta");
388 }
389
390 #[tokio::test]
392 async fn empty_tool_input_becomes_empty_object() {
393 let usage = Usage::new(100, 0);
394 let start_message = Message::new(
395 "msg_test".to_string(),
396 Vec::new(),
397 Model::Known(KnownModel::ClaudeSonnet46),
398 usage,
399 );
400 let start_event = MessageStreamEvent::MessageStart(MessageStartEvent::new(start_message));
401
402 let tool_use_block =
404 ContentBlock::ToolUse(ToolUseBlock::new("tool_123", "get_document", Value::Null));
405 let content_start =
406 MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent::new(tool_use_block, 0));
407
408 let json_delta = InputJsonDelta::new(String::new());
410 let content_delta = MessageStreamEvent::ContentBlockDelta(ContentBlockDeltaEvent::new(
411 ContentBlockDelta::InputJsonDelta(json_delta),
412 0,
413 ));
414
415 let content_stop = MessageStreamEvent::ContentBlockStop(ContentBlockStopEvent::new(0));
417
418 let delta_usage = MessageDeltaUsage::new(10);
420 let message_delta = MessageDelta::new().with_stop_reason(StopReason::ToolUse);
421 let delta_event =
422 MessageStreamEvent::MessageDelta(MessageDeltaEvent::new(message_delta, delta_usage));
423
424 let events = vec![
425 Ok(start_event),
426 Ok(content_start),
427 Ok(content_delta),
428 Ok(content_stop),
429 Ok(delta_event),
430 ];
431 let event_stream = stream::iter(events);
432
433 let (mut acc_stream, rx) = AccumulatingStream::new(event_stream);
434
435 use futures::StreamExt;
436 while acc_stream.next().await.is_some() {}
437
438 let message = rx.await.expect("channel closed").expect("accumulation failed");
439
440 assert_eq!(message.content.len(), 1, "Should have one content block");
441 let tool_use = message.content[0].as_tool_use().expect("Expected ToolUseBlock");
442
443 assert!(
445 tool_use.input.is_object(),
446 "Empty tool input should be an object, not null. Got: {:?}",
447 tool_use.input
448 );
449 assert!(
450 tool_use.input.as_object().expect("input should be object").is_empty(),
451 "Empty tool input should be an empty object"
452 );
453 println!("tool_use.input: {:?}", tool_use.input);
454 }
455
456 #[tokio::test]
458 async fn tool_input_without_delta_uses_initial_value() {
459 let usage = Usage::new(100, 0);
460 let start_message = Message::new(
461 "msg_test".to_string(),
462 Vec::new(),
463 Model::Known(KnownModel::ClaudeSonnet46),
464 usage,
465 );
466 let start_event = MessageStreamEvent::MessageStart(MessageStartEvent::new(start_message));
467
468 let input = serde_json::json!({"key": "value"});
470 let tool_use_block =
471 ContentBlock::ToolUse(ToolUseBlock::new("tool_123", "get_document", input.clone()));
472 let content_start =
473 MessageStreamEvent::ContentBlockStart(ContentBlockStartEvent::new(tool_use_block, 0));
474
475 let content_stop = MessageStreamEvent::ContentBlockStop(ContentBlockStopEvent::new(0));
478
479 let delta_usage = MessageDeltaUsage::new(10);
480 let message_delta = MessageDelta::new().with_stop_reason(StopReason::ToolUse);
481 let delta_event =
482 MessageStreamEvent::MessageDelta(MessageDeltaEvent::new(message_delta, delta_usage));
483
484 let events = vec![Ok(start_event), Ok(content_start), Ok(content_stop), Ok(delta_event)];
485 let event_stream = stream::iter(events);
486
487 let (mut acc_stream, rx) = AccumulatingStream::new(event_stream);
488
489 use futures::StreamExt;
490 while acc_stream.next().await.is_some() {}
491
492 let message = rx.await.expect("channel closed").expect("accumulation failed");
493
494 let tool_use = message.content[0].as_tool_use().expect("Expected ToolUseBlock");
495
496 assert_eq!(tool_use.input, input, "Tool input should match initial value");
497 println!("tool_use.input: {:?}", tool_use.input);
498 }
499}