forest/libp2p/
peer_manager.rs1use std::{
5 sync::Arc,
6 time::{Duration, Instant},
7};
8
9use ahash::{HashMap, HashSet};
10use flume::{Receiver, Sender};
11use parking_lot::RwLock;
12use rand::seq::SliceRandom;
13use tracing::{debug, trace, warn};
14
15use crate::libp2p::*;
16
17const NEW_PEER_MUL: f64 = 0.9;
19
20pub(in crate::libp2p) const SHUFFLE_PEERS_PREFIX: usize = 100;
22
23const LOCAL_INV_ALPHA: u32 = 5;
25const GLOBAL_INV_ALPHA: u32 = 20;
27
28#[derive(Debug, Default)]
29struct PeerInfo {
31 successes: u32,
33 failures: u32,
35 average_time: Duration,
37}
38
39#[derive(Default)]
42struct PeerSets {
43 full_peers: HashMap<PeerId, PeerInfo>,
45 bad_peers: HashSet<PeerId>,
48}
49
50pub struct PeerManager {
53 peers: RwLock<PeerSets>,
55 avg_global_time: RwLock<Duration>,
57 peer_ops_tx: Sender<PeerOperation>,
59 peer_ops_rx: Receiver<PeerOperation>,
61 peer_ban_list: tokio::sync::RwLock<HashMap<PeerId, Option<Instant>>>,
63 protected_peers: RwLock<HashSet<PeerId>>,
65}
66
67impl Default for PeerManager {
68 fn default() -> Self {
69 let (peer_ops_tx, peer_ops_rx) = flume::unbounded();
70 PeerManager {
71 peers: Default::default(),
72 avg_global_time: Default::default(),
73 peer_ops_tx,
74 peer_ops_rx,
75 peer_ban_list: Default::default(),
76 protected_peers: Default::default(),
77 }
78 }
79}
80
81impl PeerManager {
82 pub fn is_peer_new(&self, peer_id: &PeerId) -> bool {
84 let peers = self.peers.read();
85 !peers.bad_peers.contains(peer_id) && !peers.full_peers.contains_key(peer_id)
86 }
87
88 #[cfg(test)]
90 pub fn touch_peer(&self, peer_id: &PeerId) {
91 let mut peers = self.peers.write();
92 peers.full_peers.entry(*peer_id).or_default();
93 }
94
95 pub(in crate::libp2p) fn sorted_peers(&self) -> Vec<PeerId> {
98 let peer_lk = self.peers.read();
99 let average_time = self.avg_global_time.read();
100 let mut peers: Vec<_> = peer_lk
101 .full_peers
102 .iter()
103 .map(|(&p, info)| {
104 let cost = if info.successes + info.failures > 0 {
105 let fail_rate = f64::from(info.failures) / f64::from(info.successes);
108 info.average_time.as_secs_f64() + fail_rate * average_time.as_secs_f64()
109 } else {
110 average_time.as_secs_f64() * NEW_PEER_MUL
112 };
113 (p, cost)
114 })
115 .collect();
116
117 peers.sort_unstable_by(|(_, v1), (_, v2)| v1.total_cmp(v2));
119
120 peers.into_iter().map(|(peer, _)| peer).collect()
121 }
122
123 pub fn top_peers_shuffled(&self) -> Vec<PeerId> {
126 let mut peers: Vec<_> = self
127 .sorted_peers()
128 .into_iter()
129 .take(SHUFFLE_PEERS_PREFIX)
130 .collect();
131
132 peers.shuffle(&mut crate::utils::rand::forest_rng());
134 peers
135 }
136
137 pub fn log_global_success(&self, dur: Duration) {
140 debug!("logging global success");
141 let mut avg_global = self.avg_global_time.write();
142 if *avg_global == Duration::default() {
143 *avg_global = dur;
144 } else if dur < *avg_global {
145 let delta = (*avg_global - dur) / GLOBAL_INV_ALPHA;
146 *avg_global -= delta
147 } else {
148 let delta = (dur - *avg_global) / GLOBAL_INV_ALPHA;
149 *avg_global += delta
150 }
151 }
152
153 pub fn log_success(&self, peer: &PeerId, dur: Duration) {
156 trace!("logging success for {peer}");
157 let mut peers = self.peers.write();
158 if peers.bad_peers.remove(peer) {
160 metrics::BAD_PEERS.set(peers.bad_peers.len() as _);
161 };
162 let peer_stats = peers.full_peers.entry(*peer).or_default();
163 peer_stats.successes += 1;
164 log_time(peer_stats, dur);
165 }
166
167 pub fn log_failure(&self, peer: &PeerId, dur: Duration) {
170 trace!("logging failure for {peer}");
171 let mut peers = self.peers.write();
172 if !peers.bad_peers.contains(peer) {
173 metrics::PEER_FAILURE_TOTAL.inc();
174 let peer_stats = peers.full_peers.entry(*peer).or_default();
175 peer_stats.failures += 1;
176 log_time(peer_stats, dur);
177 }
178 }
179
180 pub fn mark_peer_bad(&self, peer_id: PeerId, reason: impl Into<String>) {
183 let mut peers = self.peers.write();
184 remove_peer(&mut peers, &peer_id);
185
186 let reason = reason.into();
188 tracing::debug!(%peer_id, %reason, "marked peer bad");
189 if peers.bad_peers.insert(peer_id) {
190 metrics::BAD_PEERS.set(peers.bad_peers.len() as _);
191 }
192 }
193
194 pub fn unmark_peer_bad(&self, peer_id: &PeerId) {
195 let mut peers = self.peers.write();
196 if peers.bad_peers.remove(peer_id) {
197 metrics::BAD_PEERS.set(peers.bad_peers.len() as _);
198 }
199 }
200
201 pub fn remove_peer(&self, peer_id: &PeerId) {
203 let mut peers = self.peers.write();
204 remove_peer(&mut peers, peer_id);
205 }
206
207 pub fn peer_ops_rx(&self) -> &Receiver<PeerOperation> {
209 &self.peer_ops_rx
210 }
211
212 pub async fn ban_peer(
214 &self,
215 peer: PeerId,
216 reason: impl Into<String>,
217 duration: Option<Duration>,
218 get_user_agent: impl Fn(&PeerId) -> Option<String>,
219 ) {
220 if self.is_peer_protected(&peer) {
221 return;
222 }
223 let mut locked = self.peer_ban_list.write().await;
224 locked.insert(peer, duration.and_then(|d| Instant::now().checked_add(d)));
225 let user_agent = get_user_agent(&peer);
226 if let Err(e) = self
227 .peer_ops_tx
228 .send_async(PeerOperation::Ban {
229 peer,
230 user_agent,
231 reason: reason.into(),
232 })
233 .await
234 {
235 warn!("ban_peer err: {e}");
236 }
237 }
238
239 pub async fn ban_peer_with_default_duration(
241 &self,
242 peer: PeerId,
243 reason: impl Into<String>,
244 get_user_agent: impl Fn(&PeerId) -> Option<String>,
245 ) {
246 const BAN_PEER_DURATION: Duration = Duration::from_secs(60 * 60); self.ban_peer(peer, reason, Some(BAN_PEER_DURATION), get_user_agent)
248 .await
249 }
250
251 pub async fn peer_operation_event_loop_task(self: Arc<Self>) -> anyhow::Result<()> {
252 let mut unban_list = vec![];
253 loop {
254 unban_list.clear();
255
256 let now = Instant::now();
257 for (peer, expiration) in self.peer_ban_list.read().await.iter() {
258 if let Some(expiration) = expiration
259 && &now > expiration
260 {
261 unban_list.push(*peer);
262 }
263 }
264 if !unban_list.is_empty() {
265 {
266 let mut locked = self.peer_ban_list.write().await;
267 for peer in unban_list.iter() {
268 locked.remove(peer);
269 }
270 }
271 for &peer in unban_list.iter() {
272 if let Err(e) = self
273 .peer_ops_tx
274 .send_async(PeerOperation::Unban(peer))
275 .await
276 {
277 warn!("unban_peer err: {e}");
278 }
279 }
280 }
281 tokio::time::sleep(Duration::from_secs(60)).await;
282 }
283 }
284
285 pub fn peer_count(&self) -> usize {
286 self.peers.read().full_peers.len()
287 }
288
289 pub fn protect_peer(&self, peer_id: PeerId) {
290 self.protected_peers.write().insert(peer_id);
291 }
292
293 pub fn unprotect_peer(&self, peer_id: &PeerId) {
294 self.protected_peers.write().remove(peer_id);
295 }
296
297 pub fn list_protected_peers(&self) -> HashSet<PeerId> {
298 self.protected_peers.read().clone()
299 }
300
301 pub fn is_peer_protected(&self, peer_id: &PeerId) -> bool {
302 self.protected_peers.read().contains(peer_id)
303 }
304}
305
306fn remove_peer(peers: &mut PeerSets, peer_id: &PeerId) {
307 if peers.full_peers.remove(peer_id).is_some() {
308 metrics::FULL_PEERS.set(peers.full_peers.len() as _);
309 }
310 trace!(
311 "removing peer {peer_id}, remaining chain exchange peers: {}",
312 peers.full_peers.len()
313 );
314}
315
316fn log_time(info: &mut PeerInfo, dur: Duration) {
317 if info.average_time == Duration::default() {
318 info.average_time = dur;
319 } else if dur < info.average_time {
320 let delta = (info.average_time - dur) / LOCAL_INV_ALPHA;
321 info.average_time -= delta
322 } else {
323 let delta = (dur - info.average_time) / LOCAL_INV_ALPHA;
324 info.average_time += delta
325 }
326}
327
328pub enum PeerOperation {
329 Ban {
330 peer: PeerId,
331 user_agent: Option<String>,
332 reason: String,
333 },
334 Unban(PeerId),
335}