llm_registry_api/
rate_limit.rs

1//! Rate limiting middleware
2//!
3//! This module provides rate limiting functionality using the token bucket algorithm
4//! with Redis for distributed rate limiting across multiple service instances.
5
6use axum::{
7    body::Body,
8    extract::{ConnectInfo, Request, State},
9    http::{HeaderValue, StatusCode},
10    middleware::Next,
11    response::{IntoResponse, Response},
12};
13use serde::{Deserialize, Serialize};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use std::time::{SystemTime, UNIX_EPOCH};
17use tracing::{debug, warn};
18
19use crate::error::ErrorResponse;
20
21/// Rate limit configuration
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct RateLimitConfig {
24    /// Maximum number of requests allowed
25    pub max_requests: u32,
26
27    /// Time window in seconds
28    pub window_secs: u64,
29
30    /// Whether rate limiting is enabled
31    pub enabled: bool,
32
33    /// Rate limit by IP address
34    pub by_ip: bool,
35
36    /// Rate limit by user ID (from JWT)
37    pub by_user: bool,
38
39    /// Custom identifier header (e.g., API key)
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub identifier_header: Option<String>,
42}
43
44impl Default for RateLimitConfig {
45    fn default() -> Self {
46        Self {
47            max_requests: 100,
48            window_secs: 60,
49            enabled: true,
50            by_ip: true,
51            by_user: true,
52            identifier_header: None,
53        }
54    }
55}
56
57impl RateLimitConfig {
58    /// Create a new rate limit configuration
59    pub fn new(max_requests: u32, window_secs: u64) -> Self {
60        Self {
61            max_requests,
62            window_secs,
63            ..Default::default()
64        }
65    }
66
67    /// Disable rate limiting
68    pub fn disabled() -> Self {
69        Self {
70            enabled: false,
71            ..Default::default()
72        }
73    }
74
75    /// Set max requests
76    pub fn with_max_requests(mut self, max_requests: u32) -> Self {
77        self.max_requests = max_requests;
78        self
79    }
80
81    /// Set window in seconds
82    pub fn with_window_secs(mut self, window_secs: u64) -> Self {
83        self.window_secs = window_secs;
84        self
85    }
86
87    /// Enable/disable rate limiting by IP
88    pub fn with_by_ip(mut self, by_ip: bool) -> Self {
89        self.by_ip = by_ip;
90        self
91    }
92
93    /// Enable/disable rate limiting by user
94    pub fn with_by_user(mut self, by_user: bool) -> Self {
95        self.by_user = by_user;
96        self
97    }
98
99    /// Set custom identifier header
100    pub fn with_identifier_header(mut self, header: impl Into<String>) -> Self {
101        self.identifier_header = Some(header.into());
102        self
103    }
104}
105
106/// Rate limiter state
107#[derive(Clone)]
108pub struct RateLimiterState {
109    config: Arc<RateLimitConfig>,
110    // In-memory storage for rate limiting (in production, use Redis)
111    storage: Arc<tokio::sync::RwLock<std::collections::HashMap<String, TokenBucket>>>,
112}
113
114impl RateLimiterState {
115    /// Create a new rate limiter state
116    pub fn new(config: RateLimitConfig) -> Self {
117        Self {
118            config: Arc::new(config),
119            storage: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
120        }
121    }
122
123    /// Get configuration
124    pub fn config(&self) -> &RateLimitConfig {
125        &self.config
126    }
127}
128
129/// Token bucket for rate limiting
130#[derive(Debug, Clone)]
131struct TokenBucket {
132    /// Number of tokens currently available
133    tokens: f64,
134
135    /// Last refill timestamp
136    last_refill: u64,
137
138    /// Maximum tokens (capacity)
139    capacity: f64,
140
141    /// Refill rate (tokens per second)
142    refill_rate: f64,
143}
144
145impl TokenBucket {
146    /// Create a new token bucket
147    fn new(capacity: u32, window_secs: u64) -> Self {
148        let refill_rate = capacity as f64 / window_secs as f64;
149        Self {
150            tokens: capacity as f64,
151            last_refill: Self::current_time_secs(),
152            capacity: capacity as f64,
153            refill_rate,
154        }
155    }
156
157    /// Get current time in seconds
158    fn current_time_secs() -> u64 {
159        SystemTime::now()
160            .duration_since(UNIX_EPOCH)
161            .unwrap()
162            .as_secs()
163    }
164
165    /// Refill tokens based on elapsed time
166    fn refill(&mut self) {
167        let now = Self::current_time_secs();
168        let elapsed = now - self.last_refill;
169
170        if elapsed > 0 {
171            let new_tokens = elapsed as f64 * self.refill_rate;
172            self.tokens = (self.tokens + new_tokens).min(self.capacity);
173            self.last_refill = now;
174        }
175    }
176
177    /// Try to consume a token
178    fn try_consume(&mut self, count: f64) -> bool {
179        self.refill();
180
181        if self.tokens >= count {
182            self.tokens -= count;
183            true
184        } else {
185            false
186        }
187    }
188
189    /// Get time until next token is available (in seconds)
190    fn time_until_available(&self) -> u64 {
191        if self.tokens >= 1.0 {
192            return 0;
193        }
194
195        let tokens_needed = 1.0 - self.tokens;
196        (tokens_needed / self.refill_rate).ceil() as u64
197    }
198}
199
200/// Rate limiting middleware
201///
202/// This middleware implements rate limiting using the token bucket algorithm.
203/// It can rate limit by IP address, user ID, or custom identifier.
204///
205/// # Example
206///
207/// ```rust,no_run
208/// use axum::{Router, routing::get, middleware};
209/// use llm_registry_api::rate_limit::{rate_limit, RateLimiterState, RateLimitConfig};
210///
211/// # async fn example() {
212/// let config = RateLimitConfig::new(100, 60); // 100 requests per minute
213/// let rate_limiter = RateLimiterState::new(config);
214///
215/// let app = Router::new()
216///     .route("/api/assets", get(|| async { "OK" }))
217///     .layer(middleware::from_fn_with_state(rate_limiter, rate_limit));
218/// # }
219/// ```
220pub async fn rate_limit(
221    State(limiter): State<RateLimiterState>,
222    request: Request,
223    next: Next,
224) -> Result<Response, RateLimitError> {
225    // Skip if rate limiting is disabled
226    if !limiter.config.enabled {
227        return Ok(next.run(request).await);
228    }
229
230    // Extract identifier for rate limiting
231    let identifier = extract_identifier(&request, &limiter.config);
232
233    debug!("Rate limiting for identifier: {}", identifier);
234
235    // Check rate limit
236    let allowed = check_rate_limit(&limiter, &identifier).await;
237
238    if !allowed {
239        warn!("Rate limit exceeded for identifier: {}", identifier);
240        return Err(RateLimitError::LimitExceeded {
241            retry_after: limiter.config.window_secs,
242        });
243    }
244
245    // Continue processing request
246    let mut response = next.run(request).await;
247
248    // Add rate limit headers
249    add_rate_limit_headers(&mut response, &limiter.config);
250
251    Ok(response)
252}
253
254/// Extract identifier for rate limiting
255fn extract_identifier(request: &Request<Body>, config: &RateLimitConfig) -> String {
256    let mut parts = Vec::new();
257
258    // Extract IP address
259    if config.by_ip {
260        if let Some(ConnectInfo(addr)) = request.extensions().get::<ConnectInfo<SocketAddr>>() {
261            parts.push(format!("ip:{}", addr.ip()));
262        }
263    }
264
265    // Extract user ID from auth extension
266    if config.by_user {
267        if let Some(user) = request.extensions().get::<crate::auth::AuthUser>() {
268            parts.push(format!("user:{}", user.user_id()));
269        }
270    }
271
272    // Extract custom identifier from header
273    if let Some(header_name) = &config.identifier_header {
274        if let Some(value) = request.headers().get(header_name) {
275            if let Ok(value_str) = value.to_str() {
276                parts.push(format!("custom:{}", value_str));
277            }
278        }
279    }
280
281    // If no identifier could be extracted, use a default
282    if parts.is_empty() {
283        parts.push("anonymous".to_string());
284    }
285
286    parts.join("|")
287}
288
289/// Check rate limit for an identifier
290async fn check_rate_limit(limiter: &RateLimiterState, identifier: &str) -> bool {
291    let mut storage = limiter.storage.write().await;
292
293    let bucket = storage
294        .entry(identifier.to_string())
295        .or_insert_with(|| {
296            TokenBucket::new(limiter.config.max_requests, limiter.config.window_secs)
297        });
298
299    bucket.try_consume(1.0)
300}
301
302/// Add rate limit headers to response
303fn add_rate_limit_headers(response: &mut Response, config: &RateLimitConfig) {
304    // Add standard rate limit headers
305    response.headers_mut().insert(
306        "X-RateLimit-Limit",
307        HeaderValue::from_str(&config.max_requests.to_string()).unwrap(),
308    );
309
310    response.headers_mut().insert(
311        "X-RateLimit-Window",
312        HeaderValue::from_str(&config.window_secs.to_string()).unwrap(),
313    );
314}
315
316/// Rate limit errors
317#[derive(Debug)]
318pub enum RateLimitError {
319    /// Rate limit exceeded
320    LimitExceeded {
321        /// Seconds until the limit resets
322        retry_after: u64,
323    },
324}
325
326impl IntoResponse for RateLimitError {
327    fn into_response(self) -> Response {
328        match self {
329            RateLimitError::LimitExceeded { retry_after } => {
330                let error_response = ErrorResponse {
331                    status: 429,
332                    error: "Rate limit exceeded".to_string(),
333                    code: Some("RATE_LIMIT_EXCEEDED".to_string()),
334                    timestamp: chrono::Utc::now(),
335                };
336
337                let mut response = (StatusCode::TOO_MANY_REQUESTS, axum::Json(error_response))
338                    .into_response();
339
340                // Add Retry-After header
341                response.headers_mut().insert(
342                    "Retry-After",
343                    HeaderValue::from_str(&retry_after.to_string()).unwrap(),
344                );
345
346                response
347            }
348        }
349    }
350}
351
352impl std::fmt::Display for RateLimitError {
353    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
354        match self {
355            RateLimitError::LimitExceeded { retry_after } => {
356                write!(f, "Rate limit exceeded. Retry after {} seconds", retry_after)
357            }
358        }
359    }
360}
361
362impl std::error::Error for RateLimitError {}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_rate_limit_config() {
370        let config = RateLimitConfig::new(100, 60);
371        assert_eq!(config.max_requests, 100);
372        assert_eq!(config.window_secs, 60);
373        assert!(config.enabled);
374    }
375
376    #[test]
377    fn test_rate_limit_config_builder() {
378        let config = RateLimitConfig::default()
379            .with_max_requests(200)
380            .with_window_secs(120)
381            .with_by_ip(false)
382            .with_identifier_header("X-API-Key");
383
384        assert_eq!(config.max_requests, 200);
385        assert_eq!(config.window_secs, 120);
386        assert!(!config.by_ip);
387        assert_eq!(
388            config.identifier_header,
389            Some("X-API-Key".to_string())
390        );
391    }
392
393    #[test]
394    fn test_token_bucket_creation() {
395        let bucket = TokenBucket::new(100, 60);
396        assert_eq!(bucket.capacity, 100.0);
397        assert_eq!(bucket.tokens, 100.0);
398    }
399
400    #[test]
401    fn test_token_bucket_consume() {
402        let mut bucket = TokenBucket::new(10, 60);
403
404        // Should be able to consume up to capacity
405        for _ in 0..10 {
406            assert!(bucket.try_consume(1.0));
407        }
408
409        // Should fail after exhausting tokens
410        assert!(!bucket.try_consume(1.0));
411    }
412
413    #[test]
414    fn test_token_bucket_refill() {
415        let mut bucket = TokenBucket::new(10, 10); // 1 token per second
416
417        // Consume all tokens
418        for _ in 0..10 {
419            bucket.try_consume(1.0);
420        }
421
422        assert_eq!(bucket.tokens, 0.0);
423
424        // Simulate time passing by manually updating last_refill
425        bucket.last_refill -= 5; // 5 seconds ago
426
427        // Refill should add 5 tokens
428        bucket.refill();
429        assert_eq!(bucket.tokens, 5.0);
430    }
431
432    #[tokio::test]
433    async fn test_rate_limiter_state() {
434        let config = RateLimitConfig::new(5, 60);
435        let limiter = RateLimiterState::new(config);
436
437        // Should allow requests up to the limit
438        for _ in 0..5 {
439            assert!(check_rate_limit(&limiter, "test-user").await);
440        }
441
442        // Should deny additional requests
443        assert!(!check_rate_limit(&limiter, "test-user").await);
444
445        // Different identifier should have its own limit
446        assert!(check_rate_limit(&limiter, "other-user").await);
447    }
448
449    #[test]
450    fn test_disabled_rate_limit() {
451        let config = RateLimitConfig::disabled();
452        assert!(!config.enabled);
453    }
454}