use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use llmkit_core::{
ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmProvider,
LlmResult,
};
use tokio::sync::Mutex;
use crate::layer::LlmLayer;
#[derive(Debug, Clone, Copy)]
pub struct RateLimitLayer {
capacity: f64,
refill_per_sec: f64,
}
impl RateLimitLayer {
pub fn token_bucket(capacity: u64, window: Duration) -> Self {
let secs = window.as_secs_f64().max(f64::MIN_POSITIVE);
Self { capacity: capacity as f64, refill_per_sec: capacity as f64 / secs }
}
}
impl LlmLayer for RateLimitLayer {
type Provider = RateLimit;
fn layer(self, inner: Arc<dyn LlmProvider>) -> RateLimit {
RateLimit {
inner,
bucket: Arc::new(Mutex::new(Bucket {
tokens: self.capacity,
capacity: self.capacity,
refill_per_sec: self.refill_per_sec,
last: Instant::now(),
})),
}
}
}
struct Bucket {
tokens: f64,
capacity: f64,
refill_per_sec: f64,
last: Instant,
}
impl Bucket {
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
self.last = now;
}
fn time_until_available(&mut self, cost: f64) -> Duration {
self.refill();
let cost = cost.min(self.capacity);
if self.tokens >= cost {
self.tokens -= cost;
Duration::ZERO
} else {
let deficit = cost - self.tokens;
let wait = deficit / self.refill_per_sec;
self.tokens = 0.0;
Duration::from_secs_f64(wait)
}
}
}
pub struct RateLimit {
inner: Arc<dyn LlmProvider>,
bucket: Arc<Mutex<Bucket>>,
}
impl RateLimit {
fn estimated_cost(req: &ChatRequest) -> f64 {
let chars: usize = req
.messages
.iter()
.filter_map(|m| m.content.as_text())
.map(|t| t.len())
.sum::<usize>()
+ req.system.as_deref().map(str::len).unwrap_or(0);
let prompt = (chars / 4) as f64;
prompt + req.max_tokens.unwrap_or(256) as f64
}
async fn acquire(&self, cost: f64) {
let wait = {
let mut bucket = self.bucket.lock().await;
bucket.time_until_available(cost)
};
if !wait.is_zero() {
tokio::time::sleep(wait).await;
}
}
}
#[async_trait]
impl LlmProvider for RateLimit {
async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
self.acquire(Self::estimated_cost(&req)).await;
self.inner.chat(req).await
}
async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
self.acquire(Self::estimated_cost(&req)).await;
self.inner.chat_stream(req).await
}
async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
let cost: f64 = req.input.iter().map(|s| (s.len() / 4) as f64).sum();
self.acquire(cost).await;
self.inner.embed(req).await
}
fn name(&self) -> &'static str {
self.inner.name()
}
fn model(&self) -> &str {
self.inner.model()
}
fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
self.inner.estimate_cost(req)
}
}