Skip to main content

trojan_server/
rate_limit.rs

1//! Per-IP rate limiting for connection throttling.
2
3use std::{
4    collections::HashMap,
5    net::IpAddr,
6    sync::Arc,
7    time::{Duration, Instant},
8};
9
10use parking_lot::RwLock;
11use tokio::sync::Notify;
12use tracing::debug;
13
14/// Rate limiter that tracks connections per IP address.
15pub struct RateLimiter {
16    /// Map of IP -> (connection count, window start time)
17    entries: Arc<RwLock<HashMap<IpAddr, RateLimitEntry>>>,
18    /// Maximum connections allowed per IP in the time window
19    max_connections: u32,
20    /// Time window for rate limiting
21    window: Duration,
22    /// Notify for shutdown
23    shutdown: Arc<Notify>,
24}
25
26#[derive(Clone)]
27struct RateLimitEntry {
28    count: u32,
29    window_start: Instant,
30}
31
32impl RateLimiter {
33    /// Create a new rate limiter.
34    pub fn new(max_connections_per_ip: u32, window_secs: u64) -> Self {
35        Self {
36            entries: Arc::new(RwLock::new(HashMap::new())),
37            max_connections: max_connections_per_ip,
38            window: Duration::from_secs(window_secs),
39            shutdown: Arc::new(Notify::new()),
40        }
41    }
42
43    /// Start the background cleanup task.
44    pub fn start_cleanup_task(&self, cleanup_interval: Duration) {
45        let entries = self.entries.clone();
46        let window = self.window;
47        let shutdown = self.shutdown.clone();
48
49        tokio::spawn(async move {
50            loop {
51                tokio::select! {
52                    _ = shutdown.notified() => {
53                        debug!("rate limiter cleanup task shutting down");
54                        break;
55                    }
56                    _ = tokio::time::sleep(cleanup_interval) => {
57                        let now = Instant::now();
58                        let mut map = entries.write();
59                        let before = map.len();
60                        map.retain(|_, entry| {
61                            now.duration_since(entry.window_start) < window
62                        });
63                        let removed = before - map.len();
64                        if removed > 0 {
65                            debug!(removed, remaining = map.len(), "rate limit entries cleaned up");
66                        }
67                    }
68                }
69            }
70        });
71    }
72
73    /// Check if a connection from the given IP is allowed.
74    /// Returns true if allowed, false if rate limited.
75    pub fn check_and_increment(&self, ip: IpAddr) -> bool {
76        let now = Instant::now();
77        let mut map = self.entries.write();
78
79        if let Some(entry) = map.get_mut(&ip) {
80            // Check if window has expired
81            if now.duration_since(entry.window_start) >= self.window {
82                // Reset window
83                entry.count = 1;
84                entry.window_start = now;
85                true
86            } else if entry.count >= self.max_connections {
87                // Rate limited
88                false
89            } else {
90                // Increment and allow
91                entry.count += 1;
92                true
93            }
94        } else {
95            // New IP, create entry
96            map.insert(
97                ip,
98                RateLimitEntry {
99                    count: 1,
100                    window_start: now,
101                },
102            );
103            true
104        }
105    }
106
107    /// Signal shutdown to cleanup task.
108    pub fn shutdown(&self) {
109        self.shutdown.notify_waiters();
110    }
111}
112
113impl Drop for RateLimiter {
114    fn drop(&mut self) {
115        self.shutdown();
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_rate_limit_allows_under_limit() {
125        let limiter = RateLimiter::new(5, 60);
126        let ip: IpAddr = "127.0.0.1".parse().unwrap();
127
128        // Should allow 5 connections
129        for _ in 0..5 {
130            assert!(limiter.check_and_increment(ip));
131        }
132
133        // 6th should be blocked
134        assert!(!limiter.check_and_increment(ip));
135    }
136
137    #[test]
138    fn test_rate_limit_different_ips() {
139        let limiter = RateLimiter::new(2, 60);
140        let ip1: IpAddr = "127.0.0.1".parse().unwrap();
141        let ip2: IpAddr = "127.0.0.2".parse().unwrap();
142
143        // Both IPs should get their own quota
144        assert!(limiter.check_and_increment(ip1));
145        assert!(limiter.check_and_increment(ip1));
146        assert!(!limiter.check_and_increment(ip1)); // blocked
147
148        assert!(limiter.check_and_increment(ip2));
149        assert!(limiter.check_and_increment(ip2));
150        assert!(!limiter.check_and_increment(ip2)); // blocked
151    }
152
153    #[test]
154    fn test_rate_limit_window_reset() {
155        // Use a very short window for testing
156        let limiter = RateLimiter::new(1, 0); // 0 second window = always resets
157        let ip: IpAddr = "127.0.0.1".parse().unwrap();
158
159        // First connection
160        assert!(limiter.check_and_increment(ip));
161        // Window has "expired" (0 seconds), so this resets and allows
162        assert!(limiter.check_and_increment(ip));
163    }
164}