trojan_server/
rate_limit.rs1use std::{
4 collections::HashMap,
5 net::IpAddr,
6 sync::Arc,
7 time::{Duration, Instant},
8};
9
10use parking_lot::RwLock;
11use tokio::sync::Notify;
12use tracing::debug;
13
14pub struct RateLimiter {
16 entries: Arc<RwLock<HashMap<IpAddr, RateLimitEntry>>>,
18 max_connections: u32,
20 window: Duration,
22 shutdown: Arc<Notify>,
24}
25
26#[derive(Clone)]
27struct RateLimitEntry {
28 count: u32,
29 window_start: Instant,
30}
31
32impl RateLimiter {
33 pub fn new(max_connections_per_ip: u32, window_secs: u64) -> Self {
35 Self {
36 entries: Arc::new(RwLock::new(HashMap::new())),
37 max_connections: max_connections_per_ip,
38 window: Duration::from_secs(window_secs),
39 shutdown: Arc::new(Notify::new()),
40 }
41 }
42
43 pub fn start_cleanup_task(&self, cleanup_interval: Duration) {
45 let entries = self.entries.clone();
46 let window = self.window;
47 let shutdown = self.shutdown.clone();
48
49 tokio::spawn(async move {
50 loop {
51 tokio::select! {
52 _ = shutdown.notified() => {
53 debug!("rate limiter cleanup task shutting down");
54 break;
55 }
56 _ = tokio::time::sleep(cleanup_interval) => {
57 let now = Instant::now();
58 let mut map = entries.write();
59 let before = map.len();
60 map.retain(|_, entry| {
61 now.duration_since(entry.window_start) < window
62 });
63 let removed = before - map.len();
64 if removed > 0 {
65 debug!(removed, remaining = map.len(), "rate limit entries cleaned up");
66 }
67 }
68 }
69 }
70 });
71 }
72
73 pub fn check_and_increment(&self, ip: IpAddr) -> bool {
76 let now = Instant::now();
77 let mut map = self.entries.write();
78
79 if let Some(entry) = map.get_mut(&ip) {
80 if now.duration_since(entry.window_start) >= self.window {
82 entry.count = 1;
84 entry.window_start = now;
85 true
86 } else if entry.count >= self.max_connections {
87 false
89 } else {
90 entry.count += 1;
92 true
93 }
94 } else {
95 map.insert(
97 ip,
98 RateLimitEntry {
99 count: 1,
100 window_start: now,
101 },
102 );
103 true
104 }
105 }
106
107 pub fn shutdown(&self) {
109 self.shutdown.notify_waiters();
110 }
111}
112
113impl Drop for RateLimiter {
114 fn drop(&mut self) {
115 self.shutdown();
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_rate_limit_allows_under_limit() {
125 let limiter = RateLimiter::new(5, 60);
126 let ip: IpAddr = "127.0.0.1".parse().unwrap();
127
128 for _ in 0..5 {
130 assert!(limiter.check_and_increment(ip));
131 }
132
133 assert!(!limiter.check_and_increment(ip));
135 }
136
137 #[test]
138 fn test_rate_limit_different_ips() {
139 let limiter = RateLimiter::new(2, 60);
140 let ip1: IpAddr = "127.0.0.1".parse().unwrap();
141 let ip2: IpAddr = "127.0.0.2".parse().unwrap();
142
143 assert!(limiter.check_and_increment(ip1));
145 assert!(limiter.check_and_increment(ip1));
146 assert!(!limiter.check_and_increment(ip1)); assert!(limiter.check_and_increment(ip2));
149 assert!(limiter.check_and_increment(ip2));
150 assert!(!limiter.check_and_increment(ip2)); }
152
153 #[test]
154 fn test_rate_limit_window_reset() {
155 let limiter = RateLimiter::new(1, 0); let ip: IpAddr = "127.0.0.1".parse().unwrap();
158
159 assert!(limiter.check_and_increment(ip));
161 assert!(limiter.check_and_increment(ip));
163 }
164}