use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::Mutex;
use cognis_core::Result;
use cognis_llm::chat::ChatResponse;
use super::{Middleware, MiddlewareCtx, Next};
#[async_trait]
pub trait RateLimiter: Send + Sync {
async fn acquire(&self, estimated_tokens: u64);
}
pub struct TokenBucket {
inner: Mutex<TokenBucketState>,
}
struct TokenBucketState {
permits: f64,
capacity: f64,
rate_per_sec: f64,
last_refill: Instant,
}
impl TokenBucket {
pub fn new(rate_per_sec: f64, burst: u64) -> Self {
Self {
inner: Mutex::new(TokenBucketState {
permits: burst as f64,
capacity: burst as f64,
rate_per_sec,
last_refill: Instant::now(),
}),
}
}
}
#[async_trait]
impl RateLimiter for TokenBucket {
async fn acquire(&self, estimated_tokens: u64) {
let needed = (estimated_tokens.max(1)) as f64;
loop {
let wait = {
let mut s = self.inner.lock().await;
let now = Instant::now();
let elapsed = now.duration_since(s.last_refill).as_secs_f64();
s.permits = (s.permits + elapsed * s.rate_per_sec).min(s.capacity);
s.last_refill = now;
if s.permits >= needed {
s.permits -= needed;
None
} else {
let deficit = needed - s.permits;
Some(Duration::from_secs_f64(
(deficit / s.rate_per_sec).max(0.001),
))
}
};
match wait {
None => return,
Some(d) => tokio::time::sleep(d).await,
}
}
}
}
pub struct SlidingWindowLimiter {
inner: Mutex<SlidingWindowState>,
capacity: u64,
window: Duration,
}
struct SlidingWindowState {
events: std::collections::VecDeque<(Instant, u64)>,
used: u64,
}
impl SlidingWindowLimiter {
pub fn new(capacity: u64, window: Duration) -> Self {
Self {
inner: Mutex::new(SlidingWindowState {
events: std::collections::VecDeque::new(),
used: 0,
}),
capacity,
window,
}
}
}
#[async_trait]
impl RateLimiter for SlidingWindowLimiter {
async fn acquire(&self, estimated_tokens: u64) {
let need = estimated_tokens.max(1);
loop {
let wait = {
let mut s = self.inner.lock().await;
let now = Instant::now();
while let Some(&(t, n)) = s.events.front() {
if now.duration_since(t) >= self.window {
s.events.pop_front();
s.used = s.used.saturating_sub(n);
} else {
break;
}
}
if s.used + need <= self.capacity {
s.events.push_back((now, need));
s.used += need;
None
} else {
let oldest = s.events.front().map(|(t, _)| *t);
oldest.map(|t| {
let elapsed = now.duration_since(t);
if elapsed >= self.window {
Duration::from_millis(1)
} else {
self.window - elapsed
}
})
}
};
match wait {
None => return,
Some(d) => tokio::time::sleep(d).await,
}
}
}
}
pub struct CostBasedLimiter {
inner: Mutex<CostState>,
cap: u64,
}
struct CostState {
spent: u64,
}
impl CostBasedLimiter {
pub fn new(cap: u64) -> Self {
Self {
inner: Mutex::new(CostState { spent: 0 }),
cap,
}
}
pub async fn reset(&self) {
self.inner.lock().await.spent = 0;
}
pub async fn refund(&self, units: u64) {
let mut s = self.inner.lock().await;
s.spent = s.spent.saturating_sub(units);
}
pub async fn spent(&self) -> u64 {
self.inner.lock().await.spent
}
}
#[async_trait]
impl RateLimiter for CostBasedLimiter {
async fn acquire(&self, estimated_tokens: u64) {
let cost = estimated_tokens.max(1);
loop {
{
let mut s = self.inner.lock().await;
if s.spent + cost <= self.cap {
s.spent += cost;
return;
}
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
}
}
pub struct CompositeLimiter {
limiters: Vec<Arc<dyn RateLimiter>>,
}
impl CompositeLimiter {
pub fn new() -> Self {
Self {
limiters: Vec::new(),
}
}
pub fn push(mut self, limiter: Arc<dyn RateLimiter>) -> Self {
self.limiters.push(limiter);
self
}
}
impl Default for CompositeLimiter {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl RateLimiter for CompositeLimiter {
async fn acquire(&self, estimated_tokens: u64) {
for l in &self.limiters {
l.acquire(estimated_tokens).await;
}
}
}
pub struct RateLimit {
limiter: Arc<dyn RateLimiter>,
estimator: Arc<dyn Fn(&MiddlewareCtx) -> u64 + Send + Sync>,
}
impl RateLimit {
pub fn new(limiter: Arc<dyn RateLimiter>) -> Self {
Self {
limiter,
estimator: Arc::new(default_estimator),
}
}
pub fn with_estimator<F>(mut self, f: F) -> Self
where
F: Fn(&MiddlewareCtx) -> u64 + Send + Sync + 'static,
{
self.estimator = Arc::new(f);
self
}
}
fn default_estimator(ctx: &MiddlewareCtx) -> u64 {
ctx.messages
.iter()
.map(|m| m.content().chars().count() as u64)
.sum()
}
#[async_trait]
impl Middleware for RateLimit {
async fn call(&self, ctx: MiddlewareCtx, next: Arc<dyn Next>) -> Result<ChatResponse> {
let cost = (self.estimator)(&ctx);
self.limiter.acquire(cost).await;
next.invoke(ctx).await
}
fn name(&self) -> &str {
"RateLimit"
}
}
#[cfg(test)]
mod tests {
use super::super::tests_util::*;
use super::*;
use crate::middleware::MiddlewarePipeline;
use cognis_core::Message;
use cognis_llm::chat::ChatOptions;
use cognis_llm::Client;
#[tokio::test]
async fn token_bucket_acquires_immediately_when_permits_available() {
let b = TokenBucket::new(1000.0, 100);
let start = Instant::now();
b.acquire(10).await;
assert!(start.elapsed() < Duration::from_millis(100));
}
#[tokio::test]
async fn token_bucket_blocks_when_drained() {
let b = TokenBucket::new(50.0, 10); b.acquire(10).await;
let start = Instant::now();
b.acquire(5).await;
assert!(start.elapsed() >= Duration::from_millis(50));
}
#[tokio::test]
async fn middleware_passes_through_when_under_limit() {
let provider = make_recording_provider("ok");
let pipe = MiddlewarePipeline::new()
.push(RateLimit::new(Arc::new(TokenBucket::new(100000.0, 100))))
.build(Client::new(provider.clone()));
let r = pipe
.invoke(
vec![Message::human("hi")],
Vec::new(),
ChatOptions::default(),
)
.await
.unwrap();
assert_eq!(r.message.content(), "ok");
}
#[tokio::test]
async fn sliding_window_admits_until_cap_then_blocks_until_window_passes() {
let l = SlidingWindowLimiter::new(10, Duration::from_millis(100));
l.acquire(5).await;
l.acquire(5).await;
let start = Instant::now();
l.acquire(1).await;
assert!(
start.elapsed() >= Duration::from_millis(80),
"expected wait, got {:?}",
start.elapsed()
);
}
#[tokio::test]
async fn cost_based_limiter_admits_until_cap() {
let l = CostBasedLimiter::new(100);
l.acquire(40).await;
l.acquire(40).await;
assert_eq!(l.spent().await, 80);
}
#[tokio::test]
async fn cost_based_limiter_blocks_then_unblocks_on_reset() {
let l = Arc::new(CostBasedLimiter::new(50));
l.acquire(50).await;
let l2 = l.clone();
let h = tokio::spawn(async move { l2.acquire(10).await });
tokio::time::sleep(Duration::from_millis(60)).await;
assert!(!h.is_finished(), "should be blocked while over budget");
l.reset().await;
h.await.unwrap();
assert_eq!(l.spent().await, 10);
}
#[tokio::test]
async fn cost_based_limiter_refund_releases_capacity() {
let l = CostBasedLimiter::new(100);
l.acquire(80).await;
l.refund(50).await;
assert_eq!(l.spent().await, 30);
l.acquire(60).await;
assert_eq!(l.spent().await, 90);
}
#[tokio::test]
async fn composite_limiter_runs_every_inner() {
let token: Arc<dyn RateLimiter> = Arc::new(TokenBucket::new(1000.0, 100));
let cost: Arc<dyn RateLimiter> = Arc::new(CostBasedLimiter::new(1000));
let comp = CompositeLimiter::new().push(token).push(cost);
comp.acquire(10).await;
}
}