use std::cmp::Ordering;
use itertools::Itertools;
use tox_crypto::*;
use tox_packet::dht::packed_node::*;
use crate::dht::dht_node::*;
use crate::dht::kbucket::*;
use crate::dht::ktree::*;
use crate::dht::ip_port::IsGlobal;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ForcedKtree {
ktree: Ktree,
kbucket: Kbucket<DhtNode>,
}
impl ForcedKtree {
pub fn new(pk: &PublicKey) -> Self {
ForcedKtree {
ktree: Ktree::new(pk),
kbucket: Kbucket::new(KBUCKET_DEFAULT_SIZE),
}
}
pub fn get_node(&self, pk: &PublicKey) -> Option<&DhtNode> {
self.ktree.get_node(pk).or_else(||
self.kbucket.get_node(&self.ktree.pk, pk)
)
}
pub fn get_node_mut(&mut self, pk: &PublicKey) -> Option<&mut DhtNode> {
let base_pk = self.ktree.pk;
let bucket = &mut self.kbucket;
self.ktree.get_node_mut(pk).or_else(move ||
bucket.get_node_mut(&base_pk, pk)
)
}
pub fn try_add(&mut self, node: PackedNode) -> bool {
if self.ktree.try_add(node) {
if let Some(dht_node) = self.kbucket.remove(&self.ktree.pk, &node.pk) {
let added_node = self.ktree.get_node_mut(&node.pk).expect("Node should be added");
if node.saddr.is_ipv4() {
added_node.assoc6 = dht_node.assoc6;
} else {
added_node.assoc4 = dht_node.assoc4;
}
}
true
} else if !self.ktree.contains(&node.pk) {
self.kbucket.try_add(&self.ktree.pk, node, true)
} else {
false
}
}
pub fn remove(&mut self, node_pk: &PublicKey) -> Option<DhtNode> {
self.ktree.remove(node_pk).or_else(||
self.kbucket.remove(&self.ktree.pk, node_pk)
)
}
pub fn get_closest(&self, pk: &PublicKey, count: u8, only_global: bool) -> Kbucket<PackedNode> {
let mut kbucket = self.ktree.get_closest(pk, count, only_global);
for node in self.kbucket.iter().filter(|node| !node.is_bad()) {
if let Some(pn) = node.to_packed_node() {
if !only_global || IsGlobal::is_global(&pn.saddr.ip()) {
kbucket.try_add(pk, pn, true);
}
}
}
kbucket
}
pub fn contains(&self, pk: &PublicKey) -> bool {
self.kbucket.contains(&self.ktree.pk, pk) ||
self.ktree.contains(pk)
}
pub fn can_add(&self, new_node: &PackedNode) -> bool {
self.ktree.can_add(new_node) ||
!self.ktree.contains(&new_node.pk) && self.kbucket.can_add(&self.ktree.pk, new_node, true)
}
pub fn is_empty(&self) -> bool {
self.ktree.is_empty() &&
self.kbucket.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &DhtNode> {
let pk = self.ktree.pk;
self.ktree.iter().merge_by(self.kbucket.iter(), move |x, y|
pk.distance(&x.pk, &y.pk) == Ordering::Less
)
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut DhtNode> {
let pk = self.ktree.pk;
self.ktree.iter_mut().merge_by(self.kbucket.iter_mut(), move |x, y|
pk.distance(&x.pk, &y.pk) == Ordering::Less
)
}
pub fn is_all_discarded(&self) -> bool {
self.iter().all(|node| node.is_discarded())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use std::time::Duration;
#[test]
fn forced_ktree_try_add() {
let pk = PublicKey([0; PUBLICKEYBYTES]);
let mut ktree = ForcedKtree::new(&pk);
for i in 0 .. 8 {
let mut pk = [i + 2; PUBLICKEYBYTES];
pk[0] = 255;
let pk = PublicKey(pk);
let addr = SocketAddr::new("1.2.3.4".parse().unwrap(), 12345 + u16::from(i));
let node = PackedNode::new(addr, &pk);
assert!(ktree.try_add(node));
}
let mut pk = [1; PUBLICKEYBYTES];
pk[0] = 255;
let pk = PublicKey(pk);
let node = PackedNode::new(
"1.2.3.5:12345".parse().unwrap(),
&pk
);
assert!(ktree.try_add(node));
}
#[test]
fn forced_ktree_remove() {
let pk = PublicKey([0; PUBLICKEYBYTES]);
let mut ktree = ForcedKtree::new(&pk);
let node1 = PackedNode::new(
"1.2.3.4:12345".parse().unwrap(),
&PublicKey([1; PUBLICKEYBYTES])
);
let node2 = PackedNode::new(
"1.2.3.4:12345".parse().unwrap(),
&PublicKey([2; PUBLICKEYBYTES])
);
assert!(ktree.remove(&node1.pk).is_none());
assert!(ktree.is_empty());
assert!(ktree.kbucket.try_add(&pk, node1, true));
assert!(!ktree.is_empty());
assert!(ktree.try_add(node2));
assert!(!ktree.is_empty());
assert!(ktree.remove(&node1.pk).is_some());
assert!(!ktree.is_empty());
assert!(ktree.remove(&node2.pk).is_some());
assert!(ktree.is_empty());
}
#[test]
fn forced_ktree_get_closest() {
let pk = PublicKey([0; PUBLICKEYBYTES]);
let mut ktree = ForcedKtree::new(&pk);
fn node_by_idx(i: u8) -> PackedNode {
let addr = SocketAddr::new("1.2.3.4".parse().unwrap(), 12345 + u16::from(i));
PackedNode::new(addr, &PublicKey([i + 1; PUBLICKEYBYTES]))
}
for i in 0 .. 4 {
assert!(ktree.try_add(node_by_idx(i)));
}
for i in 4 .. 8 {
assert!(ktree.kbucket.try_add(&pk, node_by_idx(i), true));
}
for count in 1 ..= 4 {
let closest: Vec<_> = ktree.get_closest(&PublicKey([0; PUBLICKEYBYTES]), count, true).into();
let should_be = (0 .. count).map(node_by_idx).collect::<Vec<_>>();
assert_eq!(closest, should_be);
let closest: Vec<_> = ktree.get_closest(&PublicKey([255; PUBLICKEYBYTES]), count, true).into();
let should_be = (8 - count .. 8).rev().map(node_by_idx).collect::<Vec<_>>();
assert_eq!(closest, should_be);
}
}
#[test]
fn forced_ktree_contains() {
crypto_init().unwrap();
let (pk, _) = gen_keypair();
let mut ktree = ForcedKtree::new(&pk);
assert!(!ktree.contains(&pk));
let node = PackedNode::new(
"1.2.3.5:12345".parse().unwrap(),
&gen_keypair().0
);
assert!(!ktree.contains(&node.pk));
assert!(ktree.try_add(node));
assert!(ktree.contains(&node.pk));
let node = PackedNode::new(
"1.2.3.4:12345".parse().unwrap(),
&PublicKey([1; PUBLICKEYBYTES])
);
assert!(!ktree.contains(&node.pk));
assert!(ktree.kbucket.try_add(&pk, node, true));
assert!(ktree.contains(&node.pk));
}
#[test]
fn forced_ktree_can_add() {
crypto_init().unwrap();
let pk = PublicKey([0; PUBLICKEYBYTES]);
let mut ktree = ForcedKtree::new(&pk);
for i in 0 .. 16 {
let mut pk = [i + 2; PUBLICKEYBYTES];
pk[0] = 255;
let pk = PublicKey(pk);
let addr = SocketAddr::new("1.2.3.4".parse().unwrap(), 12345 + u16::from(i));
let node = PackedNode::new(addr, &pk);
assert!(ktree.can_add(&node));
assert!(ktree.try_add(node));
assert!(!ktree.can_add(&node));
}
let mut pk = [1; PUBLICKEYBYTES];
pk[0] = 255;
let pk = PublicKey(pk);
let node = PackedNode::new(
"1.2.3.5:12345".parse().unwrap(),
&pk
);
assert!(ktree.can_add(&node));
let mut pk = [18; PUBLICKEYBYTES];
pk[0] = 255;
let pk = PublicKey(pk);
let node = PackedNode::new(
"1.2.3.5:12345".parse().unwrap(),
&pk
);
assert!(!ktree.can_add(&node));
}
#[test]
fn forced_ktree_iter() {
let pk = PublicKey([0; PUBLICKEYBYTES]);
let mut ktree = ForcedKtree::new(&pk);
assert!(ktree.iter().next().is_none());
fn node_by_idx(i: u8) -> PackedNode {
let addr = SocketAddr::new("1.2.3.4".parse().unwrap(), 12345 + u16::from(i));
PackedNode::new(addr, &PublicKey([i + 1; PUBLICKEYBYTES]))
}
for i in 0 .. 4 {
assert!(ktree.try_add(node_by_idx(i)));
}
for i in 4 .. 8 {
assert!(ktree.kbucket.try_add(&pk, node_by_idx(i), true));
}
assert_eq!(ktree.iter().count(), 8);
for (i, node) in ktree.iter().enumerate() {
assert_eq!(node.pk, PublicKey([i as u8 + 1; PUBLICKEYBYTES]));
}
}
#[test]
fn forced_ktree_iter_mut() {
let pk = PublicKey([0; PUBLICKEYBYTES]);
let mut ktree = ForcedKtree::new(&pk);
assert!(ktree.iter_mut().next().is_none());
fn node_by_idx(i: u8) -> PackedNode {
let addr = SocketAddr::new("1.2.3.4".parse().unwrap(), 12345 + u16::from(i));
PackedNode::new(addr, &PublicKey([i + 1; PUBLICKEYBYTES]))
}
for i in 0 .. 4 {
assert!(ktree.try_add(node_by_idx(i)));
}
for i in 4 .. 8 {
assert!(ktree.kbucket.try_add(&pk, node_by_idx(i), true));
}
assert_eq!(ktree.iter_mut().count(), 8);
for (i, node) in ktree.iter_mut().enumerate() {
assert_eq!(node.pk, PublicKey([i as u8 + 1; PUBLICKEYBYTES]));
}
}
#[tokio::test]
async fn forced_ktree_is_all_discarded() {
crypto_init().unwrap();
let (pk, _) = gen_keypair();
let mut ktree = ForcedKtree::new(&pk);
let node = PackedNode::new(
"1.2.3.4:33445".parse().unwrap(),
&gen_keypair().0
);
assert!(ktree.try_add(node));
let node = PackedNode::new(
"1.2.3.5:12345".parse().unwrap(),
&gen_keypair().0
);
assert!(ktree.kbucket.try_add(&pk, node, true));
assert!(!ktree.is_all_discarded());
tokio::time::pause();
tokio::time::advance(KILL_NODE_TIMEOUT + Duration::from_secs(1)).await;
assert!(ktree.is_all_discarded());
}
}