Skip to main content

crabllm_proxy/ext/
rate_limit.rs

1use crate::{PREFIX_KEYS, PREFIX_RATE_LIMIT};
2use crabllm_core::{
3    BoxFuture, ExtensionError, KeyConfig, KeyRateLimit, RequestContext, Storage, storage_key,
4};
5use std::{
6    sync::Arc,
7    time::{SystemTime, UNIX_EPOCH},
8};
9
10pub struct RateLimit {
11    storage: Arc<dyn Storage>,
12    requests_per_minute: u64,
13    tokens_per_minute: Option<u64>,
14}
15
16impl RateLimit {
17    pub fn new(config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
18        let rpm = config
19            .get("requests_per_minute")
20            .and_then(|v| v.as_i64())
21            .ok_or("rate_limit: missing or invalid 'requests_per_minute'")?;
22
23        if rpm <= 0 {
24            return Err("rate_limit: 'requests_per_minute' must be positive".to_string());
25        }
26
27        let tpm = config
28            .get("tokens_per_minute")
29            .and_then(|v| v.as_i64())
30            .map(|v| {
31                if v <= 0 {
32                    Err("rate_limit: 'tokens_per_minute' must be positive".to_string())
33                } else {
34                    Ok(v as u64)
35                }
36            })
37            .transpose()?;
38
39        Ok(Self {
40            storage,
41            requests_per_minute: rpm as u64,
42            tokens_per_minute: tpm,
43        })
44    }
45
46    /// Look up per-key rate limit from storage. Returns the per-key
47    /// override merged with global defaults, or the global defaults
48    /// if the key has no override.
49    async fn limits_for(&self, principal: &str) -> (u64, Option<u64>) {
50        if principal == "__global" {
51            return (self.requests_per_minute, self.tokens_per_minute);
52        }
53
54        let skey = storage_key(&PREFIX_KEYS, principal.as_bytes());
55        let rl = self
56            .storage
57            .get(&skey)
58            .await
59            .ok()
60            .flatten()
61            .and_then(|bytes| serde_json::from_slice::<KeyConfig>(&bytes).ok())
62            .and_then(|kc| kc.rate_limit);
63
64        match rl {
65            Some(KeyRateLimit {
66                requests_per_minute,
67                tokens_per_minute,
68            }) => (
69                requests_per_minute.unwrap_or(self.requests_per_minute),
70                tokens_per_minute.or(self.tokens_per_minute),
71            ),
72            None => (self.requests_per_minute, self.tokens_per_minute),
73        }
74    }
75}
76
77fn current_minute() -> u64 {
78    SystemTime::now()
79        .duration_since(UNIX_EPOCH)
80        .unwrap_or_default()
81        .as_secs()
82        / 60
83}
84
85impl crabllm_core::Extension for RateLimit {
86    fn name(&self) -> &str {
87        "rate_limit"
88    }
89
90    fn prefix(&self) -> crabllm_core::Prefix {
91        PREFIX_RATE_LIMIT
92    }
93
94    fn on_request(&self, ctx: &RequestContext) -> BoxFuture<'_, Result<(), ExtensionError>> {
95        let principal = ctx.principal.as_deref().unwrap_or("__global").to_string();
96
97        Box::pin(async move {
98            let (rpm_limit, tpm_limit) = self.limits_for(&principal).await;
99            let minute = current_minute();
100
101            // Check RPM.
102            let rpm_suffix = format!("{principal}:{minute}");
103            let rpm_key = self.storage_key(rpm_suffix.as_bytes());
104            let count = self
105                .storage
106                .increment(&rpm_key, 1)
107                .await
108                .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
109
110            if count as u64 > rpm_limit {
111                return Err(ExtensionError::new(
112                    429,
113                    "rate limit exceeded (RPM)",
114                    "rate_limit_error",
115                ));
116            }
117
118            // Check TPM.
119            if let Some(limit) = tpm_limit {
120                let tpm_suffix = format!("{principal}:tpm:{minute}");
121                let tpm_key = self.storage_key(tpm_suffix.as_bytes());
122                let tokens = self
123                    .storage
124                    .increment(&tpm_key, 0)
125                    .await
126                    .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
127
128                if tokens as u64 > limit {
129                    return Err(ExtensionError::new(
130                        429,
131                        "rate limit exceeded (TPM)",
132                        "rate_limit_error",
133                    ));
134                }
135            }
136
137            Ok(())
138        })
139    }
140
141    fn on_response(
142        &self,
143        ctx: &RequestContext,
144        _raw_request: &[u8],
145        raw_response: &[u8],
146    ) -> BoxFuture<'_, ()> {
147        let usage = crabllm_core::Usage::from(raw_response);
148        let total_tokens = usage.total_tokens() as i64;
149        if total_tokens == 0 {
150            return Box::pin(async {});
151        }
152
153        let principal = ctx.principal.as_deref().unwrap_or("__global");
154        let minute = current_minute();
155        let tpm_suffix = format!("{principal}:tpm:{minute}");
156        let tpm_key = self.storage_key(tpm_suffix.as_bytes());
157
158        Box::pin(async move {
159            let _ = self.storage.increment(&tpm_key, total_tokens).await;
160        })
161    }
162
163    fn on_chunk(&self, ctx: &RequestContext, raw_chunk: &[u8]) -> BoxFuture<'_, ()> {
164        let usage = crabllm_core::Usage::from(raw_chunk);
165        let total_tokens = usage.total_tokens() as i64;
166        if total_tokens == 0 {
167            return Box::pin(async {});
168        }
169
170        let principal = ctx.principal.as_deref().unwrap_or("__global");
171        let minute = current_minute();
172        let tpm_suffix = format!("{principal}:tpm:{minute}");
173        let tpm_key = self.storage_key(tpm_suffix.as_bytes());
174
175        Box::pin(async move {
176            let _ = self.storage.increment(&tpm_key, total_tokens).await;
177        })
178    }
179}