Skip to main content

llmkit_tower/
rate_limit.rs

1//! Token-bucket rate limiting, per provider.
2
3use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use llmkit_core::{
8    ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmProvider,
9    LlmResult,
10};
11use tokio::sync::Mutex;
12
13use crate::layer::LlmLayer;
14
15/// Token-bucket rate limiter: `capacity` tokens refilling over `window`.
16#[derive(Debug, Clone, Copy)]
17pub struct RateLimitLayer {
18    capacity: f64,
19    refill_per_sec: f64,
20}
21
22impl RateLimitLayer {
23    /// `capacity` tokens that fully refill over `window`.
24    ///
25    /// e.g. `token_bucket(60_000, Duration::from_secs(60))` ≈ 60k tokens/min.
26    pub fn token_bucket(capacity: u64, window: Duration) -> Self {
27        let secs = window.as_secs_f64().max(f64::MIN_POSITIVE);
28        Self { capacity: capacity as f64, refill_per_sec: capacity as f64 / secs }
29    }
30}
31
32impl LlmLayer for RateLimitLayer {
33    type Provider = RateLimit;
34    fn layer(self, inner: Arc<dyn LlmProvider>) -> RateLimit {
35        RateLimit {
36            inner,
37            bucket: Arc::new(Mutex::new(Bucket {
38                tokens: self.capacity,
39                capacity: self.capacity,
40                refill_per_sec: self.refill_per_sec,
41                last: Instant::now(),
42            })),
43        }
44    }
45}
46
47struct Bucket {
48    tokens: f64,
49    capacity: f64,
50    refill_per_sec: f64,
51    last: Instant,
52}
53
54impl Bucket {
55    fn refill(&mut self) {
56        let now = Instant::now();
57        let elapsed = now.duration_since(self.last).as_secs_f64();
58        self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
59        self.last = now;
60    }
61
62    /// Time to wait until `cost` tokens are available, then deduct them.
63    fn time_until_available(&mut self, cost: f64) -> Duration {
64        self.refill();
65        let cost = cost.min(self.capacity);
66        if self.tokens >= cost {
67            self.tokens -= cost;
68            Duration::ZERO
69        } else {
70            let deficit = cost - self.tokens;
71            let wait = deficit / self.refill_per_sec;
72            self.tokens = 0.0;
73            // The deficit is paid off by the future refill we're waiting for.
74            Duration::from_secs_f64(wait)
75        }
76    }
77}
78
79/// Provider produced by [`RateLimitLayer`].
80pub struct RateLimit {
81    inner: Arc<dyn LlmProvider>,
82    bucket: Arc<Mutex<Bucket>>,
83}
84
85impl RateLimit {
86    /// Estimate the token cost of a request for bucket accounting.
87    fn estimated_cost(req: &ChatRequest) -> f64 {
88        let chars: usize = req
89            .messages
90            .iter()
91            .filter_map(|m| m.content.as_text())
92            .map(|t| t.len())
93            .sum::<usize>()
94            + req.system.as_deref().map(str::len).unwrap_or(0);
95        let prompt = (chars / 4) as f64;
96        prompt + req.max_tokens.unwrap_or(256) as f64
97    }
98
99    async fn acquire(&self, cost: f64) {
100        let wait = {
101            let mut bucket = self.bucket.lock().await;
102            bucket.time_until_available(cost)
103        };
104        if !wait.is_zero() {
105            tokio::time::sleep(wait).await;
106        }
107    }
108}
109
110#[async_trait]
111impl LlmProvider for RateLimit {
112    async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
113        self.acquire(Self::estimated_cost(&req)).await;
114        self.inner.chat(req).await
115    }
116
117    async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
118        self.acquire(Self::estimated_cost(&req)).await;
119        self.inner.chat_stream(req).await
120    }
121
122    async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
123        let cost: f64 = req.input.iter().map(|s| (s.len() / 4) as f64).sum();
124        self.acquire(cost).await;
125        self.inner.embed(req).await
126    }
127
128    fn name(&self) -> &'static str {
129        self.inner.name()
130    }
131
132    fn model(&self) -> &str {
133        self.inner.model()
134    }
135
136    fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
137        self.inner.estimate_cost(req)
138    }
139}