use std::collections::HashSet;
use libp2p::PeerId;
use tokio::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum NextPeer {
Peer(PeerId),
RefetchMesh,
MaxPeersReached,
BudgetExhausted,
NoMorePeers,
}
#[derive(Debug)]
pub(crate) struct ParentPullBudget {
tried: HashSet<PeerId>,
attempts: usize,
max_additional: usize,
started: Instant,
budget: Duration,
refetched: bool,
}
impl ParentPullBudget {
pub(crate) fn new(initial_peer: PeerId, max_additional: usize, budget: Duration) -> Self {
let mut tried = HashSet::new();
let _ = tried.insert(initial_peer);
Self {
tried,
attempts: 0,
max_additional,
started: Instant::now(),
budget,
refetched: false,
}
}
pub(crate) fn attempts(&self) -> usize {
self.attempts
}
pub(crate) fn total_attempts(&self) -> usize {
self.attempts + 1
}
pub(crate) fn next(&mut self, mesh_peers: &[PeerId]) -> NextPeer {
if self.attempts >= self.max_additional {
return NextPeer::MaxPeersReached;
}
if self.started.elapsed() >= self.budget {
return NextPeer::BudgetExhausted;
}
let untried = mesh_peers.iter().find(|p| !self.tried.contains(p)).copied();
match untried {
Some(peer) => NextPeer::Peer(peer),
None if !self.refetched => NextPeer::RefetchMesh,
None => NextPeer::NoMorePeers,
}
}
pub(crate) fn record_attempt(&mut self, peer: PeerId) {
let _ = self.tried.insert(peer);
self.attempts += 1;
}
pub(crate) fn record_refetch(&mut self) {
self.refetched = true;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn peer(byte: u8) -> PeerId {
use libp2p::identity::Keypair;
let kp = Keypair::ed25519_from_bytes([byte; 32]).expect("ed25519 keypair");
kp.public().to_peer_id()
}
#[test]
fn initial_peer_is_never_suggested() {
let initial = peer(1);
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_secs(10));
let mesh = vec![initial];
assert_eq!(budget.next(&mesh), NextPeer::RefetchMesh);
}
#[test]
fn picks_untried_peer_from_mesh() {
let initial = peer(1);
let p2 = peer(2);
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_secs(10));
assert_eq!(budget.next(&[initial, p2]), NextPeer::Peer(p2));
}
#[test]
fn does_not_retry_same_peer() {
let initial = peer(1);
let p2 = peer(2);
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_secs(10));
let first = budget.next(&[initial, p2]);
assert_eq!(first, NextPeer::Peer(p2));
budget.record_attempt(p2);
assert_eq!(budget.next(&[initial, p2]), NextPeer::RefetchMesh);
}
#[test]
fn stops_at_max_additional_peers() {
let initial = peer(1);
let peers = [peer(2), peer(3), peer(4), peer(5)];
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_secs(10));
let mesh: Vec<PeerId> = std::iter::once(initial)
.chain(peers.iter().copied())
.collect();
for expected in &peers[..3] {
match budget.next(&mesh) {
NextPeer::Peer(p) => {
assert_eq!(&p, expected, "scheduler returned wrong peer");
budget.record_attempt(p);
}
other => panic!("expected Peer, got {:?}", other),
}
}
assert_eq!(budget.next(&mesh), NextPeer::MaxPeersReached);
assert_eq!(budget.attempts(), 3);
assert_eq!(budget.total_attempts(), 4);
}
#[test]
fn budget_exhausted_before_attempts_reached() {
let initial = peer(1);
let p2 = peer(2);
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_millis(0));
std::thread::sleep(Duration::from_millis(1));
assert_eq!(budget.next(&[initial, p2]), NextPeer::BudgetExhausted);
}
#[test]
fn refetch_then_no_more_peers() {
let initial = peer(1);
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_secs(10));
assert_eq!(budget.next(&[initial]), NextPeer::RefetchMesh);
budget.record_refetch();
assert_eq!(budget.next(&[initial]), NextPeer::NoMorePeers);
}
#[test]
fn refetch_followed_by_new_peer_is_accepted() {
let initial = peer(1);
let p2 = peer(2);
let mut budget = ParentPullBudget::new(initial, 3, Duration::from_secs(10));
assert_eq!(budget.next(&[initial]), NextPeer::RefetchMesh);
budget.record_refetch();
assert_eq!(budget.next(&[initial, p2]), NextPeer::Peer(p2));
}
}