use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
pub enum BackpressureLevel {
Healthy,
Moderate,
Critical,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlowControlConfig {
pub max_pending_messages: u32,
pub moderate_threshold_percent: u32,
pub critical_threshold_percent: u32,
pub max_unacked_tokens: u32,
pub ack_timeout_secs: u64,
pub inference_timeout_secs: u64,
pub keepalive_interval_secs: u64,
}
impl Default for FlowControlConfig {
fn default() -> Self {
Self {
max_pending_messages: 1000,
moderate_threshold_percent: 70,
critical_threshold_percent: 90,
max_unacked_tokens: 10_000,
ack_timeout_secs: 30,
inference_timeout_secs: 300,
keepalive_interval_secs: 30,
}
}
}
#[derive(Debug, Clone)]
pub struct StreamFlowControl {
config: FlowControlConfig,
pending_messages: Arc<AtomicU32>,
unacked_tokens: Arc<AtomicU32>,
last_ack_timestamp: Arc<std::sync::Mutex<Instant>>,
stream_start_time: Instant,
}
impl StreamFlowControl {
pub fn new(config: FlowControlConfig) -> Self {
Self {
config,
pending_messages: Arc::new(AtomicU32::new(0)),
unacked_tokens: Arc::new(AtomicU32::new(0)),
last_ack_timestamp: Arc::new(std::sync::Mutex::new(Instant::now())),
stream_start_time: Instant::now(),
}
}
pub fn check_backpressure(&self) -> BackpressureLevel {
let pending = self.pending_messages.load(Ordering::Relaxed);
let max = self.config.max_pending_messages;
let critical_threshold =
(max as f32 * self.config.critical_threshold_percent as f32 / 100.0) as u32;
let moderate_threshold =
(max as f32 * self.config.moderate_threshold_percent as f32 / 100.0) as u32;
if pending >= critical_threshold {
BackpressureLevel::Critical
} else if pending >= moderate_threshold {
BackpressureLevel::Moderate
} else {
BackpressureLevel::Healthy
}
}
pub fn add_message(&self) -> Result<(), String> {
let pending = self.pending_messages.fetch_add(1, Ordering::SeqCst);
if pending >= self.config.max_pending_messages {
self.pending_messages.fetch_sub(1, Ordering::SeqCst);
return Err("Buffer full, backpressure triggered".to_string());
}
Ok(())
}
pub fn message_sent(&self) {
self.pending_messages.fetch_sub(1, Ordering::SeqCst);
}
pub fn add_tokens(&self, count: u32) -> Result<(), String> {
let total = self.unacked_tokens.fetch_add(count, Ordering::SeqCst) + count;
if total > self.config.max_unacked_tokens {
self.unacked_tokens.fetch_sub(count, Ordering::SeqCst);
return Err("Token limit exceeded, backpressure triggered".to_string());
}
Ok(())
}
pub fn ack_tokens(&self, count: u32) {
self.unacked_tokens.fetch_sub(
count.min(self.unacked_tokens.load(Ordering::Relaxed)),
Ordering::SeqCst,
);
if let Ok(mut last_ack) = self.last_ack_timestamp.lock() {
*last_ack = Instant::now();
}
}
pub fn is_ack_timeout(&self) -> bool {
if let Ok(last_ack) = self.last_ack_timestamp.lock() {
last_ack.elapsed() > Duration::from_secs(self.config.ack_timeout_secs)
} else {
false
}
}
pub fn is_inference_timeout(&self) -> bool {
self.stream_start_time.elapsed() > Duration::from_secs(self.config.inference_timeout_secs)
}
pub fn buffer_utilization_percent(&self) -> u32 {
let pending = self.pending_messages.load(Ordering::Relaxed);
if pending == 0 {
return 0;
}
let percent =
((pending as f32 / self.config.max_pending_messages as f32) * 100.0).round() as u32;
percent.max(1) }
pub fn unacked_token_count(&self) -> u32 {
self.unacked_tokens.load(Ordering::Relaxed)
}
pub fn elapsed_secs(&self) -> u64 {
self.stream_start_time.elapsed().as_secs()
}
}
#[derive(Debug, Clone)]
pub struct ConnectionPool {
max_connections: u32,
active_connections: Arc<AtomicU32>,
global_pending_messages: Arc<AtomicU32>,
}
impl ConnectionPool {
pub fn new(max_connections: u32) -> Self {
Self {
max_connections,
active_connections: Arc::new(AtomicU32::new(0)),
global_pending_messages: Arc::new(AtomicU32::new(0)),
}
}
pub fn acquire_connection(&self) -> Result<ConnectionGuard, String> {
let active = self.active_connections.fetch_add(1, Ordering::SeqCst);
if active >= self.max_connections {
self.active_connections.fetch_sub(1, Ordering::SeqCst);
return Err("Connection pool full".to_string());
}
Ok(ConnectionGuard {
active_connections: self.active_connections.clone(),
})
}
pub fn active_count(&self) -> u32 {
self.active_connections.load(Ordering::Relaxed)
}
pub fn utilization_percent(&self) -> u32 {
let active = self.active_connections.load(Ordering::Relaxed);
((active as f32 / self.max_connections as f32) * 100.0 + 0.5) as u32
}
pub fn pending_messages(&self) -> u32 {
self.global_pending_messages.load(Ordering::Relaxed)
}
pub fn add_pending(&self) {
self.global_pending_messages.fetch_add(1, Ordering::SeqCst);
}
pub fn remove_pending(&self) {
self.global_pending_messages.fetch_sub(1, Ordering::SeqCst);
}
}
pub struct ConnectionGuard {
active_connections: Arc<AtomicU32>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
self.active_connections.fetch_sub(1, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flow_control_creation() {
let config = FlowControlConfig::default();
let fc = StreamFlowControl::new(config);
assert_eq!(fc.check_backpressure(), BackpressureLevel::Healthy);
assert_eq!(fc.buffer_utilization_percent(), 0);
}
#[test]
fn test_backpressure_levels() {
let config = FlowControlConfig {
max_pending_messages: 100,
moderate_threshold_percent: 70,
critical_threshold_percent: 90,
..Default::default()
};
let fc = StreamFlowControl::new(config);
for _ in 0..60 {
let _ = fc.add_message();
}
assert_eq!(fc.check_backpressure(), BackpressureLevel::Healthy);
for _ in 0..15 {
let _ = fc.add_message();
}
assert_eq!(fc.check_backpressure(), BackpressureLevel::Moderate);
for _ in 0..15 {
let _ = fc.add_message();
}
assert_eq!(fc.check_backpressure(), BackpressureLevel::Critical);
}
#[test]
fn test_buffer_overflow() {
let config = FlowControlConfig {
max_pending_messages: 10,
..Default::default()
};
let fc = StreamFlowControl::new(config);
for _ in 0..10 {
assert!(fc.add_message().is_ok());
}
assert!(fc.add_message().is_err());
}
#[test]
fn test_token_management() {
let config = FlowControlConfig {
max_unacked_tokens: 1000,
..Default::default()
};
let fc = StreamFlowControl::new(config);
assert!(fc.add_tokens(500).is_ok());
assert_eq!(fc.unacked_token_count(), 500);
fc.ack_tokens(200);
assert_eq!(fc.unacked_token_count(), 300);
assert!(fc.add_tokens(700).is_ok());
assert_eq!(fc.unacked_token_count(), 1000);
assert!(fc.add_tokens(1).is_err());
}
#[test]
fn test_connection_pool() {
let pool = ConnectionPool::new(3);
let conn1 = pool.acquire_connection().unwrap();
assert_eq!(pool.active_count(), 1);
let conn2 = pool.acquire_connection().unwrap();
assert_eq!(pool.active_count(), 2);
let conn3 = pool.acquire_connection().unwrap();
assert_eq!(pool.active_count(), 3);
assert!(pool.acquire_connection().is_err());
drop(conn1);
assert_eq!(pool.active_count(), 2);
let _conn4 = pool.acquire_connection().unwrap();
assert_eq!(pool.active_count(), 3);
}
#[test]
fn test_timeout_detection() {
let config = FlowControlConfig {
ack_timeout_secs: 1,
..Default::default()
};
let fc = StreamFlowControl::new(config);
assert!(!fc.is_ack_timeout());
std::thread::sleep(Duration::from_millis(1100));
assert!(fc.is_ack_timeout());
fc.ack_tokens(0); assert!(!fc.is_ack_timeout());
}
#[test]
fn test_utilization_percent() {
let pool = ConnectionPool::new(100);
assert_eq!(pool.utilization_percent(), 0);
let _conn1 = pool.acquire_connection().unwrap();
assert_eq!(pool.utilization_percent(), 1);
let _conns: Vec<_> = (0..49)
.map(|_| pool.acquire_connection().unwrap())
.collect();
assert_eq!(pool.utilization_percent(), 50);
}
}