Skip to main content

lash_core/provider/
rate_limit.rs

1use super::support::*;
2
3#[derive(Debug)]
4pub struct ProviderRateLimiter {
5    state: Mutex<ProviderRateLimiterState>,
6    clock: Arc<dyn crate::Clock>,
7}
8
9#[derive(Debug)]
10struct ProviderRateLimiterState {
11    policy: ProviderRateLimitPolicy,
12    semaphore: Option<Arc<tokio::sync::Semaphore>>,
13    request_bucket: WindowBucket,
14    token_bucket: WindowBucket,
15}
16
17#[derive(Clone, Debug)]
18struct WindowBucket {
19    used: u32,
20    reset_at: std::time::Instant,
21}
22
23impl WindowBucket {
24    fn new(reset_at: std::time::Instant) -> Self {
25        Self { used: 0, reset_at }
26    }
27}
28
29#[derive(Debug)]
30pub struct ProviderRateLimitPermit {
31    _concurrency: Option<tokio::sync::OwnedSemaphorePermit>,
32}
33
34impl ProviderRateLimiter {
35    pub fn new(policy: ProviderRateLimitPolicy) -> Self {
36        Self::with_clock(policy, Arc::new(crate::SystemClock))
37    }
38
39    pub fn with_clock(policy: ProviderRateLimitPolicy, clock: Arc<dyn crate::Clock>) -> Self {
40        let semaphore = policy
41            .max_concurrency
42            .filter(|limit| *limit > 0)
43            .map(tokio::sync::Semaphore::new)
44            .map(Arc::new);
45        let now = clock.now();
46        Self {
47            state: Mutex::new(ProviderRateLimiterState {
48                policy,
49                semaphore,
50                request_bucket: WindowBucket::new(now),
51                token_bucket: WindowBucket::new(now),
52            }),
53            clock,
54        }
55    }
56
57    pub fn configure(&self, policy: ProviderRateLimitPolicy) {
58        let mut state = self.state.lock().expect("provider rate limiter lock");
59        if state.policy.max_concurrency != policy.max_concurrency {
60            state.semaphore = policy
61                .max_concurrency
62                .filter(|limit| *limit > 0)
63                .map(tokio::sync::Semaphore::new)
64                .map(Arc::new);
65        }
66        state.policy = policy;
67    }
68
69    pub fn clock(&self) -> Arc<dyn crate::Clock> {
70        Arc::clone(&self.clock)
71    }
72
73    pub async fn admit(&self, request: &LlmRequest) -> ProviderRateLimitPermit {
74        let semaphore = self
75            .state
76            .lock()
77            .expect("provider rate limiter lock")
78            .semaphore
79            .clone();
80        let concurrency = match semaphore {
81            Some(semaphore) => Some(semaphore.acquire_owned().await.expect("semaphore open")),
82            None => None,
83        };
84        self.wait_for_buckets(1, estimate_request_tokens(request))
85            .await;
86        ProviderRateLimitPermit {
87            _concurrency: concurrency,
88        }
89    }
90
91    async fn wait_for_buckets(&self, requests: u32, tokens: u32) {
92        loop {
93            let wait = {
94                let mut state = self.state.lock().expect("provider rate limiter lock");
95                let now = self.clock.now();
96                let policy = state.policy.clone();
97                let request_wait = bucket_wait(
98                    &mut state.request_bucket,
99                    now,
100                    policy.requests_per_window,
101                    policy.request_window_ms,
102                    requests,
103                );
104                let token_wait = bucket_wait(
105                    &mut state.token_bucket,
106                    now,
107                    policy.tokens_per_window,
108                    policy.token_window_ms,
109                    tokens,
110                );
111                match (request_wait, token_wait) {
112                    (None, None) => return,
113                    (Some(a), Some(b)) => Some(a.max(b)),
114                    (Some(a), None) | (None, Some(a)) => Some(a),
115                }
116            };
117            if let Some(wait) = wait {
118                self.clock.sleep(wait).await;
119            }
120        }
121    }
122}
123
124fn bucket_wait(
125    bucket: &mut WindowBucket,
126    now: std::time::Instant,
127    limit: Option<u32>,
128    window_ms: Option<u64>,
129    cost: u32,
130) -> Option<Duration> {
131    let limit = limit.filter(|limit| *limit > 0)?;
132    let window = Duration::from_millis(window_ms.unwrap_or(60_000).max(1));
133    if now >= bucket.reset_at {
134        bucket.used = 0;
135        bucket.reset_at = now + window;
136    }
137    if bucket.used.saturating_add(cost.min(limit)) <= limit {
138        bucket.used = bucket.used.saturating_add(cost.min(limit));
139        None
140    } else {
141        Some(bucket.reset_at.saturating_duration_since(now))
142    }
143}
144
145fn estimate_request_tokens(request: &LlmRequest) -> u32 {
146    let mut chars = request.model.len();
147    for message in &request.messages {
148        for block in message.blocks.iter() {
149            match block {
150                LlmContentBlock::Text { text, .. } => chars += text.len(),
151                LlmContentBlock::ToolCall { input_json, .. } => chars += input_json.len(),
152                LlmContentBlock::ToolResult { content, .. } => chars += content.len(),
153                LlmContentBlock::Reasoning { text, .. } => chars += text.len(),
154                LlmContentBlock::Image { .. } => chars += 256,
155            }
156        }
157    }
158    chars = chars.saturating_add(request.attachments.iter().map(|a| a.data.len() / 4).sum());
159    ((chars / 4).max(1)).try_into().unwrap_or(u32::MAX)
160}