Skip to main content

lago_api/
rate_limit.rs

1//! Simple in-memory token bucket rate limiter for HTTP routes.
2//!
3//! Designed for single-instance deployments (Railway). Uses a per-IP
4//! token bucket with configurable capacity and refill rate.
5
6use std::collections::HashMap;
7use std::net::IpAddr;
8use std::sync::Mutex;
9use std::time::Instant;
10
11use axum::extract::{ConnectInfo, Request};
12use axum::http::StatusCode;
13use axum::middleware::Next;
14use axum::response::{IntoResponse, Response};
15use serde::Serialize;
16
17/// Rate limiter configuration.
18#[derive(Debug, Clone)]
19pub struct RateLimitConfig {
20    /// Maximum tokens (requests) in the bucket.
21    pub capacity: u32,
22    /// Tokens refilled per second.
23    pub refill_per_second: f64,
24}
25
26impl Default for RateLimitConfig {
27    fn default() -> Self {
28        Self {
29            // 1000 req/min = ~16.67 req/sec
30            capacity: 1000,
31            refill_per_second: 1000.0 / 60.0,
32        }
33    }
34}
35
36/// A single bucket tracking tokens for one IP.
37struct Bucket {
38    tokens: f64,
39    last_refill: Instant,
40}
41
42impl Bucket {
43    fn new(capacity: u32) -> Self {
44        Self {
45            tokens: capacity as f64,
46            last_refill: Instant::now(),
47        }
48    }
49
50    /// Try to consume one token. Returns the remaining tokens if successful.
51    fn try_consume(&mut self, config: &RateLimitConfig) -> Option<u32> {
52        // Refill tokens based on elapsed time
53        let now = Instant::now();
54        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
55        self.tokens =
56            (self.tokens + elapsed * config.refill_per_second).min(config.capacity as f64);
57        self.last_refill = now;
58
59        if self.tokens >= 1.0 {
60            self.tokens -= 1.0;
61            Some(self.tokens as u32)
62        } else {
63            None
64        }
65    }
66
67    /// Seconds until the next token is available.
68    fn retry_after(&self, config: &RateLimitConfig) -> u32 {
69        let needed = 1.0 - self.tokens;
70        if needed <= 0.0 {
71            return 0;
72        }
73        (needed / config.refill_per_second).ceil() as u32
74    }
75}
76
77/// Thread-safe rate limiter state shared across requests.
78pub struct RateLimiter {
79    config: RateLimitConfig,
80    buckets: Mutex<HashMap<IpAddr, Bucket>>,
81}
82
83impl RateLimiter {
84    pub fn new(config: RateLimitConfig) -> Self {
85        Self {
86            config,
87            buckets: Mutex::new(HashMap::new()),
88        }
89    }
90
91    /// Try to consume a token for the given IP.
92    /// Returns `Ok(remaining)` or `Err(retry_after_secs)`.
93    pub fn check(&self, ip: IpAddr) -> Result<u32, u32> {
94        let mut buckets = self.buckets.lock().unwrap();
95        let bucket = buckets
96            .entry(ip)
97            .or_insert_with(|| Bucket::new(self.config.capacity));
98
99        match bucket.try_consume(&self.config) {
100            Some(remaining) => Ok(remaining),
101            None => Err(bucket.retry_after(&self.config)),
102        }
103    }
104
105    /// Get the configured capacity for rate limit headers.
106    pub fn capacity(&self) -> u32 {
107        self.config.capacity
108    }
109
110    /// Periodically clean up expired buckets (buckets that are full).
111    /// Call this from a background task if needed.
112    pub fn cleanup(&self) {
113        let mut buckets = self.buckets.lock().unwrap();
114        let config = &self.config;
115        buckets.retain(|_, bucket| {
116            let elapsed = bucket.last_refill.elapsed().as_secs_f64();
117            let tokens = bucket.tokens + elapsed * config.refill_per_second;
118            // Remove buckets that have been idle long enough to be full
119            tokens < config.capacity as f64
120        });
121    }
122}
123
124/// JSON body for 429 responses.
125#[derive(Serialize)]
126struct RateLimitExceeded {
127    error: String,
128    message: String,
129    retry_after: u32,
130}
131
132/// Extract client IP from the request, checking common proxy headers.
133fn extract_client_ip(request: &Request) -> IpAddr {
134    // Check X-Forwarded-For header (Railway sets this)
135    if let Some(forwarded) = request.headers().get("x-forwarded-for") {
136        if let Ok(value) = forwarded.to_str() {
137            // Take the first IP (original client)
138            if let Some(first) = value.split(',').next() {
139                if let Ok(ip) = first.trim().parse::<IpAddr>() {
140                    return ip;
141                }
142            }
143        }
144    }
145
146    // Check X-Real-IP
147    if let Some(real_ip) = request.headers().get("x-real-ip") {
148        if let Ok(value) = real_ip.to_str() {
149            if let Ok(ip) = value.trim().parse::<IpAddr>() {
150                return ip;
151            }
152        }
153    }
154
155    // Fall back to connection info
156    if let Some(connect_info) = request
157        .extensions()
158        .get::<ConnectInfo<std::net::SocketAddr>>()
159    {
160        return connect_info.0.ip();
161    }
162
163    // Last resort: localhost
164    IpAddr::V4(std::net::Ipv4Addr::LOCALHOST)
165}
166
167/// Axum middleware that applies rate limiting per client IP.
168///
169/// Uses the `RateLimiter` from the request extensions. Set rate limit
170/// headers on all responses, return 429 when exceeded.
171pub async fn rate_limit_middleware(
172    axum::extract::State(limiter): axum::extract::State<std::sync::Arc<RateLimiter>>,
173    request: Request,
174    next: Next,
175) -> Response {
176    let client_ip = extract_client_ip(&request);
177
178    match limiter.check(client_ip) {
179        Ok(remaining) => {
180            let mut response = next.run(request).await;
181            let headers = response.headers_mut();
182            headers.insert(
183                "x-ratelimit-limit",
184                limiter.capacity().to_string().parse().unwrap(),
185            );
186            headers.insert(
187                "x-ratelimit-remaining",
188                remaining.to_string().parse().unwrap(),
189            );
190            response
191        }
192        Err(retry_after) => {
193            let body = RateLimitExceeded {
194                error: "rate_limit_exceeded".to_string(),
195                message: format!(
196                    "rate limit exceeded: {} requests per minute",
197                    limiter.capacity()
198                ),
199                retry_after,
200            };
201            let mut response = (StatusCode::TOO_MANY_REQUESTS, axum::Json(body)).into_response();
202            let headers = response.headers_mut();
203            headers.insert("retry-after", retry_after.to_string().parse().unwrap());
204            headers.insert(
205                "x-ratelimit-limit",
206                limiter.capacity().to_string().parse().unwrap(),
207            );
208            headers.insert("x-ratelimit-remaining", "0".parse().unwrap());
209            response
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn bucket_allows_up_to_capacity() {
220        let config = RateLimitConfig {
221            capacity: 3,
222            refill_per_second: 1.0,
223        };
224        let mut bucket = Bucket::new(3);
225
226        assert!(bucket.try_consume(&config).is_some()); // 2 left
227        assert!(bucket.try_consume(&config).is_some()); // 1 left
228        assert!(bucket.try_consume(&config).is_some()); // 0 left
229        assert!(bucket.try_consume(&config).is_none()); // empty
230    }
231
232    #[test]
233    fn rate_limiter_per_ip_isolation() {
234        let limiter = RateLimiter::new(RateLimitConfig {
235            capacity: 2,
236            refill_per_second: 0.0, // No refill for testing
237        });
238
239        let ip1: IpAddr = "1.1.1.1".parse().unwrap();
240        let ip2: IpAddr = "2.2.2.2".parse().unwrap();
241
242        assert!(limiter.check(ip1).is_ok());
243        assert!(limiter.check(ip1).is_ok());
244        assert!(limiter.check(ip1).is_err()); // ip1 exhausted
245
246        assert!(limiter.check(ip2).is_ok()); // ip2 still has tokens
247    }
248
249    #[test]
250    fn retry_after_calculated() {
251        let config = RateLimitConfig {
252            capacity: 1,
253            refill_per_second: 1.0,
254        };
255        let mut bucket = Bucket::new(1);
256
257        bucket.try_consume(&config); // Consume the token
258        let retry = bucket.retry_after(&config);
259        assert!(retry >= 1, "should need at least 1 second");
260    }
261}