use std::{
convert::{TryFrom, TryInto},
fmt,
mem,
str::FromStr,
};
use super::{NodeId, node_id::NodeIdError};
pub type NodeDistance = XorDistance;
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct XorDistance(u128);
impl XorDistance {
pub fn new() -> Self {
Self(0)
}
pub fn from_node_ids(x: &NodeId, y: &NodeId) -> Self {
let arr = x ^ y;
arr[..]
.try_into()
.expect("unreachable panic: NodeId::byte_size() <= NodeDistance::byte_size()")
}
pub const fn max_distance() -> Self {
Self(u128::MAX)
}
pub const fn zero() -> Self {
Self(0)
}
pub const fn byte_size() -> usize {
mem::size_of::<u128>()
}
pub fn get_bucket_index(&self) -> u8 {
((u8::try_from(Self::byte_size()).unwrap() * 8) - u8::try_from(self.0.leading_zeros()).unwrap())
.saturating_sub(1)
}
pub fn to_bytes(&self) -> [u8; Self::byte_size()] {
self.0.to_be_bytes()
}
pub fn as_u128(&self) -> u128 {
self.0
}
}
impl TryFrom<&[u8]> for XorDistance {
type Error = NodeIdError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() > Self::byte_size() {
return Err(NodeIdError::IncorrectByteCount);
}
let mut buf = [0; Self::byte_size()];
let offset = Self::byte_size() - bytes.len();
buf.get_mut(offset..)
.ok_or(NodeIdError::IncorrectByteCount)?
.copy_from_slice(bytes);
Ok(XorDistance(u128::from_be_bytes(buf)))
}
}
impl fmt::Display for NodeDistance {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl FromStr for XorDistance {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
u128::from_str_radix(s, 16).map(XorDistance)
}
}
impl fmt::Debug for XorDistance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut digits = 0;
let mut suffix = "";
loop {
let prefix = self.0 / u128::pow(10, 3 * (digits + 1));
if prefix == 0 || digits > 8 {
return write!(f, "XorDist: {}{}", self.0 / u128::pow(10, 3 * digits), suffix);
}
digits += 1;
suffix = match suffix {
"" => "thousand",
"thousand" => "million",
"million" => "billion",
"billion" => "trillion",
"trillion" => "quadrillion",
"quadrillion" => "quintillion",
"quintillion" => "sextillion",
"sextillion" => "septillion",
"septillion" => "e24",
_ => suffix,
}
}
}
}
#[cfg(test)]
mod test {
use rand::rngs::OsRng;
use super::*;
use crate::types::CommsPublicKey;
mod ord {
use super::*;
#[test]
fn it_uses_big_endian_ordering() {
let a = NodeDistance::try_from(&[0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1][..]).unwrap();
let b = NodeDistance::try_from(&[1u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][..]).unwrap();
assert!(a < b);
}
}
mod get_bucket_index {
use super::*;
#[test]
fn it_returns_the_correct_index() {
fn check_for_dist(lsb_dist: u8, expected: u8) {
assert_eq!(
NodeDistance::try_from(&[0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, lsb_dist][..])
.unwrap()
.get_bucket_index(),
expected,
"Failed for dist = {lsb_dist}"
);
}
assert_eq!(NodeDistance::max_distance().get_bucket_index(), 127);
assert_eq!(NodeDistance::zero().get_bucket_index(), 0);
check_for_dist(1, 0);
for i in 2..4 {
check_for_dist(i, 1);
}
for i in 4..8 {
check_for_dist(i, 2);
}
for i in 8..16 {
check_for_dist(i, 3);
}
assert_eq!(
NodeDistance::try_from(&[0u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0b01000001, 0, 0][..])
.unwrap()
.get_bucket_index(),
8 * 2 + 7 - 1
);
assert_eq!(
NodeDistance::try_from(&[0b10000000u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0][..])
.unwrap()
.get_bucket_index(),
103
);
}
#[test]
fn correctness_fuzzing() {
for _ in 0..100 {
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let a = NodeId::from_public_key(&pk);
let (_, pk) = CommsPublicKey::random_keypair(&mut OsRng);
let b = NodeId::from_public_key(&pk);
let dist = NodeDistance::from_node_ids(&a, &b);
let i = u32::from(dist.get_bucket_index());
let dist = dist.as_u128();
assert!(2u128.pow(i) <= dist, "Failed for {dist}, i = {i}");
assert!(dist < 2u128.pow(i + 1), "Failed for {dist}, i = {i}");
}
}
}
}