1use std::collections::HashMap;
22use std::net::SocketAddr;
23
24const DEFAULT_BAN_THRESHOLD: i32 = 100;
28
29const DEFAULT_BAN_DURATION: u64 = 24 * 60 * 60;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum Misbehavior {
37 InvalidBlockHeader,
39 InvalidBlock,
41 InvalidTransaction,
43 UnexpectedMessage,
45 AddrFlood,
47 InvalidNetworkMessage,
49 DosAttack,
51 UnconnectableBlock,
53 DuplicateVersion,
55 Custom(i32),
57}
58
59impl Misbehavior {
60 pub fn score(&self) -> i32 {
62 match self {
63 Misbehavior::InvalidBlockHeader => 100,
64 Misbehavior::InvalidBlock => 100,
65 Misbehavior::InvalidTransaction => 10,
66 Misbehavior::UnexpectedMessage => 10,
67 Misbehavior::AddrFlood => 20,
68 Misbehavior::InvalidNetworkMessage => 20,
69 Misbehavior::DosAttack => 50,
70 Misbehavior::UnconnectableBlock => 10,
71 Misbehavior::DuplicateVersion => 10,
72 Misbehavior::Custom(score) => *score,
73 }
74 }
75
76 pub fn reason(&self) -> &'static str {
78 match self {
79 Misbehavior::InvalidBlockHeader => "invalid block header",
80 Misbehavior::InvalidBlock => "invalid block (consensus failure)",
81 Misbehavior::InvalidTransaction => "invalid transaction",
82 Misbehavior::UnexpectedMessage => "unexpected message during handshake",
83 Misbehavior::AddrFlood => "addr message flood",
84 Misbehavior::InvalidNetworkMessage => "invalid network message",
85 Misbehavior::DosAttack => "DoS attack detected",
86 Misbehavior::UnconnectableBlock => "unconnectable block",
87 Misbehavior::DuplicateVersion => "duplicate version message",
88 Misbehavior::Custom(_) => "custom violation",
89 }
90 }
91}
92
93impl std::fmt::Display for Misbehavior {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "{} (+{})", self.reason(), self.score())
96 }
97}
98
99#[derive(Debug, Clone)]
103struct PeerScore {
104 score: i32,
106 addr: SocketAddr,
108 violations: Vec<(Misbehavior, u64)>,
110}
111
112#[derive(Debug, Clone)]
114pub struct BanEntry {
115 pub addr: SocketAddr,
117 pub ban_time: u64,
119 pub ban_duration: u64,
121 pub reason: String,
123}
124
125impl BanEntry {
126 pub fn is_expired(&self, now: u64) -> bool {
128 now >= self.ban_time + self.ban_duration
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum ScoreAction {
137 None,
139 Ban {
141 peer_id: u64,
142 addr: SocketAddr,
143 reason: String,
144 },
145}
146
147#[derive(Debug)]
151pub struct PeerScoring {
152 peer_scores: HashMap<u64, PeerScore>,
154 ban_list: HashMap<SocketAddr, BanEntry>,
156 ban_threshold: i32,
158 ban_duration: u64,
160}
161
162impl PeerScoring {
163 pub fn new() -> Self {
165 PeerScoring {
166 peer_scores: HashMap::new(),
167 ban_list: HashMap::new(),
168 ban_threshold: DEFAULT_BAN_THRESHOLD,
169 ban_duration: DEFAULT_BAN_DURATION,
170 }
171 }
172
173 pub fn with_config(ban_threshold: i32, ban_duration: u64) -> Self {
175 PeerScoring {
176 peer_scores: HashMap::new(),
177 ban_list: HashMap::new(),
178 ban_threshold,
179 ban_duration,
180 }
181 }
182
183 pub fn register_peer(&mut self, peer_id: u64, addr: SocketAddr) {
185 self.peer_scores.insert(
186 peer_id,
187 PeerScore {
188 score: 0,
189 addr,
190 violations: Vec::new(),
191 },
192 );
193 }
194
195 pub fn remove_peer(&mut self, peer_id: u64) {
197 self.peer_scores.remove(&peer_id);
198 }
199
200 pub fn record_misbehavior(
204 &mut self,
205 peer_id: u64,
206 violation: Misbehavior,
207 now: u64,
208 ) -> ScoreAction {
209 let increment = violation.score();
210
211 let peer = match self.peer_scores.get_mut(&peer_id) {
212 Some(p) => p,
213 None => return ScoreAction::None, };
215
216 peer.score += increment;
217 peer.violations.push((violation, now));
218
219 tracing::debug!(
220 "Peer {} misbehaving: {} (score now {})",
221 peer_id,
222 violation,
223 peer.score
224 );
225
226 if peer.score >= self.ban_threshold {
227 let addr = peer.addr;
228 let reason = format!(
229 "ban score {} >= {} (last: {})",
230 peer.score,
231 self.ban_threshold,
232 violation.reason()
233 );
234
235 self.ban_list.insert(
237 addr,
238 BanEntry {
239 addr,
240 ban_time: now,
241 ban_duration: self.ban_duration,
242 reason: reason.clone(),
243 },
244 );
245
246 tracing::warn!("Banning peer {} ({}): {}", peer_id, addr, reason);
247
248 ScoreAction::Ban {
249 peer_id,
250 addr,
251 reason,
252 }
253 } else {
254 ScoreAction::None
255 }
256 }
257
258 pub fn is_banned(&self, addr: &SocketAddr, now: u64) -> bool {
260 match self.ban_list.get(addr) {
261 Some(entry) => !entry.is_expired(now),
262 None => false,
263 }
264 }
265
266 pub fn ban_address(&mut self, addr: SocketAddr, reason: String, now: u64) {
268 self.ban_list.insert(
269 addr,
270 BanEntry {
271 addr,
272 ban_time: now,
273 ban_duration: self.ban_duration,
274 reason,
275 },
276 );
277 }
278
279 pub fn unban_address(&mut self, addr: &SocketAddr) -> bool {
281 self.ban_list.remove(addr).is_some()
282 }
283
284 pub fn sweep_expired_bans(&mut self, now: u64) -> usize {
286 let before = self.ban_list.len();
287 self.ban_list.retain(|_, entry| !entry.is_expired(now));
288 before - self.ban_list.len()
289 }
290
291 pub fn get_score(&self, peer_id: u64) -> i32 {
293 self.peer_scores.get(&peer_id).map(|p| p.score).unwrap_or(0)
294 }
295
296 pub fn list_bans(&self) -> Vec<&BanEntry> {
298 self.ban_list.values().collect()
299 }
300
301 pub fn peer_count(&self) -> usize {
303 self.peer_scores.len()
304 }
305
306 pub fn ban_count(&self) -> usize {
308 self.ban_list.len()
309 }
310}
311
312impl Default for PeerScoring {
313 fn default() -> Self {
314 Self::new()
315 }
316}
317
318#[cfg(test)]
321mod tests {
322 use super::*;
323
324 fn addr(port: u16) -> SocketAddr {
325 format!("127.0.0.1:{}", port).parse().unwrap()
326 }
327
328 #[test]
329 fn test_new_peer_score_zero() {
330 let mut scoring = PeerScoring::new();
331 scoring.register_peer(1, addr(8333));
332 assert_eq!(scoring.get_score(1), 0);
333 }
334
335 #[test]
336 fn test_record_misbehavior_below_threshold() {
337 let mut scoring = PeerScoring::new();
338 scoring.register_peer(1, addr(8333));
339
340 let action = scoring.record_misbehavior(1, Misbehavior::InvalidTransaction, 1000);
341 assert_eq!(action, ScoreAction::None);
342 assert_eq!(scoring.get_score(1), 10);
343 }
344
345 #[test]
346 fn test_record_misbehavior_triggers_ban() {
347 let mut scoring = PeerScoring::new();
348 scoring.register_peer(1, addr(8333));
349
350 let action = scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
352 match action {
353 ScoreAction::Ban {
354 peer_id, addr: a, ..
355 } => {
356 assert_eq!(peer_id, 1);
357 assert_eq!(a, addr(8333));
358 }
359 ScoreAction::None => panic!("Expected ban"),
360 }
361
362 assert!(scoring.is_banned(&addr(8333), 1000));
363 }
364
365 #[test]
366 fn test_cumulative_score_triggers_ban() {
367 let mut scoring = PeerScoring::new();
368 scoring.register_peer(1, addr(8333));
369
370 for i in 0..9 {
372 let action = scoring.record_misbehavior(1, Misbehavior::InvalidTransaction, 1000 + i);
373 assert_eq!(action, ScoreAction::None);
374 }
375
376 let action = scoring.record_misbehavior(1, Misbehavior::InvalidTransaction, 1009);
377 assert!(matches!(action, ScoreAction::Ban { .. }));
378 assert_eq!(scoring.get_score(1), 100);
379 }
380
381 #[test]
382 fn test_ban_expires() {
383 let mut scoring = PeerScoring::new();
384 scoring.register_peer(1, addr(8333));
385
386 scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
387 assert!(scoring.is_banned(&addr(8333), 1000));
388
389 assert!(scoring.is_banned(&addr(8333), 1000 + 12 * 3600));
391
392 assert!(!scoring.is_banned(&addr(8333), 1000 + DEFAULT_BAN_DURATION));
394 }
395
396 #[test]
397 fn test_sweep_expired_bans() {
398 let mut scoring = PeerScoring::new();
399 scoring.register_peer(1, addr(8333));
400 scoring.register_peer(2, addr(8334));
401
402 scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
403 scoring.record_misbehavior(2, Misbehavior::InvalidBlock, 2000);
404
405 assert_eq!(scoring.ban_count(), 2);
406
407 let swept = scoring.sweep_expired_bans(1000 + DEFAULT_BAN_DURATION);
409 assert_eq!(swept, 1);
410 assert_eq!(scoring.ban_count(), 1);
411 assert!(!scoring.is_banned(&addr(8333), 1000 + DEFAULT_BAN_DURATION));
412 assert!(scoring.is_banned(&addr(8334), 1000 + DEFAULT_BAN_DURATION));
413 }
414
415 #[test]
416 fn test_manual_ban_unban() {
417 let mut scoring = PeerScoring::new();
418
419 scoring.ban_address(addr(9999), "manual ban".into(), 1000);
420 assert!(scoring.is_banned(&addr(9999), 1000));
421
422 let removed = scoring.unban_address(&addr(9999));
423 assert!(removed);
424 assert!(!scoring.is_banned(&addr(9999), 1000));
425 }
426
427 #[test]
428 fn test_remove_peer() {
429 let mut scoring = PeerScoring::new();
430 scoring.register_peer(1, addr(8333));
431 assert_eq!(scoring.peer_count(), 1);
432
433 scoring.remove_peer(1);
434 assert_eq!(scoring.peer_count(), 0);
435 assert_eq!(scoring.get_score(1), 0); }
437
438 #[test]
439 fn test_unknown_peer_misbehavior_ignored() {
440 let mut scoring = PeerScoring::new();
441 let action = scoring.record_misbehavior(999, Misbehavior::InvalidBlock, 1000);
442 assert_eq!(action, ScoreAction::None);
443 }
444
445 #[test]
446 fn test_custom_threshold() {
447 let mut scoring = PeerScoring::with_config(50, 3600);
448 scoring.register_peer(1, addr(8333));
449
450 let action = scoring.record_misbehavior(1, Misbehavior::DosAttack, 1000);
452 assert!(matches!(action, ScoreAction::Ban { .. }));
453 }
454
455 #[test]
456 fn test_misbehavior_scores() {
457 assert_eq!(Misbehavior::InvalidBlockHeader.score(), 100);
458 assert_eq!(Misbehavior::InvalidBlock.score(), 100);
459 assert_eq!(Misbehavior::InvalidTransaction.score(), 10);
460 assert_eq!(Misbehavior::UnexpectedMessage.score(), 10);
461 assert_eq!(Misbehavior::AddrFlood.score(), 20);
462 assert_eq!(Misbehavior::InvalidNetworkMessage.score(), 20);
463 assert_eq!(Misbehavior::DosAttack.score(), 50);
464 assert_eq!(Misbehavior::UnconnectableBlock.score(), 10);
465 assert_eq!(Misbehavior::DuplicateVersion.score(), 10);
466 assert_eq!(Misbehavior::Custom(42).score(), 42);
467 }
468
469 #[test]
470 fn test_misbehavior_display() {
471 let m = Misbehavior::InvalidTransaction;
472 let s = format!("{}", m);
473 assert!(s.contains("invalid transaction"));
474 assert!(s.contains("+10"));
475 }
476
477 #[test]
478 fn test_list_bans() {
479 let mut scoring = PeerScoring::new();
480 scoring.register_peer(1, addr(8333));
481 scoring.register_peer(2, addr(8334));
482
483 scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
484 scoring.record_misbehavior(2, Misbehavior::InvalidBlockHeader, 1000);
485
486 let bans = scoring.list_bans();
487 assert_eq!(bans.len(), 2);
488 }
489}