llmkit_tower/
rate_limit.rs1use std::sync::Arc;
4use std::time::{Duration, Instant};
5
6use async_trait::async_trait;
7use llmkit_core::{
8 ChatRequest, ChatResponse, ChatStream, CostEstimate, EmbedRequest, EmbedResponse, LlmProvider,
9 LlmResult,
10};
11use tokio::sync::Mutex;
12
13use crate::layer::LlmLayer;
14
15#[derive(Debug, Clone, Copy)]
17pub struct RateLimitLayer {
18 capacity: f64,
19 refill_per_sec: f64,
20}
21
22impl RateLimitLayer {
23 pub fn token_bucket(capacity: u64, window: Duration) -> Self {
27 let secs = window.as_secs_f64().max(f64::MIN_POSITIVE);
28 Self { capacity: capacity as f64, refill_per_sec: capacity as f64 / secs }
29 }
30}
31
32impl LlmLayer for RateLimitLayer {
33 type Provider = RateLimit;
34 fn layer(self, inner: Arc<dyn LlmProvider>) -> RateLimit {
35 RateLimit {
36 inner,
37 bucket: Arc::new(Mutex::new(Bucket {
38 tokens: self.capacity,
39 capacity: self.capacity,
40 refill_per_sec: self.refill_per_sec,
41 last: Instant::now(),
42 })),
43 }
44 }
45}
46
47struct Bucket {
48 tokens: f64,
49 capacity: f64,
50 refill_per_sec: f64,
51 last: Instant,
52}
53
54impl Bucket {
55 fn refill(&mut self) {
56 let now = Instant::now();
57 let elapsed = now.duration_since(self.last).as_secs_f64();
58 self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
59 self.last = now;
60 }
61
62 fn time_until_available(&mut self, cost: f64) -> Duration {
64 self.refill();
65 let cost = cost.min(self.capacity);
66 if self.tokens >= cost {
67 self.tokens -= cost;
68 Duration::ZERO
69 } else {
70 let deficit = cost - self.tokens;
71 let wait = deficit / self.refill_per_sec;
72 self.tokens = 0.0;
73 Duration::from_secs_f64(wait)
75 }
76 }
77}
78
79pub struct RateLimit {
81 inner: Arc<dyn LlmProvider>,
82 bucket: Arc<Mutex<Bucket>>,
83}
84
85impl RateLimit {
86 fn estimated_cost(req: &ChatRequest) -> f64 {
88 let chars: usize = req
89 .messages
90 .iter()
91 .filter_map(|m| m.content.as_text())
92 .map(|t| t.len())
93 .sum::<usize>()
94 + req.system.as_deref().map(str::len).unwrap_or(0);
95 let prompt = (chars / 4) as f64;
96 prompt + req.max_tokens.unwrap_or(256) as f64
97 }
98
99 async fn acquire(&self, cost: f64) {
100 let wait = {
101 let mut bucket = self.bucket.lock().await;
102 bucket.time_until_available(cost)
103 };
104 if !wait.is_zero() {
105 tokio::time::sleep(wait).await;
106 }
107 }
108}
109
110#[async_trait]
111impl LlmProvider for RateLimit {
112 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
113 self.acquire(Self::estimated_cost(&req)).await;
114 self.inner.chat(req).await
115 }
116
117 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
118 self.acquire(Self::estimated_cost(&req)).await;
119 self.inner.chat_stream(req).await
120 }
121
122 async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
123 let cost: f64 = req.input.iter().map(|s| (s.len() / 4) as f64).sum();
124 self.acquire(cost).await;
125 self.inner.embed(req).await
126 }
127
128 fn name(&self) -> &'static str {
129 self.inner.name()
130 }
131
132 fn model(&self) -> &str {
133 self.inner.model()
134 }
135
136 fn estimate_cost(&self, req: &ChatRequest) -> Option<CostEstimate> {
137 self.inner.estimate_cost(req)
138 }
139}