use crate::types::BadPeer;
use sc_network::ReputationChange as Rep;
use sc_network_types::PeerId;
use schnellru::{ByLength, LruMap};
const LOG_TARGET: &str = "sync::disconnected_peers";
const MAX_DISCONNECTED_PEERS_STATE: u32 = 512;
const DISCONNECTED_PEER_BACKOFF_SECONDS: u64 = 60;
const MAX_NUM_DISCONNECTS: u64 = 3;
pub const REPUTATION_REPORT: Rep = Rep::new_fatal("Peer disconnected with inflight after backoffs");
#[derive(Debug)]
struct DisconnectedState {
num_disconnects: u64,
last_disconnect: std::time::Instant,
}
impl DisconnectedState {
pub fn new() -> Self {
Self { num_disconnects: 1, last_disconnect: std::time::Instant::now() }
}
pub fn increment(&mut self) {
self.num_disconnects = self.num_disconnects.saturating_add(1);
self.last_disconnect = std::time::Instant::now();
}
pub fn num_disconnects(&self) -> u64 {
self.num_disconnects
}
pub fn last_disconnect(&self) -> std::time::Instant {
self.last_disconnect
}
}
pub struct DisconnectedPeers {
disconnected_peers: LruMap<PeerId, DisconnectedState>,
backoff_seconds: u64,
}
impl DisconnectedPeers {
pub fn new() -> Self {
Self {
disconnected_peers: LruMap::new(ByLength::new(MAX_DISCONNECTED_PEERS_STATE)),
backoff_seconds: DISCONNECTED_PEER_BACKOFF_SECONDS,
}
}
pub fn on_disconnect_during_request(&mut self, peer: PeerId) -> Option<BadPeer> {
if let Some(state) = self.disconnected_peers.get(&peer) {
state.increment();
let should_ban = state.num_disconnects() >= MAX_NUM_DISCONNECTS;
log::debug!(
target: LOG_TARGET,
"Disconnected known peer {peer} state: {state:?}, should ban: {should_ban}",
);
should_ban.then(|| {
self.disconnected_peers.remove(&peer);
BadPeer(peer, REPUTATION_REPORT)
})
} else {
log::debug!(
target: LOG_TARGET,
"Added peer {peer} for the first time"
);
self.disconnected_peers.insert(peer, DisconnectedState::new());
None
}
}
pub fn is_peer_available(&mut self, peer_id: &PeerId) -> bool {
let Some(state) = self.disconnected_peers.get(peer_id) else {
return true;
};
let elapsed = state.last_disconnect().elapsed();
if elapsed.as_secs() >= self.backoff_seconds * state.num_disconnects {
log::debug!(target: LOG_TARGET, "Peer {peer_id} is available for queries");
self.disconnected_peers.remove(peer_id);
true
} else {
log::debug!(target: LOG_TARGET,"Peer {peer_id} is backedoff");
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_disconnected_peer_state() {
let mut state = DisconnectedPeers::new();
let peer = PeerId::random();
assert_eq!(state.is_peer_available(&peer), true);
for _ in 0..MAX_NUM_DISCONNECTS - 1 {
assert!(state.on_disconnect_during_request(peer).is_none());
assert_eq!(state.is_peer_available(&peer), false);
}
assert!(state.on_disconnect_during_request(peer).is_some());
assert!(state.disconnected_peers.get(&peer).is_none());
}
#[test]
fn ensure_backoff_time() {
const TEST_BACKOFF_SECONDS: u64 = 2;
let mut state = DisconnectedPeers {
disconnected_peers: LruMap::new(ByLength::new(1)),
backoff_seconds: TEST_BACKOFF_SECONDS,
};
let peer = PeerId::random();
assert!(state.on_disconnect_during_request(peer).is_none());
assert_eq!(state.is_peer_available(&peer), false);
std::thread::sleep(Duration::from_secs(TEST_BACKOFF_SECONDS + 1));
assert_eq!(state.is_peer_available(&peer), true);
}
}