Skip to main content

trojan_server/
rate_limit.rs

1//! Per-IP rate limiting for connection throttling.
2
3use std::{
4    collections::HashMap,
5    net::IpAddr,
6    sync::Arc,
7    time::{Duration, Instant},
8};
9
10use parking_lot::RwLock;
11use tokio::sync::Notify;
12use tracing::debug;
13
14/// Rate limiter that tracks connections per IP address.
15#[derive(Debug)]
16pub struct RateLimiter {
17    /// Map of IP -> (connection count, window start time)
18    entries: Arc<RwLock<HashMap<IpAddr, RateLimitEntry>>>,
19    /// Maximum connections allowed per IP in the time window
20    max_connections: u32,
21    /// Time window for rate limiting
22    window: Duration,
23    /// Notify for shutdown
24    shutdown: Arc<Notify>,
25}
26
27#[derive(Clone, Debug)]
28struct RateLimitEntry {
29    count: u32,
30    window_start: Instant,
31}
32
33impl RateLimiter {
34    /// Create a new rate limiter.
35    pub fn new(max_connections_per_ip: u32, window_secs: u64) -> Self {
36        Self {
37            entries: Arc::new(RwLock::new(HashMap::new())),
38            max_connections: max_connections_per_ip,
39            window: Duration::from_secs(window_secs),
40            shutdown: Arc::new(Notify::new()),
41        }
42    }
43
44    /// Start the background cleanup task.
45    pub fn start_cleanup_task(&self, cleanup_interval: Duration) {
46        let entries = self.entries.clone();
47        let window = self.window;
48        let shutdown = self.shutdown.clone();
49
50        tokio::spawn(async move {
51            loop {
52                tokio::select! {
53                    _ = shutdown.notified() => {
54                        debug!("rate limiter cleanup task shutting down");
55                        break;
56                    }
57                    _ = tokio::time::sleep(cleanup_interval) => {
58                        let now = Instant::now();
59                        let mut map = entries.write();
60                        let before = map.len();
61                        map.retain(|_, entry| {
62                            now.duration_since(entry.window_start) < window
63                        });
64                        let removed = before - map.len();
65                        if removed > 0 {
66                            debug!(removed, remaining = map.len(), "rate limit entries cleaned up");
67                        }
68                    }
69                }
70            }
71        });
72    }
73
74    /// Check if a connection from the given IP is allowed.
75    /// Returns true if allowed, false if rate limited.
76    pub fn check_and_increment(&self, ip: IpAddr) -> bool {
77        let now = Instant::now();
78        let mut map = self.entries.write();
79
80        if let Some(entry) = map.get_mut(&ip) {
81            // Check if window has expired
82            if now.duration_since(entry.window_start) >= self.window {
83                // Reset window
84                entry.count = 1;
85                entry.window_start = now;
86                true
87            } else if entry.count >= self.max_connections {
88                // Rate limited
89                false
90            } else {
91                // Increment and allow
92                entry.count += 1;
93                true
94            }
95        } else {
96            // New IP, create entry
97            map.insert(
98                ip,
99                RateLimitEntry {
100                    count: 1,
101                    window_start: now,
102                },
103            );
104            true
105        }
106    }
107
108    /// Signal shutdown to cleanup task.
109    pub fn shutdown(&self) {
110        self.shutdown.notify_waiters();
111    }
112}
113
114impl Drop for RateLimiter {
115    fn drop(&mut self) {
116        self.shutdown();
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_rate_limit_allows_under_limit() {
126        let limiter = RateLimiter::new(5, 60);
127        let ip: IpAddr = "127.0.0.1".parse().unwrap();
128
129        // Should allow 5 connections
130        for _ in 0..5 {
131            assert!(limiter.check_and_increment(ip));
132        }
133
134        // 6th should be blocked
135        assert!(!limiter.check_and_increment(ip));
136    }
137
138    #[test]
139    fn test_rate_limit_different_ips() {
140        let limiter = RateLimiter::new(2, 60);
141        let ip1: IpAddr = "127.0.0.1".parse().unwrap();
142        let ip2: IpAddr = "127.0.0.2".parse().unwrap();
143
144        // Both IPs should get their own quota
145        assert!(limiter.check_and_increment(ip1));
146        assert!(limiter.check_and_increment(ip1));
147        assert!(!limiter.check_and_increment(ip1)); // blocked
148
149        assert!(limiter.check_and_increment(ip2));
150        assert!(limiter.check_and_increment(ip2));
151        assert!(!limiter.check_and_increment(ip2)); // blocked
152    }
153
154    #[test]
155    fn test_rate_limit_window_reset() {
156        // Use a very short window for testing
157        let limiter = RateLimiter::new(1, 0); // 0 second window = always resets
158        let ip: IpAddr = "127.0.0.1".parse().unwrap();
159
160        // First connection
161        assert!(limiter.check_and_increment(ip));
162        // Window has "expired" (0 seconds), so this resets and allows
163        assert!(limiter.check_and_increment(ip));
164    }
165}