crabllm_proxy/ext/
rate_limit.rs1use crate::{PREFIX_KEYS, PREFIX_RATE_LIMIT};
2use crabllm_core::{
3 BoxFuture, ExtensionError, KeyConfig, KeyRateLimit, RequestContext, Storage, storage_key,
4};
5use std::{
6 sync::Arc,
7 time::{SystemTime, UNIX_EPOCH},
8};
9
10pub struct RateLimit {
11 storage: Arc<dyn Storage>,
12 requests_per_minute: u64,
13 tokens_per_minute: Option<u64>,
14}
15
16impl RateLimit {
17 pub fn new(config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
18 let rpm = config
19 .get("requests_per_minute")
20 .and_then(|v| v.as_i64())
21 .ok_or("rate_limit: missing or invalid 'requests_per_minute'")?;
22
23 if rpm <= 0 {
24 return Err("rate_limit: 'requests_per_minute' must be positive".to_string());
25 }
26
27 let tpm = config
28 .get("tokens_per_minute")
29 .and_then(|v| v.as_i64())
30 .map(|v| {
31 if v <= 0 {
32 Err("rate_limit: 'tokens_per_minute' must be positive".to_string())
33 } else {
34 Ok(v as u64)
35 }
36 })
37 .transpose()?;
38
39 Ok(Self {
40 storage,
41 requests_per_minute: rpm as u64,
42 tokens_per_minute: tpm,
43 })
44 }
45
46 async fn limits_for(&self, principal: &str) -> (u64, Option<u64>) {
50 if principal == "__global" {
51 return (self.requests_per_minute, self.tokens_per_minute);
52 }
53
54 let skey = storage_key(&PREFIX_KEYS, principal.as_bytes());
55 let rl = self
56 .storage
57 .get(&skey)
58 .await
59 .ok()
60 .flatten()
61 .and_then(|bytes| serde_json::from_slice::<KeyConfig>(&bytes).ok())
62 .and_then(|kc| kc.rate_limit);
63
64 match rl {
65 Some(KeyRateLimit {
66 requests_per_minute,
67 tokens_per_minute,
68 }) => (
69 requests_per_minute.unwrap_or(self.requests_per_minute),
70 tokens_per_minute.or(self.tokens_per_minute),
71 ),
72 None => (self.requests_per_minute, self.tokens_per_minute),
73 }
74 }
75}
76
77fn current_minute() -> u64 {
78 SystemTime::now()
79 .duration_since(UNIX_EPOCH)
80 .unwrap_or_default()
81 .as_secs()
82 / 60
83}
84
85impl crabllm_core::Extension for RateLimit {
86 fn name(&self) -> &str {
87 "rate_limit"
88 }
89
90 fn prefix(&self) -> crabllm_core::Prefix {
91 PREFIX_RATE_LIMIT
92 }
93
94 fn on_request(&self, ctx: &RequestContext) -> BoxFuture<'_, Result<(), ExtensionError>> {
95 let principal = ctx.principal.as_deref().unwrap_or("__global").to_string();
96
97 Box::pin(async move {
98 let (rpm_limit, tpm_limit) = self.limits_for(&principal).await;
99 let minute = current_minute();
100
101 let rpm_suffix = format!("{principal}:{minute}");
103 let rpm_key = self.storage_key(rpm_suffix.as_bytes());
104 let count = self
105 .storage
106 .increment(&rpm_key, 1)
107 .await
108 .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
109
110 if count as u64 > rpm_limit {
111 return Err(ExtensionError::new(
112 429,
113 "rate limit exceeded (RPM)",
114 "rate_limit_error",
115 ));
116 }
117
118 if let Some(limit) = tpm_limit {
120 let tpm_suffix = format!("{principal}:tpm:{minute}");
121 let tpm_key = self.storage_key(tpm_suffix.as_bytes());
122 let tokens = self
123 .storage
124 .increment(&tpm_key, 0)
125 .await
126 .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
127
128 if tokens as u64 > limit {
129 return Err(ExtensionError::new(
130 429,
131 "rate limit exceeded (TPM)",
132 "rate_limit_error",
133 ));
134 }
135 }
136
137 Ok(())
138 })
139 }
140
141 fn on_response(
142 &self,
143 ctx: &RequestContext,
144 _raw_request: &[u8],
145 raw_response: &[u8],
146 ) -> BoxFuture<'_, ()> {
147 let usage = crabllm_core::Usage::from(raw_response);
148 let total_tokens = usage.total_tokens() as i64;
149 if total_tokens == 0 {
150 return Box::pin(async {});
151 }
152
153 let principal = ctx.principal.as_deref().unwrap_or("__global");
154 let minute = current_minute();
155 let tpm_suffix = format!("{principal}:tpm:{minute}");
156 let tpm_key = self.storage_key(tpm_suffix.as_bytes());
157
158 Box::pin(async move {
159 let _ = self.storage.increment(&tpm_key, total_tokens).await;
160 })
161 }
162
163 fn on_chunk(&self, ctx: &RequestContext, raw_chunk: &[u8]) -> BoxFuture<'_, ()> {
164 let usage = crabllm_core::Usage::from(raw_chunk);
165 let total_tokens = usage.total_tokens() as i64;
166 if total_tokens == 0 {
167 return Box::pin(async {});
168 }
169
170 let principal = ctx.principal.as_deref().unwrap_or("__global");
171 let minute = current_minute();
172 let tpm_suffix = format!("{principal}:tpm:{minute}");
173 let tpm_key = self.storage_key(tpm_suffix.as_bytes());
174
175 Box::pin(async move {
176 let _ = self.storage.increment(&tpm_key, total_tokens).await;
177 })
178 }
179}