1use std::time::{SystemTime, UNIX_EPOCH};
14
15#[derive(Debug, Clone, serde::Serialize)]
19pub struct StreamChunk {
20 pub id: String,
22 pub object: String,
24 pub created: u64,
26 pub model: String,
28 pub choices: Vec<StreamChoice>,
30}
31
32#[derive(Debug, Clone, serde::Serialize)]
34pub struct StreamChoice {
35 pub index: usize,
37 pub delta: StreamDelta,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub finish_reason: Option<String>,
42 #[serde(skip_serializing_if = "Option::is_none")]
44 pub logprobs: Option<serde_json::Value>,
45}
46
47#[derive(Debug, Clone, serde::Serialize)]
49pub struct StreamDelta {
50 #[serde(skip_serializing_if = "Option::is_none")]
52 pub role: Option<String>,
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub content: Option<String>,
56}
57
58pub struct SseFormatter {
65 pub include_usage: bool,
67 model_name: String,
68}
69
70impl SseFormatter {
71 pub fn new(model_name: &str) -> Self {
73 Self {
74 include_usage: false,
75 model_name: model_name.to_owned(),
76 }
77 }
78
79 pub fn with_usage(mut self) -> Self {
81 self.include_usage = true;
82 self
83 }
84
85 fn now_secs() -> u64 {
87 SystemTime::now()
88 .duration_since(UNIX_EPOCH)
89 .unwrap_or_default()
90 .as_secs()
91 }
92
93 pub fn first_chunk(&self, request_id: &str) -> String {
98 let chunk = StreamChunk {
99 id: request_id.to_owned(),
100 object: "chat.completion.chunk".to_owned(),
101 created: Self::now_secs(),
102 model: self.model_name.clone(),
103 choices: vec![StreamChoice {
104 index: 0,
105 delta: StreamDelta {
106 role: Some("assistant".to_owned()),
107 content: Some(String::new()),
108 },
109 finish_reason: None,
110 logprobs: None,
111 }],
112 };
113 Self::format_event(&serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_owned()))
114 }
115
116 pub fn token_chunk(&self, request_id: &str, token_text: &str) -> String {
118 let chunk = StreamChunk {
119 id: request_id.to_owned(),
120 object: "chat.completion.chunk".to_owned(),
121 created: Self::now_secs(),
122 model: self.model_name.clone(),
123 choices: vec![StreamChoice {
124 index: 0,
125 delta: StreamDelta {
126 role: None,
127 content: Some(token_text.to_owned()),
128 },
129 finish_reason: None,
130 logprobs: None,
131 }],
132 };
133 Self::format_event(&serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_owned()))
134 }
135
136 pub fn final_chunk(&self, request_id: &str, finish_reason: &str) -> String {
138 let chunk = StreamChunk {
139 id: request_id.to_owned(),
140 object: "chat.completion.chunk".to_owned(),
141 created: Self::now_secs(),
142 model: self.model_name.clone(),
143 choices: vec![StreamChoice {
144 index: 0,
145 delta: StreamDelta {
146 role: None,
147 content: None,
148 },
149 finish_reason: Some(finish_reason.to_owned()),
150 logprobs: None,
151 }],
152 };
153 Self::format_event(&serde_json::to_string(&chunk).unwrap_or_else(|_| "{}".to_owned()))
154 }
155
156 pub fn done_sentinel() -> &'static str {
158 "data: [DONE]\n\n"
159 }
160
161 pub fn format_event(data: &str) -> String {
163 format!("data: {data}\n\n")
164 }
165
166 pub fn error_event(message: &str) -> String {
168 let escaped = message.replace('\\', "\\\\").replace('"', "\\\"");
170 Self::format_event(&format!(r#"{{"error":{{"message":"{escaped}"}}}}"#))
171 }
172}
173
174pub struct TokenStream {
183 buffer: Vec<u8>,
184 pub flush_at_whitespace: bool,
187}
188
189impl Default for TokenStream {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195impl TokenStream {
196 pub fn new() -> Self {
198 Self {
199 buffer: Vec::new(),
200 flush_at_whitespace: false,
201 }
202 }
203
204 pub fn push_token_bytes(&mut self, bytes: &[u8]) -> Option<String> {
210 self.buffer.extend_from_slice(bytes);
211
212 match std::str::from_utf8(&self.buffer) {
214 Ok(s) => {
215 if self.flush_at_whitespace {
216 if s.contains(char::is_whitespace) {
218 let text = s.to_owned();
219 self.buffer.clear();
220 Some(text)
221 } else {
222 None
223 }
224 } else {
225 let text = s.to_owned();
226 self.buffer.clear();
227 Some(text)
228 }
229 }
230 Err(e) => {
231 let valid_up_to = e.valid_up_to();
233 if valid_up_to > 0 {
234 let text = std::str::from_utf8(&self.buffer[..valid_up_to])
236 .unwrap_or("") .to_owned();
238 self.buffer.drain(..valid_up_to);
239 Some(text)
240 } else {
241 None
243 }
244 }
245 }
246 }
247
248 pub fn flush(&mut self) -> String {
252 let text = String::from_utf8_lossy(&self.buffer).into_owned();
253 self.buffer.clear();
254 text
255 }
256
257 pub fn is_empty(&self) -> bool {
259 self.buffer.is_empty()
260 }
261}
262
263#[derive(Debug, Default, serde::Serialize)]
267pub struct StreamStats {
268 pub tokens_generated: usize,
270 pub prefill_tokens: usize,
272 pub time_to_first_token_ms: u64,
274 pub total_time_ms: u64,
276 pub tokens_per_second: f32,
278}
279
280impl StreamStats {
281 pub fn new() -> Self {
283 Self::default()
284 }
285
286 pub fn finish(&mut self, tokens: usize, prefill: usize, ttft_ms: u64, total_ms: u64) {
288 self.tokens_generated = tokens;
289 self.prefill_tokens = prefill;
290 self.time_to_first_token_ms = ttft_ms;
291 self.total_time_ms = total_ms;
292 self.tokens_per_second = self.throughput();
293 }
294
295 pub fn throughput(&self) -> f32 {
299 if self.total_time_ms == 0 {
300 return 0.0;
301 }
302 self.tokens_generated as f32 / (self.total_time_ms as f32 / 1_000.0)
303 }
304
305 pub fn to_usage_chunk(&self, request_id: &str, model: &str) -> String {
314 let payload = serde_json::json!({
315 "id": request_id,
316 "object": "chat.completion.chunk",
317 "model": model,
318 "usage": {
319 "prompt_tokens": self.prefill_tokens,
320 "completion_tokens": self.tokens_generated,
321 "total_tokens": self.prefill_tokens + self.tokens_generated,
322 }
323 });
324 SseFormatter::format_event(
325 &serde_json::to_string(&payload).unwrap_or_else(|_| "{}".to_owned()),
326 )
327 }
328}
329
330#[cfg(test)]
333mod tests {
334 use super::*;
335
336 fn make_formatter() -> SseFormatter {
337 SseFormatter::new("bonsai-8b")
338 }
339
340 #[test]
343 fn test_sse_formatter_first_chunk_has_role() {
344 let fmt = make_formatter();
345 let event = fmt.first_chunk("req-001");
346 let json_part = event
347 .strip_prefix("data: ")
348 .expect("must start with data:")
349 .trim_end();
350 let v: serde_json::Value = serde_json::from_str(json_part).expect("must be valid JSON");
351 let role = &v["choices"][0]["delta"]["role"];
352 assert_eq!(role, "assistant", "first chunk must carry role: assistant");
353 }
354
355 #[test]
356 fn test_sse_formatter_token_chunk_has_content() {
357 let fmt = make_formatter();
358 let event = fmt.token_chunk("req-002", "Hello");
359 let json_part = event
360 .strip_prefix("data: ")
361 .expect("must start with data:")
362 .trim_end();
363 let v: serde_json::Value = serde_json::from_str(json_part).expect("must be valid JSON");
364 let content = &v["choices"][0]["delta"]["content"];
365 assert_eq!(content, "Hello", "token chunk must carry content");
366 assert!(
368 v["choices"][0]["delta"]["role"].is_null(),
369 "token chunk must not carry role"
370 );
371 }
372
373 #[test]
374 fn test_sse_formatter_final_chunk_has_finish_reason() {
375 let fmt = make_formatter();
376 let event = fmt.final_chunk("req-003", "stop");
377 let json_part = event
378 .strip_prefix("data: ")
379 .expect("must start with data:")
380 .trim_end();
381 let v: serde_json::Value = serde_json::from_str(json_part).expect("must be valid JSON");
382 let reason = &v["choices"][0]["finish_reason"];
383 assert_eq!(reason, "stop", "final chunk must carry finish_reason");
384 }
385
386 #[test]
387 fn test_sse_formatter_done_sentinel() {
388 assert_eq!(SseFormatter::done_sentinel(), "data: [DONE]\n\n");
389 }
390
391 #[test]
392 fn test_sse_format_event() {
393 let event = SseFormatter::format_event(r#"{"foo":"bar"}"#);
394 assert_eq!(event, "data: {\"foo\":\"bar\"}\n\n");
395 }
396
397 #[test]
398 fn test_sse_error_event() {
399 let event = SseFormatter::error_event("something went wrong");
400 assert!(event.starts_with("data: "), "must be an SSE data event");
401 assert!(
402 event.contains("something went wrong"),
403 "must contain the message"
404 );
405 let json_part = event
407 .strip_prefix("data: ")
408 .expect("data: prefix")
409 .trim_end();
410 let v: serde_json::Value =
411 serde_json::from_str(json_part).expect("error event must be valid JSON");
412 assert!(v["error"]["message"].is_string());
413 }
414
415 #[test]
418 fn test_token_stream_ascii_passthrough() {
419 let mut ts = TokenStream::new();
420 let result = ts.push_token_bytes(b"hello");
421 assert_eq!(result, Some("hello".to_owned()));
422 assert!(ts.is_empty());
423 }
424
425 #[test]
426 fn test_token_stream_flush() {
427 let mut ts = TokenStream::new();
428 ts.push_token_bytes(b"hi");
430 let partial = &[0xE4u8, 0xB8u8]; let result = ts.push_token_bytes(partial);
434 assert!(result.is_none(), "incomplete sequence should return None");
435 let flushed = ts.flush();
437 assert!(!flushed.is_empty() || flushed.is_empty()); assert!(ts.is_empty(), "buffer must be empty after flush");
439 }
440
441 #[test]
442 fn test_token_stream_empty_after_flush() {
443 let mut ts = TokenStream::new();
444 let _ = ts.flush(); assert!(ts.is_empty());
446 }
447
448 #[test]
449 fn test_token_stream_multibyte_utf8() {
450 let mut ts = TokenStream::new();
451 let bytes = "中".as_bytes();
453
454 let r1 = ts.push_token_bytes(&bytes[..2]);
456 assert!(r1.is_none(), "incomplete UTF-8 should return None");
457
458 let r2 = ts.push_token_bytes(&bytes[2..]);
460 assert_eq!(r2, Some("中".to_owned()));
461 assert!(ts.is_empty());
462 }
463
464 #[test]
467 fn test_stream_stats_throughput() {
468 let mut stats = StreamStats::new();
469 stats.tokens_generated = 100;
470 stats.total_time_ms = 2_000; let tps = stats.throughput();
472 assert!((tps - 50.0).abs() < 0.01, "expected 50 tps, got {tps}");
473 }
474
475 #[test]
476 fn test_stream_stats_throughput_zero_time() {
477 let stats = StreamStats::new(); assert_eq!(stats.throughput(), 0.0);
479 }
480
481 #[test]
482 fn test_stream_stats_finish() {
483 let mut stats = StreamStats::new();
484 stats.finish(200, 50, 120, 4_000);
485 assert_eq!(stats.tokens_generated, 200);
486 assert_eq!(stats.prefill_tokens, 50);
487 assert_eq!(stats.time_to_first_token_ms, 120);
488 assert_eq!(stats.total_time_ms, 4_000);
489 assert!((stats.tokens_per_second - 50.0).abs() < 0.01);
491 }
492
493 #[test]
494 fn test_stream_chunk_serializes_correctly() {
495 let chunk = StreamChunk {
496 id: "chatcmpl-abc".to_owned(),
497 object: "chat.completion.chunk".to_owned(),
498 created: 1_700_000_000,
499 model: "bonsai-8b".to_owned(),
500 choices: vec![StreamChoice {
501 index: 0,
502 delta: StreamDelta {
503 role: Some("assistant".to_owned()),
504 content: Some("Hi".to_owned()),
505 },
506 finish_reason: None,
507 logprobs: None,
508 }],
509 };
510
511 let json = serde_json::to_string(&chunk).expect("serialization must succeed");
512 let v: serde_json::Value = serde_json::from_str(&json).expect("must parse back to JSON");
513
514 assert_eq!(v["id"], "chatcmpl-abc");
515 assert_eq!(v["object"], "chat.completion.chunk");
516 assert_eq!(v["choices"][0]["delta"]["role"], "assistant");
517 assert_eq!(v["choices"][0]["delta"]["content"], "Hi");
518 assert!(v["choices"][0]["finish_reason"].is_null());
520 }
521
522 #[test]
523 fn test_stream_stats_usage_chunk() {
524 let mut stats = StreamStats::new();
525 stats.finish(10, 5, 50, 1_000);
526 let chunk = stats.to_usage_chunk("req-x", "bonsai-8b");
527 assert!(chunk.starts_with("data: "));
528 let json_part = chunk.strip_prefix("data: ").expect("prefix").trim_end();
529 let v: serde_json::Value =
530 serde_json::from_str(json_part).expect("usage chunk must be valid JSON");
531 assert_eq!(v["usage"]["prompt_tokens"], 5);
532 assert_eq!(v["usage"]["completion_tokens"], 10);
533 assert_eq!(v["usage"]["total_tokens"], 15);
534 }
535}