use std::sync::Arc;
use ahash::AHashMap;
use libp2p::PeerId;
use rand::{thread_rng, Rng};
use tokio::sync::RwLock;
#[derive(Default, Debug, Clone)]
pub struct PeerResponseTracker {
first_responder: Arc<RwLock<AHashMap<PeerId, usize>>>,
}
impl PeerResponseTracker {
pub async fn received_block_from(&self, from: &PeerId) {
*self.first_responder.write().await.entry(*from).or_default() += 1;
}
pub async fn choose(&self, peers: &[PeerId]) -> Option<PeerId> {
if peers.is_empty() {
return None;
}
let rnd: f64 = thread_rng().gen();
let mut total = 0.;
for peer in peers {
total += self.get_peer_count(peer).await as f64;
}
let mut counted = 0.0;
for peer in peers {
counted += self.get_peer_count(peer).await as f64 / total;
if counted > rnd {
return Some(*peer);
}
}
peers.iter().last().copied()
}
pub async fn get_peer_count(&self, peer: &PeerId) -> usize {
self.first_responder
.read()
.await
.get(peer)
.copied()
.unwrap_or(1)
}
}