1use serde::Deserialize;
15
16use crate::llm::message::{ContentBlock, StopReason, Usage};
17
18#[derive(Debug, Clone)]
20pub enum StreamEvent {
21 TextDelta(String),
23
24 ContentBlockComplete(ContentBlock),
26
27 ToolUseStart { id: String, name: String },
29
30 ToolInputDelta { id: String, partial_json: String },
32
33 Done {
35 usage: Usage,
36 stop_reason: Option<StopReason>,
37 },
38
39 Ttft(u64),
41
42 Error(String),
44}
45
46#[derive(Debug, Deserialize)]
48#[serde(tag = "type")]
49pub enum RawSseEvent {
50 #[serde(rename = "message_start")]
51 MessageStart { message: MessageStartPayload },
52
53 #[serde(rename = "content_block_start")]
54 ContentBlockStart {
55 index: usize,
56 content_block: RawContentBlock,
57 },
58
59 #[serde(rename = "content_block_delta")]
60 ContentBlockDelta { index: usize, delta: RawDelta },
61
62 #[serde(rename = "content_block_stop")]
63 ContentBlockStop { index: usize },
64
65 #[serde(rename = "message_delta")]
66 MessageDelta {
67 delta: MessageDeltaPayload,
68 usage: Option<Usage>,
69 },
70
71 #[serde(rename = "message_stop")]
72 MessageStop {},
73
74 #[serde(rename = "ping")]
75 Ping {},
76
77 #[serde(rename = "error")]
78 Error { error: ErrorPayload },
79}
80
81#[derive(Debug, Deserialize)]
83pub struct MessageStartPayload {
84 pub id: Option<String>,
85 pub model: Option<String>,
86 pub usage: Option<Usage>,
87}
88
89#[derive(Debug, Deserialize)]
91pub struct MessageDeltaPayload {
92 pub stop_reason: Option<StopReason>,
93}
94
95#[derive(Debug, Deserialize)]
97pub struct ErrorPayload {
98 #[serde(rename = "type")]
99 pub error_type: Option<String>,
100 pub message: Option<String>,
101}
102
103#[derive(Debug, Deserialize)]
105#[serde(tag = "type")]
106pub enum RawContentBlock {
107 #[serde(rename = "text")]
108 Text { text: Option<String> },
109
110 #[serde(rename = "tool_use")]
111 ToolUse {
112 id: String,
113 name: String,
114 input: Option<serde_json::Value>,
115 },
116
117 #[serde(rename = "thinking")]
118 Thinking {
119 thinking: Option<String>,
120 signature: Option<String>,
121 },
122}
123
124#[derive(Debug, Deserialize)]
126#[serde(tag = "type")]
127#[allow(clippy::enum_variant_names)]
128pub enum RawDelta {
129 #[serde(rename = "text_delta")]
130 TextDelta { text: String },
131
132 #[serde(rename = "input_json_delta")]
133 InputJsonDelta { partial_json: String },
134
135 #[serde(rename = "thinking_delta")]
136 ThinkingDelta { thinking: String },
137
138 #[serde(rename = "signature_delta")]
139 SignatureDelta { signature: String },
140}
141
142pub struct StreamParser {
147 blocks: Vec<PartialBlock>,
149 usage: Usage,
151 pub model: Option<String>,
153 pub request_id: Option<String>,
155}
156
157enum PartialBlock {
159 Text(String),
160 ToolUse {
161 id: String,
162 name: String,
163 input_json: String,
164 },
165 Thinking {
166 thinking: String,
167 signature: String,
168 },
169}
170
171impl StreamParser {
172 pub fn new() -> Self {
173 Self {
174 blocks: Vec::new(),
175 usage: Usage::default(),
176 model: None,
177 request_id: None,
178 }
179 }
180
181 pub fn process(&mut self, raw: RawSseEvent) -> Vec<StreamEvent> {
183 match raw {
184 RawSseEvent::MessageStart { message } => {
185 if let Some(usage) = message.usage {
186 self.usage = usage;
187 }
188 self.model = message.model;
189 self.request_id = message.id;
190 vec![]
191 }
192
193 RawSseEvent::ContentBlockStart {
194 index,
195 content_block,
196 } => {
197 while self.blocks.len() <= index {
199 self.blocks.push(PartialBlock::Text(String::new()));
200 }
201
202 match content_block {
203 RawContentBlock::Text { text } => {
204 self.blocks[index] = PartialBlock::Text(text.unwrap_or_default());
205 vec![]
206 }
207 RawContentBlock::ToolUse { id, name, input: _ } => {
208 let event = StreamEvent::ToolUseStart {
209 id: id.clone(),
210 name: name.clone(),
211 };
212 self.blocks[index] = PartialBlock::ToolUse {
213 id,
214 name,
215 input_json: String::new(),
216 };
217 vec![event]
218 }
219 RawContentBlock::Thinking {
220 thinking,
221 signature,
222 } => {
223 self.blocks[index] = PartialBlock::Thinking {
224 thinking: thinking.unwrap_or_default(),
225 signature: signature.unwrap_or_default(),
226 };
227 vec![]
228 }
229 }
230 }
231
232 RawSseEvent::ContentBlockDelta { index, delta } => {
233 if index >= self.blocks.len() {
234 return vec![];
235 }
236
237 match delta {
238 RawDelta::TextDelta { text } => {
239 if let PartialBlock::Text(ref mut buf) = self.blocks[index] {
240 buf.push_str(&text);
241 }
242 vec![StreamEvent::TextDelta(text)]
243 }
244 RawDelta::InputJsonDelta { partial_json } => {
245 let mut events = vec![];
246 if let PartialBlock::ToolUse {
247 ref id,
248 ref mut input_json,
249 ..
250 } = self.blocks[index]
251 {
252 input_json.push_str(&partial_json);
253 events.push(StreamEvent::ToolInputDelta {
254 id: id.clone(),
255 partial_json,
256 });
257 }
258 events
259 }
260 RawDelta::ThinkingDelta { thinking } => {
261 if let PartialBlock::Thinking {
262 thinking: ref mut buf,
263 ..
264 } = self.blocks[index]
265 {
266 buf.push_str(&thinking);
267 }
268 vec![]
269 }
270 RawDelta::SignatureDelta { signature } => {
271 if let PartialBlock::Thinking {
272 signature: ref mut buf,
273 ..
274 } = self.blocks[index]
275 {
276 buf.push_str(&signature);
277 }
278 vec![]
279 }
280 }
281 }
282
283 RawSseEvent::ContentBlockStop { index } => {
284 if index >= self.blocks.len() {
285 return vec![];
286 }
287
288 let block =
289 std::mem::replace(&mut self.blocks[index], PartialBlock::Text(String::new()));
290
291 let content_block = match block {
292 PartialBlock::Text(text) => ContentBlock::Text { text },
293 PartialBlock::ToolUse {
294 id,
295 name,
296 input_json,
297 } => {
298 let input = serde_json::from_str(&input_json)
299 .unwrap_or(serde_json::Value::Object(Default::default()));
300 ContentBlock::ToolUse { id, name, input }
301 }
302 PartialBlock::Thinking {
303 thinking,
304 signature,
305 } => ContentBlock::Thinking {
306 thinking,
307 signature: if signature.is_empty() {
308 None
309 } else {
310 Some(signature)
311 },
312 },
313 };
314
315 vec![StreamEvent::ContentBlockComplete(content_block)]
316 }
317
318 RawSseEvent::MessageDelta { delta, usage } => {
319 if let Some(u) = usage {
320 self.usage.merge(&u);
321 }
322 vec![StreamEvent::Done {
323 usage: self.usage.clone(),
324 stop_reason: delta.stop_reason,
325 }]
326 }
327
328 RawSseEvent::MessageStop {} => vec![],
329
330 RawSseEvent::Ping {} => vec![],
331
332 RawSseEvent::Error { error } => {
333 let msg = error
334 .message
335 .unwrap_or_else(|| "Unknown stream error".to_string());
336 vec![StreamEvent::Error(msg)]
337 }
338 }
339 }
340}