use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
pub const DEFAULT_MIN_PER_CONNECTION: usize = 1_000_000;
pub const DEFAULT_TOTAL_LIMIT: usize = 50_000_000;
#[derive(Debug)]
pub struct GlobalBandwidthManager {
total_limit: usize,
connection_count: AtomicUsize,
min_per_connection: usize,
}
impl GlobalBandwidthManager {
pub fn new(total_limit: usize, min_per_connection: Option<usize>) -> Self {
Self {
total_limit,
connection_count: AtomicUsize::new(0),
min_per_connection: min_per_connection.unwrap_or(DEFAULT_MIN_PER_CONNECTION),
}
}
pub fn register_connection(&self) -> usize {
let count = self.connection_count.fetch_add(1, Ordering::AcqRel) + 1;
self.compute_per_connection_rate(count)
}
pub fn unregister_connection(&self) -> usize {
let result =
self.connection_count
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |count| {
if count > 0 {
Some(count - 1)
} else {
None }
});
match result {
Ok(prev_count) => {
let new_count = prev_count.saturating_sub(1).max(1);
self.compute_per_connection_rate(new_count)
}
Err(_) => {
tracing::warn!(
"unregister_connection called when count was already 0; \
possible double-unregister bug"
);
self.compute_per_connection_rate(1)
}
}
}
pub fn current_per_connection_rate(&self) -> usize {
let count = self.connection_count.load(Ordering::Acquire).max(1);
self.compute_per_connection_rate(count)
}
pub fn connection_count(&self) -> usize {
self.connection_count.load(Ordering::Acquire)
}
pub fn total_limit(&self) -> usize {
self.total_limit
}
pub fn min_per_connection(&self) -> usize {
self.min_per_connection
}
fn compute_per_connection_rate(&self, count: usize) -> usize {
let fair_share = self.total_limit / count.max(1);
fair_share.max(self.min_per_connection)
}
}
#[derive(Debug)]
pub struct ConnectionBandwidthHandle {
manager: Arc<GlobalBandwidthManager>,
}
impl ConnectionBandwidthHandle {
pub fn new(manager: Arc<GlobalBandwidthManager>) -> Self {
manager.register_connection();
Self { manager }
}
pub fn current_rate(&self) -> usize {
self.manager.current_per_connection_rate()
}
pub fn manager(&self) -> &GlobalBandwidthManager {
&self.manager
}
}
impl Drop for ConnectionBandwidthHandle {
fn drop(&mut self) {
self.manager.unregister_connection();
}
}
impl Clone for ConnectionBandwidthHandle {
fn clone(&self) -> Self {
Self::new(Arc::clone(&self.manager))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_connection() {
let manager = GlobalBandwidthManager::new(50_000_000, Some(1_000_000));
let rate = manager.register_connection();
assert_eq!(rate, 50_000_000); assert_eq!(manager.connection_count(), 1);
manager.unregister_connection();
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_multiple_connections_fair_share() {
let manager = GlobalBandwidthManager::new(50_000_000, Some(1_000_000));
let rate1 = manager.register_connection();
assert_eq!(rate1, 50_000_000);
let rate2 = manager.register_connection();
assert_eq!(rate2, 25_000_000);
assert_eq!(manager.current_per_connection_rate(), 25_000_000);
assert_eq!(manager.connection_count(), 2);
let rate3 = manager.register_connection();
assert_eq!(rate3, 16_666_666);
manager.unregister_connection();
assert_eq!(manager.current_per_connection_rate(), 25_000_000);
}
#[test]
fn test_minimum_enforcement() {
let manager = GlobalBandwidthManager::new(10_000_000, Some(5_000_000));
for _ in 0..10 {
manager.register_connection();
}
assert_eq!(manager.current_per_connection_rate(), 5_000_000);
}
#[test]
fn test_connection_handle_auto_unregister() {
let manager = Arc::new(GlobalBandwidthManager::new(50_000_000, None));
{
let _handle1 = ConnectionBandwidthHandle::new(Arc::clone(&manager));
assert_eq!(manager.connection_count(), 1);
{
let _handle2 = ConnectionBandwidthHandle::new(Arc::clone(&manager));
assert_eq!(manager.connection_count(), 2);
}
assert_eq!(manager.connection_count(), 1);
}
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_handle_current_rate() {
let manager = Arc::new(GlobalBandwidthManager::new(50_000_000, None));
let handle1 = ConnectionBandwidthHandle::new(Arc::clone(&manager));
assert_eq!(handle1.current_rate(), 50_000_000);
let handle2 = ConnectionBandwidthHandle::new(Arc::clone(&manager));
assert_eq!(handle1.current_rate(), 25_000_000);
assert_eq!(handle2.current_rate(), 25_000_000);
drop(handle2);
assert_eq!(handle1.current_rate(), 50_000_000);
}
#[test]
fn test_underflow_protection() {
let manager = GlobalBandwidthManager::new(50_000_000, None);
let rate = manager.unregister_connection();
assert!(rate > 0);
assert_eq!(manager.connection_count(), 0);
manager.unregister_connection();
manager.unregister_connection();
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_handle_clone_registers_new_connection() {
let manager = Arc::new(GlobalBandwidthManager::new(50_000_000, None));
let handle1 = ConnectionBandwidthHandle::new(Arc::clone(&manager));
assert_eq!(manager.connection_count(), 1);
let handle2 = handle1.clone();
assert_eq!(
manager.connection_count(),
2,
"Clone should register new connection"
);
assert_eq!(handle1.current_rate(), 25_000_000);
assert_eq!(handle2.current_rate(), 25_000_000);
drop(handle1);
assert_eq!(manager.connection_count(), 1);
assert_eq!(handle2.current_rate(), 50_000_000);
drop(handle2);
assert_eq!(manager.connection_count(), 0);
}
#[test]
fn test_concurrent_register_unregister() {
use std::thread;
let manager = Arc::new(GlobalBandwidthManager::new(50_000_000, Some(1_000_000)));
let mut handles = vec![];
for _ in 0..50 {
let mgr = Arc::clone(&manager);
let handle = thread::spawn(move || {
for _ in 0..100 {
let rate = mgr.register_connection();
assert!(rate > 0, "Rate should be positive");
std::hint::black_box(rate);
mgr.unregister_connection();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().expect("Thread should not panic");
}
assert_eq!(
manager.connection_count(),
0,
"All connections should be unregistered after concurrent operations"
);
}
#[test]
fn test_concurrent_rate_queries_during_churn() {
use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
use std::thread;
use std::time::Duration;
let manager = Arc::new(GlobalBandwidthManager::new(50_000_000, Some(1_000_000)));
let running = Arc::new(AtomicBool::new(true));
let mut handles = vec![];
let mgr_clone = Arc::clone(&manager);
let running_clone = Arc::clone(&running);
let spawner = thread::spawn(move || {
while running_clone.load(AtomicOrdering::Acquire) {
mgr_clone.register_connection();
thread::sleep(Duration::from_micros(10));
mgr_clone.unregister_connection();
}
});
for _ in 0..10 {
let mgr = Arc::clone(&manager);
let running_clone = Arc::clone(&running);
let handle = thread::spawn(move || {
let mut iterations = 0;
while running_clone.load(AtomicOrdering::Acquire) && iterations < 1000 {
let rate = mgr.current_per_connection_rate();
assert!(rate > 0, "Rate should always be positive");
iterations += 1;
}
});
handles.push(handle);
}
thread::sleep(Duration::from_millis(50));
running.store(false, AtomicOrdering::Release);
for handle in handles {
handle.join().expect("Query thread should not panic");
}
spawner.join().expect("Spawner thread should not panic");
}
#[test]
fn test_large_connection_count() {
let manager = GlobalBandwidthManager::new(50_000_000, Some(1_000));
for _ in 0..10_000 {
manager.register_connection();
}
assert_eq!(manager.connection_count(), 10_000);
let rate = manager.current_per_connection_rate();
assert_eq!(rate, 5_000, "Fair share should be 50M / 10K = 5KB");
let manager2 = GlobalBandwidthManager::new(50_000_000, Some(10_000));
for _ in 0..10_000 {
manager2.register_connection();
}
assert_eq!(
manager2.current_per_connection_rate(),
10_000,
"Minimum should be enforced"
);
}
#[test]
fn test_zero_total_limit() {
let manager = GlobalBandwidthManager::new(0, Some(1_000_000));
let rate = manager.register_connection();
assert_eq!(rate, 1_000_000);
}
}