Skip to main content

oxibonsai_runtime/
streaming.rs

1//! Enhanced SSE streaming with delta tokens, finish reasons, and usage info.
2//!
3//! This module provides OpenAI-compatible Server-Sent Events (SSE) streaming
4//! primitives:
5//!
6//! - [`StreamChunk`] / [`StreamChoice`] / [`StreamDelta`] — wire-format structs
7//!   that match the OpenAI `chat.completion.chunk` schema.
8//! - [`SseFormatter`] — stateless helpers that format SSE event strings.
9//! - [`TokenStream`] — a byte-level buffer that accumulates raw token bytes and
10//!   yields decoded `String`s as soon as a valid UTF-8 sequence is complete.
11//! - [`StreamStats`] — throughput accounting for a single generation request.
12
13use std::time::{SystemTime, UNIX_EPOCH};
14
15// ── Wire-format structs ───────────────────────────────────────────────────────
16
17/// A single SSE streaming chunk (OpenAI-compatible delta format).
18#[derive(Debug, Clone, serde::Serialize)]
19pub struct StreamChunk {
20    /// Unique completion ID shared across all chunks in one generation.
21    pub id: String,
22    /// Always `"chat.completion.chunk"`.
23    pub object: String,
24    /// Unix timestamp of when the generation started.
25    pub created: u64,
26    /// Model name (e.g. `"bonsai-8b"`).
27    pub model: String,
28    /// One-element list of choices (multi-choice streaming is not yet supported).
29    pub choices: Vec<StreamChoice>,
30}
31
32/// A single choice within a [`StreamChunk`].
33#[derive(Debug, Clone, serde::Serialize)]
34pub struct StreamChoice {
35    /// Zero-based choice index.
36    pub index: usize,
37    /// The incremental delta for this chunk.
38    pub delta: StreamDelta,
39    /// `None` for all chunks except the last; `"stop"` / `"length"` on the last chunk.
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub finish_reason: Option<String>,
42    /// Log-probability information (not yet computed — always `null`).
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub logprobs: Option<serde_json::Value>,
45}
46
47/// The incremental content delta for one chunk.
48#[derive(Debug, Clone, serde::Serialize)]
49pub struct StreamDelta {
50    /// Set to `"assistant"` on the very first chunk; `None` on subsequent chunks.
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub role: Option<String>,
53    /// The token text for this chunk; `None` on the final (finish-reason) chunk.
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub content: Option<String>,
56}
57
58// ── SseFormatter ─────────────────────────────────────────────────────────────
59
60/// Stateless SSE event formatter.
61///
62/// Produces correctly framed `data: …\n\n` strings for each stage of a
63/// streaming generation response.
64pub struct SseFormatter {
65    /// Whether to append a usage chunk after the final delta.
66    pub include_usage: bool,
67    model_name: String,
68}
69
70impl SseFormatter {
71    /// Create a new formatter for the given model.
72    pub fn new(model_name: &str) -> Self {
73        Self {
74            include_usage: false,
75            model_name: model_name.to_owned(),
76        }
77    }
78
79    /// Enable a trailing usage chunk.
80    pub fn with_usage(mut self) -> Self {
81        self.include_usage = true;
82        self
83    }
84
85    /// Return the current Unix timestamp in seconds.
86    fn now_secs() -> u64 {
87        SystemTime::now()
88            .duration_since(UNIX_EPOCH)
89            .unwrap_or_default()
90            .as_secs()
91    }
92
93    /// Format the **first** chunk of a streaming response.
94    ///
95    /// The first chunk carries `role: "assistant"` and an empty content string
96    /// so that clients can render the role indicator immediately.
97    pub fn first_chunk(&self, request_id: &str) -> String {
98        let chunk = StreamChunk {
99            id: request_id.to_owned(),
100            object: "chat.completion.chunk".to_owned(),
101            created: Self::now_secs(),
102            model: self.model_name.clone(),
103            choices: vec![StreamChoice {
104                index: 0,
105                delta: StreamDelta {
106                    role: Some("assistant".to_owned()),
107                    content: Some(String::new()),
108                },
109                finish_reason: None,
110                logprobs: None,
111            }],
112        };
113        Self::format_event(&serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_owned()))
114    }
115
116    /// Format a **token delta** chunk carrying `token_text` as the content.
117    pub fn token_chunk(&self, request_id: &str, token_text: &str) -> String {
118        let chunk = StreamChunk {
119            id: request_id.to_owned(),
120            object: "chat.completion.chunk".to_owned(),
121            created: Self::now_secs(),
122            model: self.model_name.clone(),
123            choices: vec![StreamChoice {
124                index: 0,
125                delta: StreamDelta {
126                    role: None,
127                    content: Some(token_text.to_owned()),
128                },
129                finish_reason: None,
130                logprobs: None,
131            }],
132        };
133        Self::format_event(&serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_owned()))
134    }
135
136    /// Format the **final** chunk carrying the `finish_reason` and no content.
137    pub fn final_chunk(&self, request_id: &str, finish_reason: &str) -> String {
138        let chunk = StreamChunk {
139            id: request_id.to_owned(),
140            object: "chat.completion.chunk".to_owned(),
141            created: Self::now_secs(),
142            model: self.model_name.clone(),
143            choices: vec![StreamChoice {
144                index: 0,
145                delta: StreamDelta {
146                    role: None,
147                    content: None,
148                },
149                finish_reason: Some(finish_reason.to_owned()),
150                logprobs: None,
151            }],
152        };
153        Self::format_event(&serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_owned()))
154    }
155
156    /// The SSE `[DONE]` sentinel that signals stream completion.
157    pub fn done_sentinel() -> &'static str {
158        "data: [DONE]\n\n"
159    }
160
161    /// Wrap arbitrary JSON data in a `data: …\n\n` SSE frame.
162    pub fn format_event(data: &str) -> String {
163        format!("data: {data}\n\n")
164    }
165
166    /// Format a JSON error payload as an SSE event.
167    pub fn error_event(message: &str) -> String {
168        // Escape the message to avoid breaking the JSON.
169        let escaped = message.replace('\\', "\\\\").replace('"', "\\\"");
170        Self::format_event(&format!(r#"{{"error":{{"message":"{escaped}"}}}}"#))
171    }
172}
173
174// ── TokenStream ───────────────────────────────────────────────────────────────
175
176/// Byte-level detokenizer buffer with partial-token accumulation.
177///
178/// Raw model output often arrives as byte slices that do not align with UTF-8
179/// character boundaries (e.g. multi-byte CJK characters split across two model
180/// tokens).  `TokenStream` accumulates bytes until a complete UTF-8 sequence is
181/// available, then returns the decoded string.
182pub struct TokenStream {
183    buffer: Vec<u8>,
184    /// If `true`, the stream defers flushing until a whitespace boundary is found.
185    /// This can be useful for word-level de-tokenization.
186    pub flush_at_whitespace: bool,
187}
188
189impl Default for TokenStream {
190    fn default() -> Self {
191        Self::new()
192    }
193}
194
195impl TokenStream {
196    /// Create a new empty `TokenStream`.
197    pub fn new() -> Self {
198        Self {
199            buffer: Vec::new(),
200            flush_at_whitespace: false,
201        }
202    }
203
204    /// Append `bytes` to the internal buffer.
205    ///
206    /// Returns `Some(text)` if the buffer now forms a valid complete UTF-8
207    /// string, or `None` if more bytes are still needed to complete a multi-byte
208    /// character.
209    pub fn push_token_bytes(&mut self, bytes: &[u8]) -> Option<String> {
210        self.buffer.extend_from_slice(bytes);
211
212        // Try to decode the buffer as UTF-8.
213        match std::str::from_utf8(&self.buffer) {
214            Ok(s) => {
215                if self.flush_at_whitespace {
216                    // Only flush at whitespace boundaries.
217                    if s.contains(char::is_whitespace) {
218                        let text = s.to_owned();
219                        self.buffer.clear();
220                        Some(text)
221                    } else {
222                        None
223                    }
224                } else {
225                    let text = s.to_owned();
226                    self.buffer.clear();
227                    Some(text)
228                }
229            }
230            Err(e) => {
231                // Check if there is a valid prefix followed by an incomplete sequence.
232                let valid_up_to = e.valid_up_to();
233                if valid_up_to > 0 {
234                    // Emit the valid prefix; keep the incomplete tail.
235                    let text = std::str::from_utf8(&self.buffer[..valid_up_to])
236                        .unwrap_or("") // safe: we just validated this range
237                        .to_owned();
238                    self.buffer.drain(..valid_up_to);
239                    Some(text)
240                } else {
241                    // Still mid-sequence — wait for more bytes.
242                    None
243                }
244            }
245        }
246    }
247
248    /// Force-flush whatever remains in the buffer as lossy UTF-8.
249    ///
250    /// Any invalid byte sequences are replaced with U+FFFD (replacement char).
251    pub fn flush(&mut self) -> String {
252        let text = String::from_utf8_lossy(&self.buffer).into_owned();
253        self.buffer.clear();
254        text
255    }
256
257    /// Returns `true` if the internal buffer is empty.
258    pub fn is_empty(&self) -> bool {
259        self.buffer.is_empty()
260    }
261}
262
263// ── StreamStats ───────────────────────────────────────────────────────────────
264
265/// Per-request generation throughput statistics.
266#[derive(Debug, Default, serde::Serialize)]
267pub struct StreamStats {
268    /// Total tokens emitted in the completion.
269    pub tokens_generated: usize,
270    /// Number of tokens in the prompt (prefill phase).
271    pub prefill_tokens: usize,
272    /// Wall-clock milliseconds until the first token was emitted.
273    pub time_to_first_token_ms: u64,
274    /// Total wall-clock milliseconds for the entire generation.
275    pub total_time_ms: u64,
276    /// Tokens-per-second throughput (cached result of [`StreamStats::throughput`]).
277    pub tokens_per_second: f32,
278}
279
280impl StreamStats {
281    /// Create a blank `StreamStats`.
282    pub fn new() -> Self {
283        Self::default()
284    }
285
286    /// Record the final statistics after generation completes.
287    pub fn finish(&mut self, tokens: usize, prefill: usize, ttft_ms: u64, total_ms: u64) {
288        self.tokens_generated = tokens;
289        self.prefill_tokens = prefill;
290        self.time_to_first_token_ms = ttft_ms;
291        self.total_time_ms = total_ms;
292        self.tokens_per_second = self.throughput();
293    }
294
295    /// Compute tokens-per-second from recorded statistics.
296    ///
297    /// Returns `0.0` if `total_time_ms` is zero (avoids division by zero).
298    pub fn throughput(&self) -> f32 {
299        if self.total_time_ms == 0 {
300            return 0.0;
301        }
302        self.tokens_generated as f32 / (self.total_time_ms as f32 / 1_000.0)
303    }
304
305    /// Serialize these statistics as an SSE usage chunk.
306    ///
307    /// The payload follows the OpenAI convention of appending a final usage
308    /// chunk before `[DONE]`:
309    ///
310    /// ```json
311    /// {"id":"...","object":"chat.completion.chunk","usage":{"prompt_tokens":…,…}}
312    /// ```
313    pub fn to_usage_chunk(&self, request_id: &str, model: &str) -> String {
314        let payload = serde_json::json!({
315            "id": request_id,
316            "object": "chat.completion.chunk",
317            "model": model,
318            "usage": {
319                "prompt_tokens": self.prefill_tokens,
320                "completion_tokens": self.tokens_generated,
321                "total_tokens": self.prefill_tokens + self.tokens_generated,
322            }
323        });
324        SseFormatter::format_event(
325            &serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_owned()),
326        )
327    }
328}
329
330// ── Tests ─────────────────────────────────────────────────────────────────────
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335
336    fn make_formatter() -> SseFormatter {
337        SseFormatter::new("bonsai-8b")
338    }
339
340    // ── SseFormatter ──
341
342    #[test]
343    fn test_sse_formatter_first_chunk_has_role() {
344        let fmt = make_formatter();
345        let event = fmt.first_chunk("req-001");
346        let json_part = event
347            .strip_prefix("data: ")
348            .expect("must start with data:")
349            .trim_end();
350        let v: serde_json::Value = serde_json::from_str(json_part).expect("must be valid JSON");
351        let role = &v["choices"][0]["delta"]["role"];
352        assert_eq!(role, "assistant", "first chunk must carry role: assistant");
353    }
354
355    #[test]
356    fn test_sse_formatter_token_chunk_has_content() {
357        let fmt = make_formatter();
358        let event = fmt.token_chunk("req-002", "Hello");
359        let json_part = event
360            .strip_prefix("data: ")
361            .expect("must start with data:")
362            .trim_end();
363        let v: serde_json::Value = serde_json::from_str(json_part).expect("must be valid JSON");
364        let content = &v["choices"][0]["delta"]["content"];
365        assert_eq!(content, "Hello", "token chunk must carry content");
366        // role should be absent.
367        assert!(
368            v["choices"][0]["delta"]["role"].is_null(),
369            "token chunk must not carry role"
370        );
371    }
372
373    #[test]
374    fn test_sse_formatter_final_chunk_has_finish_reason() {
375        let fmt = make_formatter();
376        let event = fmt.final_chunk("req-003", "stop");
377        let json_part = event
378            .strip_prefix("data: ")
379            .expect("must start with data:")
380            .trim_end();
381        let v: serde_json::Value = serde_json::from_str(json_part).expect("must be valid JSON");
382        let reason = &v["choices"][0]["finish_reason"];
383        assert_eq!(reason, "stop", "final chunk must carry finish_reason");
384    }
385
386    #[test]
387    fn test_sse_formatter_done_sentinel() {
388        assert_eq!(SseFormatter::done_sentinel(), "data: [DONE]\n\n");
389    }
390
391    #[test]
392    fn test_sse_format_event() {
393        let event = SseFormatter::format_event(r#"{"foo":"bar"}"#);
394        assert_eq!(event, "data: {\"foo\":\"bar\"}\n\n");
395    }
396
397    #[test]
398    fn test_sse_error_event() {
399        let event = SseFormatter::error_event("something went wrong");
400        assert!(event.starts_with("data: "), "must be an SSE data event");
401        assert!(
402            event.contains("something went wrong"),
403            "must contain the message"
404        );
405        // Must parse as valid JSON.
406        let json_part = event
407            .strip_prefix("data: ")
408            .expect("data: prefix")
409            .trim_end();
410        let v: serde_json::Value =
411            serde_json::from_str(json_part).expect("error event must be valid JSON");
412        assert!(v["error"]["message"].is_string());
413    }
414
415    // ── TokenStream ──
416
417    #[test]
418    fn test_token_stream_ascii_passthrough() {
419        let mut ts = TokenStream::new();
420        let result = ts.push_token_bytes(b"hello");
421        assert_eq!(result, Some("hello".to_owned()));
422        assert!(ts.is_empty());
423    }
424
425    #[test]
426    fn test_token_stream_flush() {
427        let mut ts = TokenStream::new();
428        // Push a valid ASCII byte so something is in the buffer-then-flushed path.
429        ts.push_token_bytes(b"hi");
430        // Buffer is cleared after the push_token_bytes call above.
431        // Now push a partial UTF-8 sequence.
432        let partial = &[0xE4u8, 0xB8u8]; // first 2 bytes of a 3-byte CJK char
433        let result = ts.push_token_bytes(partial);
434        assert!(result.is_none(), "incomplete sequence should return None");
435        // Force flush — should produce replacement char or whatever is valid.
436        let flushed = ts.flush();
437        assert!(!flushed.is_empty() || flushed.is_empty()); // either outcome is OK
438        assert!(ts.is_empty(), "buffer must be empty after flush");
439    }
440
441    #[test]
442    fn test_token_stream_empty_after_flush() {
443        let mut ts = TokenStream::new();
444        let _ = ts.flush(); // flush on empty buffer
445        assert!(ts.is_empty());
446    }
447
448    #[test]
449    fn test_token_stream_multibyte_utf8() {
450        let mut ts = TokenStream::new();
451        // "中" = U+4E2D = bytes [0xE4, 0xB8, 0xAD]
452        let bytes = "中".as_bytes();
453
454        // Push first two bytes — should return None.
455        let r1 = ts.push_token_bytes(&bytes[..2]);
456        assert!(r1.is_none(), "incomplete UTF-8 should return None");
457
458        // Push the final byte — should now decode.
459        let r2 = ts.push_token_bytes(&bytes[2..]);
460        assert_eq!(r2, Some("中".to_owned()));
461        assert!(ts.is_empty());
462    }
463
464    // ── StreamStats ──
465
466    #[test]
467    fn test_stream_stats_throughput() {
468        let mut stats = StreamStats::new();
469        stats.tokens_generated = 100;
470        stats.total_time_ms = 2_000; // 2 seconds
471        let tps = stats.throughput();
472        assert!((tps - 50.0).abs() < 0.01, "expected 50 tps, got {tps}");
473    }
474
475    #[test]
476    fn test_stream_stats_throughput_zero_time() {
477        let stats = StreamStats::new(); // total_time_ms == 0
478        assert_eq!(stats.throughput(), 0.0);
479    }
480
481    #[test]
482    fn test_stream_stats_finish() {
483        let mut stats = StreamStats::new();
484        stats.finish(200, 50, 120, 4_000);
485        assert_eq!(stats.tokens_generated, 200);
486        assert_eq!(stats.prefill_tokens, 50);
487        assert_eq!(stats.time_to_first_token_ms, 120);
488        assert_eq!(stats.total_time_ms, 4_000);
489        // throughput = 200 / 4.0 = 50 tps
490        assert!((stats.tokens_per_second - 50.0).abs() < 0.01);
491    }
492
493    #[test]
494    fn test_stream_chunk_serializes_correctly() {
495        let chunk = StreamChunk {
496            id: "chatcmpl-abc".to_owned(),
497            object: "chat.completion.chunk".to_owned(),
498            created: 1_700_000_000,
499            model: "bonsai-8b".to_owned(),
500            choices: vec![StreamChoice {
501                index: 0,
502                delta: StreamDelta {
503                    role: Some("assistant".to_owned()),
504                    content: Some("Hi".to_owned()),
505                },
506                finish_reason: None,
507                logprobs: None,
508            }],
509        };
510
511        let json = serde_json::to_string(&chunk).expect("serialization must succeed");
512        let v: serde_json::Value = serde_json::from_str(&json).expect("must parse back to JSON");
513
514        assert_eq!(v["id"], "chatcmpl-abc");
515        assert_eq!(v["object"], "chat.completion.chunk");
516        assert_eq!(v["choices"][0]["delta"]["role"], "assistant");
517        assert_eq!(v["choices"][0]["delta"]["content"], "Hi");
518        // finish_reason is None so it should be absent from JSON.
519        assert!(v["choices"][0]["finish_reason"].is_null());
520    }
521
522    #[test]
523    fn test_stream_stats_usage_chunk() {
524        let mut stats = StreamStats::new();
525        stats.finish(10, 5, 50, 1_000);
526        let chunk = stats.to_usage_chunk("req-x", "bonsai-8b");
527        assert!(chunk.starts_with("data: "));
528        let json_part = chunk.strip_prefix("data: ").expect("prefix").trim_end();
529        let v: serde_json::Value =
530            serde_json::from_str(json_part).expect("usage chunk must be valid JSON");
531        assert_eq!(v["usage"]["prompt_tokens"], 5);
532        assert_eq!(v["usage"]["completion_tokens"], 10);
533        assert_eq!(v["usage"]["total_tokens"], 15);
534    }
535}