Skip to main content

lash_core/provider/
rate_limit.rs

1use super::support::*;
2
3#[derive(Debug)]
4pub struct ProviderRateLimiter {
5    state: Mutex<ProviderRateLimiterState>,
6}
7
8#[derive(Debug)]
9struct ProviderRateLimiterState {
10    policy: ProviderRateLimitPolicy,
11    semaphore: Option<Arc<tokio::sync::Semaphore>>,
12    request_bucket: WindowBucket,
13    token_bucket: WindowBucket,
14}
15
16#[derive(Clone, Debug)]
17struct WindowBucket {
18    used: u32,
19    reset_at: Instant,
20}
21
22impl Default for WindowBucket {
23    fn default() -> Self {
24        Self {
25            used: 0,
26            reset_at: Instant::now(),
27        }
28    }
29}
30
31#[derive(Debug)]
32pub struct ProviderRateLimitPermit {
33    _concurrency: Option<tokio::sync::OwnedSemaphorePermit>,
34}
35
36impl ProviderRateLimiter {
37    pub fn new(policy: ProviderRateLimitPolicy) -> Self {
38        let semaphore = policy
39            .max_concurrency
40            .filter(|limit| *limit > 0)
41            .map(tokio::sync::Semaphore::new)
42            .map(Arc::new);
43        Self {
44            state: Mutex::new(ProviderRateLimiterState {
45                policy,
46                semaphore,
47                request_bucket: WindowBucket::default(),
48                token_bucket: WindowBucket::default(),
49            }),
50        }
51    }
52
53    pub fn configure(&self, policy: ProviderRateLimitPolicy) {
54        let mut state = self.state.lock().expect("provider rate limiter lock");
55        if state.policy.max_concurrency != policy.max_concurrency {
56            state.semaphore = policy
57                .max_concurrency
58                .filter(|limit| *limit > 0)
59                .map(tokio::sync::Semaphore::new)
60                .map(Arc::new);
61        }
62        state.policy = policy;
63    }
64
65    pub async fn admit(&self, request: &LlmRequest) -> ProviderRateLimitPermit {
66        let semaphore = self
67            .state
68            .lock()
69            .expect("provider rate limiter lock")
70            .semaphore
71            .clone();
72        let concurrency = match semaphore {
73            Some(semaphore) => Some(semaphore.acquire_owned().await.expect("semaphore open")),
74            None => None,
75        };
76        self.wait_for_buckets(1, estimate_request_tokens(request))
77            .await;
78        ProviderRateLimitPermit {
79            _concurrency: concurrency,
80        }
81    }
82
83    async fn wait_for_buckets(&self, requests: u32, tokens: u32) {
84        loop {
85            let wait = {
86                let mut state = self.state.lock().expect("provider rate limiter lock");
87                let now = Instant::now();
88                let policy = state.policy.clone();
89                let request_wait = bucket_wait(
90                    &mut state.request_bucket,
91                    now,
92                    policy.requests_per_window,
93                    policy.request_window_ms,
94                    requests,
95                );
96                let token_wait = bucket_wait(
97                    &mut state.token_bucket,
98                    now,
99                    policy.tokens_per_window,
100                    policy.token_window_ms,
101                    tokens,
102                );
103                match (request_wait, token_wait) {
104                    (None, None) => return,
105                    (Some(a), Some(b)) => Some(a.max(b)),
106                    (Some(a), None) | (None, Some(a)) => Some(a),
107                }
108            };
109            if let Some(wait) = wait {
110                tokio::time::sleep(wait).await;
111            }
112        }
113    }
114}
115
116fn bucket_wait(
117    bucket: &mut WindowBucket,
118    now: Instant,
119    limit: Option<u32>,
120    window_ms: Option<u64>,
121    cost: u32,
122) -> Option<Duration> {
123    let limit = limit.filter(|limit| *limit > 0)?;
124    let window = Duration::from_millis(window_ms.unwrap_or(60_000).max(1));
125    if now >= bucket.reset_at {
126        bucket.used = 0;
127        bucket.reset_at = now + window;
128    }
129    if bucket.used.saturating_add(cost.min(limit)) <= limit {
130        bucket.used = bucket.used.saturating_add(cost.min(limit));
131        None
132    } else {
133        Some(bucket.reset_at.saturating_duration_since(now))
134    }
135}
136
137fn estimate_request_tokens(request: &LlmRequest) -> u32 {
138    let mut chars = request.model.len();
139    for message in &request.messages {
140        for block in message.blocks.iter() {
141            match block {
142                LlmContentBlock::Text { text, .. } => chars += text.len(),
143                LlmContentBlock::ToolCall { input_json, .. } => chars += input_json.len(),
144                LlmContentBlock::ToolResult { content, .. } => chars += content.len(),
145                LlmContentBlock::Reasoning { text, .. } => chars += text.len(),
146                LlmContentBlock::Image { .. } => chars += 256,
147            }
148        }
149    }
150    chars = chars.saturating_add(request.attachments.iter().map(|a| a.data.len() / 4).sum());
151    ((chars / 4).max(1)).try_into().unwrap_or(u32::MAX)
152}