Skip to main content

car_inference/
stream.rs

1//! Streaming inference — SSE parsing for real-time token output.
2//!
3//! Supports streaming from OpenAI-compatible, Anthropic, and Google APIs.
4//! Each provider uses Server-Sent Events (SSE) with different JSON schemas.
5
6use crate::tasks::generate::ToolCall;
7use crate::TokenUsage;
8use std::collections::HashMap;
9
10/// Events emitted during a streaming inference response.
11#[derive(Debug, Clone)]
12pub enum StreamEvent {
13    /// Partial text token from the model.
14    TextDelta(String),
15    /// A tool call is starting (name known, arguments pending).
16    ToolCallStart { name: String, index: usize, id: Option<String> },
17    /// Partial tool call arguments (JSON fragment).
18    ToolCallDelta {
19        index: usize,
20        arguments_delta: String,
21    },
22    /// Provider-reported cumulative token usage observed mid-stream.
23    ///
24    /// Anthropic emits this twice: once in `message_start` with the
25    /// finalized `input_tokens` (plus a stub `output_tokens: 1`), and
26    /// again in `message_delta` at end of stream with the real
27    /// `output_tokens`. Consumers should prefer per-field monotonicity
28    /// (see [`StreamAccumulator`]) rather than overwriting blindly.
29    Usage {
30        input_tokens: u64,
31        output_tokens: u64,
32    },
33    /// Stream is complete. Contains the final aggregated result.
34    Done {
35        text: String,
36        tool_calls: Vec<ToolCall>,
37    },
38}
39
40/// Parse a single SSE data line from an OpenAI-compatible streaming response.
41/// Returns all events found in the line (supports multiple tool calls per chunk).
42pub fn parse_openai_sse_line(line: &str) -> Vec<StreamEvent> {
43    let data = match line.strip_prefix("data: ") {
44        Some(d) => d,
45        None => return Vec::new(),
46    };
47    if data == "[DONE]" {
48        return Vec::new();
49    }
50
51    let json: serde_json::Value = match serde_json::from_str(data) {
52        Ok(v) => v,
53        Err(_) => return Vec::new(),
54    };
55
56    let mut events = Vec::new();
57
58    // choices[].delta — present on every text/tool chunk, absent on the
59    // final usage-only chunk when `stream_options.include_usage=true`.
60    if let Some(delta) = json
61        .get("choices")
62        .and_then(|c| c.as_array())
63        .and_then(|c| c.first())
64        .and_then(|c| c.get("delta"))
65    {
66        if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
67            if !content.is_empty() {
68                events.push(StreamEvent::TextDelta(content.to_string()));
69            }
70        }
71
72        // Tool calls — collect ALL tool call events from this chunk
73        if let Some(tool_calls) = delta.get("tool_calls").and_then(|t| t.as_array()) {
74            for tc in tool_calls {
75                let index = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
76                if let Some(function) = tc.get("function") {
77                    if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
78                        let id = tc.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
79                        events.push(StreamEvent::ToolCallStart {
80                            name: name.to_string(),
81                            index,
82                            id,
83                        });
84                    }
85                    if let Some(args) = function.get("arguments").and_then(|a| a.as_str()) {
86                        if !args.is_empty() {
87                            events.push(StreamEvent::ToolCallDelta {
88                                index,
89                                arguments_delta: args.to_string(),
90                            });
91                        }
92                    }
93                }
94            }
95        }
96    }
97
98    // OpenAI sends real usage only when the request sets
99    // `stream_options.include_usage=true`; it arrives in a final chunk
100    // with `"choices": []` and a top-level `"usage"` object.
101    if let Some(usage) = json.get("usage") {
102        let input = usage
103            .get("prompt_tokens")
104            .and_then(|n| n.as_u64())
105            .unwrap_or(0);
106        let output = usage
107            .get("completion_tokens")
108            .and_then(|n| n.as_u64())
109            .unwrap_or(0);
110        if input != 0 || output != 0 {
111            events.push(StreamEvent::Usage {
112                input_tokens: input,
113                output_tokens: output,
114            });
115        }
116    }
117
118    events
119}
120
121/// Parse a single SSE data line from an Anthropic streaming response.
122pub fn parse_anthropic_sse_line(event_type: &str, data: &str) -> Vec<StreamEvent> {
123    match event_type {
124        "content_block_delta" => {
125            let json: serde_json::Value = match serde_json::from_str(data) {
126                Ok(v) => v,
127                Err(_) => return Vec::new(),
128            };
129            let delta = match json.get("delta") {
130                Some(d) => d,
131                None => return Vec::new(),
132            };
133            let delta_type = match delta.get("type").and_then(|t| t.as_str()) {
134                Some(t) => t,
135                None => return Vec::new(),
136            };
137
138            match delta_type {
139                "text_delta" => match delta.get("text").and_then(|t| t.as_str()) {
140                    Some(text) => vec![StreamEvent::TextDelta(text.to_string())],
141                    None => Vec::new(),
142                },
143                "input_json_delta" => match delta.get("partial_json").and_then(|p| p.as_str()) {
144                    Some(partial) => {
145                        let index =
146                            json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
147                        vec![StreamEvent::ToolCallDelta {
148                            index,
149                            arguments_delta: partial.to_string(),
150                        }]
151                    }
152                    None => Vec::new(),
153                },
154                _ => Vec::new(),
155            }
156        }
157        "content_block_start" => {
158            let json: serde_json::Value = match serde_json::from_str(data) {
159                Ok(v) => v,
160                Err(_) => return Vec::new(),
161            };
162            let block = match json.get("content_block") {
163                Some(b) => b,
164                None => return Vec::new(),
165            };
166            if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
167                if let Some(name) = block.get("name").and_then(|n| n.as_str()) {
168                    let index = json.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
169                    let id = block.get("id").and_then(|i| i.as_str()).map(|s| s.to_string());
170                    return vec![StreamEvent::ToolCallStart {
171                        name: name.to_string(),
172                        index,
173                        id,
174                    }];
175                }
176            }
177            Vec::new()
178        }
179        // Beginning of the response — Anthropic reports the finalized
180        // `input_tokens` here along with a stub `output_tokens: 1`.
181        // Shape: `{"message":{"usage":{"input_tokens":123,"output_tokens":1}}}`
182        "message_start" => {
183            let json: serde_json::Value = match serde_json::from_str(data) {
184                Ok(v) => v,
185                Err(_) => return Vec::new(),
186            };
187            let Some(usage) = json.pointer("/message/usage") else {
188                return Vec::new();
189            };
190            let input = usage.get("input_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
191            let output = usage.get("output_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
192            if input == 0 && output == 0 {
193                return Vec::new();
194            }
195            vec![StreamEvent::Usage { input_tokens: input, output_tokens: output }]
196        }
197        // End of the response — Anthropic reports the final
198        // `output_tokens` here (input is already known from
199        // `message_start`). Shape: `{"usage":{"output_tokens":456}}`.
200        "message_delta" => {
201            let json: serde_json::Value = match serde_json::from_str(data) {
202                Ok(v) => v,
203                Err(_) => return Vec::new(),
204            };
205            let Some(usage) = json.get("usage") else {
206                return Vec::new();
207            };
208            let input = usage.get("input_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
209            let output = usage.get("output_tokens").and_then(|n| n.as_u64()).unwrap_or(0);
210            if input == 0 && output == 0 {
211                return Vec::new();
212            }
213            vec![StreamEvent::Usage { input_tokens: input, output_tokens: output }]
214        }
215        _ => Vec::new(),
216    }
217}
218
219/// Accumulator for building the final result from stream events.
220#[derive(Default)]
221pub struct StreamAccumulator {
222    pub text: String,
223    tool_names: HashMap<usize, String>,
224    tool_args: HashMap<usize, String>,
225    tool_ids: HashMap<usize, String>,
226    /// Highest `input_tokens` value seen in a `Usage` event. Anthropic
227    /// only sends this on `message_start`; other providers may send it
228    /// multiple times and we keep the largest as the authoritative
229    /// count.
230    input_tokens: u64,
231    /// Highest `output_tokens` value seen in a `Usage` event. For
232    /// Anthropic this grows from the `message_start` stub (`1`) to the
233    /// final count in `message_delta`, so we track monotonically.
234    output_tokens: u64,
235    /// Whether any `Usage` event was observed. `false` means the
236    /// provider never reported usage (e.g. OpenAI without
237    /// `stream_options.include_usage=true`) and [`finish_with_usage`]
238    /// should return `None`.
239    saw_usage: bool,
240}
241
242impl StreamAccumulator {
243    pub fn push(&mut self, event: &StreamEvent) {
244        match event {
245            StreamEvent::TextDelta(t) => self.text.push_str(t),
246            StreamEvent::ToolCallStart { name, index, id } => {
247                self.tool_names.insert(*index, name.clone());
248                self.tool_args.entry(*index).or_default();
249                if let Some(id) = id {
250                    self.tool_ids.insert(*index, id.clone());
251                }
252            }
253            StreamEvent::ToolCallDelta {
254                index,
255                arguments_delta,
256            } => {
257                self.tool_args
258                    .entry(*index)
259                    .or_default()
260                    .push_str(arguments_delta);
261            }
262            StreamEvent::Usage { input_tokens, output_tokens } => {
263                self.saw_usage = true;
264                // Per-field max: Anthropic's `message_start` carries
265                // real input + stub output=1; `message_delta` carries
266                // only final output. Neither event should be allowed
267                // to clobber the other's authoritative value.
268                if *input_tokens > self.input_tokens {
269                    self.input_tokens = *input_tokens;
270                }
271                if *output_tokens > self.output_tokens {
272                    self.output_tokens = *output_tokens;
273                }
274            }
275            StreamEvent::Done { .. } => {}
276        }
277    }
278
279    pub fn finish(self) -> (String, Vec<ToolCall>) {
280        let (text, tool_calls, _) = self.finish_with_usage();
281        (text, tool_calls)
282    }
283
284    /// Like [`finish`] but also returns the accumulated [`TokenUsage`]
285    /// when the provider reported any. Returns `None` for usage if no
286    /// `Usage` event was observed — callers can fall back to their
287    /// own estimator.
288    pub fn finish_with_usage(self) -> (String, Vec<ToolCall>, Option<TokenUsage>) {
289        let mut tool_calls = Vec::new();
290        let mut indices: Vec<usize> = self.tool_names.keys().copied().collect();
291        indices.sort();
292
293        for idx in indices {
294            let id = self.tool_ids.get(&idx).cloned();
295            let name = self.tool_names.get(&idx).cloned().unwrap_or_default();
296            let args_str = self.tool_args.get(&idx).cloned().unwrap_or_default();
297            let arguments: HashMap<String, serde_json::Value> =
298                serde_json::from_str(&args_str).unwrap_or_default();
299            tool_calls.push(ToolCall { id, name, arguments });
300        }
301
302        let usage = if self.saw_usage {
303            Some(TokenUsage {
304                prompt_tokens: self.input_tokens,
305                completion_tokens: self.output_tokens,
306                total_tokens: self.input_tokens + self.output_tokens,
307                // Context-window sizing comes from model metadata, not
308                // per-response usage — leave it zero and let the
309                // caller populate it if needed.
310                context_window: 0,
311            })
312        } else {
313            None
314        };
315
316        (self.text, tool_calls, usage)
317    }
318}
319
320/// Parse SSE lines from a raw byte stream. Handles both OpenAI and Anthropic formats.
321/// Returns (event_type, data) pairs. OpenAI doesn't send event types (always "message").
322pub fn parse_sse_lines(chunk: &str) -> Vec<(String, String)> {
323    let mut events = Vec::new();
324    let mut current_event = String::new();
325    let mut current_data = String::new();
326
327    for line in chunk.lines() {
328        if line.starts_with("event: ") {
329            current_event = line[7..].to_string();
330        } else if line.starts_with("data: ") {
331            current_data = line[6..].to_string();
332        } else if line.is_empty() && !current_data.is_empty() {
333            events.push((
334                if current_event.is_empty() {
335                    "message".to_string()
336                } else {
337                    current_event.clone()
338                },
339                current_data.clone(),
340            ));
341            current_event.clear();
342            current_data.clear();
343        }
344    }
345
346    // Handle case where stream doesn't end with empty line
347    if !current_data.is_empty() {
348        events.push((
349            if current_event.is_empty() {
350                "message".to_string()
351            } else {
352                current_event
353            },
354            current_data,
355        ));
356    }
357
358    events
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn parse_openai_text_delta() {
367        let line = r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#;
368        let events = parse_openai_sse_line(line);
369        assert_eq!(events.len(), 1);
370        match &events[0] {
371            StreamEvent::TextDelta(t) => assert_eq!(t, "Hello"),
372            other => panic!("expected TextDelta, got {:?}", other),
373        }
374    }
375
376    #[test]
377    fn parse_openai_tool_call_start() {
378        let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"edit_file"}}]}}]}"#;
379        let events = parse_openai_sse_line(line);
380        assert_eq!(events.len(), 1);
381        match &events[0] {
382            StreamEvent::ToolCallStart { name, index, .. } => {
383                assert_eq!(name, "edit_file");
384                assert_eq!(*index, 0);
385            }
386            other => panic!("expected ToolCallStart, got {:?}", other),
387        }
388    }
389
390    #[test]
391    fn parse_openai_tool_call_delta() {
392        let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\"path\":"}}]}}]}"#;
393        let events = parse_openai_sse_line(line);
394        assert_eq!(events.len(), 1);
395        match &events[0] {
396            StreamEvent::ToolCallDelta {
397                index,
398                arguments_delta,
399            } => {
400                assert_eq!(*index, 0);
401                assert!(arguments_delta.contains("path"));
402            }
403            other => panic!("expected ToolCallDelta, got {:?}", other),
404        }
405    }
406
407    #[test]
408    fn parse_openai_multiple_tool_calls_in_chunk() {
409        // When OpenAI sends multiple tool call deltas in a single SSE chunk
410        let line = r#"data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"name":"read_file"}},{"index":1,"function":{"name":"search"}}]}}]}"#;
411        let events = parse_openai_sse_line(line);
412        assert_eq!(events.len(), 2);
413        match &events[0] {
414            StreamEvent::ToolCallStart { name, index, .. } => {
415                assert_eq!(name, "read_file");
416                assert_eq!(*index, 0);
417            }
418            other => panic!("expected ToolCallStart, got {:?}", other),
419        }
420        match &events[1] {
421            StreamEvent::ToolCallStart { name, index, .. } => {
422                assert_eq!(name, "search");
423                assert_eq!(*index, 1);
424            }
425            other => panic!("expected ToolCallStart, got {:?}", other),
426        }
427    }
428
429    #[test]
430    fn parse_openai_done() {
431        assert!(parse_openai_sse_line("data: [DONE]").is_empty());
432    }
433
434    #[test]
435    fn parse_anthropic_text_delta() {
436        let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"world"}}"#;
437        let events = parse_anthropic_sse_line("content_block_delta", data);
438        assert_eq!(events.len(), 1);
439        match &events[0] {
440            StreamEvent::TextDelta(t) => assert_eq!(t, "world"),
441            other => panic!("expected TextDelta, got {:?}", other),
442        }
443    }
444
445    #[test]
446    fn parse_anthropic_tool_start() {
447        let data = r#"{"type":"content_block_start","index":1,"content_block":{"type":"tool_use","id":"t1","name":"search","input":{}}}"#;
448        let events = parse_anthropic_sse_line("content_block_start", data);
449        assert_eq!(events.len(), 1);
450        match &events[0] {
451            StreamEvent::ToolCallStart { name, index, .. } => {
452                assert_eq!(name, "search");
453                assert_eq!(*index, 1);
454            }
455            other => panic!("expected ToolCallStart, got {:?}", other),
456        }
457    }
458
459    #[test]
460    fn accumulator_builds_result() {
461        let mut acc = StreamAccumulator::default();
462        acc.push(&StreamEvent::TextDelta("Hello ".into()));
463        acc.push(&StreamEvent::TextDelta("world".into()));
464        acc.push(&StreamEvent::ToolCallStart {
465            name: "search".into(),
466            index: 0,
467            id: None,
468        });
469        acc.push(&StreamEvent::ToolCallDelta {
470            index: 0,
471            arguments_delta: r#"{"q":"test"}"#.into(),
472        });
473
474        let (text, tools) = acc.finish();
475        assert_eq!(text, "Hello world");
476        assert_eq!(tools.len(), 1);
477        assert_eq!(tools[0].name, "search");
478        assert!(tools[0].arguments.contains_key("q"));
479    }
480
481    #[test]
482    fn parse_sse_lines_openai_format() {
483        let chunk = "data: {\"choices\":[{\"delta\":{\"content\":\"Hi\"}}]}\n\ndata: [DONE]\n\n";
484        let events = parse_sse_lines(chunk);
485        assert_eq!(events.len(), 2);
486        assert_eq!(events[0].0, "message");
487        assert_eq!(events[1].1, "[DONE]");
488    }
489
490    #[test]
491    fn parse_sse_lines_anthropic_format() {
492        let chunk = "event: content_block_delta\ndata: {\"delta\":{\"type\":\"text_delta\",\"text\":\"Hi\"}}\n\n";
493        let events = parse_sse_lines(chunk);
494        assert_eq!(events.len(), 1);
495        assert_eq!(events[0].0, "content_block_delta");
496    }
497
498    #[test]
499    fn parse_anthropic_message_start_emits_usage() {
500        let data = r#"{"type":"message_start","message":{"id":"msg_1","role":"assistant","usage":{"input_tokens":245,"output_tokens":1}}}"#;
501        let events = parse_anthropic_sse_line("message_start", data);
502        assert_eq!(events.len(), 1);
503        match &events[0] {
504            StreamEvent::Usage { input_tokens, output_tokens } => {
505                assert_eq!(*input_tokens, 245);
506                assert_eq!(*output_tokens, 1);
507            }
508            other => panic!("expected Usage, got {:?}", other),
509        }
510    }
511
512    #[test]
513    fn parse_anthropic_message_delta_emits_usage() {
514        let data = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":87}}"#;
515        let events = parse_anthropic_sse_line("message_delta", data);
516        assert_eq!(events.len(), 1);
517        match &events[0] {
518            StreamEvent::Usage { input_tokens, output_tokens } => {
519                assert_eq!(*input_tokens, 0);
520                assert_eq!(*output_tokens, 87);
521            }
522            other => panic!("expected Usage, got {:?}", other),
523        }
524    }
525
526    #[test]
527    fn parse_anthropic_message_start_without_usage_is_empty() {
528        // Some forward-compat payloads may omit usage; don't crash.
529        let data = r#"{"type":"message_start","message":{"id":"msg_1"}}"#;
530        assert!(parse_anthropic_sse_line("message_start", data).is_empty());
531    }
532
533    #[test]
534    fn accumulator_tracks_usage_across_anthropic_stream() {
535        // Simulate the exact shape of a real Anthropic stream:
536        // message_start → content_block_start → content_block_delta × 3 → message_delta.
537        let mut acc = StreamAccumulator::default();
538        for event in parse_anthropic_sse_line(
539            "message_start",
540            r#"{"message":{"usage":{"input_tokens":245,"output_tokens":1}}}"#,
541        ) {
542            acc.push(&event);
543        }
544        for event in parse_anthropic_sse_line(
545            "content_block_start",
546            r#"{"index":0,"content_block":{"type":"text","text":""}}"#,
547        ) {
548            acc.push(&event);
549        }
550        for (chunk, _) in [
551            (r#"{"delta":{"type":"text_delta","text":"Hello"}}"#, ()),
552            (r#"{"delta":{"type":"text_delta","text":", "}}"#, ()),
553            (r#"{"delta":{"type":"text_delta","text":"world"}}"#, ()),
554        ] {
555            for event in parse_anthropic_sse_line("content_block_delta", chunk) {
556                acc.push(&event);
557            }
558        }
559        for event in parse_anthropic_sse_line(
560            "message_delta",
561            r#"{"usage":{"output_tokens":87}}"#,
562        ) {
563            acc.push(&event);
564        }
565
566        let (text, tools, usage) = acc.finish_with_usage();
567        assert_eq!(text, "Hello, world");
568        assert!(tools.is_empty());
569        let usage = usage.expect("provider reported usage; must surface");
570        assert_eq!(usage.prompt_tokens, 245);
571        // message_delta output (87) must win over message_start stub (1).
572        assert_eq!(usage.completion_tokens, 87);
573        assert_eq!(usage.total_tokens, 332);
574    }
575
576    #[test]
577    fn parse_openai_final_chunk_emits_usage() {
578        // OpenAI's final usage chunk when `stream_options.include_usage`
579        // is set: `choices` is empty and `usage` carries the real counts.
580        let line = r#"data: {"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87,"total_tokens":332}}"#;
581        let events = parse_openai_sse_line(line);
582        assert_eq!(events.len(), 1);
583        match &events[0] {
584            StreamEvent::Usage {
585                input_tokens,
586                output_tokens,
587            } => {
588                assert_eq!(*input_tokens, 245);
589                assert_eq!(*output_tokens, 87);
590            }
591            other => panic!("expected Usage, got {:?}", other),
592        }
593    }
594
595    #[test]
596    fn accumulator_tracks_usage_across_openai_stream() {
597        // Simulate a full OpenAI stream with `stream_options.include_usage`:
598        // text delta chunks followed by a choiceless usage-only chunk.
599        let mut acc = StreamAccumulator::default();
600        for line in [
601            r#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#,
602            r#"data: {"choices":[{"delta":{"content":", "}}]}"#,
603            r#"data: {"choices":[{"delta":{"content":"world"}}]}"#,
604            r#"data: {"id":"chatcmpl-1","choices":[],"usage":{"prompt_tokens":245,"completion_tokens":87}}"#,
605        ] {
606            for event in parse_openai_sse_line(line) {
607                acc.push(&event);
608            }
609        }
610
611        let (text, tools, usage) = acc.finish_with_usage();
612        assert_eq!(text, "Hello, world");
613        assert!(tools.is_empty());
614        let usage = usage.expect("provider reported usage; must surface");
615        assert_eq!(usage.prompt_tokens, 245);
616        assert_eq!(usage.completion_tokens, 87);
617        assert_eq!(usage.total_tokens, 332);
618    }
619
620    #[test]
621    fn accumulator_returns_no_usage_when_provider_silent() {
622        // OpenAI without `stream_options.include_usage` — no Usage
623        // events. `finish_with_usage` returns None so callers can fall
624        // back to their own estimator.
625        let mut acc = StreamAccumulator::default();
626        acc.push(&StreamEvent::TextDelta("hi".into()));
627        let (_, _, usage) = acc.finish_with_usage();
628        assert!(usage.is_none());
629    }
630}