use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::{Notify, OwnedSemaphorePermit, Semaphore, TryAcquireError};
#[derive(Debug)]
pub struct Backpressure {
max_size: usize,
current: AtomicUsize,
space_available: Notify,
semaphore: Arc<Semaphore>,
}
impl Backpressure {
#[must_use]
pub fn new(max_size: usize) -> Self {
Self {
max_size,
current: AtomicUsize::new(0),
space_available: Notify::new(),
semaphore: Arc::new(Semaphore::new(max_size)),
}
}
pub fn try_acquire(&self, amount: usize) -> bool {
let current = self.current.load(Ordering::Acquire);
if current + amount <= self.max_size {
self.current.fetch_add(amount, Ordering::Release);
true
} else {
false
}
}
pub async fn acquire(&self, amount: usize) {
loop {
if self.try_acquire(amount) {
return;
}
self.space_available.notified().await;
}
}
pub fn release(&self, amount: usize) {
self.current.fetch_sub(amount, Ordering::Release);
self.space_available.notify_one();
}
#[must_use]
pub fn current_size(&self) -> usize {
self.current.load(Ordering::Acquire)
}
#[must_use]
pub const fn max_size(&self) -> usize {
self.max_size
}
#[must_use]
pub fn is_full(&self) -> bool {
self.current_size() >= self.max_size
}
#[must_use]
pub fn available(&self) -> usize {
self.max_size.saturating_sub(self.current_size())
}
pub fn try_acquire_permit(&self) -> Result<OwnedSemaphorePermit, TryAcquireError> {
self.semaphore.clone().try_acquire_owned()
}
pub async fn acquire_permit(&self) -> OwnedSemaphorePermit {
self.semaphore
.clone()
.acquire_owned()
.await
.expect("semaphore should not be closed")
}
#[must_use]
pub fn available_permits(&self) -> usize {
self.semaphore.available_permits()
}
}
impl Default for Backpressure {
fn default() -> Self {
Self::new(64 * 1024) }
}
#[derive(Debug)]
pub struct RateLimiter {
max_ops: usize,
interval_ms: u64,
current: AtomicUsize,
last_reset: std::sync::Mutex<std::time::Instant>,
}
impl RateLimiter {
#[must_use]
pub fn new(max_ops: usize, interval: std::time::Duration) -> Self {
Self {
max_ops,
interval_ms: interval.as_millis() as u64,
current: AtomicUsize::new(0),
last_reset: std::sync::Mutex::new(std::time::Instant::now()),
}
}
pub fn try_acquire(&self) -> bool {
self.maybe_reset();
let current = self.current.fetch_add(1, Ordering::AcqRel);
if current < self.max_ops {
true
} else {
self.current.fetch_sub(1, Ordering::Release);
false
}
}
pub async fn acquire(&self) {
while !self.try_acquire() {
let sleep_time = self.time_until_reset();
tokio::time::sleep(sleep_time).await;
}
}
fn maybe_reset(&self) {
let mut last_reset = self
.last_reset
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let elapsed = last_reset.elapsed();
if elapsed.as_millis() as u64 >= self.interval_ms {
self.current.store(0, Ordering::Release);
*last_reset = std::time::Instant::now();
}
}
#[allow(clippy::significant_drop_tightening)]
fn time_until_reset(&self) -> std::time::Duration {
let last_reset = self
.last_reset
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let elapsed = last_reset.elapsed();
let interval = std::time::Duration::from_millis(self.interval_ms);
if elapsed >= interval {
std::time::Duration::ZERO
} else {
interval.checked_sub(elapsed).unwrap()
}
}
}
#[derive(Debug)]
pub struct TokenBucket {
capacity: usize,
tokens: AtomicUsize,
refill_rate: f64,
last_refill: std::sync::Mutex<std::time::Instant>,
}
impl TokenBucket {
#[must_use]
pub fn new(capacity: usize, refill_rate: f64) -> Self {
Self {
capacity,
tokens: AtomicUsize::new(capacity),
refill_rate,
last_refill: std::sync::Mutex::new(std::time::Instant::now()),
}
}
pub fn try_consume(&self, count: usize) -> bool {
self.refill();
loop {
let current = self.tokens.load(Ordering::Acquire);
if current < count {
return false;
}
if self
.tokens
.compare_exchange(
current,
current - count,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
return true;
}
}
}
pub async fn consume(&self, count: usize) {
while !self.try_consume(count) {
let wait_time = self.time_for_tokens(count);
tokio::time::sleep(wait_time).await;
}
}
fn refill(&self) {
let mut last_refill = self
.last_refill
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let elapsed = last_refill.elapsed().as_secs_f64();
let new_tokens = (elapsed * self.refill_rate) as usize;
if new_tokens > 0 {
let current = self.tokens.load(Ordering::Acquire);
let new_value = (current + new_tokens).min(self.capacity);
self.tokens.store(new_value, Ordering::Release);
*last_refill = std::time::Instant::now();
}
}
fn time_for_tokens(&self, count: usize) -> std::time::Duration {
let current = self.tokens.load(Ordering::Acquire);
if current >= count {
return std::time::Duration::ZERO;
}
let needed = count - current;
let seconds = needed as f64 / self.refill_rate;
std::time::Duration::from_secs_f64(seconds)
}
#[must_use]
pub fn tokens(&self) -> usize {
self.refill();
self.tokens.load(Ordering::Acquire)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backpressure_acquire() {
let bp = Backpressure::new(100);
assert!(bp.try_acquire(50));
assert!(bp.try_acquire(50));
assert!(!bp.try_acquire(1));
bp.release(50);
assert!(bp.try_acquire(50));
}
#[test]
fn backpressure_permits() {
let bp = Backpressure::new(3);
let p1 = bp.try_acquire_permit().unwrap();
let p2 = bp.try_acquire_permit().unwrap();
let p3 = bp.try_acquire_permit().unwrap();
assert!(bp.try_acquire_permit().is_err());
assert_eq!(bp.available_permits(), 0);
drop(p1);
assert_eq!(bp.available_permits(), 1);
let _p4 = bp.try_acquire_permit().unwrap();
drop(p2);
drop(p3);
}
#[tokio::test]
async fn backpressure_async_permit() {
let bp = Backpressure::new(2);
let permit1 = bp.acquire_permit().await;
let permit2 = bp.acquire_permit().await;
assert_eq!(bp.available_permits(), 0);
drop(permit1);
assert_eq!(bp.available_permits(), 1);
drop(permit2);
assert_eq!(bp.available_permits(), 2);
}
#[test]
fn rate_limiter_basic() {
let limiter = RateLimiter::new(5, std::time::Duration::from_secs(1));
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire()); }
#[test]
fn token_bucket_basic() {
let bucket = TokenBucket::new(10, 5.0);
assert!(bucket.try_consume(5));
assert!(bucket.try_consume(5));
assert!(!bucket.try_consume(1)); }
}