#![cfg(native)]
use super::error::SignalError;
use super::signal::Signal;
use parking_lot::Mutex;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ThrottleStrategy {
FixedWindow,
SlidingWindow,
TokenBucket {
refill_rate: u32,
},
LeakyBucket {
leak_rate: u32,
},
}
#[derive(Debug, Clone)]
pub struct ThrottleConfig {
strategy: ThrottleStrategy,
max_emissions: u32,
window_size: Duration,
drop_on_limit: bool,
}
impl ThrottleConfig {
pub fn new() -> Self {
Self {
strategy: ThrottleStrategy::FixedWindow,
max_emissions: 100,
window_size: Duration::from_secs(1),
drop_on_limit: true,
}
}
pub fn with_strategy(mut self, strategy: ThrottleStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_max_emissions(mut self, max: u32) -> Self {
self.max_emissions = max;
self
}
pub fn with_window_size(mut self, window: Duration) -> Self {
self.window_size = window;
self
}
pub fn with_drop_on_limit(mut self, drop: bool) -> Self {
self.drop_on_limit = drop;
self
}
pub fn strategy(&self) -> ThrottleStrategy {
self.strategy
}
pub fn max_emissions(&self) -> u32 {
self.max_emissions
}
pub fn window_size(&self) -> Duration {
self.window_size
}
pub fn drop_on_limit(&self) -> bool {
self.drop_on_limit
}
}
impl Default for ThrottleConfig {
fn default() -> Self {
Self::new()
}
}
struct ThrottleState<T> {
emissions: VecDeque<Instant>,
tokens: f64,
last_refill: Instant,
window_start: Instant,
window_count: u32,
queue: VecDeque<T>,
}
impl<T> ThrottleState<T> {
fn new(max_emissions: u32) -> Self {
Self {
emissions: VecDeque::new(),
tokens: max_emissions as f64,
last_refill: Instant::now(),
window_start: Instant::now(),
window_count: 0,
queue: VecDeque::new(),
}
}
fn can_emit_fixed_window(&mut self, config: &ThrottleConfig) -> bool {
let now = Instant::now();
if now.duration_since(self.window_start) >= config.window_size {
self.window_start = now;
self.window_count = 0;
}
if self.window_count < config.max_emissions {
self.window_count += 1;
true
} else {
false
}
}
fn can_emit_sliding_window(&mut self, config: &ThrottleConfig) -> bool {
let now = Instant::now();
let window_start = now - config.window_size;
while let Some(&emission_time) = self.emissions.front() {
if emission_time < window_start {
self.emissions.pop_front();
} else {
break;
}
}
if self.emissions.len() < config.max_emissions as usize {
self.emissions.push_back(now);
true
} else {
false
}
}
fn can_emit_token_bucket(&mut self, refill_rate: u32, max_emissions: u32) -> bool {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
let tokens_to_add = elapsed * refill_rate as f64;
self.tokens = (self.tokens + tokens_to_add).min(max_emissions as f64);
self.last_refill = now;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
fn can_emit(&mut self, config: &ThrottleConfig) -> bool {
match config.strategy {
ThrottleStrategy::FixedWindow => self.can_emit_fixed_window(config),
ThrottleStrategy::SlidingWindow => self.can_emit_sliding_window(config),
ThrottleStrategy::TokenBucket { refill_rate } => {
self.can_emit_token_bucket(refill_rate, config.max_emissions)
}
ThrottleStrategy::LeakyBucket { leak_rate } => {
self.can_emit_token_bucket(leak_rate, config.max_emissions)
}
}
}
fn enqueue(&mut self, item: T) {
self.queue.push_back(item);
}
fn dequeue(&mut self) -> Option<T> {
self.queue.pop_front()
}
fn queue_len(&self) -> usize {
self.queue.len()
}
}
pub struct SignalThrottle<T: Send + Sync + 'static> {
signal: Signal<T>,
config: ThrottleConfig,
state: Arc<Mutex<ThrottleState<T>>>,
dropped_count: Arc<Mutex<u64>>,
}
impl<T: Send + Sync + 'static> SignalThrottle<T> {
pub fn new(signal: Signal<T>, config: ThrottleConfig) -> Self {
let throttle = Self {
signal,
config: config.clone(),
state: Arc::new(Mutex::new(ThrottleState::new(config.max_emissions))),
dropped_count: Arc::new(Mutex::new(0)),
};
if !config.drop_on_limit {
throttle.start_queue_processor();
}
throttle
}
pub async fn send(&self, item: T) -> Result<(), SignalError> {
let can_emit = {
let mut state = self.state.lock();
state.can_emit(&self.config)
};
if can_emit {
self.signal.send(item).await
} else if self.config.drop_on_limit {
*self.dropped_count.lock() += 1;
Ok(())
} else {
self.state.lock().enqueue(item);
Ok(())
}
}
pub fn dropped_count(&self) -> u64 {
*self.dropped_count.lock()
}
pub fn queue_length(&self) -> usize {
self.state.lock().queue_len()
}
pub fn reset(&self) {
*self.dropped_count.lock() = 0;
let mut state = self.state.lock();
*state = ThrottleState::new(self.config.max_emissions);
}
fn start_queue_processor(&self) {
let state = Arc::clone(&self.state);
let signal = self.signal.clone();
let config = self.config.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_millis(100));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
loop {
let (_can_emit, item) = {
let mut state = state.lock();
let can_emit = state.can_emit(&config);
let item = if can_emit { state.dequeue() } else { None };
(can_emit, item)
};
if let Some(item) = item {
if let Err(e) = signal.send(item).await {
eprintln!("Failed to send throttled signal: {}", e);
}
} else {
break;
}
}
}
});
}
}
impl<T: Send + Sync + 'static> Clone for SignalThrottle<T> {
fn clone(&self) -> Self {
Self {
signal: self.signal.clone(),
config: self.config.clone(),
state: Arc::clone(&self.state),
dropped_count: Arc::clone(&self.dropped_count),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::signals::SignalName;
use std::sync::atomic::{AtomicUsize, Ordering};
async fn poll_until<F, Fut>(
timeout: std::time::Duration,
interval: std::time::Duration,
mut condition: F,
) -> Result<(), String>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = bool>,
{
let start = std::time::Instant::now();
while start.elapsed() < timeout {
if condition().await {
return Ok(());
}
tokio::time::sleep(interval).await;
}
Err(format!("Timeout after {:?} waiting for condition", timeout))
}
#[test]
fn test_throttle_config() {
let config = ThrottleConfig::new()
.with_strategy(ThrottleStrategy::SlidingWindow)
.with_max_emissions(50)
.with_window_size(Duration::from_millis(500))
.with_drop_on_limit(false);
assert_eq!(config.strategy(), ThrottleStrategy::SlidingWindow);
assert_eq!(config.max_emissions(), 50);
assert_eq!(config.window_size(), Duration::from_millis(500));
assert!(!config.drop_on_limit());
}
#[tokio::test]
async fn test_fixed_window_throttle() {
let signal = Signal::<i32>::new(SignalName::custom("test_fixed"));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |_| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let config = ThrottleConfig::new()
.with_strategy(ThrottleStrategy::FixedWindow)
.with_max_emissions(5)
.with_window_size(Duration::from_millis(500));
let throttle = SignalThrottle::new(signal, config);
for i in 0..10 {
throttle.send(i).await.unwrap();
}
poll_until(
Duration::from_millis(100),
Duration::from_millis(10),
|| async { counter.load(Ordering::SeqCst) == 5 },
)
.await
.expect("5 signals should be processed within 100ms");
assert_eq!(counter.load(Ordering::SeqCst), 5);
assert_eq!(throttle.dropped_count(), 5);
}
#[tokio::test]
async fn test_sliding_window_throttle() {
let signal = Signal::<i32>::new(SignalName::custom("test_sliding"));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |_| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let config = ThrottleConfig::new()
.with_strategy(ThrottleStrategy::SlidingWindow)
.with_max_emissions(3)
.with_window_size(Duration::from_millis(200));
let throttle = SignalThrottle::new(signal, config);
for i in 0..5 {
throttle.send(i).await.unwrap();
}
poll_until(
Duration::from_millis(50),
Duration::from_millis(10),
|| async { counter.load(Ordering::SeqCst) == 3 },
)
.await
.expect("3 signals should be processed within 50ms");
tokio::time::sleep(Duration::from_millis(200)).await;
throttle.send(99).await.unwrap();
poll_until(
Duration::from_millis(50),
Duration::from_millis(10),
|| async { counter.load(Ordering::SeqCst) == 4 },
)
.await
.expect("4th signal should be processed within 50ms");
}
#[tokio::test]
async fn test_token_bucket_throttle() {
let signal = Signal::<i32>::new(SignalName::custom("test_token"));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |_| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let config = ThrottleConfig::new()
.with_strategy(ThrottleStrategy::TokenBucket { refill_rate: 10 })
.with_max_emissions(10);
let throttle = SignalThrottle::new(signal, config);
for i in 0..10 {
throttle.send(i).await.unwrap();
}
poll_until(
Duration::from_millis(50),
Duration::from_millis(10),
|| async { counter.load(Ordering::SeqCst) == 10 },
)
.await
.expect("10 signals should be processed within 50ms");
throttle.send(100).await.unwrap();
poll_until(
Duration::from_millis(50),
Duration::from_millis(10),
|| async { throttle.dropped_count() == 1 },
)
.await
.expect("11th signal should be dropped within 50ms");
}
#[tokio::test]
async fn test_queue_mode() {
let signal = Signal::<i32>::new(SignalName::custom("test_queue"));
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
signal.connect(move |_| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Ok(())
}
});
let config = ThrottleConfig::new()
.with_strategy(ThrottleStrategy::FixedWindow)
.with_max_emissions(5)
.with_window_size(Duration::from_millis(300))
.with_drop_on_limit(false);
let throttle = SignalThrottle::new(signal, config);
for i in 0..10 {
throttle.send(i).await.unwrap();
}
poll_until(
Duration::from_millis(100),
Duration::from_millis(10),
|| async { counter.load(Ordering::SeqCst) == 5 },
)
.await
.expect("5 signals should be processed immediately within 100ms");
assert_eq!(counter.load(Ordering::SeqCst), 5);
assert!(throttle.queue_length() > 0);
poll_until(
Duration::from_millis(500),
Duration::from_millis(20),
|| async { counter.load(Ordering::SeqCst) >= 9 },
)
.await
.expect("Queue should be processed within 500ms");
assert!(counter.load(Ordering::SeqCst) >= 9);
assert_eq!(throttle.dropped_count(), 0);
}
}