use alloc::vec::Vec;
use core::{fmt, ops::Add, time::Duration};
use sha2::{Digest as _, Sha256};
pub struct KBuckets<K, V, TNow, const ENTRIES_PER_BUCKET: usize> {
local_key: (K, Key),
buckets: Vec<Bucket<K, V, TNow, ENTRIES_PER_BUCKET>>,
pending_timeout: Duration,
}
impl<K, V, TNow, const ENTRIES_PER_BUCKET: usize> KBuckets<K, V, TNow, ENTRIES_PER_BUCKET>
where
K: Clone + PartialEq + AsRef<[u8]>,
TNow: Clone + Add<Duration, Output = TNow> + Ord,
{
pub fn new(local_key: K, pending_timeout: Duration) -> Self {
let local_key_hashed = Key::new(local_key.as_ref());
KBuckets {
local_key: (local_key, local_key_hashed),
buckets: (0..256)
.map(|_| Bucket {
entries: arrayvec::ArrayVec::new(),
num_connected_entries: 0,
pending_entry: None,
})
.collect(),
pending_timeout,
}
}
pub fn local_key(&self) -> &K {
&self.local_key.0
}
pub fn get(&self, key: &K) -> Option<&V> {
let key_hashed = Key::new(key.as_ref());
let distance = match distance_log2(&self.local_key.1, &key_hashed) {
Some(d) => d,
None => return None,
};
self.buckets[usize::from(distance)].get(key)
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
let key_hashed = Key::new(key.as_ref());
let distance = match distance_log2(&self.local_key.1, &key_hashed) {
Some(d) => d,
None => return None,
};
self.buckets[usize::from(distance)].get_mut(key)
}
pub fn entry<'a>(&'a mut self, key: &'a K) -> Entry<'a, K, V, TNow, ENTRIES_PER_BUCKET>
where
K: Clone,
{
let key_hashed = Key::new(key.as_ref());
let distance = match distance_log2(&self.local_key.1, &key_hashed) {
Some(d) => d,
None => return Entry::LocalKey,
};
if self.buckets[usize::from(distance)].get_mut(key).is_some() {
return Entry::Occupied(OccupiedEntry {
inner: self,
key,
distance,
});
}
Entry::Vacant(VacantEntry {
inner: self,
key,
distance,
})
}
pub fn closest_entries(&self, target: &K) -> impl Iterator<Item = (&K, &V)> {
let target_hashed = Key::new(target.as_ref());
let mut list = self.iter_ordered().collect::<Vec<_>>();
list.sort_by_key(|(key, _)| {
let key_hashed = Key::new(key.as_ref());
distance_log2(&key_hashed, &target_hashed).map_or(0, |d| u16::from(d) + 1)
});
list.into_iter()
}
}
impl<K, V, TNow, const ENTRIES_PER_BUCKET: usize> KBuckets<K, V, TNow, ENTRIES_PER_BUCKET> {
pub fn iter_ordered(&self) -> impl Iterator<Item = (&K, &V)> {
self.buckets
.iter()
.flat_map(|b| b.entries.iter().map(|(k, v)| (k, v)))
}
pub fn iter_mut_ordered(&mut self) -> impl Iterator<Item = (&K, &mut V)> {
self.buckets
.iter_mut()
.flat_map(|b| b.entries.iter_mut().map(|(k, v)| (&*k, v)))
}
}
impl<K, V, TNow, const ENTRIES_PER_BUCKET: usize> fmt::Debug
for KBuckets<K, V, TNow, ENTRIES_PER_BUCKET>
where
K: fmt::Debug,
V: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_list().entries(self.iter_ordered()).finish()
}
}
pub enum Entry<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize> {
LocalKey,
Vacant(VacantEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>),
Occupied(OccupiedEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>),
}
impl<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize> Entry<'a, K, V, TNow, ENTRIES_PER_BUCKET>
where
K: Clone + PartialEq + AsRef<[u8]>,
TNow: Clone + Add<Duration, Output = TNow> + Ord,
{
pub fn into_occupied(self) -> Option<OccupiedEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>> {
match self {
Entry::LocalKey | Entry::Vacant(_) => None,
Entry::Occupied(e) => Some(e),
}
}
pub fn or_insert(
self,
value: V,
now: &TNow,
state: PeerState,
) -> Result<OccupiedEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>, OrInsertError> {
match self {
Entry::LocalKey => Err(OrInsertError::LocalKey),
Entry::Vacant(v) => match v.insert(value, now, state) {
Ok(InsertResult { entry, .. }) => Ok(entry),
Err(InsertError::Full) => Err(OrInsertError::Full),
},
Entry::Occupied(e) => Ok(e),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, derive_more::Display, derive_more::Error)]
pub enum OrInsertError {
Full,
LocalKey,
}
pub struct VacantEntry<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize> {
inner: &'a mut KBuckets<K, V, TNow, ENTRIES_PER_BUCKET>,
key: &'a K,
distance: u8,
}
pub struct InsertResult<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize> {
pub entry: OccupiedEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>,
pub removed_entry: Option<(K, V)>,
}
impl<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize>
VacantEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>
where
K: Clone + PartialEq + AsRef<[u8]>,
TNow: Clone + Add<Duration, Output = TNow> + Ord,
{
pub fn insert(
self,
value: V,
now: &TNow,
state: PeerState,
) -> Result<InsertResult<'a, K, V, TNow, ENTRIES_PER_BUCKET>, InsertError> {
let bucket = &mut self.inner.buckets[usize::from(self.distance)];
let removed_entry = match state {
PeerState::Connected if bucket.num_connected_entries < ENTRIES_PER_BUCKET => {
let mut previous_entry = None;
if bucket.entries.is_full() {
previous_entry = bucket.entries.pop();
debug_assert!(previous_entry.is_some());
bucket.pending_entry = Some(now.clone() + self.inner.pending_timeout);
}
bucket
.entries
.insert(bucket.num_connected_entries, (self.key.clone(), value));
bucket.num_connected_entries += 1;
if bucket.num_connected_entries == ENTRIES_PER_BUCKET {
bucket.pending_entry = None;
}
previous_entry
}
PeerState::Connected => {
debug_assert!(bucket.entries.is_full());
debug_assert_eq!(bucket.num_connected_entries, ENTRIES_PER_BUCKET);
debug_assert!(bucket.pending_entry.is_none());
return Err(InsertError::Full);
}
PeerState::Disconnected if bucket.entries.is_full() => {
if bucket.num_connected_entries == ENTRIES_PER_BUCKET {
return Err(InsertError::Full);
}
if *bucket.pending_entry.as_ref().unwrap() > *now {
return Err(InsertError::Full);
}
let previous_entry = bucket.entries.pop();
bucket.entries.push((self.key.clone(), value));
bucket.pending_entry = Some(now.clone() + self.inner.pending_timeout);
previous_entry
}
PeerState::Disconnected => {
debug_assert!(!bucket.entries.is_full());
debug_assert!(bucket.pending_entry.is_none());
bucket.entries.push((self.key.clone(), value));
if bucket.entries.is_full() {
bucket.pending_entry = Some(now.clone() + self.inner.pending_timeout);
}
None
}
};
Ok(InsertResult {
entry: OccupiedEntry {
inner: self.inner,
key: self.key,
distance: self.distance,
},
removed_entry,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, derive_more::Display, derive_more::Error)]
pub enum InsertError {
Full,
}
pub struct OccupiedEntry<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize> {
inner: &'a mut KBuckets<K, V, TNow, ENTRIES_PER_BUCKET>,
key: &'a K,
distance: u8,
}
impl<'a, K, V, TNow, const ENTRIES_PER_BUCKET: usize>
OccupiedEntry<'a, K, V, TNow, ENTRIES_PER_BUCKET>
where
K: Clone + PartialEq + AsRef<[u8]>,
TNow: Clone + Add<Duration, Output = TNow> + Ord,
{
pub fn set_state(&mut self, now: &TNow, state: PeerState) {
let bucket = &mut self.inner.buckets[usize::from(self.distance)];
let position = bucket
.entries
.iter()
.position(|(k, _)| *k == *self.key)
.unwrap();
match state {
PeerState::Connected if position >= bucket.num_connected_entries => {
debug_assert!(bucket.num_connected_entries < ENTRIES_PER_BUCKET);
let entry = bucket.entries.remove(position);
bucket.entries.insert(bucket.num_connected_entries, entry);
bucket.num_connected_entries += 1;
if position == bucket.entries.capacity() - 1 {
debug_assert!(bucket.pending_entry.is_some());
bucket.pending_entry = Some(now.clone() + self.inner.pending_timeout);
}
}
PeerState::Disconnected if position < bucket.num_connected_entries => {
let entry = bucket.entries.remove(position);
bucket.num_connected_entries -= 1;
bucket.entries.insert(bucket.num_connected_entries, entry);
if bucket.num_connected_entries == bucket.entries.capacity() - 1 {
debug_assert!(bucket.pending_entry.is_none());
bucket.pending_entry = Some(now.clone() + self.inner.pending_timeout);
}
}
_ => {}
}
}
pub fn get_mut(&mut self) -> &mut V {
self.inner.buckets[usize::from(self.distance)]
.get_mut(self.key)
.unwrap()
}
}
pub enum PeerState {
Connected,
Disconnected,
}
struct Bucket<K, V, TNow, const ENTRIES_PER_BUCKET: usize> {
entries: arrayvec::ArrayVec<(K, V), ENTRIES_PER_BUCKET>,
num_connected_entries: usize,
pending_entry: Option<TNow>,
}
impl<K, V, TNow, const ENTRIES_PER_BUCKET: usize> Bucket<K, V, TNow, ENTRIES_PER_BUCKET>
where
K: PartialEq,
{
fn get(&self, key: &K) -> Option<&V> {
if let Some((_, value)) = self.entries.iter().find(|e| e.0 == *key) {
return Some(value);
}
None
}
fn get_mut(&mut self, key: &K) -> Option<&mut V> {
if let Some((_, value)) = self.entries.iter_mut().find(|e| e.0 == *key) {
return Some(value);
}
None
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct Key {
digest: [u8; 32],
}
impl Key {
fn new(value: &[u8]) -> Self {
Self {
digest: Sha256::digest(value).into(),
}
}
#[cfg(test)] fn from_sha256_hash(hash: [u8; 32]) -> Self {
Self { digest: hash }
}
}
fn distance_log2(a: &Key, b: &Key) -> Option<u8> {
for n in 0..32 {
let a = a.digest[n];
let b = b.digest[n];
let xor_leading_zeroes = (a ^ b).leading_zeros();
if xor_leading_zeroes == 8 {
continue;
}
let xor_distance = u32::try_from((31 - n) * 8).unwrap() + (8 - xor_leading_zeroes);
debug_assert!(xor_distance > 0);
debug_assert!(xor_distance <= 256);
return Some(u8::try_from(xor_distance - 1).unwrap());
}
None
}
#[cfg(test)]
mod tests {
use core::time::Duration;
use sha2::{Digest as _, Sha256};
#[test]
fn basic_distance_1() {
let a = super::Key::from_sha256_hash([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0,
]);
let b = super::Key::from_sha256_hash([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1,
]);
assert_eq!(super::distance_log2(&a, &b), Some(0));
}
#[test]
fn basic_distance_2() {
let a = super::Key::from_sha256_hash([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0,
]);
let b = super::Key::from_sha256_hash([
0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0,
]);
assert_eq!(super::distance_log2(&a, &b), Some(255));
}
#[test]
fn basic_distance_3() {
let a = super::Key::from_sha256_hash([
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0,
]);
let b = super::Key::from_sha256_hash([
0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 128, 6, 5, 7, 94, 103, 94, 26, 20, 0, 0,
1, 37, 198, 200, 57, 33, 32,
]);
assert_eq!(super::distance_log2(&a, &b), Some(255));
}
#[test]
fn distance_of_zero() {
let a = super::Key::new(&[1, 2, 3, 4]);
let b = super::Key::new(&[1, 2, 3, 4]);
assert_eq!(super::distance_log2(&a, &b), None);
}
#[test]
fn nodes_kicked_out() {
let local_key = vec![0u8; 4];
let mut max_bucket_keys = {
let local_key_hash = Sha256::digest(&local_key);
(0..).map(move |_| {
loop {
let other_key: [u8; 32] = rand::random();
let other_key_hashed = Sha256::digest(other_key);
if ((local_key_hash[0] ^ other_key_hashed[0]) & 0x80) != 0 {
break other_key.to_vec();
}
}
})
};
let mut buckets = super::KBuckets::<_, _, _, 4>::new(local_key, Duration::from_secs(1));
for _ in 0..4 {
match buckets.entry(&max_bucket_keys.next().unwrap()) {
super::Entry::Vacant(e) => {
e.insert((), &Duration::new(0, 0), super::PeerState::Disconnected)
.unwrap();
}
_ => panic!(),
}
}
match buckets.entry(&max_bucket_keys.next().unwrap()) {
super::Entry::Vacant(e) => {
assert!(
e.insert((), &Duration::new(0, 0), super::PeerState::Disconnected)
.is_err()
);
}
_ => panic!(),
}
match buckets.entry(&max_bucket_keys.next().unwrap()) {
super::Entry::Vacant(e) => {
assert!(
e.insert((), &Duration::new(2, 0), super::PeerState::Disconnected)
.is_ok()
);
}
_ => panic!(),
}
}
}