crabtalk_proxy/ext/
rate_limit.rs1use crabtalk_core::{
2 BoxFuture, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, ExtensionError,
3 Prefix, RequestContext, Storage,
4};
5use std::{
6 sync::Arc,
7 time::{SystemTime, UNIX_EPOCH},
8};
9
10pub struct RateLimit {
11 storage: Arc<dyn Storage>,
12 requests_per_minute: u64,
13 tokens_per_minute: Option<u64>,
14}
15
16impl RateLimit {
17 pub fn new(config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
18 let rpm = config
19 .get("requests_per_minute")
20 .and_then(|v| v.as_i64())
21 .ok_or("rate_limit: missing or invalid 'requests_per_minute'")?;
22
23 if rpm <= 0 {
24 return Err("rate_limit: 'requests_per_minute' must be positive".to_string());
25 }
26
27 let tpm = config
28 .get("tokens_per_minute")
29 .and_then(|v| v.as_i64())
30 .map(|v| {
31 if v <= 0 {
32 Err("rate_limit: 'tokens_per_minute' must be positive".to_string())
33 } else {
34 Ok(v as u64)
35 }
36 })
37 .transpose()?;
38
39 Ok(Self {
40 storage,
41 requests_per_minute: rpm as u64,
42 tokens_per_minute: tpm,
43 })
44 }
45}
46
47fn current_minute() -> u64 {
48 SystemTime::now()
49 .duration_since(UNIX_EPOCH)
50 .unwrap_or_default()
51 .as_secs()
52 / 60
53}
54
55impl crabtalk_core::Extension for RateLimit {
56 fn name(&self) -> &str {
57 "rate_limit"
58 }
59
60 fn prefix(&self) -> Prefix {
61 *b"rlim"
62 }
63
64 fn on_request(&self, ctx: &RequestContext) -> BoxFuture<'_, Result<(), ExtensionError>> {
65 let key_name = ctx.key_name.as_deref().unwrap_or("__global");
66 let minute = current_minute();
67
68 let rpm_suffix = format!("{key_name}:{minute}");
69 let rpm_key = self.storage_key(rpm_suffix.as_bytes());
70 let rpm_limit = self.requests_per_minute;
71
72 let tpm_limit = self.tokens_per_minute;
73 let tpm_key = tpm_limit.map(|_| {
74 let tpm_suffix = format!("{key_name}:tpm:{minute}");
75 self.storage_key(tpm_suffix.as_bytes())
76 });
77
78 Box::pin(async move {
79 let count = self
81 .storage
82 .increment(&rpm_key, 1)
83 .await
84 .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
85
86 if count as u64 > rpm_limit {
87 return Err(ExtensionError::new(
88 429,
89 "rate limit exceeded (RPM)",
90 "rate_limit_error",
91 ));
92 }
93
94 if let (Some(limit), Some(key)) = (tpm_limit, &tpm_key) {
96 let tokens = self
97 .storage
98 .increment(key, 0)
99 .await
100 .map_err(|e| ExtensionError::new(500, e.to_string(), "server_error"))?;
101
102 if tokens as u64 > limit {
103 return Err(ExtensionError::new(
104 429,
105 "rate limit exceeded (TPM)",
106 "rate_limit_error",
107 ));
108 }
109 }
110
111 Ok(())
112 })
113 }
114
115 fn on_response(
116 &self,
117 ctx: &RequestContext,
118 _request: &ChatCompletionRequest,
119 response: &ChatCompletionResponse,
120 ) -> BoxFuture<'_, ()> {
121 if self.tokens_per_minute.is_none() {
122 return Box::pin(async {});
123 }
124
125 let total_tokens = response
126 .usage
127 .as_ref()
128 .map(|u| u.total_tokens as i64)
129 .unwrap_or(0);
130
131 if total_tokens == 0 {
132 return Box::pin(async {});
133 }
134
135 let key_name = ctx.key_name.as_deref().unwrap_or("__global");
136 let minute = current_minute();
137 let tpm_suffix = format!("{key_name}:tpm:{minute}");
138 let tpm_key = self.storage_key(tpm_suffix.as_bytes());
139
140 Box::pin(async move {
141 let _ = self.storage.increment(&tpm_key, total_tokens).await;
142 })
143 }
144
145 fn on_chunk(&self, ctx: &RequestContext, chunk: &ChatCompletionChunk) -> BoxFuture<'_, ()> {
146 if self.tokens_per_minute.is_none() {
147 return Box::pin(async {});
148 }
149
150 let total_tokens = chunk
151 .usage
152 .as_ref()
153 .map(|u| u.total_tokens as i64)
154 .unwrap_or(0);
155
156 if total_tokens == 0 {
157 return Box::pin(async {});
158 }
159
160 let key_name = ctx.key_name.as_deref().unwrap_or("__global");
161 let minute = current_minute();
162 let tpm_suffix = format!("{key_name}:tpm:{minute}");
163 let tpm_key = self.storage_key(tpm_suffix.as_bytes());
164
165 Box::pin(async move {
166 let _ = self.storage.increment(&tpm_key, total_tokens).await;
167 })
168 }
169}