forest/libp2p/
peer_manager.rs

1// Copyright 2019-2025 ChainSafe Systems
2// SPDX-License-Identifier: Apache-2.0, MIT
3
4use std::{
5    sync::Arc,
6    time::{Duration, Instant},
7};
8
9use ahash::{HashMap, HashSet};
10use flume::{Receiver, Sender};
11use parking_lot::RwLock;
12use rand::seq::SliceRandom;
13use tracing::{debug, trace, warn};
14
15use crate::libp2p::*;
16
17/// New peer multiplier slightly less than 1 to incentivize choosing new peers.
18const NEW_PEER_MUL: f64 = 0.9;
19
20/// Defines max number of peers to send each chain exchange request to.
21pub(in crate::libp2p) const SHUFFLE_PEERS_PREFIX: usize = 100;
22
23/// Local duration multiplier, affects duration delta change.
24const LOCAL_INV_ALPHA: u32 = 5;
25/// Global duration multiplier, affects duration delta change.
26const GLOBAL_INV_ALPHA: u32 = 20;
27
28#[derive(Debug, Default)]
29/// Contains info about the peer's head [Tipset], as well as the request stats.
30struct PeerInfo {
31    /// Number of successful requests.
32    successes: u32,
33    /// Number of failed requests.
34    failures: u32,
35    /// Average response time for the peer.
36    average_time: Duration,
37}
38
39/// Peer tracking sets, these are handled together to avoid race conditions or
40/// deadlocks when updating state.
41#[derive(Default)]
42struct PeerSets {
43    /// Map of full peers available.
44    full_peers: HashMap<PeerId, PeerInfo>,
45    /// Set of peers to ignore for being incompatible/ failing to accept
46    /// connections.
47    bad_peers: HashSet<PeerId>,
48}
49
50/// Thread safe peer manager which handles peer management for the
51/// `ChainExchange` protocol.
52pub struct PeerManager {
53    /// Full and bad peer sets.
54    peers: RwLock<PeerSets>,
55    /// Average response time from peers.
56    avg_global_time: RwLock<Duration>,
57    /// Peer operation sender
58    peer_ops_tx: Sender<PeerOperation>,
59    /// Peer operation receiver
60    peer_ops_rx: Receiver<PeerOperation>,
61    /// Peer ban list, key is peer id, value is expiration time
62    peer_ban_list: tokio::sync::RwLock<HashMap<PeerId, Option<Instant>>>,
63    /// A set of peers that won't be proactively banned or disconnected from
64    protected_peers: RwLock<HashSet<PeerId>>,
65}
66
67impl Default for PeerManager {
68    fn default() -> Self {
69        let (peer_ops_tx, peer_ops_rx) = flume::unbounded();
70        PeerManager {
71            peers: Default::default(),
72            avg_global_time: Default::default(),
73            peer_ops_tx,
74            peer_ops_rx,
75            peer_ban_list: Default::default(),
76            protected_peers: Default::default(),
77        }
78    }
79}
80
81impl PeerManager {
82    /// Returns true if peer is not marked as bad or not already in set.
83    pub fn is_peer_new(&self, peer_id: &PeerId) -> bool {
84        let peers = self.peers.read();
85        !peers.bad_peers.contains(peer_id) && !peers.full_peers.contains_key(peer_id)
86    }
87
88    /// Mark peer as active even if we haven't communicated with it yet.
89    #[cfg(test)]
90    pub fn touch_peer(&self, peer_id: &PeerId) {
91        let mut peers = self.peers.write();
92        peers.full_peers.entry(*peer_id).or_default();
93    }
94
95    /// Sort peers based on a score function with the success rate and latency
96    /// of requests.
97    pub(in crate::libp2p) fn sorted_peers(&self) -> Vec<PeerId> {
98        let peer_lk = self.peers.read();
99        let average_time = self.avg_global_time.read();
100        let mut peers: Vec<_> = peer_lk
101            .full_peers
102            .iter()
103            .map(|(&p, info)| {
104                let cost = if info.successes + info.failures > 0 {
105                    // Calculate cost based on fail rate and latency
106                    // Note that when `success` is zero, the result is `inf`
107                    let fail_rate = f64::from(info.failures) / f64::from(info.successes);
108                    info.average_time.as_secs_f64() + fail_rate * average_time.as_secs_f64()
109                } else {
110                    // There have been no failures or successes
111                    average_time.as_secs_f64() * NEW_PEER_MUL
112                };
113                (p, cost)
114            })
115            .collect();
116
117        // Unstable sort because hashmap iter order doesn't need to be preserved.
118        peers.sort_unstable_by(|(_, v1), (_, v2)| v1.total_cmp(v2));
119
120        peers.into_iter().map(|(peer, _)| peer).collect()
121    }
122
123    /// Return shuffled slice of ordered peers from the peer manager. Ordering
124    /// is based on failure rate and latency of the peer.
125    pub fn top_peers_shuffled(&self) -> Vec<PeerId> {
126        let mut peers: Vec<_> = self
127            .sorted_peers()
128            .into_iter()
129            .take(SHUFFLE_PEERS_PREFIX)
130            .collect();
131
132        // Shuffle top peers, to avoid sending all requests to same predictable peer.
133        peers.shuffle(&mut crate::utils::rand::forest_rng());
134        peers
135    }
136
137    /// Logs a global request success. This just updates the average for the
138    /// peer manager.
139    pub fn log_global_success(&self, dur: Duration) {
140        debug!("logging global success");
141        let mut avg_global = self.avg_global_time.write();
142        if *avg_global == Duration::default() {
143            *avg_global = dur;
144        } else if dur < *avg_global {
145            let delta = (*avg_global - dur) / GLOBAL_INV_ALPHA;
146            *avg_global -= delta
147        } else {
148            let delta = (dur - *avg_global) / GLOBAL_INV_ALPHA;
149            *avg_global += delta
150        }
151    }
152
153    /// Logs a success for the given peer, and updates the average request
154    /// duration.
155    pub fn log_success(&self, peer: &PeerId, dur: Duration) {
156        trace!("logging success for {peer}");
157        let mut peers = self.peers.write();
158        // Attempt to remove the peer and decrement bad peer count
159        if peers.bad_peers.remove(peer) {
160            metrics::BAD_PEERS.set(peers.bad_peers.len() as _);
161        };
162        let peer_stats = peers.full_peers.entry(*peer).or_default();
163        peer_stats.successes += 1;
164        log_time(peer_stats, dur);
165    }
166
167    /// Logs a failure for the given peer, and updates the average request
168    /// duration.
169    pub fn log_failure(&self, peer: &PeerId, dur: Duration) {
170        trace!("logging failure for {peer}");
171        let mut peers = self.peers.write();
172        if !peers.bad_peers.contains(peer) {
173            metrics::PEER_FAILURE_TOTAL.inc();
174            let peer_stats = peers.full_peers.entry(*peer).or_default();
175            peer_stats.failures += 1;
176            log_time(peer_stats, dur);
177        }
178    }
179
180    /// Removes a peer from the set and returns true if the value was present
181    /// previously
182    pub fn mark_peer_bad(&self, peer_id: PeerId, reason: impl Into<String>) {
183        let mut peers = self.peers.write();
184        remove_peer(&mut peers, &peer_id);
185
186        // Add peer to bad peer set
187        let reason = reason.into();
188        tracing::debug!(%peer_id, %reason, "marked peer bad");
189        if peers.bad_peers.insert(peer_id) {
190            metrics::BAD_PEERS.set(peers.bad_peers.len() as _);
191        }
192    }
193
194    pub fn unmark_peer_bad(&self, peer_id: &PeerId) {
195        let mut peers = self.peers.write();
196        if peers.bad_peers.remove(peer_id) {
197            metrics::BAD_PEERS.set(peers.bad_peers.len() as _);
198        }
199    }
200
201    /// Remove peer from managed set, does not mark as bad
202    pub fn remove_peer(&self, peer_id: &PeerId) {
203        let mut peers = self.peers.write();
204        remove_peer(&mut peers, peer_id);
205    }
206
207    /// Gets peer operation receiver
208    pub fn peer_ops_rx(&self) -> &Receiver<PeerOperation> {
209        &self.peer_ops_rx
210    }
211
212    /// Bans a peer with an optional duration
213    pub async fn ban_peer(
214        &self,
215        peer: PeerId,
216        reason: impl Into<String>,
217        duration: Option<Duration>,
218        get_user_agent: impl Fn(&PeerId) -> Option<String>,
219    ) {
220        if self.is_peer_protected(&peer) {
221            return;
222        }
223        let mut locked = self.peer_ban_list.write().await;
224        locked.insert(peer, duration.and_then(|d| Instant::now().checked_add(d)));
225        let user_agent = get_user_agent(&peer);
226        if let Err(e) = self
227            .peer_ops_tx
228            .send_async(PeerOperation::Ban {
229                peer,
230                user_agent,
231                reason: reason.into(),
232            })
233            .await
234        {
235            warn!("ban_peer err: {e}");
236        }
237    }
238
239    /// Bans a peer with the default duration(`1h`)
240    pub async fn ban_peer_with_default_duration(
241        &self,
242        peer: PeerId,
243        reason: impl Into<String>,
244        get_user_agent: impl Fn(&PeerId) -> Option<String>,
245    ) {
246        const BAN_PEER_DURATION: Duration = Duration::from_secs(60 * 60); //1h
247        self.ban_peer(peer, reason, Some(BAN_PEER_DURATION), get_user_agent)
248            .await
249    }
250
251    pub async fn peer_operation_event_loop_task(self: Arc<Self>) -> anyhow::Result<()> {
252        let mut unban_list = vec![];
253        loop {
254            unban_list.clear();
255
256            let now = Instant::now();
257            for (peer, expiration) in self.peer_ban_list.read().await.iter() {
258                if let Some(expiration) = expiration
259                    && &now > expiration
260                {
261                    unban_list.push(*peer);
262                }
263            }
264            if !unban_list.is_empty() {
265                {
266                    let mut locked = self.peer_ban_list.write().await;
267                    for peer in unban_list.iter() {
268                        locked.remove(peer);
269                    }
270                }
271                for &peer in unban_list.iter() {
272                    if let Err(e) = self
273                        .peer_ops_tx
274                        .send_async(PeerOperation::Unban(peer))
275                        .await
276                    {
277                        warn!("unban_peer err: {e}");
278                    }
279                }
280            }
281            tokio::time::sleep(Duration::from_secs(60)).await;
282        }
283    }
284
285    pub fn peer_count(&self) -> usize {
286        self.peers.read().full_peers.len()
287    }
288
289    pub fn protect_peer(&self, peer_id: PeerId) {
290        self.protected_peers.write().insert(peer_id);
291    }
292
293    pub fn unprotect_peer(&self, peer_id: &PeerId) {
294        self.protected_peers.write().remove(peer_id);
295    }
296
297    pub fn list_protected_peers(&self) -> HashSet<PeerId> {
298        self.protected_peers.read().clone()
299    }
300
301    pub fn is_peer_protected(&self, peer_id: &PeerId) -> bool {
302        self.protected_peers.read().contains(peer_id)
303    }
304}
305
306fn remove_peer(peers: &mut PeerSets, peer_id: &PeerId) {
307    if peers.full_peers.remove(peer_id).is_some() {
308        metrics::FULL_PEERS.set(peers.full_peers.len() as _);
309    }
310    trace!(
311        "removing peer {peer_id}, remaining chain exchange peers: {}",
312        peers.full_peers.len()
313    );
314}
315
316fn log_time(info: &mut PeerInfo, dur: Duration) {
317    if info.average_time == Duration::default() {
318        info.average_time = dur;
319    } else if dur < info.average_time {
320        let delta = (info.average_time - dur) / LOCAL_INV_ALPHA;
321        info.average_time -= delta
322    } else {
323        let delta = (dur - info.average_time) / LOCAL_INV_ALPHA;
324        info.average_time += delta
325    }
326}
327
328pub enum PeerOperation {
329    Ban {
330        peer: PeerId,
331        user_agent: Option<String>,
332        reason: String,
333    },
334    Unban(PeerId),
335}