lash_core/provider/
rate_limit.rs1use super::support::*;
2
3#[derive(Debug)]
4pub struct ProviderRateLimiter {
5 state: Mutex<ProviderRateLimiterState>,
6 clock: Arc<dyn crate::Clock>,
7}
8
9#[derive(Debug)]
10struct ProviderRateLimiterState {
11 policy: ProviderRateLimitPolicy,
12 semaphore: Option<Arc<tokio::sync::Semaphore>>,
13 request_bucket: WindowBucket,
14 token_bucket: WindowBucket,
15}
16
17#[derive(Clone, Debug)]
18struct WindowBucket {
19 used: u32,
20 reset_at: std::time::Instant,
21}
22
23impl WindowBucket {
24 fn new(reset_at: std::time::Instant) -> Self {
25 Self { used: 0, reset_at }
26 }
27}
28
29#[derive(Debug)]
30pub struct ProviderRateLimitPermit {
31 _concurrency: Option<tokio::sync::OwnedSemaphorePermit>,
32}
33
34impl ProviderRateLimiter {
35 pub fn new(policy: ProviderRateLimitPolicy) -> Self {
36 Self::with_clock(policy, Arc::new(crate::SystemClock))
37 }
38
39 pub fn with_clock(policy: ProviderRateLimitPolicy, clock: Arc<dyn crate::Clock>) -> Self {
40 let semaphore = policy
41 .max_concurrency
42 .filter(|limit| *limit > 0)
43 .map(tokio::sync::Semaphore::new)
44 .map(Arc::new);
45 let now = clock.now();
46 Self {
47 state: Mutex::new(ProviderRateLimiterState {
48 policy,
49 semaphore,
50 request_bucket: WindowBucket::new(now),
51 token_bucket: WindowBucket::new(now),
52 }),
53 clock,
54 }
55 }
56
57 pub fn configure(&self, policy: ProviderRateLimitPolicy) {
58 let mut state = self.state.lock().expect("provider rate limiter lock");
59 if state.policy.max_concurrency != policy.max_concurrency {
60 state.semaphore = policy
61 .max_concurrency
62 .filter(|limit| *limit > 0)
63 .map(tokio::sync::Semaphore::new)
64 .map(Arc::new);
65 }
66 state.policy = policy;
67 }
68
69 pub fn clock(&self) -> Arc<dyn crate::Clock> {
70 Arc::clone(&self.clock)
71 }
72
73 pub async fn admit(&self, request: &LlmRequest) -> ProviderRateLimitPermit {
74 let semaphore = self
75 .state
76 .lock()
77 .expect("provider rate limiter lock")
78 .semaphore
79 .clone();
80 let concurrency = match semaphore {
81 Some(semaphore) => Some(semaphore.acquire_owned().await.expect("semaphore open")),
82 None => None,
83 };
84 self.wait_for_buckets(1, estimate_request_tokens(request))
85 .await;
86 ProviderRateLimitPermit {
87 _concurrency: concurrency,
88 }
89 }
90
91 async fn wait_for_buckets(&self, requests: u32, tokens: u32) {
92 loop {
93 let wait = {
94 let mut state = self.state.lock().expect("provider rate limiter lock");
95 let now = self.clock.now();
96 let policy = state.policy.clone();
97 let request_wait = bucket_wait(
98 &mut state.request_bucket,
99 now,
100 policy.requests_per_window,
101 policy.request_window_ms,
102 requests,
103 );
104 let token_wait = bucket_wait(
105 &mut state.token_bucket,
106 now,
107 policy.tokens_per_window,
108 policy.token_window_ms,
109 tokens,
110 );
111 match (request_wait, token_wait) {
112 (None, None) => return,
113 (Some(a), Some(b)) => Some(a.max(b)),
114 (Some(a), None) | (None, Some(a)) => Some(a),
115 }
116 };
117 if let Some(wait) = wait {
118 self.clock.sleep(wait).await;
119 }
120 }
121 }
122}
123
124fn bucket_wait(
125 bucket: &mut WindowBucket,
126 now: std::time::Instant,
127 limit: Option<u32>,
128 window_ms: Option<u64>,
129 cost: u32,
130) -> Option<Duration> {
131 let limit = limit.filter(|limit| *limit > 0)?;
132 let window = Duration::from_millis(window_ms.unwrap_or(60_000).max(1));
133 if now >= bucket.reset_at {
134 bucket.used = 0;
135 bucket.reset_at = now + window;
136 }
137 if bucket.used.saturating_add(cost.min(limit)) <= limit {
138 bucket.used = bucket.used.saturating_add(cost.min(limit));
139 None
140 } else {
141 Some(bucket.reset_at.saturating_duration_since(now))
142 }
143}
144
145fn estimate_request_tokens(request: &LlmRequest) -> u32 {
146 let mut chars = request.model.len();
147 for message in &request.messages {
148 for block in message.blocks.iter() {
149 match block {
150 LlmContentBlock::Text { text, .. } => chars += text.len(),
151 LlmContentBlock::ToolCall { input_json, .. } => chars += input_json.len(),
152 LlmContentBlock::ToolResult { content, .. } => chars += content.len(),
153 LlmContentBlock::Reasoning { text, .. } => chars += text.len(),
154 LlmContentBlock::Image { .. } => chars += 256,
155 }
156 }
157 }
158 chars = chars.saturating_add(request.attachments.iter().map(|a| a.data.len() / 4).sum());
159 ((chars / 4).max(1)).try_into().unwrap_or(u32::MAX)
160}