use std::collections::HashSet;
use std::net::SocketAddr;
use irontide_core::{AddressFamily, Id20};
use crate::compact::CompactNodeInfo;
use crate::routing_table::K;
#[cfg(test)]
const ALPHA: usize = 3;
pub(crate) struct IterativeLookup<C> {
pub target: Id20,
pub closest: Vec<CompactNodeInfo>,
pub queried: HashSet<Id20>,
pub callbacks: C,
}
impl<C> IterativeLookup<C> {
pub fn new(target: Id20, callbacks: C) -> Self {
Self {
target,
closest: Vec::new(),
queried: HashSet::new(),
callbacks,
}
}
pub fn next_to_query(&mut self, alpha: usize) -> Vec<CompactNodeInfo> {
let to_query: Vec<CompactNodeInfo> = self
.closest
.iter()
.filter(|n| !self.queried.contains(&n.id))
.take(alpha)
.copied()
.collect();
for node in &to_query {
self.queried.insert(node.id);
}
to_query
}
pub fn feed_nodes(&mut self, new_nodes: Vec<CompactNodeInfo>, family: AddressFamily) {
let family_match = |addr: &SocketAddr| match family {
AddressFamily::V4 => addr.is_ipv4(),
AddressFamily::V6 => addr.is_ipv6(),
};
for node in new_nodes {
if family_match(&node.addr) && !self.closest.iter().any(|n| n.id == node.id) {
self.closest.push(node);
}
}
let target = self.target;
self.closest.sort_by_key(|n| n.id.xor_distance(&target));
self.closest.truncate(K * 2);
}
#[cfg(test)]
pub fn is_exhausted(&self, has_pending: bool) -> bool {
let has_unqueried = self.closest.iter().any(|n| !self.queried.contains(&n.id));
!has_unqueried && !has_pending
}
}
#[cfg(test)]
pub(crate) struct GetPeersCallbacks {
pub tokens: std::collections::HashMap<Id20, (SocketAddr, Vec<u8>)>,
}
pub(crate) struct FindNodeCallbacks {
pub round: u8,
pub max_rounds: u8,
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_node(id_bytes: [u8; 20], port: u16) -> CompactNodeInfo {
CompactNodeInfo {
id: Id20(id_bytes),
addr: SocketAddr::from(([127, 0, 0, 1], port)),
}
}
#[test]
fn iterative_lookup_next_to_query_alpha() {
let target = Id20::ZERO;
let callbacks = FindNodeCallbacks {
round: 0,
max_rounds: 6,
};
let mut lookup = IterativeLookup::new(target, callbacks);
for i in 1..=5u8 {
let mut id = [0u8; 20];
id[19] = i;
lookup.closest.push(make_node(id, 6880 + u16::from(i)));
}
let batch1 = lookup.next_to_query(ALPHA);
assert_eq!(batch1.len(), ALPHA);
assert_eq!(lookup.queried.len(), 3);
let batch2 = lookup.next_to_query(ALPHA);
assert_eq!(batch2.len(), 2);
assert_eq!(lookup.queried.len(), 5);
let batch3 = lookup.next_to_query(ALPHA);
assert!(batch3.is_empty());
}
#[test]
fn iterative_lookup_exhausted() {
let target = Id20::ZERO;
let callbacks = FindNodeCallbacks {
round: 0,
max_rounds: 6,
};
let mut lookup = IterativeLookup::new(target, callbacks);
assert!(lookup.is_exhausted(false));
assert!(!lookup.is_exhausted(true));
let mut id = [0u8; 20];
id[19] = 1;
lookup.closest.push(make_node(id, 6881));
assert!(!lookup.is_exhausted(false));
let _ = lookup.next_to_query(ALPHA);
assert_eq!(lookup.queried.len(), 1);
assert!(lookup.is_exhausted(false));
assert!(!lookup.is_exhausted(true));
}
#[test]
fn lookup_get_peers_token_tracking() {
let target = Id20::ZERO;
let callbacks = GetPeersCallbacks {
tokens: HashMap::new(),
};
let mut lookup = IterativeLookup::new(target, callbacks);
let node1_id = {
let mut id = [0u8; 20];
id[19] = 1;
Id20(id)
};
let node2_id = {
let mut id = [0u8; 20];
id[19] = 2;
Id20(id)
};
let addr1: SocketAddr = "1.2.3.4:6881".parse().unwrap();
let addr2: SocketAddr = "5.6.7.8:6882".parse().unwrap();
lookup
.callbacks
.tokens
.insert(node1_id, (addr1, b"token_a".to_vec()));
lookup
.callbacks
.tokens
.insert(node2_id, (addr2, b"token_b".to_vec()));
assert_eq!(lookup.callbacks.tokens.len(), 2);
let (stored_addr, stored_token) = lookup
.callbacks
.tokens
.get(&node1_id)
.expect("node1 token missing");
assert_eq!(*stored_addr, addr1);
assert_eq!(stored_token, b"token_a");
let (stored_addr, stored_token) = lookup
.callbacks
.tokens
.get(&node2_id)
.expect("node2 token missing");
assert_eq!(*stored_addr, addr2);
assert_eq!(stored_token, b"token_b");
lookup
.callbacks
.tokens
.insert(node1_id, (addr1, b"token_a_v2".to_vec()));
assert_eq!(lookup.callbacks.tokens.len(), 2);
let (_, updated_token) = lookup
.callbacks
.tokens
.get(&node1_id)
.expect("node1 updated token missing");
assert_eq!(updated_token, b"token_a_v2");
}
}