use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
pub enum CompressionFormat {
None,
Gzip,
Deflate,
Brotli,
}
impl CompressionFormat {
pub fn from_accept_encoding(header: &str) -> Vec<Self> {
let mut formats = Vec::new();
if header.contains("gzip") {
formats.push(CompressionFormat::Gzip);
}
if header.contains("deflate") {
formats.push(CompressionFormat::Deflate);
}
if header.contains("br") {
formats.push(CompressionFormat::Brotli);
}
if formats.is_empty() {
formats.push(CompressionFormat::None);
}
formats
}
pub fn header_value(&self) -> &'static str {
match self {
CompressionFormat::None => "",
CompressionFormat::Gzip => "gzip",
CompressionFormat::Deflate => "deflate",
CompressionFormat::Brotli => "br",
}
}
}
#[derive(Debug, Clone)]
pub struct SSEConfig {
pub event_type: String,
pub completion_event_type: String,
pub error_event_type: String,
pub keepalive_event_type: String,
}
impl Default for SSEConfig {
fn default() -> Self {
Self {
event_type: "token".to_string(),
completion_event_type: "complete".to_string(),
error_event_type: "error".to_string(),
keepalive_event_type: "heartbeat".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct SSEMessage {
pub event: String,
pub data: String,
pub id: Option<String>,
pub retry: Option<u32>,
}
impl SSEMessage {
pub fn new(event: String, data: String) -> Self {
Self {
event,
data,
id: None,
retry: None,
}
}
pub fn with_id(mut self, id: String) -> Self {
self.id = Some(id);
self
}
pub fn with_retry(mut self, retry_ms: u32) -> Self {
self.retry = Some(retry_ms);
self
}
pub fn to_sse_format(&self) -> String {
let mut output = format!("event: {}\n", self.event);
output.push_str(&format!("data: {}\n", self.data));
if let Some(id) = &self.id {
output.push_str(&format!("id: {}\n", id));
}
if let Some(retry) = self.retry {
output.push_str(&format!("retry: {}\n", retry));
}
output.push('\n');
output
}
}
#[derive(Debug)]
pub struct TokenBatcher {
batch_size: usize,
max_wait_ms: Duration,
buffer: Vec<String>,
last_flush: Instant,
}
impl TokenBatcher {
pub fn new(batch_size: usize, max_wait_ms: u64) -> Self {
Self {
batch_size,
max_wait_ms: Duration::from_millis(max_wait_ms),
buffer: Vec::with_capacity(batch_size),
last_flush: Instant::now(),
}
}
pub fn add_token(&mut self, token: String) {
self.buffer.push(token);
}
pub fn should_flush(&self) -> bool {
self.buffer.len() >= self.batch_size || self.last_flush.elapsed() > self.max_wait_ms
}
pub fn flush(&mut self) -> String {
let batched = self.buffer.join("");
self.buffer.clear();
self.last_flush = Instant::now();
batched
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct TimeoutManager {
inference_timeout: Duration,
token_timeout: Duration,
start_time: Instant,
last_token_time: Instant,
}
impl TimeoutManager {
pub fn new(inference_timeout_secs: u64, token_timeout_secs: u64) -> Self {
let now = Instant::now();
Self {
inference_timeout: Duration::from_secs(inference_timeout_secs),
token_timeout: Duration::from_secs(token_timeout_secs),
start_time: now,
last_token_time: now,
}
}
pub fn is_inference_timeout(&self) -> bool {
self.start_time.elapsed() > self.inference_timeout
}
pub fn is_token_timeout(&self) -> bool {
self.last_token_time.elapsed() > self.token_timeout
}
pub fn record_token(&mut self) {
self.last_token_time = Instant::now();
}
pub fn elapsed_secs(&self) -> u64 {
self.start_time.elapsed().as_secs()
}
pub fn time_since_last_token_ms(&self) -> u64 {
self.last_token_time.elapsed().as_millis() as u64
}
}
#[derive(Debug)]
pub struct KeepAlive {
interval: Duration,
last_sent: Instant,
count: u32,
}
impl KeepAlive {
pub fn new(interval_secs: u64) -> Self {
Self {
interval: Duration::from_secs(interval_secs),
last_sent: Instant::now(),
count: 0,
}
}
pub fn should_send_keepalive(&self) -> bool {
self.last_sent.elapsed() > self.interval
}
pub fn send_keepalive(&mut self) -> u32 {
self.last_sent = Instant::now();
self.count += 1;
self.count
}
pub fn count(&self) -> u32 {
self.count
}
pub fn reset(&mut self) {
self.last_sent = Instant::now();
}
}
#[derive(Debug, Clone)]
pub struct StreamingOptimizationConfig {
pub compression: CompressionFormat,
pub sse_config: SSEConfig,
pub batch_size: usize,
pub batch_max_wait_ms: u64,
pub inference_timeout_secs: u64,
pub token_timeout_secs: u64,
pub keepalive_interval_secs: u64,
pub tcp_nodelay: bool,
}
impl Default for StreamingOptimizationConfig {
fn default() -> Self {
Self {
compression: CompressionFormat::None,
sse_config: SSEConfig::default(),
batch_size: 3,
batch_max_wait_ms: 50,
inference_timeout_secs: 300,
token_timeout_secs: 30,
keepalive_interval_secs: 30,
tcp_nodelay: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compression_format_parsing() {
let formats = CompressionFormat::from_accept_encoding("gzip, deflate, br");
assert!(formats.contains(&CompressionFormat::Gzip));
assert!(formats.contains(&CompressionFormat::Deflate));
assert!(formats.contains(&CompressionFormat::Brotli));
}
#[test]
fn test_sse_message_formatting() {
let msg = SSEMessage::new("token".to_string(), "Hello".to_string())
.with_id("123".to_string())
.with_retry(1000);
let formatted = msg.to_sse_format();
assert!(formatted.contains("event: token"));
assert!(formatted.contains("data: Hello"));
assert!(formatted.contains("id: 123"));
assert!(formatted.contains("retry: 1000"));
}
#[test]
fn test_token_batcher() {
let mut batcher = TokenBatcher::new(3, 100);
assert!(!batcher.should_flush());
batcher.add_token("Hello".to_string());
batcher.add_token(" ".to_string());
assert!(!batcher.should_flush());
batcher.add_token("World".to_string());
assert!(batcher.should_flush());
let batched = batcher.flush();
assert_eq!(batched, "Hello World");
assert!(batcher.is_empty());
}
#[test]
fn test_token_batcher_timeout() {
let mut batcher = TokenBatcher::new(100, 50);
batcher.add_token("token1".to_string());
assert!(!batcher.should_flush());
std::thread::sleep(Duration::from_millis(100));
assert!(batcher.should_flush()); }
#[test]
fn test_timeout_manager() {
let tm = TimeoutManager::new(5, 1);
assert!(!tm.is_inference_timeout());
assert!(!tm.is_token_timeout());
std::thread::sleep(Duration::from_millis(1100));
assert!(tm.is_token_timeout());
}
#[test]
fn test_keepalive() {
let mut ka = KeepAlive::new(1);
assert!(!ka.should_send_keepalive());
std::thread::sleep(Duration::from_millis(1100));
assert!(ka.should_send_keepalive());
ka.send_keepalive();
assert_eq!(ka.count(), 1);
assert!(!ka.should_send_keepalive());
}
#[test]
fn test_default_config() {
let config = StreamingOptimizationConfig::default();
assert_eq!(config.batch_size, 3);
assert_eq!(config.inference_timeout_secs, 300);
assert!(config.tcp_nodelay);
}
}