1use crate::types::{ContentBlock, ContentBlockInfo, Delta, StreamEventPayload};
12
13#[derive(Debug, Clone)]
15enum PartialBlock {
16 Text { text: String },
18 Thinking { thinking: String },
20 ToolUse {
22 id: String,
23 name: String,
24 partial_json: String,
25 },
26}
27
28#[derive(Debug, Clone)]
30pub enum AssembledEvent {
31 MessageStart {
33 metadata: serde_json::Value,
35 },
36
37 ContentBlockComplete {
39 index: u64,
41 block: ContentBlock,
43 },
44
45 TextDelta {
48 index: u64,
50 text: String,
52 },
53
54 ThinkingDelta {
56 index: u64,
58 thinking: String,
60 },
61
62 MessageComplete {
64 stop_reason: Option<String>,
66 },
67}
68
69#[derive(Debug, Default)]
85pub struct StreamAssembler {
86 blocks: Vec<Option<PartialBlock>>,
88}
89
90impl StreamAssembler {
91 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn reset(&mut self) {
98 self.blocks.clear();
99 }
100
101 pub fn process(&mut self, event: &StreamEventPayload) -> Vec<AssembledEvent> {
103 match event {
104 StreamEventPayload::MessageStart { message } => {
105 self.reset();
106 vec![AssembledEvent::MessageStart {
107 metadata: message.clone(),
108 }]
109 }
110
111 StreamEventPayload::ContentBlockStart {
112 index,
113 content_block,
114 } => {
115 let idx = *index as usize;
116 if self.blocks.len() <= idx {
118 self.blocks.resize_with(idx + 1, || None);
119 }
120 self.blocks[idx] = Some(match content_block {
121 ContentBlockInfo::Text { text } => PartialBlock::Text { text: text.clone() },
122 ContentBlockInfo::Thinking { thinking } => PartialBlock::Thinking {
123 thinking: thinking.clone(),
124 },
125 ContentBlockInfo::ToolUse { id, name, .. } => PartialBlock::ToolUse {
126 id: id.clone(),
127 name: name.clone(),
128 partial_json: String::new(),
129 },
130 });
131 vec![]
132 }
133
134 StreamEventPayload::ContentBlockDelta { index, delta } => {
135 let idx = *index as usize;
136 let mut events = Vec::new();
137
138 if let Some(Some(partial)) = self.blocks.get_mut(idx) {
139 match (partial, delta) {
140 (PartialBlock::Text { text }, Delta::TextDelta { text: fragment }) => {
141 text.push_str(fragment);
142 events.push(AssembledEvent::TextDelta {
143 index: *index,
144 text: fragment.clone(),
145 });
146 }
147 (
148 PartialBlock::Thinking { thinking },
149 Delta::ThinkingDelta { thinking: fragment },
150 ) => {
151 thinking.push_str(fragment);
152 events.push(AssembledEvent::ThinkingDelta {
153 index: *index,
154 thinking: fragment.clone(),
155 });
156 }
157 (
158 PartialBlock::ToolUse { partial_json, .. },
159 Delta::InputJsonDelta {
160 partial_json: fragment,
161 },
162 ) => {
163 partial_json.push_str(fragment);
164 }
165 _ => {
166 }
168 }
169 }
170
171 events
172 }
173
174 StreamEventPayload::ContentBlockStop { index } => {
175 let idx = *index as usize;
176 let mut events = Vec::new();
177
178 if let Some(partial) = self.blocks.get_mut(idx).and_then(Option::take) {
179 let block = match partial {
180 PartialBlock::Text { text } => ContentBlock::Text { text },
181 PartialBlock::Thinking { thinking } => {
182 ContentBlock::Thinking {
185 thinking,
186 signature: String::new(),
187 }
188 }
189 PartialBlock::ToolUse {
190 id,
191 name,
192 partial_json,
193 } => {
194 let input = serde_json::from_str(&partial_json)
195 .unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
196 ContentBlock::ToolUse { id, name, input }
197 }
198 };
199 events.push(AssembledEvent::ContentBlockComplete {
200 index: *index,
201 block,
202 });
203 }
204
205 events
206 }
207
208 StreamEventPayload::MessageDelta { delta, .. } => {
209 let _stop_reason = delta
212 .get("stop_reason")
213 .and_then(|v| v.as_str())
214 .map(String::from);
215 vec![]
216 }
217
218 StreamEventPayload::MessageStop => {
219 vec![AssembledEvent::MessageComplete { stop_reason: None }]
220 }
221
222 StreamEventPayload::Unknown => vec![],
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn assemble_text_block() {
233 let mut asm = StreamAssembler::new();
234
235 let events = asm.process(&StreamEventPayload::MessageStart {
237 message: serde_json::json!({"id": "msg_1", "role": "assistant"}),
238 });
239 assert!(matches!(events[0], AssembledEvent::MessageStart { .. }));
240
241 let events = asm.process(&StreamEventPayload::ContentBlockStart {
243 index: 0,
244 content_block: ContentBlockInfo::Text {
245 text: String::new(),
246 },
247 });
248 assert!(events.is_empty());
249
250 let events = asm.process(&StreamEventPayload::ContentBlockDelta {
252 index: 0,
253 delta: Delta::TextDelta {
254 text: "Hello".to_owned(),
255 },
256 });
257 assert_eq!(events.len(), 1);
258 assert!(matches!(&events[0], AssembledEvent::TextDelta { text, .. } if text == "Hello"));
259
260 let events = asm.process(&StreamEventPayload::ContentBlockDelta {
261 index: 0,
262 delta: Delta::TextDelta {
263 text: " world".to_owned(),
264 },
265 });
266 assert_eq!(events.len(), 1);
267
268 let events = asm.process(&StreamEventPayload::ContentBlockStop { index: 0 });
270 assert_eq!(events.len(), 1);
271 match &events[0] {
272 AssembledEvent::ContentBlockComplete { block, .. } => match block {
273 ContentBlock::Text { text } => assert_eq!(text, "Hello world"),
274 other => panic!("expected Text, got {other:?}"),
275 },
276 other => panic!("expected ContentBlockComplete, got {other:?}"),
277 }
278
279 let events = asm.process(&StreamEventPayload::MessageStop);
281 assert!(matches!(events[0], AssembledEvent::MessageComplete { .. }));
282 }
283
284 #[test]
285 fn assemble_tool_use_block() {
286 let mut asm = StreamAssembler::new();
287
288 asm.process(&StreamEventPayload::MessageStart {
289 message: serde_json::json!({}),
290 });
291
292 asm.process(&StreamEventPayload::ContentBlockStart {
293 index: 0,
294 content_block: ContentBlockInfo::ToolUse {
295 id: "tu_1".to_owned(),
296 name: "Bash".to_owned(),
297 input: serde_json::Value::Object(serde_json::Map::new()),
298 },
299 });
300
301 asm.process(&StreamEventPayload::ContentBlockDelta {
302 index: 0,
303 delta: Delta::InputJsonDelta {
304 partial_json: r#"{"command":"#.to_owned(),
305 },
306 });
307 asm.process(&StreamEventPayload::ContentBlockDelta {
308 index: 0,
309 delta: Delta::InputJsonDelta {
310 partial_json: r#""ls -la"}"#.to_owned(),
311 },
312 });
313
314 let events = asm.process(&StreamEventPayload::ContentBlockStop { index: 0 });
315 assert_eq!(events.len(), 1);
316 match &events[0] {
317 AssembledEvent::ContentBlockComplete { block, .. } => match block {
318 ContentBlock::ToolUse { id, name, input } => {
319 assert_eq!(id, "tu_1");
320 assert_eq!(name, "Bash");
321 assert_eq!(input["command"], "ls -la");
322 }
323 other => panic!("expected ToolUse, got {other:?}"),
324 },
325 other => panic!("expected ContentBlockComplete, got {other:?}"),
326 }
327 }
328}