lash_core/provider/
rate_limit.rs1use super::support::*;
2
3#[derive(Debug)]
4pub struct ProviderRateLimiter {
5 state: Mutex<ProviderRateLimiterState>,
6}
7
8#[derive(Debug)]
9struct ProviderRateLimiterState {
10 policy: ProviderRateLimitPolicy,
11 semaphore: Option<Arc<tokio::sync::Semaphore>>,
12 request_bucket: WindowBucket,
13 token_bucket: WindowBucket,
14}
15
16#[derive(Clone, Debug)]
17struct WindowBucket {
18 used: u32,
19 reset_at: Instant,
20}
21
22impl Default for WindowBucket {
23 fn default() -> Self {
24 Self {
25 used: 0,
26 reset_at: Instant::now(),
27 }
28 }
29}
30
31#[derive(Debug)]
32pub struct ProviderRateLimitPermit {
33 _concurrency: Option<tokio::sync::OwnedSemaphorePermit>,
34}
35
36impl ProviderRateLimiter {
37 pub fn new(policy: ProviderRateLimitPolicy) -> Self {
38 let semaphore = policy
39 .max_concurrency
40 .filter(|limit| *limit > 0)
41 .map(tokio::sync::Semaphore::new)
42 .map(Arc::new);
43 Self {
44 state: Mutex::new(ProviderRateLimiterState {
45 policy,
46 semaphore,
47 request_bucket: WindowBucket::default(),
48 token_bucket: WindowBucket::default(),
49 }),
50 }
51 }
52
53 pub fn configure(&self, policy: ProviderRateLimitPolicy) {
54 let mut state = self.state.lock().expect("provider rate limiter lock");
55 if state.policy.max_concurrency != policy.max_concurrency {
56 state.semaphore = policy
57 .max_concurrency
58 .filter(|limit| *limit > 0)
59 .map(tokio::sync::Semaphore::new)
60 .map(Arc::new);
61 }
62 state.policy = policy;
63 }
64
65 pub async fn admit(&self, request: &LlmRequest) -> ProviderRateLimitPermit {
66 let semaphore = self
67 .state
68 .lock()
69 .expect("provider rate limiter lock")
70 .semaphore
71 .clone();
72 let concurrency = match semaphore {
73 Some(semaphore) => Some(semaphore.acquire_owned().await.expect("semaphore open")),
74 None => None,
75 };
76 self.wait_for_buckets(1, estimate_request_tokens(request))
77 .await;
78 ProviderRateLimitPermit {
79 _concurrency: concurrency,
80 }
81 }
82
83 async fn wait_for_buckets(&self, requests: u32, tokens: u32) {
84 loop {
85 let wait = {
86 let mut state = self.state.lock().expect("provider rate limiter lock");
87 let now = Instant::now();
88 let policy = state.policy.clone();
89 let request_wait = bucket_wait(
90 &mut state.request_bucket,
91 now,
92 policy.requests_per_window,
93 policy.request_window_ms,
94 requests,
95 );
96 let token_wait = bucket_wait(
97 &mut state.token_bucket,
98 now,
99 policy.tokens_per_window,
100 policy.token_window_ms,
101 tokens,
102 );
103 match (request_wait, token_wait) {
104 (None, None) => return,
105 (Some(a), Some(b)) => Some(a.max(b)),
106 (Some(a), None) | (None, Some(a)) => Some(a),
107 }
108 };
109 if let Some(wait) = wait {
110 tokio::time::sleep(wait).await;
111 }
112 }
113 }
114}
115
116fn bucket_wait(
117 bucket: &mut WindowBucket,
118 now: Instant,
119 limit: Option<u32>,
120 window_ms: Option<u64>,
121 cost: u32,
122) -> Option<Duration> {
123 let limit = limit.filter(|limit| *limit > 0)?;
124 let window = Duration::from_millis(window_ms.unwrap_or(60_000).max(1));
125 if now >= bucket.reset_at {
126 bucket.used = 0;
127 bucket.reset_at = now + window;
128 }
129 if bucket.used.saturating_add(cost.min(limit)) <= limit {
130 bucket.used = bucket.used.saturating_add(cost.min(limit));
131 None
132 } else {
133 Some(bucket.reset_at.saturating_duration_since(now))
134 }
135}
136
137fn estimate_request_tokens(request: &LlmRequest) -> u32 {
138 let mut chars = request.model.len();
139 for message in &request.messages {
140 for block in message.blocks.iter() {
141 match block {
142 LlmContentBlock::Text { text, .. } => chars += text.len(),
143 LlmContentBlock::ToolCall { input_json, .. } => chars += input_json.len(),
144 LlmContentBlock::ToolResult { content, .. } => chars += content.len(),
145 LlmContentBlock::Reasoning { text, .. } => chars += text.len(),
146 LlmContentBlock::Image { .. } => chars += 256,
147 }
148 }
149 }
150 chars = chars.saturating_add(request.attachments.iter().map(|a| a.data.len() / 4).sum());
151 ((chars / 4).max(1)).try_into().unwrap_or(u32::MAX)
152}