use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Semaphore;
#[derive(Debug, Clone)]
pub struct ConnectionTracker {
active: Arc<AtomicUsize>,
}
impl ConnectionTracker {
pub fn new() -> Self {
Self {
active: Arc::new(AtomicUsize::new(0)),
}
}
pub fn track(&self) -> impl Drop + use<> {
self.active.fetch_add(1, Ordering::AcqRel);
ConnectionGuard {
active: self.active.clone(),
}
}
pub fn count(&self) -> usize {
self.active.load(Ordering::Acquire)
}
pub async fn wait_for_shutdown(&self, timeout: Duration) -> bool {
let start = std::time::Instant::now();
while self.count() > 0 {
if start.elapsed() >= timeout {
return false;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
true
}
}
impl Default for ConnectionTracker {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct ConnectionGuard {
active: Arc<AtomicUsize>,
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
let prev = self
.active
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |val| {
Some(val.saturating_sub(1))
});
debug_assert!(prev.is_ok_and(|v| v > 0), "Connection count underflow");
}
}
#[derive(Debug, Clone)]
pub struct ConnectionLimiter {
semaphore: Option<Arc<Semaphore>>,
max_connections: usize,
}
impl ConnectionLimiter {
pub fn new(max_connections: usize) -> Self {
let semaphore = if max_connections > 0 {
Some(Arc::new(Semaphore::new(max_connections)))
} else {
None
};
Self {
semaphore,
max_connections,
}
}
pub fn is_enabled(&self) -> bool {
self.semaphore.is_some()
}
pub fn max_connections(&self) -> usize {
self.max_connections
}
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
self.semaphore
.as_ref()
.and_then(|sem| sem.clone().try_acquire_owned().ok())
}
pub fn at_capacity(&self) -> bool {
self.semaphore
.as_ref()
.is_some_and(|sem| sem.available_permits() == 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_connection_tracker_new() {
let tracker = ConnectionTracker::new();
assert_eq!(tracker.count(), 0);
}
#[test]
fn test_connection_tracker_track_and_drop() {
let tracker = ConnectionTracker::new();
let guard1 = tracker.track();
assert_eq!(tracker.count(), 1);
let guard2 = tracker.track();
assert_eq!(tracker.count(), 2);
drop(guard1);
assert_eq!(tracker.count(), 1);
drop(guard2);
assert_eq!(tracker.count(), 0);
}
#[test]
fn test_connection_tracker_clone_shares_state() {
let tracker1 = ConnectionTracker::new();
let tracker2 = tracker1.clone();
let _guard1 = tracker1.track();
assert_eq!(tracker2.count(), 1);
let _guard2 = tracker2.track();
assert_eq!(tracker1.count(), 2);
}
#[tokio::test]
async fn test_connection_tracker_wait_for_shutdown_immediate() {
let tracker = ConnectionTracker::new();
let result = tracker.wait_for_shutdown(Duration::from_millis(100)).await;
assert!(result);
}
#[tokio::test]
async fn test_connection_tracker_wait_for_shutdown_with_connections() {
let tracker = ConnectionTracker::new();
let guard = tracker.track();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
drop(guard);
});
let result = tracker.wait_for_shutdown(Duration::from_millis(200)).await;
assert!(result);
assert_eq!(tracker.count(), 0);
}
#[tokio::test]
async fn test_connection_tracker_wait_for_shutdown_timeout() {
let tracker = ConnectionTracker::new();
let _guard = tracker.track();
let result = tracker.wait_for_shutdown(Duration::from_millis(50)).await;
assert!(!result);
assert_eq!(tracker.count(), 1);
}
#[test]
fn test_connection_limiter_unlimited() {
let limiter = ConnectionLimiter::new(0);
assert!(!limiter.is_enabled());
assert_eq!(limiter.max_connections(), 0);
assert!(!limiter.at_capacity());
}
#[test]
fn test_connection_limiter_with_limit() {
let limiter = ConnectionLimiter::new(10);
assert!(limiter.is_enabled());
assert_eq!(limiter.max_connections(), 10);
assert!(!limiter.at_capacity());
}
#[test]
fn test_connection_limiter_try_acquire_unlimited() {
let limiter = ConnectionLimiter::new(0);
assert!(limiter.try_acquire().is_none());
}
#[test]
fn test_connection_limiter_try_acquire_with_limit() {
let limiter = ConnectionLimiter::new(2);
let permit1 = limiter.try_acquire();
assert!(permit1.is_some());
let permit2 = limiter.try_acquire();
assert!(permit2.is_some());
assert!(limiter.at_capacity());
let permit3 = limiter.try_acquire();
assert!(permit3.is_none());
}
#[test]
fn test_connection_limiter_permit_release() {
let limiter = ConnectionLimiter::new(1);
let permit = limiter.try_acquire();
assert!(permit.is_some());
assert!(limiter.at_capacity());
drop(permit);
assert!(!limiter.at_capacity());
let permit2 = limiter.try_acquire();
assert!(permit2.is_some());
}
}