Skip to main content

car_inference/
stream.rs

1//! Streaming inference — SSE parsing for real-time token output.
2//!
3//! Supports streaming from OpenAI-compatible, Anthropic, and Google APIs.
4//! Each provider uses Server-Sent Events (SSE) with different JSON schemas.
5
6use crate::tasks::generate::ToolCall;
7use crate::TokenUsage;
8use std::collections::HashMap;
9
10/// Events emitted during a streaming inference response.
11#[derive(Debug, Clone)]
12pub enum StreamEvent {
13    /// Partial text token from the model.
14    TextDelta(String),
15    /// A tool call is starting (name known, arguments pending).
16    ToolCallStart {
17        name: String,
18        index: usize,
19        id: Option<String>,
20    },
21    /// Partial tool call arguments (JSON fragment).
22    ToolCallDelta {
23        index: usize,
24        arguments_delta: String,
25    },
26    /// Provider-reported cumulative token usage observed mid-stream.
27    ///
28    /// Anthropic emits this twice: once in `message_start` with the
29    /// finalized `input_tokens` (plus a stub `output_tokens: 1`), and
30    /// again in `message_delta` at end of stream with the real
31    /// `output_tokens`. Consumers should prefer per-field monotonicity
32    /// (see [`StreamAccumulator`]) rather than overwriting blindly.
33    Usage {
34        input_tokens: u64,
35        output_tokens: u64,
36    },
37    /// Stream is complete. Contains the final aggregated result.
38    Done {
39        text: String,
40        tool_calls: Vec<ToolCall>,
41    },
42}
43
44/// Parse a single SSE data line from an OpenAI-compatible streaming response.
45/// Returns all events found in the line (supports multiple tool calls per chunk).
46pub fn parse_openai_sse_line(line: &str) -> Vec<StreamEvent> {
47    let data = match line.strip_prefix("data: ") {
48        Some(d) => d,
49        None => return Vec::new(),
50    };
51    if data == "[DONE]" {
52        return Vec::new();
53    }
54
55    let json: serde_json::Value = match serde_json::from_str(data) {
56        Ok(v) => v,
57        Err(_) => return Vec::new(),
58    };
59
60    let mut events = Vec::new();
61
62    // choices[].delta — present on every text/tool chunk, absent on the
63    // final usage-only chunk when `stream_options.include_usage=true`.
64    if let Some(delta) = json
65        .get("choices")
66        .and_then(|c| c.as_array())
67        .and_then(|c| c.first())
68        .and_then(|c| c.get("delta"))
69    {
70        if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
71            if !content.is_empty() {
72                events.push(StreamEvent::TextDelta(content.to_string()));
73            }
74        }
75
76        // Tool calls — collect ALL tool call events from this chunk
77        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
78            for tc in tool_calls {
79                let index = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
80                if let Some(function) = tc.get("function") {
81                    if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
82                        let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
83                        events.push(StreamEvent::ToolCallStart {
84                            name: name.to_string(),
85                            index,
86                            id,
87                        });
88                    }
89                    if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
90                        if !args.is_empty() {
91                            events.push(StreamEvent::ToolCallDelta {
92                                index,
93                                arguments_delta: args.to_string(),
94                            });
95                        }
96                    }
97                }
98            }
99        }
100    }
101
102    // OpenAI sends real usage only when the request sets
103    // `stream_options.include_usage=true`; it arrives in a final chunk
104    // with `"choices": []` and a top-level `"usage"` object.
105    if let Some(usage) = json.get("usage") {
106        let input = usage
107            .get("prompt_tokens")
108            .and_then(|n| n.as_u64())
109            .unwrap_or(0);
110        let output = usage
111            .get("completion_tokens")
112            .and_then(|n| n.as_u64())
113            .unwrap_or(0);
114        if input != 0 || output != 0 {
115            events.push(StreamEvent::Usage {
116                input_tokens: input,
117                output_tokens: output,
118            });
119        }
120    }
121
122    events
123}
124
125/// Parse a single SSE data line from an Anthropic streaming response.
126pub fn parse_anthropic_sse_line(event_type: &str, data: &str) -> Vec<StreamEvent> {
127    match event_type {
128        "content_block_delta" => {
129            let json: serde_json::Value = match serde_json::from_str(data) {
130                Ok(v) => v,
131                Err(_) => return Vec::new(),
132            };
133            let delta = match json.get("delta") {
134                Some(d) => d,
135                None => return Vec::new(),
136            };
137            let delta_type = match delta.get("type").and_then(|t| t.as_str()) {
138                Some(t) => t,
139                None => return Vec::new(),
140            };
141
142            match delta_type {
143                "text_delta" => match delta.get("text").and_then(|t| t.as_str()) {
144                    Some(text) => vec![StreamEvent::TextDelta(text.to_string())],
145                    None => Vec::new(),
146                },
147                "input_json_delta" => match delta.get("partial_json").and_then(|p| p.as_str()) {
148                    Some(partial) => {
149                        let index =
150                            json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
151                        vec![StreamEvent::ToolCallDelta {
152                            index,
153                            arguments_delta: partial.to_string(),
154                        }]
155                    }
156                    None => Vec::new(),
157                },
158                _ => Vec::new(),
159            }
160        }
161        "content_block_start" => {
162            let json: serde_json::Value = match serde_json::from_str(data) {
163                Ok(v) => v,
164                Err(_) => return Vec::new(),
165            };
166            let block = match json.get("content_block") {
167                Some(b) => b,
168                None => return Vec::new(),
169            };
170            if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
171                if let Some(name) = block.get("name").and_then(|n| n.as_str()) {
172                    let index = json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
173                    let id = block
174                        .get("id")
175                        .and_then(|i| i.as_str())
176                        .map(|s| s.to_string());
177                    return vec![StreamEvent::ToolCallStart {
178                        name: name.to_string(),
179                        index,
180                        id,
181                    }];
182                }
183            }
184            Vec::new()
185        }
186        // Beginning of the response — Anthropic reports the finalized
187        // `input_tokens` here along with a stub `output_tokens: 1`.
188        // Shape: `{"message":{"usage":{"input_tokens":123,"output_tokens":1}}}`
189        "message_start" => {
190            let json: serde_json::Value = match serde_json::from_str(data) {
191                Ok(v) => v,
192                Err(_) => return Vec::new(),
193            };
194            let Some(usage) = json.pointer("/message/usage") else {
195                return Vec::new();
196            };
197            let input = usage
198                .get("input_tokens")
199                .and_then(|n| n.as_u64())
200                .unwrap_or(0);
201            let output = usage
202                .get("output_tokens")
203                .and_then(|n| n.as_u64())
204                .unwrap_or(0);
205            if input == 0 && output == 0 {
206                return Vec::new();
207            }
208            vec![StreamEvent::Usage {
209                input_tokens: input,
210                output_tokens: output,
211            }]
212        }
213        // End of the response — Anthropic reports the final
214        // `output_tokens` here (input is already known from
215        // `message_start`). Shape: `{"usage":{"output_tokens":456}}`.
216        "message_delta" => {
217            let json: serde_json::Value = match serde_json::from_str(data) {
218                Ok(v) => v,
219                Err(_) => return Vec::new(),
220            };
221            let Some(usage) = json.get("usage") else {
222                return Vec::new();
223            };
224            let input = usage
225                .get("input_tokens")
226                .and_then(|n| n.as_u64())
227                .unwrap_or(0);
228            let output = usage
229                .get("output_tokens")
230                .and_then(|n| n.as_u64())
231                .unwrap_or(0);
232            if input == 0 && output == 0 {
233                return Vec::new();
234            }
235            vec![StreamEvent::Usage {
236                input_tokens: input,
237                output_tokens: output,
238            }]
239        }
240        _ => Vec::new(),
241    }
242}
243
244/// Accumulator for building the final result from stream events.
245#[derive(Default)]
246pub struct StreamAccumulator {
247    pub text: String,
248    tool_names: HashMap<usize, String>,
249    tool_args: HashMap<usize, String>,
250    tool_ids: HashMap<usize, String>,
251    /// Highest `input_tokens` value seen in a `Usage` event. Anthropic
252    /// only sends this on `message_start`; other providers may send it
253    /// multiple times and we keep the largest as the authoritative
254    /// count.
255    input_tokens: u64,
256    /// Highest `output_tokens` value seen in a `Usage` event. For
257    /// Anthropic this grows from the `message_start` stub (`1`) to the
258    /// final count in `message_delta`, so we track monotonically.
259    output_tokens: u64,
260    /// Whether any `Usage` event was observed. `false` means the
261    /// provider never reported usage (e.g. OpenAI without
262    /// `stream_options.include_usage=true`) and [`finish_with_usage`]
263    /// should return `None`.
264    saw_usage: bool,
265}
266
267impl StreamAccumulator {
268    pub fn push(&mut self, event: &StreamEvent) {
269        match event {
270            StreamEvent::TextDelta(t) => self.text.push_str(t),
271            StreamEvent::ToolCallStart { name, index, id } => {
272                self.tool_names.insert(*index, name.clone());
273                self.tool_args.entry(*index).or_default();
274                if let Some(id) = id {
275                    self.tool_ids.insert(*index, id.clone());
276                }
277            }
278            StreamEvent::ToolCallDelta {
279                index,
280                arguments_delta,
281            } => {
282                self.tool_args
283                    .entry(*index)
284                    .or_default()
285                    .push_str(arguments_delta);
286            }
287            StreamEvent::Usage {
288                input_tokens,
289                output_tokens,
290            } => {
291                self.saw_usage = true;
292                // Per-field max: Anthropic's `message_start` carries
293                // real input + stub output=1; `message_delta` carries
294                // only final output. Neither event should be allowed
295                // to clobber the other's authoritative value.
296                if *input_tokens > self.input_tokens {
297                    self.input_tokens = *input_tokens;
298                }
299                if *output_tokens > self.output_tokens {
300                    self.output_tokens = *output_tokens;
301                }
302            }
303            StreamEvent::Done { .. } => {}
304        }
305    }
306
307    pub fn finish(self) -> (String, Vec<ToolCall>) {
308        let (text, tool_calls, _) = self.finish_with_usage();
309        (text, tool_calls)
310    }
311
312    /// Like [`finish`] but also returns the accumulated [`TokenUsage`]
313    /// when the provider reported any. Returns `None` for usage if no
314    /// `Usage` event was observed — callers can fall back to their
315    /// own estimator.
316    pub fn finish_with_usage(self) -> (String, Vec<ToolCall>, Option<TokenUsage>) {
317        let mut tool_calls = Vec::new();
318        let mut indices: Vec<usize> = self.tool_names.keys().copied().collect();
319        indices.sort();
320
321        for idx in indices {
322            let id = self.tool_ids.get(&idx).cloned();
323            let name = self.tool_names.get(&idx).cloned().unwrap_or_default();
324            let args_str = self.tool_args.get(&idx).cloned().unwrap_or_default();
325            let arguments: HashMap<String, serde_json::Value> =
326                serde_json::from_str(&args_str).unwrap_or_default();
327            tool_calls.push(ToolCall {
328                id,
329                name,
330                arguments,
331            });
332        }
333
334        let usage = if self.saw_usage {
335            Some(TokenUsage {
336                prompt_tokens: self.input_tokens,
337                completion_tokens: self.output_tokens,
338                total_tokens: self.input_tokens + self.output_tokens,
339                // Context-window sizing comes from model metadata, not
340                // per-response usage — leave it zero and let the
341                // caller populate it if needed.
342                context_window: 0,
343            })
344        } else {
345            None
346        };
347
348        (self.text, tool_calls, usage)
349    }
350}
351
352/// Parse SSE lines from a raw byte stream. Handles both OpenAI and Anthropic formats.
353/// Returns (event_type, data) pairs. OpenAI doesn't send event types (always "message").
354pub fn parse_sse_lines(chunk: &str) -> Vec<(String, String)> {
355    let mut events = Vec::new();
356    let mut current_event = String::new();
357    let mut current_data = String::new();
358
359    for line in chunk.lines() {
360        if line.starts_with("event: ") {
361            current_event = line[7..].to_string();
362        } else if line.starts_with("data: ") {
363            current_data = line[6..].to_string();
364        } else if line.is_empty() && !current_data.is_empty() {
365            events.push((
366                if current_event.is_empty() {
367                    "message".to_string()
368                } else {
369                    current_event.clone()
370                },
371                current_data.clone(),
372            ));
373            current_event.clear();
374            current_data.clear();
375        }
376    }
377
378    // Handle case where stream doesn't end with empty line
379    if !current_data.is_empty() {
380        events.push((
381            if current_event.is_empty() {
382                "message".to_string()
383            } else {
384                current_event
385            },
386            current_data,
387        ));
388    }
389
390    events
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn parse_openai_text_delta() {
399        let line = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#;
400        let events = parse_openai_sse_line(line);
401        assert_eq!(events.len(), 1);
402        match &events[0] {
403            StreamEvent::TextDelta(t) => assert_eq!(t, "Hello"),
404            other => panic!("expected TextDelta, got {:?}", other),
405        }
406    }
407
408    #[test]
409    fn parse_openai_tool_call_start() {
410        let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"edit_file"}}]}}]}"#;
411        let events = parse_openai_sse_line(line);
412        assert_eq!(events.len(), 1);
413        match &events[0] {
414            StreamEvent::ToolCallStart { name, index, .. } => {
415                assert_eq!(name, "edit_file");
416                assert_eq!(*index, 0);
417            }
418            other => panic!("expected ToolCallStart, got {:?}", other),
419        }
420    }
421
422    #[test]
423    fn parse_openai_tool_call_delta() {
424        let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":"}}]}}]}"#;
425        let events = parse_openai_sse_line(line);
426        assert_eq!(events.len(), 1);
427        match &events[0] {
428            StreamEvent::ToolCallDelta {
429                index,
430                arguments_delta,
431            } => {
432                assert_eq!(*index, 0);
433                assert!(arguments_delta.contains("path"));
434            }
435            other => panic!("expected ToolCallDelta, got {:?}", other),
436        }
437    }
438
439    #[test]
440    fn parse_openai_multiple_tool_calls_in_chunk() {
441        // When OpenAI sends multiple tool call deltas in a single SSE chunk
442        let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"read_file"}},{"index":1,"function":{"name":"search"}}]}}]}"#;
443        let events = parse_openai_sse_line(line);
444        assert_eq!(events.len(), 2);
445        match &events[0] {
446            StreamEvent::ToolCallStart { name, index, .. } => {
447                assert_eq!(name, "read_file");
448                assert_eq!(*index, 0);
449            }
450            other => panic!("expected ToolCallStart, got {:?}", other),
451        }
452        match &events[1] {
453            StreamEvent::ToolCallStart { name, index, .. } => {
454                assert_eq!(name, "search");
455                assert_eq!(*index, 1);
456            }
457            other => panic!("expected ToolCallStart, got {:?}", other),
458        }
459    }
460
461    #[test]
462    fn parse_openai_done() {
463        assert!(parse_openai_sse_line("data: [DONE]").is_empty());
464    }
465
466    #[test]
467    fn parse_anthropic_text_delta() {
468        let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"world"}}"#;
469        let events = parse_anthropic_sse_line("content_block_delta", data);
470        assert_eq!(events.len(), 1);
471        match &events[0] {
472            StreamEvent::TextDelta(t) => assert_eq!(t, "world"),
473            other => panic!("expected TextDelta, got {:?}", other),
474        }
475    }
476
477    #[test]
478    fn parse_anthropic_tool_start() {
479        let data = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"t1","name":"search","input":{}}}"#;
480        let events = parse_anthropic_sse_line("content_block_start", data);
481        assert_eq!(events.len(), 1);
482        match &events[0] {
483            StreamEvent::ToolCallStart { name, index, .. } => {
484                assert_eq!(name, "search");
485                assert_eq!(*index, 1);
486            }
487            other => panic!("expected ToolCallStart, got {:?}", other),
488        }
489    }
490
491    #[test]
492    fn accumulator_builds_result() {
493        let mut acc = StreamAccumulator::default();
494        acc.push(&StreamEvent::TextDelta("Hello ".into()));
495        acc.push(&StreamEvent::TextDelta("world".into()));
496        acc.push(&StreamEvent::ToolCallStart {
497            name: "search".into(),
498            index: 0,
499            id: None,
500        });
501        acc.push(&StreamEvent::ToolCallDelta {
502            index: 0,
503            arguments_delta: r#"{"q":"test"}"#.into(),
504        });
505
506        let (text, tools) = acc.finish();
507        assert_eq!(text, "Hello world");
508        assert_eq!(tools.len(), 1);
509        assert_eq!(tools[0].name, "search");
510        assert!(tools[0].arguments.contains_key("q"));
511    }
512
513    #[test]
514    fn parse_sse_lines_openai_format() {
515        let chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n";
516        let events = parse_sse_lines(chunk);
517        assert_eq!(events.len(), 2);
518        assert_eq!(events[0].0, "message");
519        assert_eq!(events[1].1, "[DONE]");
520    }
521
522    #[test]
523    fn parse_sse_lines_anthropic_format() {
524        let chunk = "event: content_block_delta\ndata: {\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n";
525        let events = parse_sse_lines(chunk);
526        assert_eq!(events.len(), 1);
527        assert_eq!(events[0].0, "content_block_delta");
528    }
529
530    #[test]
531    fn parse_anthropic_message_start_emits_usage() {
532        let data = r#"{"type":"message_start","message":{"id":"msg_1","role":"assistant","usage":{"input_tokens":245,"output_tokens":1}}}"#;
533        let events = parse_anthropic_sse_line("message_start", data);
534        assert_eq!(events.len(), 1);
535        match &events[0] {
536            StreamEvent::Usage {
537                input_tokens,
538                output_tokens,
539            } => {
540                assert_eq!(*input_tokens, 245);
541                assert_eq!(*output_tokens, 1);
542            }
543            other => panic!("expected Usage, got {:?}", other),
544        }
545    }
546
547    #[test]
548    fn parse_anthropic_message_delta_emits_usage() {
549        let data = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":87}}"#;
550        let events = parse_anthropic_sse_line("message_delta", data);
551        assert_eq!(events.len(), 1);
552        match &events[0] {
553            StreamEvent::Usage {
554                input_tokens,
555                output_tokens,
556            } => {
557                assert_eq!(*input_tokens, 0);
558                assert_eq!(*output_tokens, 87);
559            }
560            other => panic!("expected Usage, got {:?}", other),
561        }
562    }
563
564    #[test]
565    fn parse_anthropic_message_start_without_usage_is_empty() {
566        // Some forward-compat payloads may omit usage; don't crash.
567        let data = r#"{"type":"message_start","message":{"id":"msg_1"}}"#;
568        assert!(parse_anthropic_sse_line("message_start", data).is_empty());
569    }
570
571    #[test]
572    fn accumulator_tracks_usage_across_anthropic_stream() {
573        // Simulate the exact shape of a real Anthropic stream:
574        // message_start → content_block_start → content_block_delta × 3 → message_delta.
575        let mut acc = StreamAccumulator::default();
576        for event in parse_anthropic_sse_line(
577            "message_start",
578            r#"{"message":{"usage":{"input_tokens":245,"output_tokens":1}}}"#,
579        ) {
580            acc.push(&event);
581        }
582        for event in parse_anthropic_sse_line(
583            "content_block_start",
584            r#"{"index":0,"content_block":{"type":"text","text":""}}"#,
585        ) {
586            acc.push(&event);
587        }
588        for (chunk, _) in [
589            (r#"{"delta":{"type":"text_delta","text":"Hello"}}"#, ()),
590            (r#"{"delta":{"type":"text_delta","text":", "}}"#, ()),
591            (r#"{"delta":{"type":"text_delta","text":"world"}}"#, ()),
592        ] {
593            for event in parse_anthropic_sse_line("content_block_delta", chunk) {
594                acc.push(&event);
595            }
596        }
597        for event in parse_anthropic_sse_line("message_delta", r#"{"usage":{"output_tokens":87}}"#)
598        {
599            acc.push(&event);
600        }
601
602        let (text, tools, usage) = acc.finish_with_usage();
603        assert_eq!(text, "Hello, world");
604        assert!(tools.is_empty());
605        let usage = usage.expect("provider reported usage; must surface");
606        assert_eq!(usage.prompt_tokens, 245);
607        // message_delta output (87) must win over message_start stub (1).
608        assert_eq!(usage.completion_tokens, 87);
609        assert_eq!(usage.total_tokens, 332);
610    }
611
612    #[test]
613    fn parse_openai_final_chunk_emits_usage() {
614        // OpenAI's final usage chunk when `stream_options.include_usage`
615        // is set: `choices` is empty and `usage` carries the real counts.
616        let line = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87,"total_tokens":332}}"#;
617        let events = parse_openai_sse_line(line);
618        assert_eq!(events.len(), 1);
619        match &events[0] {
620            StreamEvent::Usage {
621                input_tokens,
622                output_tokens,
623            } => {
624                assert_eq!(*input_tokens, 245);
625                assert_eq!(*output_tokens, 87);
626            }
627            other => panic!("expected Usage, got {:?}", other),
628        }
629    }
630
631    #[test]
632    fn accumulator_tracks_usage_across_openai_stream() {
633        // Simulate a full OpenAI stream with `stream_options.include_usage`:
634        // text delta chunks followed by a choiceless usage-only chunk.
635        let mut acc = StreamAccumulator::default();
636        for line in [
637            r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#,
638            r#"data: {"choices":[{"delta":{"content":", "}}]}"#,
639            r#"data: {"choices":[{"delta":{"content":"world"}}]}"#,
640            r#"data: {"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87}}"#,
641        ] {
642            for event in parse_openai_sse_line(line) {
643                acc.push(&event);
644            }
645        }
646
647        let (text, tools, usage) = acc.finish_with_usage();
648        assert_eq!(text, "Hello, world");
649        assert!(tools.is_empty());
650        let usage = usage.expect("provider reported usage; must surface");
651        assert_eq!(usage.prompt_tokens, 245);
652        assert_eq!(usage.completion_tokens, 87);
653        assert_eq!(usage.total_tokens, 332);
654    }
655
656    #[test]
657    fn accumulator_returns_no_usage_when_provider_silent() {
658        // OpenAI without `stream_options.include_usage` — no Usage
659        // events. `finish_with_usage` returns None so callers can fall
660        // back to their own estimator.
661        let mut acc = StreamAccumulator::default();
662        acc.push(&StreamEvent::TextDelta("hi".into()));
663        let (_, _, usage) = acc.finish_with_usage();
664        assert!(usage.is_none());
665    }
666}