1use crate::tasks::generate::ToolCall;
7use crate::TokenUsage;
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub enum StreamEvent {
13 TextDelta(String),
15 ToolCallStart {
17 name: String,
18 index: usize,
19 id: Option<String>,
20 },
21 ToolCallDelta {
23 index: usize,
24 arguments_delta: String,
25 },
26 Usage {
34 input_tokens: u64,
35 output_tokens: u64,
36 },
37 Done {
39 text: String,
40 tool_calls: Vec<ToolCall>,
41 },
42}
43
44pub fn parse_openai_sse_line(line: &str) -> Vec<StreamEvent> {
47 let data = match line.strip_prefix("data: ") {
48 Some(d) => d,
49 None => return Vec::new(),
50 };
51 if data == "[DONE]" {
52 return Vec::new();
53 }
54
55 let json: serde_json::Value = match serde_json::from_str(data) {
56 Ok(v) => v,
57 Err(_) => return Vec::new(),
58 };
59
60 let mut events = Vec::new();
61
62 if let Some(delta) = json
65 .get("choices")
66 .and_then(|c| c.as_array())
67 .and_then(|c| c.first())
68 .and_then(|c| c.get("delta"))
69 {
70 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
71 if !content.is_empty() {
72 events.push(StreamEvent::TextDelta(content.to_string()));
73 }
74 }
75
76 if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
78 for tc in tool_calls {
79 let index = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
80 if let Some(function) = tc.get("function") {
81 if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
82 let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
83 events.push(StreamEvent::ToolCallStart {
84 name: name.to_string(),
85 index,
86 id,
87 });
88 }
89 if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
90 if !args.is_empty() {
91 events.push(StreamEvent::ToolCallDelta {
92 index,
93 arguments_delta: args.to_string(),
94 });
95 }
96 }
97 }
98 }
99 }
100 }
101
102 if let Some(usage) = json.get("usage") {
106 let input = usage
107 .get("prompt_tokens")
108 .and_then(|n| n.as_u64())
109 .unwrap_or(0);
110 let output = usage
111 .get("completion_tokens")
112 .and_then(|n| n.as_u64())
113 .unwrap_or(0);
114 if input != 0 || output != 0 {
115 events.push(StreamEvent::Usage {
116 input_tokens: input,
117 output_tokens: output,
118 });
119 }
120 }
121
122 events
123}
124
125pub fn parse_anthropic_sse_line(event_type: &str, data: &str) -> Vec<StreamEvent> {
127 match event_type {
128 "content_block_delta" => {
129 let json: serde_json::Value = match serde_json::from_str(data) {
130 Ok(v) => v,
131 Err(_) => return Vec::new(),
132 };
133 let delta = match json.get("delta") {
134 Some(d) => d,
135 None => return Vec::new(),
136 };
137 let delta_type = match delta.get("type").and_then(|t| t.as_str()) {
138 Some(t) => t,
139 None => return Vec::new(),
140 };
141
142 match delta_type {
143 "text_delta" => match delta.get("text").and_then(|t| t.as_str()) {
144 Some(text) => vec![StreamEvent::TextDelta(text.to_string())],
145 None => Vec::new(),
146 },
147 "input_json_delta" => match delta.get("partial_json").and_then(|p| p.as_str()) {
148 Some(partial) => {
149 let index =
150 json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
151 vec![StreamEvent::ToolCallDelta {
152 index,
153 arguments_delta: partial.to_string(),
154 }]
155 }
156 None => Vec::new(),
157 },
158 _ => Vec::new(),
159 }
160 }
161 "content_block_start" => {
162 let json: serde_json::Value = match serde_json::from_str(data) {
163 Ok(v) => v,
164 Err(_) => return Vec::new(),
165 };
166 let block = match json.get("content_block") {
167 Some(b) => b,
168 None => return Vec::new(),
169 };
170 if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
171 if let Some(name) = block.get("name").and_then(|n| n.as_str()) {
172 let index = json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
173 let id = block
174 .get("id")
175 .and_then(|i| i.as_str())
176 .map(|s| s.to_string());
177 return vec![StreamEvent::ToolCallStart {
178 name: name.to_string(),
179 index,
180 id,
181 }];
182 }
183 }
184 Vec::new()
185 }
186 "message_start" => {
190 let json: serde_json::Value = match serde_json::from_str(data) {
191 Ok(v) => v,
192 Err(_) => return Vec::new(),
193 };
194 let Some(usage) = json.pointer("/message/usage") else {
195 return Vec::new();
196 };
197 let input = usage
198 .get("input_tokens")
199 .and_then(|n| n.as_u64())
200 .unwrap_or(0);
201 let output = usage
202 .get("output_tokens")
203 .and_then(|n| n.as_u64())
204 .unwrap_or(0);
205 if input == 0 && output == 0 {
206 return Vec::new();
207 }
208 vec![StreamEvent::Usage {
209 input_tokens: input,
210 output_tokens: output,
211 }]
212 }
213 "message_delta" => {
217 let json: serde_json::Value = match serde_json::from_str(data) {
218 Ok(v) => v,
219 Err(_) => return Vec::new(),
220 };
221 let Some(usage) = json.get("usage") else {
222 return Vec::new();
223 };
224 let input = usage
225 .get("input_tokens")
226 .and_then(|n| n.as_u64())
227 .unwrap_or(0);
228 let output = usage
229 .get("output_tokens")
230 .and_then(|n| n.as_u64())
231 .unwrap_or(0);
232 if input == 0 && output == 0 {
233 return Vec::new();
234 }
235 vec![StreamEvent::Usage {
236 input_tokens: input,
237 output_tokens: output,
238 }]
239 }
240 _ => Vec::new(),
241 }
242}
243
244#[derive(Default)]
246pub struct StreamAccumulator {
247 pub text: String,
248 tool_names: HashMap<usize, String>,
249 tool_args: HashMap<usize, String>,
250 tool_ids: HashMap<usize, String>,
251 input_tokens: u64,
256 output_tokens: u64,
260 saw_usage: bool,
265}
266
267impl StreamAccumulator {
268 pub fn push(&mut self, event: &StreamEvent) {
269 match event {
270 StreamEvent::TextDelta(t) => self.text.push_str(t),
271 StreamEvent::ToolCallStart { name, index, id } => {
272 self.tool_names.insert(*index, name.clone());
273 self.tool_args.entry(*index).or_default();
274 if let Some(id) = id {
275 self.tool_ids.insert(*index, id.clone());
276 }
277 }
278 StreamEvent::ToolCallDelta {
279 index,
280 arguments_delta,
281 } => {
282 self.tool_args
283 .entry(*index)
284 .or_default()
285 .push_str(arguments_delta);
286 }
287 StreamEvent::Usage {
288 input_tokens,
289 output_tokens,
290 } => {
291 self.saw_usage = true;
292 if *input_tokens > self.input_tokens {
297 self.input_tokens = *input_tokens;
298 }
299 if *output_tokens > self.output_tokens {
300 self.output_tokens = *output_tokens;
301 }
302 }
303 StreamEvent::Done { .. } => {}
304 }
305 }
306
307 pub fn finish(self) -> (String, Vec<ToolCall>) {
308 let (text, tool_calls, _) = self.finish_with_usage();
309 (text, tool_calls)
310 }
311
312 pub fn finish_with_usage(self) -> (String, Vec<ToolCall>, Option<TokenUsage>) {
317 let mut tool_calls = Vec::new();
318 let mut indices: Vec<usize> = self.tool_names.keys().copied().collect();
319 indices.sort();
320
321 for idx in indices {
322 let id = self.tool_ids.get(&idx).cloned();
323 let name = self.tool_names.get(&idx).cloned().unwrap_or_default();
324 let args_str = self.tool_args.get(&idx).cloned().unwrap_or_default();
325 let arguments: HashMap<String, serde_json::Value> =
326 serde_json::from_str(&args_str).unwrap_or_default();
327 tool_calls.push(ToolCall {
328 id,
329 name,
330 arguments,
331 });
332 }
333
334 let usage = if self.saw_usage {
335 Some(TokenUsage {
336 prompt_tokens: self.input_tokens,
337 completion_tokens: self.output_tokens,
338 total_tokens: self.input_tokens + self.output_tokens,
339 context_window: 0,
343 })
344 } else {
345 None
346 };
347
348 (self.text, tool_calls, usage)
349 }
350}
351
352pub fn parse_sse_lines(chunk: &str) -> Vec<(String, String)> {
355 let mut events = Vec::new();
356 let mut current_event = String::new();
357 let mut current_data = String::new();
358
359 for line in chunk.lines() {
360 if line.starts_with("event: ") {
361 current_event = line[7..].to_string();
362 } else if line.starts_with("data: ") {
363 current_data = line[6..].to_string();
364 } else if line.is_empty() && !current_data.is_empty() {
365 events.push((
366 if current_event.is_empty() {
367 "message".to_string()
368 } else {
369 current_event.clone()
370 },
371 current_data.clone(),
372 ));
373 current_event.clear();
374 current_data.clear();
375 }
376 }
377
378 if !current_data.is_empty() {
380 events.push((
381 if current_event.is_empty() {
382 "message".to_string()
383 } else {
384 current_event
385 },
386 current_data,
387 ));
388 }
389
390 events
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn parse_openai_text_delta() {
399 let line = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#;
400 let events = parse_openai_sse_line(line);
401 assert_eq!(events.len(), 1);
402 match &events[0] {
403 StreamEvent::TextDelta(t) => assert_eq!(t, "Hello"),
404 other => panic!("expected TextDelta, got {:?}", other),
405 }
406 }
407
408 #[test]
409 fn parse_openai_tool_call_start() {
410 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"edit_file"}}]}}]}"#;
411 let events = parse_openai_sse_line(line);
412 assert_eq!(events.len(), 1);
413 match &events[0] {
414 StreamEvent::ToolCallStart { name, index, .. } => {
415 assert_eq!(name, "edit_file");
416 assert_eq!(*index, 0);
417 }
418 other => panic!("expected ToolCallStart, got {:?}", other),
419 }
420 }
421
422 #[test]
423 fn parse_openai_tool_call_delta() {
424 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":"}}]}}]}"#;
425 let events = parse_openai_sse_line(line);
426 assert_eq!(events.len(), 1);
427 match &events[0] {
428 StreamEvent::ToolCallDelta {
429 index,
430 arguments_delta,
431 } => {
432 assert_eq!(*index, 0);
433 assert!(arguments_delta.contains("path"));
434 }
435 other => panic!("expected ToolCallDelta, got {:?}", other),
436 }
437 }
438
439 #[test]
440 fn parse_openai_multiple_tool_calls_in_chunk() {
441 let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"read_file"}},{"index":1,"function":{"name":"search"}}]}}]}"#;
443 let events = parse_openai_sse_line(line);
444 assert_eq!(events.len(), 2);
445 match &events[0] {
446 StreamEvent::ToolCallStart { name, index, .. } => {
447 assert_eq!(name, "read_file");
448 assert_eq!(*index, 0);
449 }
450 other => panic!("expected ToolCallStart, got {:?}", other),
451 }
452 match &events[1] {
453 StreamEvent::ToolCallStart { name, index, .. } => {
454 assert_eq!(name, "search");
455 assert_eq!(*index, 1);
456 }
457 other => panic!("expected ToolCallStart, got {:?}", other),
458 }
459 }
460
461 #[test]
462 fn parse_openai_done() {
463 assert!(parse_openai_sse_line("data: [DONE]").is_empty());
464 }
465
466 #[test]
467 fn parse_anthropic_text_delta() {
468 let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"world"}}"#;
469 let events = parse_anthropic_sse_line("content_block_delta", data);
470 assert_eq!(events.len(), 1);
471 match &events[0] {
472 StreamEvent::TextDelta(t) => assert_eq!(t, "world"),
473 other => panic!("expected TextDelta, got {:?}", other),
474 }
475 }
476
477 #[test]
478 fn parse_anthropic_tool_start() {
479 let data = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"t1","name":"search","input":{}}}"#;
480 let events = parse_anthropic_sse_line("content_block_start", data);
481 assert_eq!(events.len(), 1);
482 match &events[0] {
483 StreamEvent::ToolCallStart { name, index, .. } => {
484 assert_eq!(name, "search");
485 assert_eq!(*index, 1);
486 }
487 other => panic!("expected ToolCallStart, got {:?}", other),
488 }
489 }
490
491 #[test]
492 fn accumulator_builds_result() {
493 let mut acc = StreamAccumulator::default();
494 acc.push(&StreamEvent::TextDelta("Hello ".into()));
495 acc.push(&StreamEvent::TextDelta("world".into()));
496 acc.push(&StreamEvent::ToolCallStart {
497 name: "search".into(),
498 index: 0,
499 id: None,
500 });
501 acc.push(&StreamEvent::ToolCallDelta {
502 index: 0,
503 arguments_delta: r#"{"q":"test"}"#.into(),
504 });
505
506 let (text, tools) = acc.finish();
507 assert_eq!(text, "Hello world");
508 assert_eq!(tools.len(), 1);
509 assert_eq!(tools[0].name, "search");
510 assert!(tools[0].arguments.contains_key("q"));
511 }
512
513 #[test]
514 fn parse_sse_lines_openai_format() {
515 let chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n";
516 let events = parse_sse_lines(chunk);
517 assert_eq!(events.len(), 2);
518 assert_eq!(events[0].0, "message");
519 assert_eq!(events[1].1, "[DONE]");
520 }
521
522 #[test]
523 fn parse_sse_lines_anthropic_format() {
524 let chunk = "event: content_block_delta\ndata: {\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n";
525 let events = parse_sse_lines(chunk);
526 assert_eq!(events.len(), 1);
527 assert_eq!(events[0].0, "content_block_delta");
528 }
529
530 #[test]
531 fn parse_anthropic_message_start_emits_usage() {
532 let data = r#"{"type":"message_start","message":{"id":"msg_1","role":"assistant","usage":{"input_tokens":245,"output_tokens":1}}}"#;
533 let events = parse_anthropic_sse_line("message_start", data);
534 assert_eq!(events.len(), 1);
535 match &events[0] {
536 StreamEvent::Usage {
537 input_tokens,
538 output_tokens,
539 } => {
540 assert_eq!(*input_tokens, 245);
541 assert_eq!(*output_tokens, 1);
542 }
543 other => panic!("expected Usage, got {:?}", other),
544 }
545 }
546
547 #[test]
548 fn parse_anthropic_message_delta_emits_usage() {
549 let data = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":87}}"#;
550 let events = parse_anthropic_sse_line("message_delta", data);
551 assert_eq!(events.len(), 1);
552 match &events[0] {
553 StreamEvent::Usage {
554 input_tokens,
555 output_tokens,
556 } => {
557 assert_eq!(*input_tokens, 0);
558 assert_eq!(*output_tokens, 87);
559 }
560 other => panic!("expected Usage, got {:?}", other),
561 }
562 }
563
564 #[test]
565 fn parse_anthropic_message_start_without_usage_is_empty() {
566 let data = r#"{"type":"message_start","message":{"id":"msg_1"}}"#;
568 assert!(parse_anthropic_sse_line("message_start", data).is_empty());
569 }
570
571 #[test]
572 fn accumulator_tracks_usage_across_anthropic_stream() {
573 let mut acc = StreamAccumulator::default();
576 for event in parse_anthropic_sse_line(
577 "message_start",
578 r#"{"message":{"usage":{"input_tokens":245,"output_tokens":1}}}"#,
579 ) {
580 acc.push(&event);
581 }
582 for event in parse_anthropic_sse_line(
583 "content_block_start",
584 r#"{"index":0,"content_block":{"type":"text","text":""}}"#,
585 ) {
586 acc.push(&event);
587 }
588 for (chunk, _) in [
589 (r#"{"delta":{"type":"text_delta","text":"Hello"}}"#, ()),
590 (r#"{"delta":{"type":"text_delta","text":", "}}"#, ()),
591 (r#"{"delta":{"type":"text_delta","text":"world"}}"#, ()),
592 ] {
593 for event in parse_anthropic_sse_line("content_block_delta", chunk) {
594 acc.push(&event);
595 }
596 }
597 for event in parse_anthropic_sse_line("message_delta", r#"{"usage":{"output_tokens":87}}"#)
598 {
599 acc.push(&event);
600 }
601
602 let (text, tools, usage) = acc.finish_with_usage();
603 assert_eq!(text, "Hello, world");
604 assert!(tools.is_empty());
605 let usage = usage.expect("provider reported usage; must surface");
606 assert_eq!(usage.prompt_tokens, 245);
607 assert_eq!(usage.completion_tokens, 87);
609 assert_eq!(usage.total_tokens, 332);
610 }
611
612 #[test]
613 fn parse_openai_final_chunk_emits_usage() {
614 let line = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87,"total_tokens":332}}"#;
617 let events = parse_openai_sse_line(line);
618 assert_eq!(events.len(), 1);
619 match &events[0] {
620 StreamEvent::Usage {
621 input_tokens,
622 output_tokens,
623 } => {
624 assert_eq!(*input_tokens, 245);
625 assert_eq!(*output_tokens, 87);
626 }
627 other => panic!("expected Usage, got {:?}", other),
628 }
629 }
630
631 #[test]
632 fn accumulator_tracks_usage_across_openai_stream() {
633 let mut acc = StreamAccumulator::default();
636 for line in [
637 r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#,
638 r#"data: {"choices":[{"delta":{"content":", "}}]}"#,
639 r#"data: {"choices":[{"delta":{"content":"world"}}]}"#,
640 r#"data: {"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87}}"#,
641 ] {
642 for event in parse_openai_sse_line(line) {
643 acc.push(&event);
644 }
645 }
646
647 let (text, tools, usage) = acc.finish_with_usage();
648 assert_eq!(text, "Hello, world");
649 assert!(tools.is_empty());
650 let usage = usage.expect("provider reported usage; must surface");
651 assert_eq!(usage.prompt_tokens, 245);
652 assert_eq!(usage.completion_tokens, 87);
653 assert_eq!(usage.total_tokens, 332);
654 }
655
656 #[test]
657 fn accumulator_returns_no_usage_when_provider_silent() {
658 let mut acc = StreamAccumulator::default();
662 acc.push(&StreamEvent::TextDelta("hi".into()));
663 let (_, _, usage) = acc.finish_with_usage();
664 assert!(usage.is_none());
665 }
666}