use crate::types::{ContentBlock, ContentBlockInfo, Delta, StreamEventPayload};
#[derive(Debug, Clone)]
enum PartialBlock {
Text { text: String },
Thinking { thinking: String },
ToolUse {
id: String,
name: String,
partial_json: String,
},
}
#[derive(Debug, Clone)]
pub enum AssembledEvent {
MessageStart {
metadata: serde_json::Value,
},
ContentBlockComplete {
index: u64,
block: ContentBlock,
},
TextDelta {
index: u64,
text: String,
},
ThinkingDelta {
index: u64,
thinking: String,
},
MessageComplete {
stop_reason: Option<String>,
},
}
#[derive(Debug, Default)]
pub struct StreamAssembler {
blocks: Vec<Option<PartialBlock>>,
}
impl StreamAssembler {
pub fn new() -> Self {
Self::default()
}
pub fn reset(&mut self) {
self.blocks.clear();
}
pub fn process(&mut self, event: &StreamEventPayload) -> Vec<AssembledEvent> {
match event {
StreamEventPayload::MessageStart { message } => {
self.reset();
vec![AssembledEvent::MessageStart {
metadata: message.clone(),
}]
}
StreamEventPayload::ContentBlockStart {
index,
content_block,
} => {
let idx = *index as usize;
if self.blocks.len() <= idx {
self.blocks.resize_with(idx + 1, || None);
}
self.blocks[idx] = Some(match content_block {
ContentBlockInfo::Text { text } => PartialBlock::Text { text: text.clone() },
ContentBlockInfo::Thinking { thinking } => PartialBlock::Thinking {
thinking: thinking.clone(),
},
ContentBlockInfo::ToolUse { id, name, .. } => PartialBlock::ToolUse {
id: id.clone(),
name: name.clone(),
partial_json: String::new(),
},
});
vec![]
}
StreamEventPayload::ContentBlockDelta { index, delta } => {
let idx = *index as usize;
let mut events = Vec::new();
if let Some(Some(partial)) = self.blocks.get_mut(idx) {
match (partial, delta) {
(PartialBlock::Text { text }, Delta::TextDelta { text: fragment }) => {
text.push_str(fragment);
events.push(AssembledEvent::TextDelta {
index: *index,
text: fragment.clone(),
});
}
(
PartialBlock::Thinking { thinking },
Delta::ThinkingDelta { thinking: fragment },
) => {
thinking.push_str(fragment);
events.push(AssembledEvent::ThinkingDelta {
index: *index,
thinking: fragment.clone(),
});
}
(
PartialBlock::ToolUse { partial_json, .. },
Delta::InputJsonDelta {
partial_json: fragment,
},
) => {
partial_json.push_str(fragment);
}
_ => {
}
}
}
events
}
StreamEventPayload::ContentBlockStop { index } => {
let idx = *index as usize;
let mut events = Vec::new();
if let Some(partial) = self.blocks.get_mut(idx).and_then(Option::take) {
let block = match partial {
PartialBlock::Text { text } => ContentBlock::Text { text },
PartialBlock::Thinking { thinking } => {
ContentBlock::Thinking {
thinking,
signature: String::new(),
}
}
PartialBlock::ToolUse {
id,
name,
partial_json,
} => {
let input = serde_json::from_str(&partial_json)
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
ContentBlock::ToolUse { id, name, input }
}
};
events.push(AssembledEvent::ContentBlockComplete {
index: *index,
block,
});
}
events
}
StreamEventPayload::MessageDelta { delta, .. } => {
let _stop_reason = delta
.get("stop_reason")
.and_then(|v| v.as_str())
.map(String::from);
vec![]
}
StreamEventPayload::MessageStop => {
vec![AssembledEvent::MessageComplete { stop_reason: None }]
}
StreamEventPayload::Unknown => vec![],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn assemble_text_block() {
let mut asm = StreamAssembler::new();
let events = asm.process(&StreamEventPayload::MessageStart {
message: serde_json::json!({"id": "msg_1", "role": "assistant"}),
});
assert!(matches!(events[0], AssembledEvent::MessageStart { .. }));
let events = asm.process(&StreamEventPayload::ContentBlockStart {
index: 0,
content_block: ContentBlockInfo::Text {
text: String::new(),
},
});
assert!(events.is_empty());
let events = asm.process(&StreamEventPayload::ContentBlockDelta {
index: 0,
delta: Delta::TextDelta {
text: "Hello".to_owned(),
},
});
assert_eq!(events.len(), 1);
assert!(matches!(&events[0], AssembledEvent::TextDelta { text, .. } if text == "Hello"));
let events = asm.process(&StreamEventPayload::ContentBlockDelta {
index: 0,
delta: Delta::TextDelta {
text: " world".to_owned(),
},
});
assert_eq!(events.len(), 1);
let events = asm.process(&StreamEventPayload::ContentBlockStop { index: 0 });
assert_eq!(events.len(), 1);
match &events[0] {
AssembledEvent::ContentBlockComplete { block, .. } => match block {
ContentBlock::Text { text } => assert_eq!(text, "Hello world"),
other => panic!("expected Text, got {other:?}"),
},
other => panic!("expected ContentBlockComplete, got {other:?}"),
}
let events = asm.process(&StreamEventPayload::MessageStop);
assert!(matches!(events[0], AssembledEvent::MessageComplete { .. }));
}
#[test]
fn assemble_tool_use_block() {
let mut asm = StreamAssembler::new();
asm.process(&StreamEventPayload::MessageStart {
message: serde_json::json!({}),
});
asm.process(&StreamEventPayload::ContentBlockStart {
index: 0,
content_block: ContentBlockInfo::ToolUse {
id: "tu_1".to_owned(),
name: "Bash".to_owned(),
input: serde_json::Value::Object(serde_json::Map::new()),
},
});
asm.process(&StreamEventPayload::ContentBlockDelta {
index: 0,
delta: Delta::InputJsonDelta {
partial_json: r#"{"command":"#.to_owned(),
},
});
asm.process(&StreamEventPayload::ContentBlockDelta {
index: 0,
delta: Delta::InputJsonDelta {
partial_json: r#""ls -la"}"#.to_owned(),
},
});
let events = asm.process(&StreamEventPayload::ContentBlockStop { index: 0 });
assert_eq!(events.len(), 1);
match &events[0] {
AssembledEvent::ContentBlockComplete { block, .. } => match block {
ContentBlock::ToolUse { id, name, input } => {
assert_eq!(id, "tu_1");
assert_eq!(name, "Bash");
assert_eq!(input["command"], "ls -la");
}
other => panic!("expected ToolUse, got {other:?}"),
},
other => panic!("expected ContentBlockComplete, got {other:?}"),
}
}
}