sentinel_proxy/inference/
streaming.rs

1//! Streaming token counting for SSE (Server-Sent Events) responses.
2//!
3//! LLM APIs use SSE for streaming responses. This module:
4//! - Parses SSE chunks to extract content deltas
5//! - Accumulates text content across chunks
6//! - Provides final token count using tiktoken
7//!
8//! # SSE Formats
9//!
10//! ## OpenAI
11//! ```text
12//! data: {"id":"...","choices":[{"delta":{"content":"Hello"}}]}
13//! data: {"id":"...","choices":[{"delta":{"content":" world"}}]}
14//! data: [DONE]
15//! ```
16//!
17//! ## Anthropic
18//! ```text
19//! event: content_block_delta
20//! data: {"type":"content_block_delta","delta":{"type":"text_delta","text":"Hello"}}
21//! ```
22//!
23//! # Usage
24//!
25//! ```ignore
26//! let mut counter = StreamingTokenCounter::new("openai", Some("gpt-4"));
27//! counter.process_chunk(chunk1);
28//! counter.process_chunk(chunk2);
29//! let tokens = counter.finalize();
30//! ```
31
32use serde_json::Value;
33use tracing::{trace, warn};
34
35use super::tiktoken::tiktoken_manager;
36use sentinel_config::InferenceProvider;
37
38/// Streaming token counter for SSE responses.
39///
40/// Accumulates content from SSE chunks and provides final token count.
41#[derive(Debug)]
42pub struct StreamingTokenCounter {
43    /// Provider type for format detection
44    provider: InferenceProvider,
45    /// Model name for tiktoken encoding selection
46    model: Option<String>,
47    /// Accumulated content text
48    content_buffer: String,
49    /// Whether the stream has completed
50    completed: bool,
51    /// Number of chunks processed
52    chunks_processed: u32,
53    /// Bytes processed
54    bytes_processed: u64,
55    /// Final usage from API (if provided in stream)
56    api_usage: Option<ApiUsage>,
57    /// Partial SSE line buffer (for chunks that split across boundaries)
58    line_buffer: String,
59}
60
61/// Usage information from API response (when provided).
62#[derive(Debug, Clone)]
63pub struct ApiUsage {
64    pub input_tokens: u64,
65    pub output_tokens: u64,
66    pub total_tokens: u64,
67}
68
69/// Result of processing an SSE chunk.
70#[derive(Debug)]
71pub struct ChunkResult {
72    /// Content extracted from this chunk
73    pub content: Option<String>,
74    /// Whether this chunk indicates stream completion
75    pub is_done: bool,
76    /// Usage info if present in this chunk
77    pub usage: Option<ApiUsage>,
78}
79
80impl StreamingTokenCounter {
81    /// Create a new streaming token counter.
82    pub fn new(provider: InferenceProvider, model: Option<String>) -> Self {
83        Self {
84            provider,
85            model,
86            content_buffer: String::with_capacity(4096),
87            completed: false,
88            chunks_processed: 0,
89            bytes_processed: 0,
90            api_usage: None,
91            line_buffer: String::new(),
92        }
93    }
94
95    /// Process an SSE chunk from the response body.
96    ///
97    /// Extracts content deltas and accumulates them.
98    /// Returns information about what was extracted.
99    pub fn process_chunk(&mut self, chunk: &[u8]) -> ChunkResult {
100        self.chunks_processed += 1;
101        self.bytes_processed += chunk.len() as u64;
102
103        let chunk_str = match std::str::from_utf8(chunk) {
104            Ok(s) => s,
105            Err(_) => {
106                warn!("Invalid UTF-8 in SSE chunk");
107                return ChunkResult {
108                    content: None,
109                    is_done: false,
110                    usage: None,
111                };
112            }
113        };
114
115        // Append to line buffer and process complete lines
116        self.line_buffer.push_str(chunk_str);
117
118        let mut result = ChunkResult {
119            content: None,
120            is_done: false,
121            usage: None,
122        };
123
124        let mut content_parts = Vec::new();
125
126        // Process complete lines
127        while let Some(newline_pos) = self.line_buffer.find('\n') {
128            let line = self.line_buffer[..newline_pos].trim();
129
130            if !line.is_empty() {
131                let line_result = self.process_sse_line(line);
132
133                if let Some(content) = line_result.content {
134                    content_parts.push(content);
135                }
136                if line_result.is_done {
137                    result.is_done = true;
138                    self.completed = true;
139                }
140                if line_result.usage.is_some() {
141                    result.usage = line_result.usage.clone();
142                    self.api_usage = line_result.usage;
143                }
144            }
145
146            // Remove processed line from buffer
147            self.line_buffer = self.line_buffer[newline_pos + 1..].to_string();
148        }
149
150        if !content_parts.is_empty() {
151            let combined = content_parts.join("");
152            self.content_buffer.push_str(&combined);
153            result.content = Some(combined);
154        }
155
156        result
157    }
158
159    /// Process a single SSE line.
160    fn process_sse_line(&self, line: &str) -> ChunkResult {
161        // SSE format: "data: {...}" or "event: ..." or just data
162        let data = if line.starts_with("data: ") {
163            &line[6..]
164        } else if line.starts_with("data:") {
165            &line[5..]
166        } else {
167            // Skip event lines, comments, etc.
168            return ChunkResult {
169                content: None,
170                is_done: false,
171                usage: None,
172            };
173        };
174
175        let data = data.trim();
176
177        // Check for stream completion marker
178        if data == "[DONE]" {
179            return ChunkResult {
180                content: None,
181                is_done: true,
182                usage: None,
183            };
184        }
185
186        // Parse JSON
187        let json: Value = match serde_json::from_str(data) {
188            Ok(v) => v,
189            Err(_) => {
190                trace!(data = data, "Failed to parse SSE data as JSON");
191                return ChunkResult {
192                    content: None,
193                    is_done: false,
194                    usage: None,
195                };
196            }
197        };
198
199        match self.provider {
200            InferenceProvider::OpenAi => self.parse_openai_chunk(&json),
201            InferenceProvider::Anthropic => self.parse_anthropic_chunk(&json),
202            InferenceProvider::Generic => {
203                // Try OpenAI format first, then Anthropic
204                let result = self.parse_openai_chunk(&json);
205                if result.content.is_some() || result.is_done || result.usage.is_some() {
206                    result
207                } else {
208                    self.parse_anthropic_chunk(&json)
209                }
210            }
211        }
212    }
213
214    /// Parse OpenAI streaming chunk format.
215    ///
216    /// Format: {"choices":[{"delta":{"content":"..."}}],"usage":{...}}
217    fn parse_openai_chunk(&self, json: &Value) -> ChunkResult {
218        let mut result = ChunkResult {
219            content: None,
220            is_done: false,
221            usage: None,
222        };
223
224        // Extract content from choices[0].delta.content
225        if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
226            if let Some(first_choice) = choices.first() {
227                // Check for finish_reason indicating completion
228                if let Some(finish_reason) = first_choice.get("finish_reason") {
229                    if !finish_reason.is_null() {
230                        result.is_done = true;
231                    }
232                }
233
234                // Extract delta content
235                if let Some(delta) = first_choice.get("delta") {
236                    if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
237                        result.content = Some(content.to_string());
238                    }
239                }
240            }
241        }
242
243        // Extract usage if present (OpenAI includes this in the final chunk)
244        if let Some(usage) = json.get("usage") {
245            let prompt_tokens = usage
246                .get("prompt_tokens")
247                .and_then(|t| t.as_u64())
248                .unwrap_or(0);
249            let completion_tokens = usage
250                .get("completion_tokens")
251                .and_then(|t| t.as_u64())
252                .unwrap_or(0);
253            let total_tokens = usage
254                .get("total_tokens")
255                .and_then(|t| t.as_u64())
256                .unwrap_or(prompt_tokens + completion_tokens);
257
258            if total_tokens > 0 {
259                result.usage = Some(ApiUsage {
260                    input_tokens: prompt_tokens,
261                    output_tokens: completion_tokens,
262                    total_tokens,
263                });
264            }
265        }
266
267        result
268    }
269
270    /// Parse Anthropic streaming chunk format.
271    ///
272    /// Format: {"type":"content_block_delta","delta":{"type":"text_delta","text":"..."}}
273    fn parse_anthropic_chunk(&self, json: &Value) -> ChunkResult {
274        let mut result = ChunkResult {
275            content: None,
276            is_done: false,
277            usage: None,
278        };
279
280        let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
281
282        match event_type {
283            "content_block_delta" => {
284                // Extract text from delta
285                if let Some(delta) = json.get("delta") {
286                    if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
287                        result.content = Some(text.to_string());
288                    }
289                }
290            }
291            "message_stop" => {
292                result.is_done = true;
293            }
294            "message_delta" => {
295                // Anthropic includes usage in message_delta at the end
296                if let Some(usage) = json.get("usage") {
297                    let output_tokens = usage
298                        .get("output_tokens")
299                        .and_then(|t| t.as_u64())
300                        .unwrap_or(0);
301
302                    if output_tokens > 0 {
303                        result.usage = Some(ApiUsage {
304                            input_tokens: 0, // Not provided in delta
305                            output_tokens,
306                            total_tokens: output_tokens,
307                        });
308                    }
309                }
310            }
311            "message_start" => {
312                // Anthropic includes input tokens in message_start
313                if let Some(message) = json.get("message") {
314                    if let Some(usage) = message.get("usage") {
315                        let input_tokens = usage
316                            .get("input_tokens")
317                            .and_then(|t| t.as_u64())
318                            .unwrap_or(0);
319
320                        if input_tokens > 0 {
321                            result.usage = Some(ApiUsage {
322                                input_tokens,
323                                output_tokens: 0,
324                                total_tokens: input_tokens,
325                            });
326                        }
327                    }
328                }
329            }
330            _ => {}
331        }
332
333        result
334    }
335
336    /// Check if the stream has completed.
337    pub fn is_completed(&self) -> bool {
338        self.completed
339    }
340
341    /// Get the accumulated content so far.
342    pub fn content(&self) -> &str {
343        &self.content_buffer
344    }
345
346    /// Get the number of chunks processed.
347    pub fn chunks_processed(&self) -> u32 {
348        self.chunks_processed
349    }
350
351    /// Get the bytes processed.
352    pub fn bytes_processed(&self) -> u64 {
353        self.bytes_processed
354    }
355
356    /// Get API-provided usage if available.
357    pub fn api_usage(&self) -> Option<&ApiUsage> {
358        self.api_usage.as_ref()
359    }
360
361    /// Finalize and get the output token count.
362    ///
363    /// Uses API-provided usage if available, otherwise counts tokens
364    /// in the accumulated content using tiktoken.
365    pub fn finalize(&self) -> StreamingTokenResult {
366        let manager = tiktoken_manager();
367
368        // Prefer API-provided usage
369        if let Some(usage) = &self.api_usage {
370            trace!(
371                input_tokens = usage.input_tokens,
372                output_tokens = usage.output_tokens,
373                total_tokens = usage.total_tokens,
374                chunks = self.chunks_processed,
375                "Using API-provided token counts for streaming response"
376            );
377
378            return StreamingTokenResult {
379                output_tokens: usage.output_tokens,
380                input_tokens: Some(usage.input_tokens),
381                total_tokens: Some(usage.total_tokens),
382                source: TokenCountSource::ApiProvided,
383                content_length: self.content_buffer.len(),
384            };
385        }
386
387        // Count tokens in accumulated content
388        let output_tokens = manager.count_tokens(self.model.as_deref(), &self.content_buffer);
389
390        trace!(
391            output_tokens = output_tokens,
392            content_len = self.content_buffer.len(),
393            chunks = self.chunks_processed,
394            model = ?self.model,
395            "Counted tokens in streaming response content"
396        );
397
398        StreamingTokenResult {
399            output_tokens,
400            input_tokens: None,
401            total_tokens: None,
402            source: TokenCountSource::Tiktoken,
403            content_length: self.content_buffer.len(),
404        }
405    }
406}
407
408/// Source of token count.
409#[derive(Debug, Clone, Copy, PartialEq, Eq)]
410pub enum TokenCountSource {
411    /// Token count provided by the API in the stream
412    ApiProvided,
413    /// Token count calculated using tiktoken
414    Tiktoken,
415}
416
417/// Result of streaming token counting.
418#[derive(Debug)]
419pub struct StreamingTokenResult {
420    /// Output tokens (completion tokens)
421    pub output_tokens: u64,
422    /// Input tokens (prompt tokens) if known
423    pub input_tokens: Option<u64>,
424    /// Total tokens if known
425    pub total_tokens: Option<u64>,
426    /// Source of the token count
427    pub source: TokenCountSource,
428    /// Length of accumulated content in bytes
429    pub content_length: usize,
430}
431
432/// Check if a response appears to be SSE based on content type.
433pub fn is_sse_response(content_type: Option<&str>) -> bool {
434    content_type.map_or(false, |ct| {
435        ct.contains("text/event-stream") || ct.contains("application/x-ndjson")
436    })
437}
438
439// ============================================================================
440// Tests
441// ============================================================================
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446
447    #[test]
448    fn test_openai_streaming() {
449        let mut counter = StreamingTokenCounter::new(InferenceProvider::OpenAi, Some("gpt-4".to_string()));
450
451        // Simulate OpenAI SSE chunks
452        let chunk1 = b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n";
453        let chunk2 = b"data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n";
454        let chunk3 = b"data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":2,\"total_tokens\":12}}\n\n";
455        let chunk4 = b"data: [DONE]\n\n";
456
457        let r1 = counter.process_chunk(chunk1);
458        assert_eq!(r1.content, Some("Hello".to_string()));
459        assert!(!r1.is_done);
460
461        let r2 = counter.process_chunk(chunk2);
462        assert_eq!(r2.content, Some(" world".to_string()));
463        assert!(!r2.is_done);
464
465        let r3 = counter.process_chunk(chunk3);
466        assert!(r3.is_done);
467        assert!(r3.usage.is_some());
468        let usage = r3.usage.unwrap();
469        assert_eq!(usage.input_tokens, 10);
470        assert_eq!(usage.output_tokens, 2);
471        assert_eq!(usage.total_tokens, 12);
472
473        let r4 = counter.process_chunk(chunk4);
474        assert!(r4.is_done);
475
476        assert_eq!(counter.content(), "Hello world");
477        assert!(counter.is_completed());
478
479        let result = counter.finalize();
480        assert_eq!(result.output_tokens, 2);
481        assert_eq!(result.input_tokens, Some(10));
482        assert_eq!(result.source, TokenCountSource::ApiProvided);
483    }
484
485    #[test]
486    fn test_anthropic_streaming() {
487        let mut counter =
488            StreamingTokenCounter::new(InferenceProvider::Anthropic, Some("claude-3-opus".to_string()));
489
490        // Simulate Anthropic SSE chunks
491        let chunk1 = b"event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":25}}}\n\n";
492        let chunk2 = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n";
493        let chunk3 = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" there\"}}\n\n";
494        let chunk4 = b"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n";
495        let chunk5 = b"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n";
496
497        counter.process_chunk(chunk1);
498        let r2 = counter.process_chunk(chunk2);
499        assert_eq!(r2.content, Some("Hello".to_string()));
500
501        let r3 = counter.process_chunk(chunk3);
502        assert_eq!(r3.content, Some(" there".to_string()));
503
504        let r4 = counter.process_chunk(chunk4);
505        assert!(r4.usage.is_some());
506        assert_eq!(r4.usage.unwrap().output_tokens, 3);
507
508        let r5 = counter.process_chunk(chunk5);
509        assert!(r5.is_done);
510
511        assert_eq!(counter.content(), "Hello there");
512        assert!(counter.is_completed());
513    }
514
515    #[test]
516    fn test_tiktoken_fallback() {
517        let mut counter = StreamingTokenCounter::new(InferenceProvider::OpenAi, Some("gpt-4".to_string()));
518
519        // Chunks without usage info
520        let chunk1 = b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello world\"}}]}\n\n";
521        let chunk2 = b"data: [DONE]\n\n";
522
523        counter.process_chunk(chunk1);
524        counter.process_chunk(chunk2);
525
526        let result = counter.finalize();
527        assert_eq!(result.source, TokenCountSource::Tiktoken);
528        // "Hello world" is 2 tokens with cl100k_base
529        assert!(result.output_tokens > 0);
530    }
531
532    #[test]
533    fn test_split_chunks() {
534        let mut counter = StreamingTokenCounter::new(InferenceProvider::OpenAi, Some("gpt-4".to_string()));
535
536        // Data split across chunk boundaries
537        let chunk1 = b"data: {\"choices\":[{\"delta\":{\"content\":\"He";
538        let chunk2 = b"llo\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n";
539
540        let r1 = counter.process_chunk(chunk1);
541        assert!(r1.content.is_none()); // No complete line yet
542
543        let r2 = counter.process_chunk(chunk2);
544        // Should get both "Hello" and " world" as the line completes
545        assert!(r2.content.is_some());
546        assert!(counter.content().contains("Hello"));
547        assert!(counter.content().contains(" world"));
548    }
549
550    #[test]
551    fn test_is_sse_response() {
552        assert!(is_sse_response(Some("text/event-stream")));
553        assert!(is_sse_response(Some("text/event-stream; charset=utf-8")));
554        assert!(is_sse_response(Some("application/x-ndjson")));
555        assert!(!is_sse_response(Some("application/json")));
556        assert!(!is_sse_response(None));
557    }
558
559    #[test]
560    fn test_generic_provider() {
561        let mut counter = StreamingTokenCounter::new(InferenceProvider::Generic, None);
562
563        // Should handle OpenAI format
564        let chunk = b"data: {\"choices\":[{\"delta\":{\"content\":\"Test\"}}]}\n\n";
565        let result = counter.process_chunk(chunk);
566        assert_eq!(result.content, Some("Test".to_string()));
567    }
568}