1use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Mutex;
9use std::time::Instant;
10
11use axum::extract::{ConnectInfo, Request};
12use axum::http::StatusCode;
13use axum::middleware::Next;
14use axum::response::{IntoResponse, Response};
15use serde::Serialize;
16
17#[derive(Debug, Clone)]
19pub struct RateLimitConfig {
20 pub capacity: u32,
22 pub refill_per_second: f64,
24}
25
26impl Default for RateLimitConfig {
27 fn default() -> Self {
28 Self {
29 capacity: 1000,
31 refill_per_second: 1000.0 / 60.0,
32 }
33 }
34}
35
36struct Bucket {
38 tokens: f64,
39 last_refill: Instant,
40}
41
42impl Bucket {
43 fn new(capacity: u32) -> Self {
44 Self {
45 tokens: capacity as f64,
46 last_refill: Instant::now(),
47 }
48 }
49
50 fn try_consume(&mut self, config: &RateLimitConfig) -> Option<u32> {
52 let now = Instant::now();
54 let elapsed = now.duration_since(self.last_refill).as_secs_f64();
55 self.tokens =
56 (self.tokens + elapsed * config.refill_per_second).min(config.capacity as f64);
57 self.last_refill = now;
58
59 if self.tokens >= 1.0 {
60 self.tokens -= 1.0;
61 Some(self.tokens as u32)
62 } else {
63 None
64 }
65 }
66
67 fn retry_after(&self, config: &RateLimitConfig) -> u32 {
69 let needed = 1.0 - self.tokens;
70 if needed <= 0.0 {
71 return 0;
72 }
73 (needed / config.refill_per_second).ceil() as u32
74 }
75}
76
77pub struct RateLimiter {
79 config: RateLimitConfig,
80 buckets: Mutex<HashMap<IpAddr, Bucket>>,
81}
82
83impl RateLimiter {
84 pub fn new(config: RateLimitConfig) -> Self {
85 Self {
86 config,
87 buckets: Mutex::new(HashMap::new()),
88 }
89 }
90
91 pub fn check(&self, ip: IpAddr) -> Result<u32, u32> {
94 let mut buckets = self.buckets.lock().unwrap();
95 let bucket = buckets
96 .entry(ip)
97 .or_insert_with(|| Bucket::new(self.config.capacity));
98
99 match bucket.try_consume(&self.config) {
100 Some(remaining) => Ok(remaining),
101 None => Err(bucket.retry_after(&self.config)),
102 }
103 }
104
105 pub fn capacity(&self) -> u32 {
107 self.config.capacity
108 }
109
110 pub fn cleanup(&self) {
113 let mut buckets = self.buckets.lock().unwrap();
114 let config = &self.config;
115 buckets.retain(|_, bucket| {
116 let elapsed = bucket.last_refill.elapsed().as_secs_f64();
117 let tokens = bucket.tokens + elapsed * config.refill_per_second;
118 tokens < config.capacity as f64
120 });
121 }
122}
123
124#[derive(Serialize)]
126struct RateLimitExceeded {
127 error: String,
128 message: String,
129 retry_after: u32,
130}
131
132fn extract_client_ip(request: &Request) -> IpAddr {
134 if let Some(forwarded) = request.headers().get("x-forwarded-for") {
136 if let Ok(value) = forwarded.to_str() {
137 if let Some(first) = value.split(',').next() {
139 if let Ok(ip) = first.trim().parse::<IpAddr>() {
140 return ip;
141 }
142 }
143 }
144 }
145
146 if let Some(real_ip) = request.headers().get("x-real-ip") {
148 if let Ok(value) = real_ip.to_str() {
149 if let Ok(ip) = value.trim().parse::<IpAddr>() {
150 return ip;
151 }
152 }
153 }
154
155 if let Some(connect_info) = request
157 .extensions()
158 .get::<ConnectInfo<std::net::SocketAddr>>()
159 {
160 return connect_info.0.ip();
161 }
162
163 IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
165}
166
167pub async fn rate_limit_middleware(
172 axum::extract::State(limiter): axum::extract::State<std::sync::Arc<RateLimiter>>,
173 request: Request,
174 next: Next,
175) -> Response {
176 let client_ip = extract_client_ip(&request);
177
178 match limiter.check(client_ip) {
179 Ok(remaining) => {
180 let mut response = next.run(request).await;
181 let headers = response.headers_mut();
182 headers.insert(
183 "x-ratelimit-limit",
184 limiter.capacity().to_string().parse().unwrap(),
185 );
186 headers.insert(
187 "x-ratelimit-remaining",
188 remaining.to_string().parse().unwrap(),
189 );
190 response
191 }
192 Err(retry_after) => {
193 let body = RateLimitExceeded {
194 error: "rate_limit_exceeded".to_string(),
195 message: format!(
196 "rate limit exceeded: {} requests per minute",
197 limiter.capacity()
198 ),
199 retry_after,
200 };
201 let mut response = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
202 let headers = response.headers_mut();
203 headers.insert("retry-after", retry_after.to_string().parse().unwrap());
204 headers.insert(
205 "x-ratelimit-limit",
206 limiter.capacity().to_string().parse().unwrap(),
207 );
208 headers.insert("x-ratelimit-remaining", "0".parse().unwrap());
209 response
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn bucket_allows_up_to_capacity() {
220 let config = RateLimitConfig {
221 capacity: 3,
222 refill_per_second: 1.0,
223 };
224 let mut bucket = Bucket::new(3);
225
226 assert!(bucket.try_consume(&config).is_some()); assert!(bucket.try_consume(&config).is_some()); assert!(bucket.try_consume(&config).is_some()); assert!(bucket.try_consume(&config).is_none()); }
231
232 #[test]
233 fn rate_limiter_per_ip_isolation() {
234 let limiter = RateLimiter::new(RateLimitConfig {
235 capacity: 2,
236 refill_per_second: 0.0, });
238
239 let ip1: IpAddr = "1.1.1.1".parse().unwrap();
240 let ip2: IpAddr = "2.2.2.2".parse().unwrap();
241
242 assert!(limiter.check(ip1).is_ok());
243 assert!(limiter.check(ip1).is_ok());
244 assert!(limiter.check(ip1).is_err()); assert!(limiter.check(ip2).is_ok()); }
248
249 #[test]
250 fn retry_after_calculated() {
251 let config = RateLimitConfig {
252 capacity: 1,
253 refill_per_second: 1.0,
254 };
255 let mut bucket = Bucket::new(1);
256
257 bucket.try_consume(&config); let retry = bucket.retry_after(&config);
259 assert!(retry >= 1, "should need at least 1 second");
260 }
261}