sentinel_proxy/inference/
streaming.rs1use serde_json::Value;
33use tracing::{trace, warn};
34
35use super::tiktoken::tiktoken_manager;
36use sentinel_config::InferenceProvider;
37
38#[derive(Debug)]
42pub struct StreamingTokenCounter {
43 provider: InferenceProvider,
45 model: Option<String>,
47 content_buffer: String,
49 completed: bool,
51 chunks_processed: u32,
53 bytes_processed: u64,
55 api_usage: Option<ApiUsage>,
57 line_buffer: String,
59}
60
61#[derive(Debug, Clone)]
63pub struct ApiUsage {
64 pub input_tokens: u64,
65 pub output_tokens: u64,
66 pub total_tokens: u64,
67}
68
69#[derive(Debug)]
71pub struct ChunkResult {
72 pub content: Option<String>,
74 pub is_done: bool,
76 pub usage: Option<ApiUsage>,
78}
79
80impl StreamingTokenCounter {
81 pub fn new(provider: InferenceProvider, model: Option<String>) -> Self {
83 Self {
84 provider,
85 model,
86 content_buffer: String::with_capacity(4096),
87 completed: false,
88 chunks_processed: 0,
89 bytes_processed: 0,
90 api_usage: None,
91 line_buffer: String::new(),
92 }
93 }
94
95 pub fn process_chunk(&mut self, chunk: &[u8]) -> ChunkResult {
100 self.chunks_processed += 1;
101 self.bytes_processed += chunk.len() as u64;
102
103 let chunk_str = match std::str::from_utf8(chunk) {
104 Ok(s) => s,
105 Err(_) => {
106 warn!("Invalid UTF-8 in SSE chunk");
107 return ChunkResult {
108 content: None,
109 is_done: false,
110 usage: None,
111 };
112 }
113 };
114
115 self.line_buffer.push_str(chunk_str);
117
118 let mut result = ChunkResult {
119 content: None,
120 is_done: false,
121 usage: None,
122 };
123
124 let mut content_parts = Vec::new();
125
126 while let Some(newline_pos) = self.line_buffer.find('\n') {
128 let line = self.line_buffer[..newline_pos].trim();
129
130 if !line.is_empty() {
131 let line_result = self.process_sse_line(line);
132
133 if let Some(content) = line_result.content {
134 content_parts.push(content);
135 }
136 if line_result.is_done {
137 result.is_done = true;
138 self.completed = true;
139 }
140 if line_result.usage.is_some() {
141 result.usage = line_result.usage.clone();
142 self.api_usage = line_result.usage;
143 }
144 }
145
146 self.line_buffer = self.line_buffer[newline_pos + 1..].to_string();
148 }
149
150 if !content_parts.is_empty() {
151 let combined = content_parts.join("");
152 self.content_buffer.push_str(&combined);
153 result.content = Some(combined);
154 }
155
156 result
157 }
158
159 fn process_sse_line(&self, line: &str) -> ChunkResult {
161 let data = if line.starts_with("data: ") {
163 &line[6..]
164 } else if line.starts_with("data:") {
165 &line[5..]
166 } else {
167 return ChunkResult {
169 content: None,
170 is_done: false,
171 usage: None,
172 };
173 };
174
175 let data = data.trim();
176
177 if data == "[DONE]" {
179 return ChunkResult {
180 content: None,
181 is_done: true,
182 usage: None,
183 };
184 }
185
186 let json: Value = match serde_json::from_str(data) {
188 Ok(v) => v,
189 Err(_) => {
190 trace!(data = data, "Failed to parse SSE data as JSON");
191 return ChunkResult {
192 content: None,
193 is_done: false,
194 usage: None,
195 };
196 }
197 };
198
199 match self.provider {
200 InferenceProvider::OpenAi => self.parse_openai_chunk(&json),
201 InferenceProvider::Anthropic => self.parse_anthropic_chunk(&json),
202 InferenceProvider::Generic => {
203 let result = self.parse_openai_chunk(&json);
205 if result.content.is_some() || result.is_done || result.usage.is_some() {
206 result
207 } else {
208 self.parse_anthropic_chunk(&json)
209 }
210 }
211 }
212 }
213
214 fn parse_openai_chunk(&self, json: &Value) -> ChunkResult {
218 let mut result = ChunkResult {
219 content: None,
220 is_done: false,
221 usage: None,
222 };
223
224 if let Some(choices) = json.get("choices").and_then(|c| c.as_array()) {
226 if let Some(first_choice) = choices.first() {
227 if let Some(finish_reason) = first_choice.get("finish_reason") {
229 if !finish_reason.is_null() {
230 result.is_done = true;
231 }
232 }
233
234 if let Some(delta) = first_choice.get("delta") {
236 if let Some(content) = delta.get("content").and_then(|c| c.as_str()) {
237 result.content = Some(content.to_string());
238 }
239 }
240 }
241 }
242
243 if let Some(usage) = json.get("usage") {
245 let prompt_tokens = usage
246 .get("prompt_tokens")
247 .and_then(|t| t.as_u64())
248 .unwrap_or(0);
249 let completion_tokens = usage
250 .get("completion_tokens")
251 .and_then(|t| t.as_u64())
252 .unwrap_or(0);
253 let total_tokens = usage
254 .get("total_tokens")
255 .and_then(|t| t.as_u64())
256 .unwrap_or(prompt_tokens + completion_tokens);
257
258 if total_tokens > 0 {
259 result.usage = Some(ApiUsage {
260 input_tokens: prompt_tokens,
261 output_tokens: completion_tokens,
262 total_tokens,
263 });
264 }
265 }
266
267 result
268 }
269
270 fn parse_anthropic_chunk(&self, json: &Value) -> ChunkResult {
274 let mut result = ChunkResult {
275 content: None,
276 is_done: false,
277 usage: None,
278 };
279
280 let event_type = json.get("type").and_then(|t| t.as_str()).unwrap_or("");
281
282 match event_type {
283 "content_block_delta" => {
284 if let Some(delta) = json.get("delta") {
286 if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
287 result.content = Some(text.to_string());
288 }
289 }
290 }
291 "message_stop" => {
292 result.is_done = true;
293 }
294 "message_delta" => {
295 if let Some(usage) = json.get("usage") {
297 let output_tokens = usage
298 .get("output_tokens")
299 .and_then(|t| t.as_u64())
300 .unwrap_or(0);
301
302 if output_tokens > 0 {
303 result.usage = Some(ApiUsage {
304 input_tokens: 0, output_tokens,
306 total_tokens: output_tokens,
307 });
308 }
309 }
310 }
311 "message_start" => {
312 if let Some(message) = json.get("message") {
314 if let Some(usage) = message.get("usage") {
315 let input_tokens = usage
316 .get("input_tokens")
317 .and_then(|t| t.as_u64())
318 .unwrap_or(0);
319
320 if input_tokens > 0 {
321 result.usage = Some(ApiUsage {
322 input_tokens,
323 output_tokens: 0,
324 total_tokens: input_tokens,
325 });
326 }
327 }
328 }
329 }
330 _ => {}
331 }
332
333 result
334 }
335
336 pub fn is_completed(&self) -> bool {
338 self.completed
339 }
340
341 pub fn content(&self) -> &str {
343 &self.content_buffer
344 }
345
346 pub fn chunks_processed(&self) -> u32 {
348 self.chunks_processed
349 }
350
351 pub fn bytes_processed(&self) -> u64 {
353 self.bytes_processed
354 }
355
356 pub fn api_usage(&self) -> Option<&ApiUsage> {
358 self.api_usage.as_ref()
359 }
360
361 pub fn finalize(&self) -> StreamingTokenResult {
366 let manager = tiktoken_manager();
367
368 if let Some(usage) = &self.api_usage {
370 trace!(
371 input_tokens = usage.input_tokens,
372 output_tokens = usage.output_tokens,
373 total_tokens = usage.total_tokens,
374 chunks = self.chunks_processed,
375 "Using API-provided token counts for streaming response"
376 );
377
378 return StreamingTokenResult {
379 output_tokens: usage.output_tokens,
380 input_tokens: Some(usage.input_tokens),
381 total_tokens: Some(usage.total_tokens),
382 source: TokenCountSource::ApiProvided,
383 content_length: self.content_buffer.len(),
384 };
385 }
386
387 let output_tokens = manager.count_tokens(self.model.as_deref(), &self.content_buffer);
389
390 trace!(
391 output_tokens = output_tokens,
392 content_len = self.content_buffer.len(),
393 chunks = self.chunks_processed,
394 model = ?self.model,
395 "Counted tokens in streaming response content"
396 );
397
398 StreamingTokenResult {
399 output_tokens,
400 input_tokens: None,
401 total_tokens: None,
402 source: TokenCountSource::Tiktoken,
403 content_length: self.content_buffer.len(),
404 }
405 }
406}
407
408#[derive(Debug, Clone, Copy, PartialEq, Eq)]
410pub enum TokenCountSource {
411 ApiProvided,
413 Tiktoken,
415}
416
417#[derive(Debug)]
419pub struct StreamingTokenResult {
420 pub output_tokens: u64,
422 pub input_tokens: Option<u64>,
424 pub total_tokens: Option<u64>,
426 pub source: TokenCountSource,
428 pub content_length: usize,
430}
431
432pub fn is_sse_response(content_type: Option<&str>) -> bool {
434 content_type.map_or(false, |ct| {
435 ct.contains("text/event-stream") || ct.contains("application/x-ndjson")
436 })
437}
438
439#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_openai_streaming() {
449 let mut counter = StreamingTokenCounter::new(InferenceProvider::OpenAi, Some("gpt-4".to_string()));
450
451 let chunk1 = b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello\"}}]}\n\n";
453 let chunk2 = b"data: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n";
454 let chunk3 = b"data: {\"choices\":[{\"delta\":{},\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":2,\"total_tokens\":12}}\n\n";
455 let chunk4 = b"data: [DONE]\n\n";
456
457 let r1 = counter.process_chunk(chunk1);
458 assert_eq!(r1.content, Some("Hello".to_string()));
459 assert!(!r1.is_done);
460
461 let r2 = counter.process_chunk(chunk2);
462 assert_eq!(r2.content, Some(" world".to_string()));
463 assert!(!r2.is_done);
464
465 let r3 = counter.process_chunk(chunk3);
466 assert!(r3.is_done);
467 assert!(r3.usage.is_some());
468 let usage = r3.usage.unwrap();
469 assert_eq!(usage.input_tokens, 10);
470 assert_eq!(usage.output_tokens, 2);
471 assert_eq!(usage.total_tokens, 12);
472
473 let r4 = counter.process_chunk(chunk4);
474 assert!(r4.is_done);
475
476 assert_eq!(counter.content(), "Hello world");
477 assert!(counter.is_completed());
478
479 let result = counter.finalize();
480 assert_eq!(result.output_tokens, 2);
481 assert_eq!(result.input_tokens, Some(10));
482 assert_eq!(result.source, TokenCountSource::ApiProvided);
483 }
484
485 #[test]
486 fn test_anthropic_streaming() {
487 let mut counter =
488 StreamingTokenCounter::new(InferenceProvider::Anthropic, Some("claude-3-opus".to_string()));
489
490 let chunk1 = b"event: message_start\ndata: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":25}}}\n\n";
492 let chunk2 = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n";
493 let chunk3 = b"event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" there\"}}\n\n";
494 let chunk4 = b"event: message_delta\ndata: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n";
495 let chunk5 = b"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n";
496
497 counter.process_chunk(chunk1);
498 let r2 = counter.process_chunk(chunk2);
499 assert_eq!(r2.content, Some("Hello".to_string()));
500
501 let r3 = counter.process_chunk(chunk3);
502 assert_eq!(r3.content, Some(" there".to_string()));
503
504 let r4 = counter.process_chunk(chunk4);
505 assert!(r4.usage.is_some());
506 assert_eq!(r4.usage.unwrap().output_tokens, 3);
507
508 let r5 = counter.process_chunk(chunk5);
509 assert!(r5.is_done);
510
511 assert_eq!(counter.content(), "Hello there");
512 assert!(counter.is_completed());
513 }
514
515 #[test]
516 fn test_tiktoken_fallback() {
517 let mut counter = StreamingTokenCounter::new(InferenceProvider::OpenAi, Some("gpt-4".to_string()));
518
519 let chunk1 = b"data: {\"choices\":[{\"delta\":{\"content\":\"Hello world\"}}]}\n\n";
521 let chunk2 = b"data: [DONE]\n\n";
522
523 counter.process_chunk(chunk1);
524 counter.process_chunk(chunk2);
525
526 let result = counter.finalize();
527 assert_eq!(result.source, TokenCountSource::Tiktoken);
528 assert!(result.output_tokens > 0);
530 }
531
532 #[test]
533 fn test_split_chunks() {
534 let mut counter = StreamingTokenCounter::new(InferenceProvider::OpenAi, Some("gpt-4".to_string()));
535
536 let chunk1 = b"data: {\"choices\":[{\"delta\":{\"content\":\"He";
538 let chunk2 = b"llo\"}}]}\n\ndata: {\"choices\":[{\"delta\":{\"content\":\" world\"}}]}\n\n";
539
540 let r1 = counter.process_chunk(chunk1);
541 assert!(r1.content.is_none()); let r2 = counter.process_chunk(chunk2);
544 assert!(r2.content.is_some());
546 assert!(counter.content().contains("Hello"));
547 assert!(counter.content().contains(" world"));
548 }
549
550 #[test]
551 fn test_is_sse_response() {
552 assert!(is_sse_response(Some("text/event-stream")));
553 assert!(is_sse_response(Some("text/event-stream; charset=utf-8")));
554 assert!(is_sse_response(Some("application/x-ndjson")));
555 assert!(!is_sse_response(Some("application/json")));
556 assert!(!is_sse_response(None));
557 }
558
559 #[test]
560 fn test_generic_provider() {
561 let mut counter = StreamingTokenCounter::new(InferenceProvider::Generic, None);
562
563 let chunk = b"data: {\"choices\":[{\"delta\":{\"content\":\"Test\"}}]}\n\n";
565 let result = counter.process_chunk(chunk);
566 assert_eq!(result.content, Some("Test".to_string()));
567 }
568}