1use std::collections::HashMap;
9use std::collections::HashSet;
10use std::net::IpAddr;
11
12#[derive(Debug, Clone)]
14pub struct BanConfig {
15 pub max_failures: u32,
17 pub use_parole: bool,
19}
20
21impl Default for BanConfig {
22 fn default() -> Self {
23 Self {
24 max_failures: 3,
25 use_parole: true,
26 }
27 }
28}
29
30#[derive(Debug)]
34pub struct BanManager {
35 config: BanConfig,
36 banned: HashSet<IpAddr>,
37 strikes: HashMap<IpAddr, u32>,
38}
39
40impl BanManager {
41 pub fn new(config: BanConfig) -> Self {
43 Self {
44 config,
45 banned: HashSet::new(),
46 strikes: HashMap::new(),
47 }
48 }
49
50 pub fn is_banned(&self, ip: &IpAddr) -> bool {
52 self.banned.contains(ip)
53 }
54
55 pub fn record_strike(&mut self, ip: IpAddr) -> bool {
59 if self.banned.contains(&ip) {
60 return false; }
62 let count = self.strikes.entry(ip).or_insert(0);
63 *count += 1;
64 if *count >= self.config.max_failures {
65 self.banned.insert(ip);
66 true
67 } else {
68 false
69 }
70 }
71
72 pub fn ban(&mut self, ip: IpAddr) {
74 self.banned.insert(ip);
75 }
76
77 pub fn unban(&mut self, ip: &IpAddr) -> bool {
81 self.strikes.remove(ip);
82 self.banned.remove(ip)
83 }
84
85 pub fn banned_list(&self) -> &HashSet<IpAddr> {
87 &self.banned
88 }
89
90 pub fn strikes_map(&self) -> &HashMap<IpAddr, u32> {
92 &self.strikes
93 }
94
95 pub fn use_parole(&self) -> bool {
97 self.config.use_parole
98 }
99
100 #[allow(dead_code)] pub fn restore(
103 config: BanConfig,
104 banned: HashSet<IpAddr>,
105 strikes: HashMap<IpAddr, u32>,
106 ) -> Self {
107 Self {
108 config,
109 banned,
110 strikes,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
120pub(crate) struct ParoleState {
121 pub original_contributors: HashSet<IpAddr>,
123 pub parole_peer: Option<IpAddr>,
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn ban_manager_empty() {
133 let mgr = BanManager::new(BanConfig::default());
134 let ip: IpAddr = "10.0.0.1".parse().unwrap();
135 assert!(!mgr.is_banned(&ip));
136 assert!(mgr.banned_list().is_empty());
137 assert!(mgr.strikes_map().is_empty());
138 }
139
140 #[test]
141 fn record_strike_below_threshold() {
142 let mut mgr = BanManager::new(BanConfig {
143 max_failures: 3,
144 use_parole: true,
145 });
146 let ip: IpAddr = "10.0.0.1".parse().unwrap();
147
148 assert!(!mgr.record_strike(ip)); assert!(!mgr.record_strike(ip)); assert!(!mgr.is_banned(&ip));
151 assert_eq!(*mgr.strikes_map().get(&ip).unwrap(), 2);
152 }
153
154 #[test]
155 fn record_strike_hits_threshold() {
156 let mut mgr = BanManager::new(BanConfig {
157 max_failures: 3,
158 use_parole: true,
159 });
160 let ip: IpAddr = "10.0.0.1".parse().unwrap();
161
162 assert!(!mgr.record_strike(ip)); assert!(!mgr.record_strike(ip)); assert!(mgr.record_strike(ip)); assert!(mgr.is_banned(&ip));
166
167 assert!(!mgr.record_strike(ip));
169 }
170
171 #[test]
172 fn manual_ban_unban() {
173 let mut mgr = BanManager::new(BanConfig::default());
174 let ip: IpAddr = "192.168.1.1".parse().unwrap();
175
176 mgr.ban(ip);
177 assert!(mgr.is_banned(&ip));
178
179 assert!(mgr.unban(&ip));
180 assert!(!mgr.is_banned(&ip));
181
182 assert!(!mgr.unban(&ip));
184 }
185
186 #[test]
187 fn unban_clears_strikes() {
188 let mut mgr = BanManager::new(BanConfig {
189 max_failures: 3,
190 use_parole: true,
191 });
192 let ip: IpAddr = "10.0.0.5".parse().unwrap();
193
194 mgr.record_strike(ip);
195 mgr.record_strike(ip);
196 assert_eq!(*mgr.strikes_map().get(&ip).unwrap(), 2);
197
198 mgr.ban(ip);
199 mgr.unban(&ip);
200 assert!(mgr.strikes_map().get(&ip).is_none());
201 assert!(!mgr.is_banned(&ip));
202 }
203
204 #[test]
205 fn restore_preserves_state() {
206 let ip1: IpAddr = "10.0.0.1".parse().unwrap();
207 let ip2: IpAddr = "10.0.0.2".parse().unwrap();
208
209 let banned = HashSet::from([ip1]);
210 let strikes = HashMap::from([(ip1, 3), (ip2, 1)]);
211
212 let mgr = BanManager::restore(BanConfig::default(), banned, strikes);
213 assert!(mgr.is_banned(&ip1));
214 assert!(!mgr.is_banned(&ip2));
215 assert_eq!(*mgr.strikes_map().get(&ip1).unwrap(), 3);
216 assert_eq!(*mgr.strikes_map().get(&ip2).unwrap(), 1);
217 }
218}