1use serde::Deserialize;
15
16use crate::llm::message::{ContentBlock, StopReason, Usage};
17
18#[derive(Debug, Clone)]
20pub enum StreamEvent {
21 TextDelta(String),
23
24 ContentBlockComplete(ContentBlock),
26
27 ToolUseStart { id: String, name: String },
29
30 ToolInputDelta { id: String, partial_json: String },
32
33 Done {
35 usage: Usage,
36 stop_reason: Option<StopReason>,
37 },
38
39 Ttft(u64),
41
42 Error(String),
44}
45
46#[derive(Debug, Deserialize)]
48#[serde(tag = "type")]
49pub enum RawSseEvent {
50 #[serde(rename = "message_start")]
51 MessageStart { message: MessageStartPayload },
52
53 #[serde(rename = "content_block_start")]
54 ContentBlockStart {
55 index: usize,
56 content_block: RawContentBlock,
57 },
58
59 #[serde(rename = "content_block_delta")]
60 ContentBlockDelta { index: usize, delta: RawDelta },
61
62 #[serde(rename = "content_block_stop")]
63 ContentBlockStop { index: usize },
64
65 #[serde(rename = "message_delta")]
66 MessageDelta {
67 delta: MessageDeltaPayload,
68 usage: Option<Usage>,
69 },
70
71 #[serde(rename = "message_stop")]
72 MessageStop {},
73
74 #[serde(rename = "ping")]
75 Ping {},
76
77 #[serde(rename = "error")]
78 Error { error: ErrorPayload },
79}
80
81#[derive(Debug, Deserialize)]
82pub struct MessageStartPayload {
83 pub id: Option<String>,
84 pub model: Option<String>,
85 pub usage: Option<Usage>,
86}
87
88#[derive(Debug, Deserialize)]
89pub struct MessageDeltaPayload {
90 pub stop_reason: Option<StopReason>,
91}
92
93#[derive(Debug, Deserialize)]
94pub struct ErrorPayload {
95 #[serde(rename = "type")]
96 pub error_type: Option<String>,
97 pub message: Option<String>,
98}
99
100#[derive(Debug, Deserialize)]
101#[serde(tag = "type")]
102pub enum RawContentBlock {
103 #[serde(rename = "text")]
104 Text { text: Option<String> },
105
106 #[serde(rename = "tool_use")]
107 ToolUse {
108 id: String,
109 name: String,
110 input: Option<serde_json::Value>,
111 },
112
113 #[serde(rename = "thinking")]
114 Thinking {
115 thinking: Option<String>,
116 signature: Option<String>,
117 },
118}
119
120#[derive(Debug, Deserialize)]
121#[serde(tag = "type")]
122#[allow(clippy::enum_variant_names)]
123pub enum RawDelta {
124 #[serde(rename = "text_delta")]
125 TextDelta { text: String },
126
127 #[serde(rename = "input_json_delta")]
128 InputJsonDelta { partial_json: String },
129
130 #[serde(rename = "thinking_delta")]
131 ThinkingDelta { thinking: String },
132
133 #[serde(rename = "signature_delta")]
134 SignatureDelta { signature: String },
135}
136
137pub struct StreamParser {
142 blocks: Vec<PartialBlock>,
144 usage: Usage,
146 pub model: Option<String>,
148 pub request_id: Option<String>,
150}
151
152enum PartialBlock {
154 Text(String),
155 ToolUse {
156 id: String,
157 name: String,
158 input_json: String,
159 },
160 Thinking {
161 thinking: String,
162 signature: String,
163 },
164}
165
166impl StreamParser {
167 pub fn new() -> Self {
168 Self {
169 blocks: Vec::new(),
170 usage: Usage::default(),
171 model: None,
172 request_id: None,
173 }
174 }
175
176 pub fn process(&mut self, raw: RawSseEvent) -> Vec<StreamEvent> {
178 match raw {
179 RawSseEvent::MessageStart { message } => {
180 if let Some(usage) = message.usage {
181 self.usage = usage;
182 }
183 self.model = message.model;
184 self.request_id = message.id;
185 vec![]
186 }
187
188 RawSseEvent::ContentBlockStart {
189 index,
190 content_block,
191 } => {
192 while self.blocks.len() <= index {
194 self.blocks.push(PartialBlock::Text(String::new()));
195 }
196
197 match content_block {
198 RawContentBlock::Text { text } => {
199 self.blocks[index] = PartialBlock::Text(text.unwrap_or_default());
200 vec![]
201 }
202 RawContentBlock::ToolUse { id, name, input: _ } => {
203 let event = StreamEvent::ToolUseStart {
204 id: id.clone(),
205 name: name.clone(),
206 };
207 self.blocks[index] = PartialBlock::ToolUse {
208 id,
209 name,
210 input_json: String::new(),
211 };
212 vec![event]
213 }
214 RawContentBlock::Thinking {
215 thinking,
216 signature,
217 } => {
218 self.blocks[index] = PartialBlock::Thinking {
219 thinking: thinking.unwrap_or_default(),
220 signature: signature.unwrap_or_default(),
221 };
222 vec![]
223 }
224 }
225 }
226
227 RawSseEvent::ContentBlockDelta { index, delta } => {
228 if index >= self.blocks.len() {
229 return vec![];
230 }
231
232 match delta {
233 RawDelta::TextDelta { text } => {
234 if let PartialBlock::Text(ref mut buf) = self.blocks[index] {
235 buf.push_str(&text);
236 }
237 vec![StreamEvent::TextDelta(text)]
238 }
239 RawDelta::InputJsonDelta { partial_json } => {
240 let mut events = vec![];
241 if let PartialBlock::ToolUse {
242 ref id,
243 ref mut input_json,
244 ..
245 } = self.blocks[index]
246 {
247 input_json.push_str(&partial_json);
248 events.push(StreamEvent::ToolInputDelta {
249 id: id.clone(),
250 partial_json,
251 });
252 }
253 events
254 }
255 RawDelta::ThinkingDelta { thinking } => {
256 if let PartialBlock::Thinking {
257 thinking: ref mut buf,
258 ..
259 } = self.blocks[index]
260 {
261 buf.push_str(&thinking);
262 }
263 vec![]
264 }
265 RawDelta::SignatureDelta { signature } => {
266 if let PartialBlock::Thinking {
267 signature: ref mut buf,
268 ..
269 } = self.blocks[index]
270 {
271 buf.push_str(&signature);
272 }
273 vec![]
274 }
275 }
276 }
277
278 RawSseEvent::ContentBlockStop { index } => {
279 if index >= self.blocks.len() {
280 return vec![];
281 }
282
283 let block =
284 std::mem::replace(&mut self.blocks[index], PartialBlock::Text(String::new()));
285
286 let content_block = match block {
287 PartialBlock::Text(text) => ContentBlock::Text { text },
288 PartialBlock::ToolUse {
289 id,
290 name,
291 input_json,
292 } => {
293 let input = serde_json::from_str(&input_json)
294 .unwrap_or(serde_json::Value::Object(Default::default()));
295 ContentBlock::ToolUse { id, name, input }
296 }
297 PartialBlock::Thinking {
298 thinking,
299 signature,
300 } => ContentBlock::Thinking {
301 thinking,
302 signature: if signature.is_empty() {
303 None
304 } else {
305 Some(signature)
306 },
307 },
308 };
309
310 vec![StreamEvent::ContentBlockComplete(content_block)]
311 }
312
313 RawSseEvent::MessageDelta { delta, usage } => {
314 if let Some(u) = usage {
315 self.usage.merge(&u);
316 }
317 vec![StreamEvent::Done {
318 usage: self.usage.clone(),
319 stop_reason: delta.stop_reason,
320 }]
321 }
322
323 RawSseEvent::MessageStop {} => vec![],
324
325 RawSseEvent::Ping {} => vec![],
326
327 RawSseEvent::Error { error } => {
328 let msg = error
329 .message
330 .unwrap_or_else(|| "Unknown stream error".to_string());
331 vec![StreamEvent::Error(msg)]
332 }
333 }
334 }
335}