1use crate::tasks::generate::ToolCall;
7use crate::TokenUsage;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub enum StreamEvent {
13 TextDelta(String),
15 ToolCallStart { name: String, index: usize, id: Option<String> },
17 ToolCallDelta {
19 index: usize,
20 arguments_delta: String,
21 },
22 Usage {
30 input_tokens: u64,
31 output_tokens: u64,
32 },
33 Done {
35 text: String,
36 tool_calls: Vec<ToolCall>,
37 },
38}
39
40pub fn parse_openai_sse_line(line: &str) -> Vec<StreamEvent> {
43 let data = match line.strip_prefix("data: ") {
44 Some(d) => d,
45 None => return Vec::new(),
46 };
47 if data == "[DONE]" {
48 return Vec::new();
49 }
50
51 let json: serde_json::Value = match serde_json::from_str(data) {
52 Ok(v) => v,
53 Err(_) => return Vec::new(),
54 };
55
56 let mut events = Vec::new();
57
58 if let Some(delta) = json
61 .get("choices")
62 .and_then(|c| c.as_array())
63 .and_then(|c| c.first())
64 .and_then(|c| c.get("delta"))
65 {
66 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
67 if !content.is_empty() {
68 events.push(StreamEvent::TextDelta(content.to_string()));
69 }
70 }
71
72 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
74 for tc in tool_calls {
75 let index = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
76 if let Some(function) = tc.get("function") {
77 if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
78 let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
79 events.push(StreamEvent::ToolCallStart {
80 name: name.to_string(),
81 index,
82 id,
83 });
84 }
85 if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
86 if !args.is_empty() {
87 events.push(StreamEvent::ToolCallDelta {
88 index,
89 arguments_delta: args.to_string(),
90 });
91 }
92 }
93 }
94 }
95 }
96 }
97
98 if let Some(usage) = json.get("usage") {
102 let input = usage
103 .get("prompt_tokens")
104 .and_then(|n| n.as_u64())
105 .unwrap_or(0);
106 let output = usage
107 .get("completion_tokens")
108 .and_then(|n| n.as_u64())
109 .unwrap_or(0);
110 if input != 0 || output != 0 {
111 events.push(StreamEvent::Usage {
112 input_tokens: input,
113 output_tokens: output,
114 });
115 }
116 }
117
118 events
119}
120
121pub fn parse_anthropic_sse_line(event_type: &str, data: &str) -> Vec<StreamEvent> {
123 match event_type {
124 "content_block_delta" => {
125 let json: serde_json::Value = match serde_json::from_str(data) {
126 Ok(v) => v,
127 Err(_) => return Vec::new(),
128 };
129 let delta = match json.get("delta") {
130 Some(d) => d,
131 None => return Vec::new(),
132 };
133 let delta_type = match delta.get("type").and_then(|t| t.as_str()) {
134 Some(t) => t,
135 None => return Vec::new(),
136 };
137
138 match delta_type {
139 "text_delta" => match delta.get("text").and_then(|t| t.as_str()) {
140 Some(text) => vec![StreamEvent::TextDelta(text.to_string())],
141 None => Vec::new(),
142 },
143 "input_json_delta" => match delta.get("partial_json").and_then(|p| p.as_str()) {
144 Some(partial) => {
145 let index =
146 json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
147 vec![StreamEvent::ToolCallDelta {
148 index,
149 arguments_delta: partial.to_string(),
150 }]
151 }
152 None => Vec::new(),
153 },
154 _ => Vec::new(),
155 }
156 }
157 "content_block_start" => {
158 let json: serde_json::Value = match serde_json::from_str(data) {
159 Ok(v) => v,
160 Err(_) => return Vec::new(),
161 };
162 let block = match json.get("content_block") {
163 Some(b) => b,
164 None => return Vec::new(),
165 };
166 if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
167 if let Some(name) = block.get("name").and_then(|n| n.as_str()) {
168 let index = json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
169 let id = block.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
170 return vec![StreamEvent::ToolCallStart {
171 name: name.to_string(),
172 index,
173 id,
174 }];
175 }
176 }
177 Vec::new()
178 }
179 "message_start" => {
183 let json: serde_json::Value = match serde_json::from_str(data) {
184 Ok(v) => v,
185 Err(_) => return Vec::new(),
186 };
187 let Some(usage) = json.pointer("/message/usage") else {
188 return Vec::new();
189 };
190 let input = usage.get("input_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
191 let output = usage.get("output_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
192 if input == 0 && output == 0 {
193 return Vec::new();
194 }
195 vec![StreamEvent::Usage { input_tokens: input, output_tokens: output }]
196 }
197 "message_delta" => {
201 let json: serde_json::Value = match serde_json::from_str(data) {
202 Ok(v) => v,
203 Err(_) => return Vec::new(),
204 };
205 let Some(usage) = json.get("usage") else {
206 return Vec::new();
207 };
208 let input = usage.get("input_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
209 let output = usage.get("output_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
210 if input == 0 && output == 0 {
211 return Vec::new();
212 }
213 vec![StreamEvent::Usage { input_tokens: input, output_tokens: output }]
214 }
215 _ => Vec::new(),
216 }
217}
218
219#[derive(Default)]
221pub struct StreamAccumulator {
222 pub text: String,
223 tool_names: HashMap<usize, String>,
224 tool_args: HashMap<usize, String>,
225 tool_ids: HashMap<usize, String>,
226 input_tokens: u64,
231 output_tokens: u64,
235 saw_usage: bool,
240}
241
242impl StreamAccumulator {
243 pub fn push(&mut self, event: &StreamEvent) {
244 match event {
245 StreamEvent::TextDelta(t) => self.text.push_str(t),
246 StreamEvent::ToolCallStart { name, index, id } => {
247 self.tool_names.insert(*index, name.clone());
248 self.tool_args.entry(*index).or_default();
249 if let Some(id) = id {
250 self.tool_ids.insert(*index, id.clone());
251 }
252 }
253 StreamEvent::ToolCallDelta {
254 index,
255 arguments_delta,
256 } => {
257 self.tool_args
258 .entry(*index)
259 .or_default()
260 .push_str(arguments_delta);
261 }
262 StreamEvent::Usage { input_tokens, output_tokens } => {
263 self.saw_usage = true;
264 if *input_tokens > self.input_tokens {
269 self.input_tokens = *input_tokens;
270 }
271 if *output_tokens > self.output_tokens {
272 self.output_tokens = *output_tokens;
273 }
274 }
275 StreamEvent::Done { .. } => {}
276 }
277 }
278
279 pub fn finish(self) -> (String, Vec<ToolCall>) {
280 let (text, tool_calls, _) = self.finish_with_usage();
281 (text, tool_calls)
282 }
283
284 pub fn finish_with_usage(self) -> (String, Vec<ToolCall>, Option<TokenUsage>) {
289 let mut tool_calls = Vec::new();
290 let mut indices: Vec<usize> = self.tool_names.keys().copied().collect();
291 indices.sort();
292
293 for idx in indices {
294 let id = self.tool_ids.get(&idx).cloned();
295 let name = self.tool_names.get(&idx).cloned().unwrap_or_default();
296 let args_str = self.tool_args.get(&idx).cloned().unwrap_or_default();
297 let arguments: HashMap<String, serde_json::Value> =
298 serde_json::from_str(&args_str).unwrap_or_default();
299 tool_calls.push(ToolCall { id, name, arguments });
300 }
301
302 let usage = if self.saw_usage {
303 Some(TokenUsage {
304 prompt_tokens: self.input_tokens,
305 completion_tokens: self.output_tokens,
306 total_tokens: self.input_tokens + self.output_tokens,
307 context_window: 0,
311 })
312 } else {
313 None
314 };
315
316 (self.text, tool_calls, usage)
317 }
318}
319
320pub fn parse_sse_lines(chunk: &str) -> Vec<(String, String)> {
323 let mut events = Vec::new();
324 let mut current_event = String::new();
325 let mut current_data = String::new();
326
327 for line in chunk.lines() {
328 if line.starts_with("event: ") {
329 current_event = line[7..].to_string();
330 } else if line.starts_with("data: ") {
331 current_data = line[6..].to_string();
332 } else if line.is_empty() && !current_data.is_empty() {
333 events.push((
334 if current_event.is_empty() {
335 "message".to_string()
336 } else {
337 current_event.clone()
338 },
339 current_data.clone(),
340 ));
341 current_event.clear();
342 current_data.clear();
343 }
344 }
345
346 if !current_data.is_empty() {
348 events.push((
349 if current_event.is_empty() {
350 "message".to_string()
351 } else {
352 current_event
353 },
354 current_data,
355 ));
356 }
357
358 events
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn parse_openai_text_delta() {
367 let line = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#;
368 let events = parse_openai_sse_line(line);
369 assert_eq!(events.len(), 1);
370 match &events[0] {
371 StreamEvent::TextDelta(t) => assert_eq!(t, "Hello"),
372 other => panic!("expected TextDelta, got {:?}", other),
373 }
374 }
375
376 #[test]
377 fn parse_openai_tool_call_start() {
378 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"edit_file"}}]}}]}"#;
379 let events = parse_openai_sse_line(line);
380 assert_eq!(events.len(), 1);
381 match &events[0] {
382 StreamEvent::ToolCallStart { name, index, .. } => {
383 assert_eq!(name, "edit_file");
384 assert_eq!(*index, 0);
385 }
386 other => panic!("expected ToolCallStart, got {:?}", other),
387 }
388 }
389
390 #[test]
391 fn parse_openai_tool_call_delta() {
392 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":"}}]}}]}"#;
393 let events = parse_openai_sse_line(line);
394 assert_eq!(events.len(), 1);
395 match &events[0] {
396 StreamEvent::ToolCallDelta {
397 index,
398 arguments_delta,
399 } => {
400 assert_eq!(*index, 0);
401 assert!(arguments_delta.contains("path"));
402 }
403 other => panic!("expected ToolCallDelta, got {:?}", other),
404 }
405 }
406
407 #[test]
408 fn parse_openai_multiple_tool_calls_in_chunk() {
409 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"read_file"}},{"index":1,"function":{"name":"search"}}]}}]}"#;
411 let events = parse_openai_sse_line(line);
412 assert_eq!(events.len(), 2);
413 match &events[0] {
414 StreamEvent::ToolCallStart { name, index, .. } => {
415 assert_eq!(name, "read_file");
416 assert_eq!(*index, 0);
417 }
418 other => panic!("expected ToolCallStart, got {:?}", other),
419 }
420 match &events[1] {
421 StreamEvent::ToolCallStart { name, index, .. } => {
422 assert_eq!(name, "search");
423 assert_eq!(*index, 1);
424 }
425 other => panic!("expected ToolCallStart, got {:?}", other),
426 }
427 }
428
429 #[test]
430 fn parse_openai_done() {
431 assert!(parse_openai_sse_line("data: [DONE]").is_empty());
432 }
433
434 #[test]
435 fn parse_anthropic_text_delta() {
436 let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"world"}}"#;
437 let events = parse_anthropic_sse_line("content_block_delta", data);
438 assert_eq!(events.len(), 1);
439 match &events[0] {
440 StreamEvent::TextDelta(t) => assert_eq!(t, "world"),
441 other => panic!("expected TextDelta, got {:?}", other),
442 }
443 }
444
445 #[test]
446 fn parse_anthropic_tool_start() {
447 let data = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"t1","name":"search","input":{}}}"#;
448 let events = parse_anthropic_sse_line("content_block_start", data);
449 assert_eq!(events.len(), 1);
450 match &events[0] {
451 StreamEvent::ToolCallStart { name, index, .. } => {
452 assert_eq!(name, "search");
453 assert_eq!(*index, 1);
454 }
455 other => panic!("expected ToolCallStart, got {:?}", other),
456 }
457 }
458
459 #[test]
460 fn accumulator_builds_result() {
461 let mut acc = StreamAccumulator::default();
462 acc.push(&StreamEvent::TextDelta("Hello ".into()));
463 acc.push(&StreamEvent::TextDelta("world".into()));
464 acc.push(&StreamEvent::ToolCallStart {
465 name: "search".into(),
466 index: 0,
467 id: None,
468 });
469 acc.push(&StreamEvent::ToolCallDelta {
470 index: 0,
471 arguments_delta: r#"{"q":"test"}"#.into(),
472 });
473
474 let (text, tools) = acc.finish();
475 assert_eq!(text, "Hello world");
476 assert_eq!(tools.len(), 1);
477 assert_eq!(tools[0].name, "search");
478 assert!(tools[0].arguments.contains_key("q"));
479 }
480
481 #[test]
482 fn parse_sse_lines_openai_format() {
483 let chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n";
484 let events = parse_sse_lines(chunk);
485 assert_eq!(events.len(), 2);
486 assert_eq!(events[0].0, "message");
487 assert_eq!(events[1].1, "[DONE]");
488 }
489
490 #[test]
491 fn parse_sse_lines_anthropic_format() {
492 let chunk = "event: content_block_delta\ndata: {\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n";
493 let events = parse_sse_lines(chunk);
494 assert_eq!(events.len(), 1);
495 assert_eq!(events[0].0, "content_block_delta");
496 }
497
498 #[test]
499 fn parse_anthropic_message_start_emits_usage() {
500 let data = r#"{"type":"message_start","message":{"id":"msg_1","role":"assistant","usage":{"input_tokens":245,"output_tokens":1}}}"#;
501 let events = parse_anthropic_sse_line("message_start", data);
502 assert_eq!(events.len(), 1);
503 match &events[0] {
504 StreamEvent::Usage { input_tokens, output_tokens } => {
505 assert_eq!(*input_tokens, 245);
506 assert_eq!(*output_tokens, 1);
507 }
508 other => panic!("expected Usage, got {:?}", other),
509 }
510 }
511
512 #[test]
513 fn parse_anthropic_message_delta_emits_usage() {
514 let data = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":87}}"#;
515 let events = parse_anthropic_sse_line("message_delta", data);
516 assert_eq!(events.len(), 1);
517 match &events[0] {
518 StreamEvent::Usage { input_tokens, output_tokens } => {
519 assert_eq!(*input_tokens, 0);
520 assert_eq!(*output_tokens, 87);
521 }
522 other => panic!("expected Usage, got {:?}", other),
523 }
524 }
525
526 #[test]
527 fn parse_anthropic_message_start_without_usage_is_empty() {
528 let data = r#"{"type":"message_start","message":{"id":"msg_1"}}"#;
530 assert!(parse_anthropic_sse_line("message_start", data).is_empty());
531 }
532
533 #[test]
534 fn accumulator_tracks_usage_across_anthropic_stream() {
535 let mut acc = StreamAccumulator::default();
538 for event in parse_anthropic_sse_line(
539 "message_start",
540 r#"{"message":{"usage":{"input_tokens":245,"output_tokens":1}}}"#,
541 ) {
542 acc.push(&event);
543 }
544 for event in parse_anthropic_sse_line(
545 "content_block_start",
546 r#"{"index":0,"content_block":{"type":"text","text":""}}"#,
547 ) {
548 acc.push(&event);
549 }
550 for (chunk, _) in [
551 (r#"{"delta":{"type":"text_delta","text":"Hello"}}"#, ()),
552 (r#"{"delta":{"type":"text_delta","text":", "}}"#, ()),
553 (r#"{"delta":{"type":"text_delta","text":"world"}}"#, ()),
554 ] {
555 for event in parse_anthropic_sse_line("content_block_delta", chunk) {
556 acc.push(&event);
557 }
558 }
559 for event in parse_anthropic_sse_line(
560 "message_delta",
561 r#"{"usage":{"output_tokens":87}}"#,
562 ) {
563 acc.push(&event);
564 }
565
566 let (text, tools, usage) = acc.finish_with_usage();
567 assert_eq!(text, "Hello, world");
568 assert!(tools.is_empty());
569 let usage = usage.expect("provider reported usage; must surface");
570 assert_eq!(usage.prompt_tokens, 245);
571 assert_eq!(usage.completion_tokens, 87);
573 assert_eq!(usage.total_tokens, 332);
574 }
575
576 #[test]
577 fn parse_openai_final_chunk_emits_usage() {
578 let line = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87,"total_tokens":332}}"#;
581 let events = parse_openai_sse_line(line);
582 assert_eq!(events.len(), 1);
583 match &events[0] {
584 StreamEvent::Usage {
585 input_tokens,
586 output_tokens,
587 } => {
588 assert_eq!(*input_tokens, 245);
589 assert_eq!(*output_tokens, 87);
590 }
591 other => panic!("expected Usage, got {:?}", other),
592 }
593 }
594
595 #[test]
596 fn accumulator_tracks_usage_across_openai_stream() {
597 let mut acc = StreamAccumulator::default();
600 for line in [
601 r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#,
602 r#"data: {"choices":[{"delta":{"content":", "}}]}"#,
603 r#"data: {"choices":[{"delta":{"content":"world"}}]}"#,
604 r#"data: {"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87}}"#,
605 ] {
606 for event in parse_openai_sse_line(line) {
607 acc.push(&event);
608 }
609 }
610
611 let (text, tools, usage) = acc.finish_with_usage();
612 assert_eq!(text, "Hello, world");
613 assert!(tools.is_empty());
614 let usage = usage.expect("provider reported usage; must surface");
615 assert_eq!(usage.prompt_tokens, 245);
616 assert_eq!(usage.completion_tokens, 87);
617 assert_eq!(usage.total_tokens, 332);
618 }
619
620 #[test]
621 fn accumulator_returns_no_usage_when_provider_silent() {
622 let mut acc = StreamAccumulator::default();
626 acc.push(&StreamEvent::TextDelta("hi".into()));
627 let (_, _, usage) = acc.finish_with_usage();
628 assert!(usage.is_none());
629 }
630}