1use super::m2m::M2MFrame;
29use super::token_native::TokenNativeCodec;
30use super::CompressionResult;
31use crate::codec::tables::{
32 KEY_ABBREV, KEY_EXPAND, MODEL_ABBREV, MODEL_EXPAND, ROLE_ABBREV, ROLE_EXPAND,
33};
34use crate::error::{M2MError, Result};
35use crate::models::Encoding;
36use bytes::Bytes;
37use serde_json::Value;
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
41pub enum StreamingMode {
42 #[default]
44 Abbreviation,
45 TokenNative,
47 Hybrid,
49 Passthrough,
51}
52
53#[derive(Debug, Clone, PartialEq)]
55pub enum SseEvent {
56 Data(Value),
58 Done,
60 Comment(String),
62 Error(String),
64}
65
66#[derive(Debug)]
71pub struct StreamingCodec {
72 accumulated_content: String,
74 chunks_processed: usize,
76 bytes_in: usize,
78 bytes_out: usize,
80 mode: StreamingMode,
82 token_native: TokenNativeCodec,
84}
85
86impl Default for StreamingCodec {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl StreamingCodec {
93 pub fn new() -> Self {
95 Self {
96 accumulated_content: String::new(),
97 chunks_processed: 0,
98 bytes_in: 0,
99 bytes_out: 0,
100 mode: StreamingMode::Abbreviation,
101 token_native: TokenNativeCodec::default(),
102 }
103 }
104
105 pub fn with_mode(mode: StreamingMode) -> Self {
107 Self {
108 mode,
109 ..Self::new()
110 }
111 }
112
113 pub fn token_native(encoding: Encoding) -> Self {
115 Self {
116 mode: StreamingMode::TokenNative,
117 token_native: TokenNativeCodec::new(encoding),
118 ..Self::new()
119 }
120 }
121
122 pub fn hybrid(encoding: Encoding) -> Self {
124 Self {
125 mode: StreamingMode::Hybrid,
126 token_native: TokenNativeCodec::new(encoding),
127 ..Self::new()
128 }
129 }
130
131 pub fn passthrough() -> Self {
133 Self {
134 mode: StreamingMode::Passthrough,
135 ..Self::new()
136 }
137 }
138
139 pub fn mode(&self) -> StreamingMode {
141 self.mode
142 }
143
144 pub fn parse_sse_line(&self, line: &str) -> Option<SseEvent> {
146 let line = line.trim();
147
148 if line.is_empty() {
149 return None;
150 }
151
152 if line.starts_with(':') {
153 return Some(SseEvent::Comment(line[1..].trim().to_string()));
154 }
155
156 if let Some(data) = line.strip_prefix("data: ") {
157 if data == "[DONE]" {
158 return Some(SseEvent::Done);
159 }
160
161 match serde_json::from_str(data) {
162 Ok(json) => Some(SseEvent::Data(json)),
163 Err(_) => Some(SseEvent::Error(format!("Invalid JSON: {}", data))),
164 }
165 } else if let Some(error) = line.strip_prefix("error: ") {
166 Some(SseEvent::Error(error.to_string()))
167 } else {
168 None
169 }
170 }
171
172 pub fn process_chunk(&mut self, chunk: &[u8]) -> Result<Vec<Bytes>> {
174 let text = std::str::from_utf8(chunk)
175 .map_err(|e| M2MError::Compression(format!("Invalid UTF-8: {}", e)))?;
176
177 self.bytes_in += chunk.len();
178
179 let mut outputs = Vec::new();
180
181 for line in text.lines() {
182 if let Some(event) = self.parse_sse_line(line) {
183 let output = self.process_event(event)?;
184 if let Some(bytes) = output {
185 self.bytes_out += bytes.len();
186 outputs.push(bytes);
187 }
188 }
189 }
190
191 self.chunks_processed += 1;
192 Ok(outputs)
193 }
194
195 fn process_event(&mut self, event: SseEvent) -> Result<Option<Bytes>> {
197 match event {
198 SseEvent::Data(json) => {
199 if let Some(content) = self.extract_delta_content(&json) {
201 self.accumulated_content.push_str(&content);
202 }
203
204 match self.mode {
205 StreamingMode::Passthrough => {
206 Ok(Some(Bytes::from(format!(
208 "data: {}\n\n",
209 serde_json::to_string(&json).unwrap_or_default()
210 ))))
211 },
212 StreamingMode::Abbreviation | StreamingMode::Hybrid => {
213 let compressed = self.compress_sse_json(&json)?;
215 Ok(Some(Bytes::from(format!("data: {}\n\n", compressed))))
216 },
217 StreamingMode::TokenNative => {
218 let json_str = serde_json::to_string(&json)
220 .map_err(|e| M2MError::Compression(e.to_string()))?;
221 let result = self.token_native.compress(&json_str)?;
222 Ok(Some(Bytes::from(format!("data: {}\n\n", result.data))))
223 },
224 }
225 },
226 SseEvent::Done => Ok(Some(Bytes::from_static(b"data: [DONE]\n\n"))),
227 SseEvent::Comment(c) => Ok(Some(Bytes::from(format!(": {}\n", c)))),
228 SseEvent::Error(e) => Ok(Some(Bytes::from(format!("error: {}\n\n", e)))),
229 }
230 }
231
232 fn extract_delta_content(&self, json: &Value) -> Option<String> {
234 json.get("choices")
236 .or_else(|| json.get("C"))?
237 .get(0)?
238 .get("delta")
239 .or_else(|| json.get("D"))?
240 .get("content")
241 .or_else(|| json.get("c"))?
242 .as_str()
243 .map(String::from)
244 }
245
246 fn compress_sse_json(&self, json: &Value) -> Result<String> {
248 let compressed = self.abbreviate_keys(json);
249 serde_json::to_string(&compressed)
250 .map_err(|e| M2MError::Compression(format!("JSON serialization failed: {}", e)))
251 }
252
253 fn abbreviate_keys(&self, value: &Value) -> Value {
255 match value {
256 Value::Object(map) => {
257 let mut new_map = serde_json::Map::new();
258 for (key, val) in map {
259 let key_str = key.as_str();
260 let new_key = KEY_ABBREV.get(key_str).copied().unwrap_or(key_str);
261 let new_val = self.abbreviate_keys(val);
262
263 let new_val = if key == "role" {
265 if let Value::String(role) = &new_val {
266 if let Some(abbrev) = ROLE_ABBREV.get(role.as_str()) {
267 Value::String((*abbrev).to_string())
268 } else {
269 new_val
270 }
271 } else {
272 new_val
273 }
274 } else if key == "model" {
276 if let Value::String(model) = &new_val {
277 if let Some(abbrev) = MODEL_ABBREV.get(model.as_str()) {
278 Value::String((*abbrev).to_string())
279 } else {
280 new_val
281 }
282 } else {
283 new_val
284 }
285 } else {
286 new_val
287 };
288
289 new_map.insert(new_key.to_string(), new_val);
290 }
291 Value::Object(new_map)
292 },
293 Value::Array(arr) => {
294 Value::Array(arr.iter().map(|v| self.abbreviate_keys(v)).collect())
295 },
296 _ => value.clone(),
297 }
298 }
299
300 pub fn accumulated_content(&self) -> &str {
302 &self.accumulated_content
303 }
304
305 pub fn finalize_token_native(&self) -> Result<CompressionResult> {
312 if self.accumulated_content.is_empty() {
313 return Err(M2MError::Compression(
314 "No content accumulated to finalize".to_string(),
315 ));
316 }
317 self.token_native.compress(&self.accumulated_content)
318 }
319
320 pub fn finalize_raw(&self) -> Vec<u8> {
324 self.token_native.compress_raw(&self.accumulated_content)
325 }
326
327 pub fn finalize_m2m(&self, response_json: &str) -> Result<String> {
343 let frame = M2MFrame::new_response(response_json)?;
344 frame.encode_string()
345 }
346
347 pub fn finalize_m2m_binary(&self, response_json: &str) -> Result<Vec<u8>> {
352 let frame = M2MFrame::new_response(response_json)?;
353 frame.encode()
354 }
355
356 pub fn stats(&self) -> StreamingStats {
358 StreamingStats {
359 chunks_processed: self.chunks_processed,
360 bytes_in: self.bytes_in,
361 bytes_out: self.bytes_out,
362 compression_ratio: if self.bytes_out > 0 {
365 self.bytes_in as f64 / self.bytes_out as f64
366 } else {
367 1.0
368 },
369 accumulated_length: self.accumulated_content.len(),
370 }
371 }
372
373 pub fn reset(&mut self) {
375 self.accumulated_content.clear();
376 self.chunks_processed = 0;
377 self.bytes_in = 0;
378 self.bytes_out = 0;
379 }
380}
381
382#[derive(Debug, Clone)]
384pub struct StreamingStats {
385 pub chunks_processed: usize,
387 pub bytes_in: usize,
389 pub bytes_out: usize,
391 pub compression_ratio: f64,
393 pub accumulated_length: usize,
395}
396
397#[derive(Debug)]
399pub struct StreamingDecompressor {
400 accumulated_content: String,
402 token_native: TokenNativeCodec,
404}
405
406impl Default for StreamingDecompressor {
407 fn default() -> Self {
408 Self::new()
409 }
410}
411
412impl StreamingDecompressor {
413 pub fn new() -> Self {
415 Self {
416 accumulated_content: String::new(),
417 token_native: TokenNativeCodec::default(),
418 }
419 }
420
421 pub fn with_encoding(encoding: Encoding) -> Self {
423 Self {
424 accumulated_content: String::new(),
425 token_native: TokenNativeCodec::new(encoding),
426 }
427 }
428
429 pub fn decompress_chunk(&mut self, chunk: &[u8]) -> Result<Bytes> {
431 let text = std::str::from_utf8(chunk)
432 .map_err(|e| M2MError::Decompression(format!("Invalid UTF-8: {}", e)))?;
433
434 let mut output = String::new();
435
436 for line in text.lines() {
437 if let Some(data) = line.strip_prefix("data: ") {
438 if data == "[DONE]" {
439 output.push_str("data: [DONE]\n\n");
440 } else if data.starts_with("#TK|") {
441 let decompressed = self.token_native.decompress(data)?;
443 if let Ok(json) = serde_json::from_str::<Value>(&decompressed) {
444 if let Some(content) = self.extract_delta_content(&json) {
446 self.accumulated_content.push_str(&content);
447 }
448 output.push_str(&format!("data: {}\n\n", decompressed));
449 } else {
450 output.push_str(&format!("data: {}\n\n", decompressed));
451 }
452 } else if let Ok(json) = serde_json::from_str::<Value>(data) {
453 let expanded = self.expand_keys(&json);
455
456 if let Some(content) = self.extract_delta_content(&expanded) {
458 self.accumulated_content.push_str(&content);
459 }
460
461 output.push_str(&format!(
462 "data: {}\n\n",
463 serde_json::to_string(&expanded).unwrap_or_default()
464 ));
465 } else {
466 output.push_str(line);
468 output.push_str("\n\n");
469 }
470 } else if !line.is_empty() {
471 output.push_str(line);
472 output.push('\n');
473 }
474 }
475
476 Ok(Bytes::from(output))
477 }
478
479 fn expand_keys(&self, value: &Value) -> Value {
481 match value {
482 Value::Object(map) => {
483 let mut new_map = serde_json::Map::new();
484 for (key, val) in map {
485 let key_str = key.as_str();
486 let new_key = KEY_EXPAND.get(key_str).copied().unwrap_or(key_str);
487 let new_val = self.expand_keys(val);
488
489 let new_val = if new_key == "role" {
491 if let Value::String(role) = &new_val {
492 if let Some(expanded) = ROLE_EXPAND.get(role.as_str()) {
493 Value::String((*expanded).to_string())
494 } else {
495 new_val
496 }
497 } else {
498 new_val
499 }
500 } else if new_key == "model" {
502 if let Value::String(model) = &new_val {
503 if let Some(expanded) = MODEL_EXPAND.get(model.as_str()) {
504 Value::String((*expanded).to_string())
505 } else {
506 new_val
507 }
508 } else {
509 new_val
510 }
511 } else {
512 new_val
513 };
514
515 new_map.insert(new_key.to_string(), new_val);
516 }
517 Value::Object(new_map)
518 },
519 Value::Array(arr) => Value::Array(arr.iter().map(|v| self.expand_keys(v)).collect()),
520 _ => value.clone(),
521 }
522 }
523
524 fn extract_delta_content(&self, json: &Value) -> Option<String> {
526 json.get("choices")
528 .or_else(|| json.get("C"))?
529 .get(0)?
530 .get("delta")
531 .or_else(|| json.get("D"))?
532 .get("content")
533 .or_else(|| json.get("c"))?
534 .as_str()
535 .map(String::from)
536 }
537
538 pub fn accumulated_content(&self) -> &str {
540 &self.accumulated_content
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547 use crate::models::Encoding;
548
549 #[test]
550 fn test_parse_sse_data() {
551 let codec = StreamingCodec::new();
552
553 let line = r#"data: {"id":"123","choices":[{"delta":{"content":"Hi"}}]}"#;
554 let event = codec.parse_sse_line(line);
555
556 assert!(matches!(event, Some(SseEvent::Data(_))));
557 }
558
559 #[test]
560 fn test_parse_sse_done() {
561 let codec = StreamingCodec::new();
562
563 let event = codec.parse_sse_line("data: [DONE]");
564 assert_eq!(event, Some(SseEvent::Done));
565 }
566
567 #[test]
568 fn test_parse_sse_comment() {
569 let codec = StreamingCodec::new();
570
571 let event = codec.parse_sse_line(": keep-alive");
572 assert_eq!(event, Some(SseEvent::Comment("keep-alive".to_string())));
573 }
574
575 #[test]
576 fn test_compress_sse_chunk() {
577 let mut codec = StreamingCodec::new();
578
579 let chunk = br#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"}}]}
580
581"#;
582
583 let outputs = codec.process_chunk(chunk).unwrap();
584 assert_eq!(outputs.len(), 1);
585
586 let output = std::str::from_utf8(&outputs[0]).unwrap();
587 assert!(output.starts_with("data: "));
588 assert!(output.contains("\"C\":")); assert!(output.contains("\"D\":")); }
593
594 #[test]
595 fn test_accumulate_content() {
596 let mut codec = StreamingCodec::new();
597
598 let chunks = vec![
599 br#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#.as_slice(),
600 br#"data: {"choices":[{"delta":{"content":" world"}}]}"#.as_slice(),
601 br#"data: {"choices":[{"delta":{"content":"!"}}]}"#.as_slice(),
602 ];
603
604 for chunk in chunks {
605 codec.process_chunk(chunk).unwrap();
606 }
607
608 assert_eq!(codec.accumulated_content(), "Hello world!");
609 }
610
611 #[test]
612 fn test_streaming_stats() {
613 let mut codec = StreamingCodec::new();
614
615 let chunk = br#"data: {"id":"123","choices":[{"delta":{"content":"Test"}}]}"#;
616 codec.process_chunk(chunk).unwrap();
617
618 let stats = codec.stats();
619 assert_eq!(stats.chunks_processed, 1);
620 assert!(stats.bytes_in > 0);
621 assert!(stats.bytes_out > 0);
622 assert!(
624 stats.compression_ratio > 1.0,
625 "Expected compression ratio > 1.0, got {}",
626 stats.compression_ratio
627 );
628 }
629
630 #[test]
631 fn test_decompress_chunk() {
632 let mut decompressor = StreamingDecompressor::new();
633
634 let chunk = br#"data: {"id":"123","C":[{"D":{"c":"Hello"}}]}"#;
637 let output = decompressor.decompress_chunk(chunk).unwrap();
638
639 let text = std::str::from_utf8(&output).unwrap();
640 assert!(text.contains("\"id\":")); assert!(text.contains("\"choices\":")); assert!(text.contains("\"delta\":")); assert!(text.contains("\"content\":")); }
645
646 #[test]
647 fn test_roundtrip() {
648 let mut codec = StreamingCodec::new();
649 let mut decompressor = StreamingDecompressor::new();
650
651 let original = br#"data: {"id":"chatcmpl-123","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello world"}}]}
652
653"#;
654
655 let compressed = codec.process_chunk(original).unwrap();
657 assert!(!compressed.is_empty());
658
659 let decompressed = decompressor.decompress_chunk(&compressed[0]).unwrap();
661
662 let orig_text = std::str::from_utf8(original).unwrap();
664 let decomp_text = std::str::from_utf8(&decompressed).unwrap();
665
666 let orig_json: Value =
668 serde_json::from_str(orig_text.strip_prefix("data: ").unwrap().trim()).unwrap();
669 let decomp_json: Value =
670 serde_json::from_str(decomp_text.strip_prefix("data: ").unwrap().trim()).unwrap();
671
672 assert_eq!(
674 orig_json["choices"][0]["delta"]["content"],
675 decomp_json["choices"][0]["delta"]["content"]
676 );
677 }
678
679 #[test]
680 fn test_passthrough_mode() {
681 let mut codec = StreamingCodec::passthrough();
682
683 let chunk = br#"data: {"id":"123","choices":[{"delta":{"content":"Test"}}]}"#;
684 let outputs = codec.process_chunk(chunk).unwrap();
685
686 let output = std::str::from_utf8(&outputs[0]).unwrap();
687 assert!(output.contains("\"id\":")); assert!(output.contains("\"choices\":")); }
691
692 #[test]
693 fn test_token_native_mode() {
694 let mut codec = StreamingCodec::token_native(Encoding::Cl100kBase);
695
696 let chunk = br#"data: {"id":"123","choices":[{"delta":{"content":"Hello"}}]}"#;
697 let outputs = codec.process_chunk(chunk).unwrap();
698
699 let output = std::str::from_utf8(&outputs[0]).unwrap();
700 assert!(
702 output.contains("#TK|C|"),
703 "Expected TokenNative format, got: {}",
704 output
705 );
706 }
707
708 #[test]
709 fn test_hybrid_mode_finalize() {
710 let mut codec = StreamingCodec::hybrid(Encoding::Cl100kBase);
711
712 let chunks = vec![
713 br#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#.as_slice(),
714 br#"data: {"choices":[{"delta":{"content":" world"}}]}"#.as_slice(),
715 br#"data: {"choices":[{"delta":{"content":"!"}}]}"#.as_slice(),
716 ];
717
718 for chunk in chunks {
719 codec.process_chunk(chunk).unwrap();
720 }
721
722 let result = codec.finalize_token_native().unwrap();
724 assert!(result.data.starts_with("#TK|"));
725
726 let decompressed = TokenNativeCodec::cl100k().decompress(&result.data).unwrap();
728 assert_eq!(decompressed, "Hello world!");
729 }
730
731 #[test]
732 fn test_streaming_mode_selection() {
733 let abbrev = StreamingCodec::new();
734 assert_eq!(abbrev.mode(), StreamingMode::Abbreviation);
735
736 let native = StreamingCodec::token_native(Encoding::Cl100kBase);
737 assert_eq!(native.mode(), StreamingMode::TokenNative);
738
739 let hybrid = StreamingCodec::hybrid(Encoding::O200kBase);
740 assert_eq!(hybrid.mode(), StreamingMode::Hybrid);
741
742 let passthrough = StreamingCodec::passthrough();
743 assert_eq!(passthrough.mode(), StreamingMode::Passthrough);
744 }
745
746 #[test]
747 fn test_decompress_token_native_chunk() {
748 let mut codec = StreamingCodec::token_native(Encoding::Cl100kBase);
750 let chunk = br#"data: {"id":"123","choices":[{"delta":{"content":"Test"}}]}"#;
751 let outputs = codec.process_chunk(chunk).unwrap();
752
753 let mut decompressor = StreamingDecompressor::new();
755 let decompressed = decompressor.decompress_chunk(&outputs[0]).unwrap();
756
757 let text = std::str::from_utf8(&decompressed).unwrap();
758 assert!(
759 text.contains("\"choices\":"),
760 "Expected expanded JSON, got: {}",
761 text
762 );
763 }
764
765 #[test]
766 fn test_finalize_m2m() {
767 use crate::codec::m2m::{M2MCodec, M2M_PREFIX};
768
769 let mut codec = StreamingCodec::new();
770
771 let chunks = vec![
773 br#"data: {"choices":[{"delta":{"content":"Hello"}}]}"#.as_slice(),
774 br#"data: {"choices":[{"delta":{"content":" world"}}]}"#.as_slice(),
775 br#"data: {"choices":[{"delta":{"content":"!"}}]}"#.as_slice(),
776 ];
777
778 for chunk in chunks {
779 codec.process_chunk(chunk).unwrap();
780 }
781
782 assert_eq!(codec.accumulated_content(), "Hello world!");
783
784 let response_json = r#"{"id":"chatcmpl-123","object":"chat.completion","model":"gpt-4o","choices":[{"index":0,"message":{"role":"assistant","content":"Hello world!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":3,"total_tokens":13}}"#;
786
787 let m2m_encoded = codec.finalize_m2m(response_json).unwrap();
789 assert!(m2m_encoded.starts_with(M2M_PREFIX));
790
791 let m2m_codec = M2MCodec::new();
793 let decoded = m2m_codec.decode_string(&m2m_encoded).unwrap();
794 assert_eq!(decoded, response_json);
795 }
796
797 #[test]
798 fn test_finalize_m2m_binary() {
799 use crate::codec::m2m::{M2MCodec, M2M_PREFIX};
800
801 let codec = StreamingCodec::new();
802
803 let response_json =
804 r#"{"id":"chatcmpl-456","model":"gpt-4o","choices":[{"message":{"content":"Test"}}]}"#;
805
806 let m2m_binary = codec.finalize_m2m_binary(response_json).unwrap();
808 assert!(m2m_binary.starts_with(M2M_PREFIX.as_bytes()));
809
810 let m2m_codec = M2MCodec::new();
812 let decoded = m2m_codec.decode(&m2m_binary).unwrap();
813 assert_eq!(decoded, response_json);
814 }
815}