trojan_server/
rate_limit.rs1use std::{
4 collections::HashMap,
5 net::IpAddr,
6 sync::Arc,
7 time::{Duration, Instant},
8};
9
10use parking_lot::RwLock;
11use tokio::sync::Notify;
12use tracing::debug;
13
14#[derive(Debug)]
16pub struct RateLimiter {
17 entries: Arc<RwLock<HashMap<IpAddr, RateLimitEntry>>>,
19 max_connections: u32,
21 window: Duration,
23 shutdown: Arc<Notify>,
25}
26
27#[derive(Clone, Debug)]
28struct RateLimitEntry {
29 count: u32,
30 window_start: Instant,
31}
32
33impl RateLimiter {
34 pub fn new(max_connections_per_ip: u32, window_secs: u64) -> Self {
36 Self {
37 entries: Arc::new(RwLock::new(HashMap::new())),
38 max_connections: max_connections_per_ip,
39 window: Duration::from_secs(window_secs),
40 shutdown: Arc::new(Notify::new()),
41 }
42 }
43
44 pub fn start_cleanup_task(&self, cleanup_interval: Duration) {
46 let entries = self.entries.clone();
47 let window = self.window;
48 let shutdown = self.shutdown.clone();
49
50 tokio::spawn(async move {
51 loop {
52 tokio::select! {
53 _ = shutdown.notified() => {
54 debug!("rate limiter cleanup task shutting down");
55 break;
56 }
57 _ = tokio::time::sleep(cleanup_interval) => {
58 let now = Instant::now();
59 let mut map = entries.write();
60 let before = map.len();
61 map.retain(|_, entry| {
62 now.duration_since(entry.window_start) < window
63 });
64 let removed = before - map.len();
65 if removed > 0 {
66 debug!(removed, remaining = map.len(), "rate limit entries cleaned up");
67 }
68 }
69 }
70 }
71 });
72 }
73
74 pub fn check_and_increment(&self, ip: IpAddr) -> bool {
77 let now = Instant::now();
78 let mut map = self.entries.write();
79
80 if let Some(entry) = map.get_mut(&ip) {
81 if now.duration_since(entry.window_start) >= self.window {
83 entry.count = 1;
85 entry.window_start = now;
86 true
87 } else if entry.count >= self.max_connections {
88 false
90 } else {
91 entry.count += 1;
93 true
94 }
95 } else {
96 map.insert(
98 ip,
99 RateLimitEntry {
100 count: 1,
101 window_start: now,
102 },
103 );
104 true
105 }
106 }
107
108 pub fn shutdown(&self) {
110 self.shutdown.notify_waiters();
111 }
112}
113
114impl Drop for RateLimiter {
115 fn drop(&mut self) {
116 self.shutdown();
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_rate_limit_allows_under_limit() {
126 let limiter = RateLimiter::new(5, 60);
127 let ip: IpAddr = "127.0.0.1".parse().unwrap();
128
129 for _ in 0..5 {
131 assert!(limiter.check_and_increment(ip));
132 }
133
134 assert!(!limiter.check_and_increment(ip));
136 }
137
138 #[test]
139 fn test_rate_limit_different_ips() {
140 let limiter = RateLimiter::new(2, 60);
141 let ip1: IpAddr = "127.0.0.1".parse().unwrap();
142 let ip2: IpAddr = "127.0.0.2".parse().unwrap();
143
144 assert!(limiter.check_and_increment(ip1));
146 assert!(limiter.check_and_increment(ip1));
147 assert!(!limiter.check_and_increment(ip1)); assert!(limiter.check_and_increment(ip2));
150 assert!(limiter.check_and_increment(ip2));
151 assert!(!limiter.check_and_increment(ip2)); }
153
154 #[test]
155 fn test_rate_limit_window_reset() {
156 let limiter = RateLimiter::new(1, 0); let ip: IpAddr = "127.0.0.1".parse().unwrap();
159
160 assert!(limiter.check_and_increment(ip));
162 assert!(limiter.check_and_increment(ip));
164 }
165}