use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use serde_json::Value;
use crate::error::{CognisError, Result};
use super::base::Runnable;
use super::config::RunnableConfig;
use super::RunnableStream;
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub max_requests: usize,
pub window_duration_ms: u64,
pub burst_limit: Option<usize>,
}
impl RateLimitConfig {
pub fn new(max_requests: usize, window_duration_ms: u64) -> Self {
Self {
max_requests,
window_duration_ms,
burst_limit: None,
}
}
pub fn with_burst_limit(mut self, burst_limit: usize) -> Self {
self.burst_limit = Some(burst_limit);
self
}
}
#[derive(Debug)]
struct TokenBucketState {
tokens: usize,
last_refill: Instant,
}
#[derive(Debug)]
pub struct RateLimiter {
config: RateLimitConfig,
state: Mutex<TokenBucketState>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let capacity = config.burst_limit.unwrap_or(config.max_requests);
Self {
state: Mutex::new(TokenBucketState {
tokens: capacity,
last_refill: Instant::now(),
}),
config,
}
}
fn capacity(&self) -> usize {
self.config.burst_limit.unwrap_or(self.config.max_requests)
}
pub fn try_acquire(&self) -> bool {
let mut state = self.state.lock().unwrap();
self.refill(&mut state);
if state.tokens > 0 {
state.tokens -= 1;
true
} else {
false
}
}
pub fn available_tokens(&self) -> usize {
let mut state = self.state.lock().unwrap();
self.refill(&mut state);
state.tokens
}
pub fn reset(&self) {
let mut state = self.state.lock().unwrap();
state.tokens = self.capacity();
state.last_refill = Instant::now();
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
fn refill(&self, state: &mut TokenBucketState) {
if self.config.window_duration_ms == 0 || self.config.max_requests == 0 {
return;
}
let now = Instant::now();
let elapsed_ms = now.duration_since(state.last_refill).as_millis() as u64;
if elapsed_ms > 0 {
let tokens_to_add = (elapsed_ms as u128 * self.config.max_requests as u128
/ self.config.window_duration_ms as u128) as usize;
if tokens_to_add > 0 {
let capacity = self.capacity();
state.tokens = (state.tokens + tokens_to_add).min(capacity);
state.last_refill = now;
}
}
}
}
#[derive(Debug)]
pub struct SlidingWindowCounter {
max_requests: usize,
window_ms: u64,
timestamps: Mutex<VecDeque<Instant>>,
}
impl SlidingWindowCounter {
pub fn new(max_requests: usize, window_ms: u64) -> Self {
Self {
max_requests,
window_ms,
timestamps: Mutex::new(VecDeque::new()),
}
}
pub fn record(&self) -> bool {
let mut timestamps = self.timestamps.lock().unwrap();
let now = Instant::now();
let window = Duration::from_millis(self.window_ms);
while let Some(&front) = timestamps.front() {
if now.duration_since(front) > window {
timestamps.pop_front();
} else {
break;
}
}
if timestamps.len() < self.max_requests {
timestamps.push_back(now);
true
} else {
false
}
}
pub fn current_count(&self) -> usize {
let mut timestamps = self.timestamps.lock().unwrap();
let now = Instant::now();
let window = Duration::from_millis(self.window_ms);
while let Some(&front) = timestamps.front() {
if now.duration_since(front) > window {
timestamps.pop_front();
} else {
break;
}
}
timestamps.len()
}
pub fn reset(&self) {
self.timestamps.lock().unwrap().clear();
}
}
pub struct RunnableRateLimit {
inner: Arc<dyn Runnable>,
limiter: RateLimiter,
}
impl RunnableRateLimit {
pub fn new(inner: Arc<dyn Runnable>, config: RateLimitConfig) -> Self {
Self {
inner,
limiter: RateLimiter::new(config),
}
}
pub fn limiter(&self) -> &RateLimiter {
&self.limiter
}
}
#[async_trait]
impl Runnable for RunnableRateLimit {
fn name(&self) -> &str {
"RunnableRateLimit"
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
if !self.limiter.try_acquire() {
return Err(CognisError::Other("Rate limit exceeded".to_string()));
}
self.inner.invoke(input, config).await
}
async fn batch(
&self,
inputs: Vec<Value>,
config: Option<&RunnableConfig>,
) -> Result<Vec<Value>> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
results.push(self.invoke(input, config).await?);
}
Ok(results)
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
if !self.limiter.try_acquire() {
return Err(CognisError::Other("Rate limit exceeded".to_string()));
}
self.inner.stream(input, config).await
}
}
pub struct RunnableThrottle {
inner: Arc<dyn Runnable>,
min_interval: Duration,
last_invocation: Mutex<Option<Instant>>,
}
impl RunnableThrottle {
pub fn new(inner: Arc<dyn Runnable>, min_interval_ms: u64) -> Self {
Self {
inner,
min_interval: Duration::from_millis(min_interval_ms),
last_invocation: Mutex::new(None),
}
}
pub fn with_duration(inner: Arc<dyn Runnable>, min_interval: Duration) -> Self {
Self {
inner,
min_interval,
last_invocation: Mutex::new(None),
}
}
async fn wait_if_needed(&self) {
let sleep_duration = {
let mut last = self.last_invocation.lock().unwrap();
let now = Instant::now();
let duration = if let Some(last_time) = *last {
let elapsed = now.duration_since(last_time);
if elapsed < self.min_interval {
Some(self.min_interval - elapsed)
} else {
None
}
} else {
None
};
*last = Some(now + duration.unwrap_or_default());
duration
};
if let Some(d) = sleep_duration {
tokio::time::sleep(d).await;
}
}
}
#[async_trait]
impl Runnable for RunnableThrottle {
fn name(&self) -> &str {
"RunnableThrottle"
}
async fn invoke(&self, input: Value, config: Option<&RunnableConfig>) -> Result<Value> {
self.wait_if_needed().await;
self.inner.invoke(input, config).await
}
async fn batch(
&self,
inputs: Vec<Value>,
config: Option<&RunnableConfig>,
) -> Result<Vec<Value>> {
let mut results = Vec::with_capacity(inputs.len());
for input in inputs {
results.push(self.invoke(input, config).await?);
}
Ok(results)
}
async fn stream(
&self,
input: Value,
config: Option<&RunnableConfig>,
) -> Result<RunnableStream> {
self.wait_if_needed().await;
self.inner.stream(input, config).await
}
}
#[derive(Debug)]
struct ConcurrencyLimiterInner {
active: usize,
}
#[derive(Debug, Clone)]
pub struct ConcurrencyLimiter {
inner: Arc<Mutex<ConcurrencyLimiterInner>>,
max_concurrent: usize,
}
impl ConcurrencyLimiter {
pub fn new(max_concurrent: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(ConcurrencyLimiterInner { active: 0 })),
max_concurrent,
}
}
pub fn try_acquire(&self) -> Option<ConcurrencyGuard> {
let mut inner = self.inner.lock().unwrap();
if inner.active < self.max_concurrent {
inner.active += 1;
Some(ConcurrencyGuard {
inner: Arc::clone(&self.inner),
})
} else {
None
}
}
pub fn active_count(&self) -> usize {
self.inner.lock().unwrap().active
}
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
}
#[derive(Debug)]
pub struct ConcurrencyGuard {
inner: Arc<Mutex<ConcurrencyLimiterInner>>,
}
impl Drop for ConcurrencyGuard {
fn drop(&mut self) {
let mut inner = self.inner.lock().unwrap();
inner.active -= 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::atomic::{AtomicUsize, Ordering};
struct Echo;
#[async_trait]
impl Runnable for Echo {
fn name(&self) -> &str {
"Echo"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
Ok(input)
}
}
struct Counter {
count: AtomicUsize,
}
impl Counter {
fn new() -> Self {
Self {
count: AtomicUsize::new(0),
}
}
fn count(&self) -> usize {
self.count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Runnable for Counter {
fn name(&self) -> &str {
"Counter"
}
async fn invoke(&self, input: Value, _config: Option<&RunnableConfig>) -> Result<Value> {
self.count.fetch_add(1, Ordering::SeqCst);
Ok(input)
}
}
#[test]
fn test_config_new() {
let c = RateLimitConfig::new(10, 1000);
assert_eq!(c.max_requests, 10);
assert_eq!(c.window_duration_ms, 1000);
assert!(c.burst_limit.is_none());
}
#[test]
fn test_config_with_burst_limit() {
let c = RateLimitConfig::new(10, 1000).with_burst_limit(20);
assert_eq!(c.burst_limit, Some(20));
}
#[test]
fn test_config_builder_chain() {
let c = RateLimitConfig::new(5, 500).with_burst_limit(10);
assert_eq!(c.max_requests, 5);
assert_eq!(c.window_duration_ms, 500);
assert_eq!(c.burst_limit, Some(10));
}
#[test]
fn test_config_clone() {
let c = RateLimitConfig::new(10, 1000).with_burst_limit(15);
let c2 = c.clone();
assert_eq!(c2.max_requests, c.max_requests);
assert_eq!(c2.window_duration_ms, c.window_duration_ms);
assert_eq!(c2.burst_limit, c.burst_limit);
}
#[test]
fn test_config_debug() {
let c = RateLimitConfig::new(10, 1000).with_burst_limit(15);
let s = format!("{:?}", c);
assert!(s.contains("10"));
assert!(s.contains("1000"));
assert!(s.contains("15"));
}
#[test]
fn test_limiter_acquire_within_limit() {
let limiter = RateLimiter::new(RateLimitConfig::new(5, 60_000));
for _ in 0..5 {
assert!(limiter.try_acquire());
}
}
#[test]
fn test_limiter_acquire_exhaustion() {
let limiter = RateLimiter::new(RateLimitConfig::new(3, 60_000));
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
assert!(!limiter.try_acquire());
}
#[test]
fn test_limiter_available_tokens() {
let limiter = RateLimiter::new(RateLimitConfig::new(5, 60_000));
assert_eq!(limiter.available_tokens(), 5);
limiter.try_acquire();
assert_eq!(limiter.available_tokens(), 4);
limiter.try_acquire();
limiter.try_acquire();
assert_eq!(limiter.available_tokens(), 2);
}
#[test]
fn test_limiter_reset() {
let limiter = RateLimiter::new(RateLimitConfig::new(3, 60_000));
limiter.try_acquire();
limiter.try_acquire();
limiter.try_acquire();
assert_eq!(limiter.available_tokens(), 0);
limiter.reset();
assert_eq!(limiter.available_tokens(), 3);
}
#[test]
fn test_limiter_config_accessor() {
let limiter = RateLimiter::new(RateLimitConfig::new(10, 2000));
assert_eq!(limiter.config().max_requests, 10);
assert_eq!(limiter.config().window_duration_ms, 2000);
}
#[test]
fn test_limiter_burst_limit() {
let limiter = RateLimiter::new(RateLimitConfig::new(3, 60_000).with_burst_limit(5));
assert_eq!(limiter.available_tokens(), 5);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[tokio::test]
async fn test_limiter_refill_over_time() {
let limiter = RateLimiter::new(RateLimitConfig::new(10, 100));
for _ in 0..10 {
limiter.try_acquire();
}
assert_eq!(limiter.available_tokens(), 0);
tokio::time::sleep(Duration::from_millis(120)).await;
assert!(limiter.available_tokens() > 0);
}
#[test]
fn test_limiter_zero_max_requests() {
let limiter = RateLimiter::new(RateLimitConfig::new(0, 1000));
assert_eq!(limiter.available_tokens(), 0);
assert!(!limiter.try_acquire());
}
#[test]
fn test_limiter_zero_window() {
let limiter = RateLimiter::new(RateLimitConfig::new(5, 0));
assert_eq!(limiter.available_tokens(), 5);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
fn test_limiter_immediate_reset_and_reuse() {
let limiter = RateLimiter::new(RateLimitConfig::new(2, 60_000));
limiter.try_acquire();
limiter.try_acquire();
assert!(!limiter.try_acquire());
limiter.reset();
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
}
#[test]
fn test_burst_limit_capacity() {
let limiter = RateLimiter::new(RateLimitConfig::new(5, 60_000).with_burst_limit(8));
assert_eq!(limiter.available_tokens(), 8);
for i in 0..8 {
assert!(limiter.try_acquire(), "Token {} should be available", i);
}
assert!(!limiter.try_acquire());
}
#[test]
fn test_sliding_window_within_limit() {
let counter = SlidingWindowCounter::new(5, 60_000);
for _ in 0..5 {
assert!(counter.record());
}
assert_eq!(counter.current_count(), 5);
}
#[test]
fn test_sliding_window_exceeds_limit() {
let counter = SlidingWindowCounter::new(3, 60_000);
assert!(counter.record());
assert!(counter.record());
assert!(counter.record());
assert!(!counter.record());
assert_eq!(counter.current_count(), 3);
}
#[test]
fn test_sliding_window_reset() {
let counter = SlidingWindowCounter::new(3, 60_000);
counter.record();
counter.record();
counter.record();
assert_eq!(counter.current_count(), 3);
counter.reset();
assert_eq!(counter.current_count(), 0);
assert!(counter.record());
}
#[tokio::test]
async fn test_sliding_window_expiry() {
let counter = SlidingWindowCounter::new(2, 50);
assert!(counter.record());
assert!(counter.record());
assert!(!counter.record());
tokio::time::sleep(Duration::from_millis(70)).await;
assert_eq!(counter.current_count(), 0);
assert!(counter.record());
}
#[test]
fn test_sliding_window_zero_limit() {
let counter = SlidingWindowCounter::new(0, 1000);
assert!(!counter.record());
assert_eq!(counter.current_count(), 0);
}
#[test]
fn test_sliding_window_very_large_window() {
let counter = SlidingWindowCounter::new(100, 3_600_000);
for _ in 0..100 {
assert!(counter.record());
}
assert!(!counter.record());
assert_eq!(counter.current_count(), 100);
}
#[tokio::test]
async fn test_runnable_rate_limit_allows_within_limit() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let config = RateLimitConfig::new(5, 60_000);
let limited = RunnableRateLimit::new(inner, config);
for i in 0..5 {
let r = limited.invoke(json!(i), None).await;
assert!(r.is_ok(), "Request {} should succeed", i);
assert_eq!(r.unwrap(), json!(i));
}
}
#[tokio::test]
async fn test_runnable_rate_limit_rejects_over_limit() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let config = RateLimitConfig::new(2, 60_000);
let limited = RunnableRateLimit::new(inner, config);
assert!(limited.invoke(json!(1), None).await.is_ok());
assert!(limited.invoke(json!(2), None).await.is_ok());
let r = limited.invoke(json!(3), None).await;
assert!(r.is_err());
assert!(format!("{}", r.unwrap_err()).contains("Rate limit exceeded"));
}
#[tokio::test]
async fn test_runnable_rate_limit_name() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let limited = RunnableRateLimit::new(inner, RateLimitConfig::new(5, 1000));
assert_eq!(limited.name(), "RunnableRateLimit");
}
#[tokio::test]
async fn test_runnable_rate_limit_batch_respects_limit() {
let counter = Arc::new(Counter::new());
let limited = RunnableRateLimit::new(
counter.clone() as Arc<dyn Runnable>,
RateLimitConfig::new(3, 60_000),
);
let inputs = vec![json!(1), json!(2), json!(3), json!(4), json!(5)];
let r = limited.batch(inputs, None).await;
assert!(r.is_err());
assert_eq!(counter.count(), 3);
}
#[tokio::test]
async fn test_runnable_rate_limit_stream_rejects() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let limited = RunnableRateLimit::new(inner, RateLimitConfig::new(1, 60_000));
assert!(limited.stream(json!(1), None).await.is_ok());
assert!(limited.stream(json!(2), None).await.is_err());
}
#[tokio::test]
async fn test_runnable_rate_limit_limiter_accessor() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let limited = RunnableRateLimit::new(inner, RateLimitConfig::new(10, 1000));
assert_eq!(limited.limiter().config().max_requests, 10);
}
#[tokio::test]
async fn test_runnable_rate_limit_single_request() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let limited = RunnableRateLimit::new(inner, RateLimitConfig::new(1, 60_000));
let r = limited.invoke(json!("only one"), None).await;
assert!(r.is_ok());
assert_eq!(r.unwrap(), json!("only one"));
assert!(limited.invoke(json!("too many"), None).await.is_err());
}
#[tokio::test]
async fn test_throttle_first_call_immediate() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let throttled = RunnableThrottle::new(inner, 200);
let start = Instant::now();
let r = throttled.invoke(json!("fast"), None).await;
let elapsed = start.elapsed();
assert!(r.is_ok());
assert_eq!(r.unwrap(), json!("fast"));
assert!(
elapsed < Duration::from_millis(50),
"First call should be immediate, took {:?}",
elapsed
);
}
#[tokio::test]
async fn test_throttle_enforces_min_interval() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let throttled = RunnableThrottle::new(inner, 80);
throttled.invoke(json!(1), None).await.unwrap();
let start = Instant::now();
let r = throttled.invoke(json!(2), None).await;
let elapsed = start.elapsed();
assert!(r.is_ok());
assert!(
elapsed >= Duration::from_millis(60),
"Expected throttle delay of ~80ms, got {:?}",
elapsed
);
}
#[tokio::test]
async fn test_throttle_no_delay_after_interval() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let throttled = RunnableThrottle::new(inner, 30);
throttled.invoke(json!(1), None).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let start = Instant::now();
let r = throttled.invoke(json!(2), None).await;
let elapsed = start.elapsed();
assert!(r.is_ok());
assert!(
elapsed < Duration::from_millis(20),
"Should not delay after interval has passed, took {:?}",
elapsed
);
}
#[tokio::test]
async fn test_throttle_name() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let throttled = RunnableThrottle::new(inner, 100);
assert_eq!(throttled.name(), "RunnableThrottle");
}
#[tokio::test]
async fn test_throttle_zero_interval() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let throttled = RunnableThrottle::new(inner, 0);
for i in 0..10 {
let r = throttled.invoke(json!(i), None).await;
assert!(r.is_ok());
}
}
#[tokio::test]
async fn test_throttle_batch() {
let counter = Arc::new(Counter::new());
let throttled = RunnableThrottle::new(counter.clone() as Arc<dyn Runnable>, 20);
let inputs = vec![json!(1), json!(2), json!(3)];
let start = Instant::now();
let results = throttled.batch(inputs, None).await;
let elapsed = start.elapsed();
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 3);
assert_eq!(counter.count(), 3);
assert!(
elapsed >= Duration::from_millis(30),
"Expected throttle delays, got {:?}",
elapsed
);
}
#[tokio::test]
async fn test_throttle_with_duration() {
let inner = Arc::new(Echo) as Arc<dyn Runnable>;
let throttled = RunnableThrottle::with_duration(inner, Duration::from_millis(50));
throttled.invoke(json!(1), None).await.unwrap();
let start = Instant::now();
throttled.invoke(json!(2), None).await.unwrap();
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(30),
"Expected ~50ms delay, got {:?}",
elapsed
);
}
#[test]
fn test_concurrency_limiter_basic() {
let limiter = ConcurrencyLimiter::new(2);
assert_eq!(limiter.max_concurrent(), 2);
assert_eq!(limiter.active_count(), 0);
let _g1 = limiter.try_acquire().unwrap();
assert_eq!(limiter.active_count(), 1);
let _g2 = limiter.try_acquire().unwrap();
assert_eq!(limiter.active_count(), 2);
assert!(limiter.try_acquire().is_none());
}
#[test]
fn test_concurrency_limiter_release_on_drop() {
let limiter = ConcurrencyLimiter::new(1);
{
let _g = limiter.try_acquire().unwrap();
assert_eq!(limiter.active_count(), 1);
assert!(limiter.try_acquire().is_none());
}
assert_eq!(limiter.active_count(), 0);
assert!(limiter.try_acquire().is_some());
}
#[test]
fn test_concurrency_guard_raii() {
let limiter = ConcurrencyLimiter::new(3);
let g1 = limiter.try_acquire().unwrap();
let g2 = limiter.try_acquire().unwrap();
let g3 = limiter.try_acquire().unwrap();
assert_eq!(limiter.active_count(), 3);
assert!(limiter.try_acquire().is_none());
drop(g2);
assert_eq!(limiter.active_count(), 2);
let _g4 = limiter.try_acquire().unwrap();
assert_eq!(limiter.active_count(), 3);
drop(g1);
drop(g3);
assert_eq!(limiter.active_count(), 1);
}
#[test]
fn test_concurrency_limiter_zero() {
let limiter = ConcurrencyLimiter::new(0);
assert!(limiter.try_acquire().is_none());
assert_eq!(limiter.active_count(), 0);
}
#[test]
fn test_concurrency_limiter_clone() {
let limiter = ConcurrencyLimiter::new(2);
let cloned = limiter.clone();
let _g = limiter.try_acquire().unwrap();
assert_eq!(cloned.active_count(), 1);
}
#[test]
fn test_concurrency_limiter_large_limit() {
let limiter = ConcurrencyLimiter::new(1000);
let mut guards = Vec::new();
for _ in 0..1000 {
guards.push(limiter.try_acquire().unwrap());
}
assert_eq!(limiter.active_count(), 1000);
assert!(limiter.try_acquire().is_none());
guards.clear();
assert_eq!(limiter.active_count(), 0);
}
#[test]
fn test_concurrency_guard_multiple_drop() {
let limiter = ConcurrencyLimiter::new(5);
let guards: Vec<_> = (0..5).map(|_| limiter.try_acquire().unwrap()).collect();
assert_eq!(limiter.active_count(), 5);
drop(guards);
assert_eq!(limiter.active_count(), 0);
let _g = limiter.try_acquire().unwrap();
assert_eq!(limiter.active_count(), 1);
}
}