use crate::error::LlmError;
use crate::llm::{Capabilities, ChatRequest, ChatResponse, ChunkStream, Embedding, LlmClient};
use async_trait::async_trait;
use parking_lot::Mutex;
use std::sync::Arc;
use tokio::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub tokens_per_minute: Option<u32>,
}
impl RateLimitConfig {
pub fn new(requests_per_minute: u32) -> Self {
Self {
requests_per_minute,
tokens_per_minute: None,
}
}
pub fn with_tokens_per_minute(mut self, tokens_per_minute: u32) -> Self {
self.tokens_per_minute = Some(tokens_per_minute);
self
}
}
#[derive(Debug)]
struct Bucket {
tokens: f64,
capacity: f64,
refill_per_sec: f64,
last_refill: Instant,
}
impl Bucket {
fn new(capacity: u32) -> Self {
let cap = f64::from(capacity.max(1));
Self {
tokens: cap,
capacity: cap,
refill_per_sec: cap / 60.0,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now
.saturating_duration_since(self.last_refill)
.as_secs_f64();
if elapsed > 0.0 {
let added = elapsed * self.refill_per_sec;
self.tokens = (self.tokens + added).min(self.capacity);
self.last_refill = now;
}
}
fn try_consume(&mut self, n: f64) -> Result<(), Duration> {
self.refill();
if self.tokens >= n {
self.tokens -= n;
Ok(())
} else {
let needed = n - self.tokens;
let secs = needed / self.refill_per_sec;
Err(Duration::from_secs_f64(secs))
}
}
}
pub struct RateLimitedClient {
inner: Arc<dyn LlmClient>,
rpm: Mutex<Bucket>,
tpm: Option<Mutex<Bucket>>,
}
impl RateLimitedClient {
pub fn new(inner: Arc<dyn LlmClient>, config: RateLimitConfig) -> Self {
let rpm = Mutex::new(Bucket::new(config.requests_per_minute));
let tpm = config.tokens_per_minute.map(|t| Mutex::new(Bucket::new(t)));
Self { inner, rpm, tpm }
}
async fn acquire_rpm(&self) {
loop {
let wait = {
let mut b = self.rpm.lock();
b.try_consume(1.0)
};
match wait {
Ok(()) => return,
Err(d) => tokio::time::sleep(d).await,
}
}
}
async fn acquire_tpm(&self, tokens: u32) {
let Some(tpm) = &self.tpm else { return };
let n = f64::from(tokens.max(1));
loop {
let wait = {
let mut b = tpm.lock();
b.try_consume(n)
};
match wait {
Ok(()) => return,
Err(d) => tokio::time::sleep(d).await,
}
}
}
}
fn estimate_tokens(req: &ChatRequest) -> u32 {
let prompt_chars: usize = req.messages.iter().map(|m| m.content.len()).sum();
let prompt_tokens = u32::try_from(prompt_chars / 4).unwrap_or(u32::MAX);
let completion_estimate = req.max_tokens.unwrap_or(256);
prompt_tokens.saturating_add(completion_estimate)
}
#[async_trait]
impl LlmClient for RateLimitedClient {
fn name(&self) -> &str {
self.inner.name()
}
fn capabilities(&self) -> &Capabilities {
self.inner.capabilities()
}
async fn complete(&self, req: ChatRequest) -> Result<ChatResponse, LlmError> {
let tokens = estimate_tokens(&req);
self.acquire_rpm().await;
self.acquire_tpm(tokens).await;
self.inner.complete(req).await
}
async fn stream(&self, req: ChatRequest) -> Result<ChunkStream, LlmError> {
let tokens = estimate_tokens(&req);
self.acquire_rpm().await;
self.acquire_tpm(tokens).await;
self.inner.stream(req).await
}
async fn embed(&self, texts: &[String]) -> Result<Vec<Embedding>, LlmError> {
self.acquire_rpm().await;
self.inner.embed(texts).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::{FakeLlmClient, FakeLlmStep};
use crate::ChatRequest;
use std::sync::Arc;
fn fake() -> Arc<dyn LlmClient> {
Arc::new(FakeLlmClient::new("fake").with_steps(vec![
FakeLlmStep::Text("ok".into()),
FakeLlmStep::Text("ok".into()),
FakeLlmStep::Text("ok".into()),
FakeLlmStep::Text("ok".into()),
FakeLlmStep::Text("ok".into()),
FakeLlmStep::Text("ok".into()),
]))
}
#[tokio::test(start_paused = true)]
async fn calls_within_rpm_capacity_do_not_block() {
let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(60));
let start = Instant::now();
for _ in 0..5 {
limited.complete(ChatRequest::new(vec![])).await.unwrap();
}
assert!(
Instant::now().saturating_duration_since(start) < Duration::from_secs(1),
"expected near-instant completion under capacity"
);
}
#[tokio::test(start_paused = true)]
async fn blocks_when_rpm_bucket_drained_then_refills() {
let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(2));
limited.complete(ChatRequest::new(vec![])).await.unwrap();
limited.complete(ChatRequest::new(vec![])).await.unwrap();
let start = Instant::now();
limited.complete(ChatRequest::new(vec![])).await.unwrap();
let waited = Instant::now().saturating_duration_since(start);
assert!(
waited >= Duration::from_secs(29),
"expected ≥29s wait, got {waited:?}"
);
}
#[tokio::test(start_paused = true)]
async fn tpm_disabled_lets_large_requests_through() {
let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(60));
let mut req = ChatRequest::new(vec![]);
req.max_tokens = Some(100_000);
let start = Instant::now();
limited.complete(req).await.unwrap();
assert!(Instant::now().saturating_duration_since(start) < Duration::from_secs(1));
}
#[tokio::test(start_paused = true)]
async fn tpm_gates_independently_of_rpm() {
let limited = RateLimitedClient::new(
fake(),
RateLimitConfig::new(1000).with_tokens_per_minute(1000),
);
let mut req = ChatRequest::new(vec![]);
req.max_tokens = Some(600);
limited.complete(req.clone()).await.unwrap();
let start = Instant::now();
limited.complete(req).await.unwrap();
let waited = Instant::now().saturating_duration_since(start);
assert!(
waited >= Duration::from_secs(11),
"expected TPM gate to delay ≥11s, got {waited:?}"
);
}
#[tokio::test(start_paused = true)]
async fn concurrent_callers_serialize_through_bucket() {
let limited = Arc::new(RateLimitedClient::new(fake(), RateLimitConfig::new(2)));
let start = Instant::now();
let mut handles = vec![];
for _ in 0..4 {
let l = limited.clone();
handles.push(tokio::spawn(async move {
l.complete(ChatRequest::new(vec![])).await.unwrap();
}));
}
for h in handles {
h.await.unwrap();
}
let waited = Instant::now().saturating_duration_since(start);
assert!(
waited >= Duration::from_secs(59),
"expected ≥59s, got {waited:?}"
);
}
#[tokio::test(start_paused = true)]
async fn refill_caps_at_capacity() {
let limited = RateLimitedClient::new(fake(), RateLimitConfig::new(10));
for _ in 0..3 {
limited.complete(ChatRequest::new(vec![])).await.unwrap();
}
tokio::time::sleep(Duration::from_secs(300)).await;
let start = Instant::now();
for _ in 0..10 {
let _ = limited.complete(ChatRequest::new(vec![])).await;
}
assert!(
Instant::now().saturating_duration_since(start) < Duration::from_secs(1),
"10 calls within capacity must not block after a long idle"
);
}
}