Skip to main content

abtc_application/
peer_scoring.rs

1//! Peer Misbehavior & Ban Score Tracking
2//!
3//! Implements Bitcoin Core's peer scoring system (see `net_processing.cpp`
4//! `Misbehaving()`). Every protocol violation adds to a peer's "ban score".
5//! Once the score reaches or exceeds the ban threshold (default 100), the peer
6//! is disconnected and its address is banned for a configurable duration.
7//!
8//! ## Violation Categories
9//!
10//! | Violation                          | Score |
11//! |-----------------------------------|-------|
12//! | Invalid block header              |  100  |
13//! | Invalid block (consensus failure) |  100  |
14//! | Invalid transaction               |   10  |
15//! | Unexpected message during handshake|  10  |
16//! | Too many addr messages            |   20  |
17//! | Invalid network message           |   20  |
18//! | DoS (too many messages, etc.)     |   50  |
19//! | Unconnectable block               |   10  |
20
21use std::collections::HashMap;
22use std::net::SocketAddr;
23
24// ── Configuration ───────────────────────────────────────────────────
25
26/// Default ban score threshold — disconnect + ban when score >= this.
27const DEFAULT_BAN_THRESHOLD: i32 = 100;
28
29/// Default ban duration in seconds (24 hours).
30const DEFAULT_BAN_DURATION: u64 = 24 * 60 * 60;
31
32// ── Violation types ─────────────────────────────────────────────────
33
34/// Categories of protocol violations with associated ban scores.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum Misbehavior {
37    /// Invalid block header (immediate ban).
38    InvalidBlockHeader,
39    /// Block fails consensus validation (immediate ban).
40    InvalidBlock,
41    /// Transaction fails validation.
42    InvalidTransaction,
43    /// Unexpected message during handshake.
44    UnexpectedMessage,
45    /// Flooding with addr messages.
46    AddrFlood,
47    /// Malformed or unrecognised network message.
48    InvalidNetworkMessage,
49    /// Generic DoS (rate-limit exceeded, etc.).
50    DosAttack,
51    /// Sent a block that doesn't connect to our chain.
52    UnconnectableBlock,
53    /// Peer sent a version message after handshake was already complete.
54    DuplicateVersion,
55    /// Custom score (for future extensions).
56    Custom(i32),
57}
58
59impl Misbehavior {
60    /// The ban score increment for this violation.
61    pub fn score(&self) -> i32 {
62        match self {
63            Misbehavior::InvalidBlockHeader => 100,
64            Misbehavior::InvalidBlock => 100,
65            Misbehavior::InvalidTransaction => 10,
66            Misbehavior::UnexpectedMessage => 10,
67            Misbehavior::AddrFlood => 20,
68            Misbehavior::InvalidNetworkMessage => 20,
69            Misbehavior::DosAttack => 50,
70            Misbehavior::UnconnectableBlock => 10,
71            Misbehavior::DuplicateVersion => 10,
72            Misbehavior::Custom(score) => *score,
73        }
74    }
75
76    /// Human-readable reason string.
77    pub fn reason(&self) -> &'static str {
78        match self {
79            Misbehavior::InvalidBlockHeader => "invalid block header",
80            Misbehavior::InvalidBlock => "invalid block (consensus failure)",
81            Misbehavior::InvalidTransaction => "invalid transaction",
82            Misbehavior::UnexpectedMessage => "unexpected message during handshake",
83            Misbehavior::AddrFlood => "addr message flood",
84            Misbehavior::InvalidNetworkMessage => "invalid network message",
85            Misbehavior::DosAttack => "DoS attack detected",
86            Misbehavior::UnconnectableBlock => "unconnectable block",
87            Misbehavior::DuplicateVersion => "duplicate version message",
88            Misbehavior::Custom(_) => "custom violation",
89        }
90    }
91}
92
93impl std::fmt::Display for Misbehavior {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "{} (+{})", self.reason(), self.score())
96    }
97}
98
99// ── Per-peer score tracking ─────────────────────────────────────────
100
101/// Tracking state for a single peer.
102#[derive(Debug, Clone)]
103struct PeerScore {
104    /// Accumulated ban score.
105    score: i32,
106    /// Network address (for banning).
107    addr: SocketAddr,
108    /// Log of violations with timestamps.
109    violations: Vec<(Misbehavior, u64)>,
110}
111
112/// A banned address entry.
113#[derive(Debug, Clone)]
114pub struct BanEntry {
115    /// The banned address.
116    pub addr: SocketAddr,
117    /// Timestamp when the ban was imposed (seconds since epoch).
118    pub ban_time: u64,
119    /// Duration of the ban in seconds.
120    pub ban_duration: u64,
121    /// The violation that triggered the ban.
122    pub reason: String,
123}
124
125impl BanEntry {
126    /// Whether this ban has expired at the given timestamp.
127    pub fn is_expired(&self, now: u64) -> bool {
128        now >= self.ban_time + self.ban_duration
129    }
130}
131
132// ── The PeerScoring manager ─────────────────────────────────────────
133
134/// Actions the caller should take after recording misbehavior.
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum ScoreAction {
137    /// No action needed — score is below threshold.
138    None,
139    /// Peer should be disconnected and banned.
140    Ban {
141        peer_id: u64,
142        addr: SocketAddr,
143        reason: String,
144    },
145}
146
147/// Manages per-peer ban scores and a ban list.
148///
149/// Thread-safe usage: wrap in `Arc<RwLock<PeerScoring>>` if shared.
150#[derive(Debug)]
151pub struct PeerScoring {
152    /// Per-peer scores, keyed by peer_id.
153    peer_scores: HashMap<u64, PeerScore>,
154    /// Banned addresses.
155    ban_list: HashMap<SocketAddr, BanEntry>,
156    /// Ban threshold (default: 100).
157    ban_threshold: i32,
158    /// Ban duration in seconds (default: 24h).
159    ban_duration: u64,
160}
161
162impl PeerScoring {
163    /// Create a new peer scoring manager with default settings.
164    pub fn new() -> Self {
165        PeerScoring {
166            peer_scores: HashMap::new(),
167            ban_list: HashMap::new(),
168            ban_threshold: DEFAULT_BAN_THRESHOLD,
169            ban_duration: DEFAULT_BAN_DURATION,
170        }
171    }
172
173    /// Create with custom threshold and ban duration.
174    pub fn with_config(ban_threshold: i32, ban_duration: u64) -> Self {
175        PeerScoring {
176            peer_scores: HashMap::new(),
177            ban_list: HashMap::new(),
178            ban_threshold,
179            ban_duration,
180        }
181    }
182
183    /// Register a new peer (call when peer connects).
184    pub fn register_peer(&mut self, peer_id: u64, addr: SocketAddr) {
185        self.peer_scores.insert(
186            peer_id,
187            PeerScore {
188                score: 0,
189                addr,
190                violations: Vec::new(),
191            },
192        );
193    }
194
195    /// Remove a peer (call when peer disconnects).
196    pub fn remove_peer(&mut self, peer_id: u64) {
197        self.peer_scores.remove(&peer_id);
198    }
199
200    /// Record a misbehavior event for a peer.
201    ///
202    /// Returns an action the caller should take (either None or Ban).
203    pub fn record_misbehavior(
204        &mut self,
205        peer_id: u64,
206        violation: Misbehavior,
207        now: u64,
208    ) -> ScoreAction {
209        let increment = violation.score();
210
211        let peer = match self.peer_scores.get_mut(&peer_id) {
212            Some(p) => p,
213            None => return ScoreAction::None, // unknown peer — ignore
214        };
215
216        peer.score += increment;
217        peer.violations.push((violation, now));
218
219        tracing::debug!(
220            "Peer {} misbehaving: {} (score now {})",
221            peer_id,
222            violation,
223            peer.score
224        );
225
226        if peer.score >= self.ban_threshold {
227            let addr = peer.addr;
228            let reason = format!(
229                "ban score {} >= {} (last: {})",
230                peer.score,
231                self.ban_threshold,
232                violation.reason()
233            );
234
235            // Add to ban list.
236            self.ban_list.insert(
237                addr,
238                BanEntry {
239                    addr,
240                    ban_time: now,
241                    ban_duration: self.ban_duration,
242                    reason: reason.clone(),
243                },
244            );
245
246            tracing::warn!("Banning peer {} ({}): {}", peer_id, addr, reason);
247
248            ScoreAction::Ban {
249                peer_id,
250                addr,
251                reason,
252            }
253        } else {
254            ScoreAction::None
255        }
256    }
257
258    /// Check whether an address is currently banned.
259    pub fn is_banned(&self, addr: &SocketAddr, now: u64) -> bool {
260        match self.ban_list.get(addr) {
261            Some(entry) => !entry.is_expired(now),
262            None => false,
263        }
264    }
265
266    /// Manually ban an address.
267    pub fn ban_address(&mut self, addr: SocketAddr, reason: String, now: u64) {
268        self.ban_list.insert(
269            addr,
270            BanEntry {
271                addr,
272                ban_time: now,
273                ban_duration: self.ban_duration,
274                reason,
275            },
276        );
277    }
278
279    /// Manually unban an address.
280    pub fn unban_address(&mut self, addr: &SocketAddr) -> bool {
281        self.ban_list.remove(addr).is_some()
282    }
283
284    /// Remove expired bans.
285    pub fn sweep_expired_bans(&mut self, now: u64) -> usize {
286        let before = self.ban_list.len();
287        self.ban_list.retain(|_, entry| !entry.is_expired(now));
288        before - self.ban_list.len()
289    }
290
291    /// Get the current score for a peer.
292    pub fn get_score(&self, peer_id: u64) -> i32 {
293        self.peer_scores.get(&peer_id).map(|p| p.score).unwrap_or(0)
294    }
295
296    /// Get all currently banned addresses.
297    pub fn list_bans(&self) -> Vec<&BanEntry> {
298        self.ban_list.values().collect()
299    }
300
301    /// Number of tracked peers.
302    pub fn peer_count(&self) -> usize {
303        self.peer_scores.len()
304    }
305
306    /// Number of banned addresses.
307    pub fn ban_count(&self) -> usize {
308        self.ban_list.len()
309    }
310}
311
312impl Default for PeerScoring {
313    fn default() -> Self {
314        Self::new()
315    }
316}
317
318// ── Tests ───────────────────────────────────────────────────────────
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    fn addr(port: u16) -> SocketAddr {
325        format!("127.0.0.1:{}", port).parse().unwrap()
326    }
327
328    #[test]
329    fn test_new_peer_score_zero() {
330        let mut scoring = PeerScoring::new();
331        scoring.register_peer(1, addr(8333));
332        assert_eq!(scoring.get_score(1), 0);
333    }
334
335    #[test]
336    fn test_record_misbehavior_below_threshold() {
337        let mut scoring = PeerScoring::new();
338        scoring.register_peer(1, addr(8333));
339
340        let action = scoring.record_misbehavior(1, Misbehavior::InvalidTransaction, 1000);
341        assert_eq!(action, ScoreAction::None);
342        assert_eq!(scoring.get_score(1), 10);
343    }
344
345    #[test]
346    fn test_record_misbehavior_triggers_ban() {
347        let mut scoring = PeerScoring::new();
348        scoring.register_peer(1, addr(8333));
349
350        // InvalidBlock = 100, which meets the threshold immediately.
351        let action = scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
352        match action {
353            ScoreAction::Ban {
354                peer_id, addr: a, ..
355            } => {
356                assert_eq!(peer_id, 1);
357                assert_eq!(a, addr(8333));
358            }
359            ScoreAction::None => panic!("Expected ban"),
360        }
361
362        assert!(scoring.is_banned(&addr(8333), 1000));
363    }
364
365    #[test]
366    fn test_cumulative_score_triggers_ban() {
367        let mut scoring = PeerScoring::new();
368        scoring.register_peer(1, addr(8333));
369
370        // 10 invalid transactions × 10 = 100 → ban
371        for i in 0..9 {
372            let action = scoring.record_misbehavior(1, Misbehavior::InvalidTransaction, 1000 + i);
373            assert_eq!(action, ScoreAction::None);
374        }
375
376        let action = scoring.record_misbehavior(1, Misbehavior::InvalidTransaction, 1009);
377        assert!(matches!(action, ScoreAction::Ban { .. }));
378        assert_eq!(scoring.get_score(1), 100);
379    }
380
381    #[test]
382    fn test_ban_expires() {
383        let mut scoring = PeerScoring::new();
384        scoring.register_peer(1, addr(8333));
385
386        scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
387        assert!(scoring.is_banned(&addr(8333), 1000));
388
389        // Still banned 12 hours later.
390        assert!(scoring.is_banned(&addr(8333), 1000 + 12 * 3600));
391
392        // Expired after 24 hours.
393        assert!(!scoring.is_banned(&addr(8333), 1000 + DEFAULT_BAN_DURATION));
394    }
395
396    #[test]
397    fn test_sweep_expired_bans() {
398        let mut scoring = PeerScoring::new();
399        scoring.register_peer(1, addr(8333));
400        scoring.register_peer(2, addr(8334));
401
402        scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
403        scoring.record_misbehavior(2, Misbehavior::InvalidBlock, 2000);
404
405        assert_eq!(scoring.ban_count(), 2);
406
407        // Sweep at t=1000+24h — first ban expired, second still active.
408        let swept = scoring.sweep_expired_bans(1000 + DEFAULT_BAN_DURATION);
409        assert_eq!(swept, 1);
410        assert_eq!(scoring.ban_count(), 1);
411        assert!(!scoring.is_banned(&addr(8333), 1000 + DEFAULT_BAN_DURATION));
412        assert!(scoring.is_banned(&addr(8334), 1000 + DEFAULT_BAN_DURATION));
413    }
414
415    #[test]
416    fn test_manual_ban_unban() {
417        let mut scoring = PeerScoring::new();
418
419        scoring.ban_address(addr(9999), "manual ban".into(), 1000);
420        assert!(scoring.is_banned(&addr(9999), 1000));
421
422        let removed = scoring.unban_address(&addr(9999));
423        assert!(removed);
424        assert!(!scoring.is_banned(&addr(9999), 1000));
425    }
426
427    #[test]
428    fn test_remove_peer() {
429        let mut scoring = PeerScoring::new();
430        scoring.register_peer(1, addr(8333));
431        assert_eq!(scoring.peer_count(), 1);
432
433        scoring.remove_peer(1);
434        assert_eq!(scoring.peer_count(), 0);
435        assert_eq!(scoring.get_score(1), 0); // unknown peer returns 0
436    }
437
438    #[test]
439    fn test_unknown_peer_misbehavior_ignored() {
440        let mut scoring = PeerScoring::new();
441        let action = scoring.record_misbehavior(999, Misbehavior::InvalidBlock, 1000);
442        assert_eq!(action, ScoreAction::None);
443    }
444
445    #[test]
446    fn test_custom_threshold() {
447        let mut scoring = PeerScoring::with_config(50, 3600);
448        scoring.register_peer(1, addr(8333));
449
450        // DosAttack = 50, meets the custom threshold of 50.
451        let action = scoring.record_misbehavior(1, Misbehavior::DosAttack, 1000);
452        assert!(matches!(action, ScoreAction::Ban { .. }));
453    }
454
455    #[test]
456    fn test_misbehavior_scores() {
457        assert_eq!(Misbehavior::InvalidBlockHeader.score(), 100);
458        assert_eq!(Misbehavior::InvalidBlock.score(), 100);
459        assert_eq!(Misbehavior::InvalidTransaction.score(), 10);
460        assert_eq!(Misbehavior::UnexpectedMessage.score(), 10);
461        assert_eq!(Misbehavior::AddrFlood.score(), 20);
462        assert_eq!(Misbehavior::InvalidNetworkMessage.score(), 20);
463        assert_eq!(Misbehavior::DosAttack.score(), 50);
464        assert_eq!(Misbehavior::UnconnectableBlock.score(), 10);
465        assert_eq!(Misbehavior::DuplicateVersion.score(), 10);
466        assert_eq!(Misbehavior::Custom(42).score(), 42);
467    }
468
469    #[test]
470    fn test_misbehavior_display() {
471        let m = Misbehavior::InvalidTransaction;
472        let s = format!("{}", m);
473        assert!(s.contains("invalid transaction"));
474        assert!(s.contains("+10"));
475    }
476
477    #[test]
478    fn test_list_bans() {
479        let mut scoring = PeerScoring::new();
480        scoring.register_peer(1, addr(8333));
481        scoring.register_peer(2, addr(8334));
482
483        scoring.record_misbehavior(1, Misbehavior::InvalidBlock, 1000);
484        scoring.record_misbehavior(2, Misbehavior::InvalidBlockHeader, 1000);
485
486        let bans = scoring.list_bans();
487        assert_eq!(bans.len(), 2);
488    }
489}