use std::{
collections::VecDeque,
fmt::Display,
num::NonZeroUsize,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use futures::{lock::Mutex, StreamExt};
use generic_array::GenericArray;
use rasi::task::spawn_ok;
use xstack::{events, identity::PeerId, EventSource, ProtocolStream, Switch};
use crate::Result;
mod uint {
use uint::construct_uint;
construct_uint! {
pub(super) struct U256(4);
}
}
#[derive(Default)]
struct KBucket<const K: usize>(VecDeque<PeerId>);
impl<const K: usize> KBucket<K> {
fn new(peer_id: PeerId) -> Self {
let mut q = VecDeque::new();
q.push_back(peer_id);
Self(q)
}
fn try_insert(&mut self, peer_id: PeerId) -> Option<PeerId> {
if let Some(index) =
self.0
.iter()
.enumerate()
.find_map(|(index, item)| if *item == peer_id { Some(index) } else { None })
{
self.0.remove(index);
self.0.push_back(peer_id);
return Some(peer_id);
}
if self.0.len() == K {
let lru = self.0.pop_front();
return lru;
}
self.0.push_back(peer_id);
return None;
}
fn len(&self) -> usize {
self.0.len()
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct KBucketDistance(uint::U256);
impl KBucketDistance {
pub fn k_index(&self) -> Option<u32> {
(256 - self.0.leading_zeros()).checked_sub(1)
}
}
impl Display for KBucketDistance {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "k_distance({:#066x})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct KBucketKey(uint::U256);
impl KBucketKey {
pub fn distance(&self, rhs: &KBucketKey) -> KBucketDistance {
KBucketDistance(self.0 ^ rhs.0)
}
}
impl From<Vec<u8>> for KBucketKey {
fn from(value: Vec<u8>) -> Self {
Self::from(value.as_slice())
}
}
impl From<&Vec<u8>> for KBucketKey {
fn from(value: &Vec<u8>) -> Self {
Self::from(value.as_slice())
}
}
impl From<&[u8]> for KBucketKey {
fn from(value: &[u8]) -> Self {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
hasher.update(value);
let buf: [u8; 32] = hasher.finalize().into();
Self(buf.into())
}
}
impl From<PeerId> for KBucketKey {
fn from(value: PeerId) -> Self {
Self::from(&value)
}
}
impl From<&PeerId> for KBucketKey {
fn from(value: &PeerId) -> Self {
Self::from(value.to_bytes().as_slice())
}
}
#[allow(unused)]
struct RawKBucketTable<const K: usize> {
local_key: KBucketKey,
k_buckets: Vec<KBucket<K>>,
k_index: GenericArray<Option<usize>, generic_array::typenum::U256>,
}
impl<const K: usize> RawKBucketTable<K> {
fn try_insert(&mut self, peer_id: PeerId) -> Option<PeerId> {
let k_index = KBucketKey::from(peer_id)
.distance(&self.local_key)
.k_index()
.expect("Insert local peer's id") as usize;
if let Some(index) = self.k_index[k_index] {
let r = self.k_buckets[index].try_insert(peer_id);
if r.is_none() {
log::trace!("peer_id={}, insert k-bucket({})", peer_id, k_index);
}
r
} else {
log::trace!("peer_id={}, insert k-bucket({})", peer_id, k_index,);
self.k_buckets.push(KBucket::new(peer_id));
self.k_index[k_index] = Some(self.k_buckets.len() - 1);
None
}
}
fn closest(&mut self, key: KBucketKey) -> Result<Vec<PeerId>> {
let k_index = key.distance(&self.local_key).k_index().unwrap_or(1) as usize;
let mut peers = vec![];
if let Some(bucket) = self.bucket(k_index) {
peers = bucket.0.iter().rev().cloned().collect();
}
let mut step = 1usize;
while peers.len() < K {
if k_index >= step {
if let Some(bucket) = self.bucket(k_index - step) {
if bucket.len() + peers.len() > K {
let offset = bucket.len() + peers.len() - K;
for peer_id in bucket.0.iter().collect::<Vec<_>>()[offset..].iter().rev() {
peers.push(**peer_id);
}
break;
}
for peer_id in bucket.0.iter().rev() {
peers.push(*peer_id);
}
}
} else if k_index + step >= self.k_index.len() {
break;
}
if k_index + step < self.k_index.len() {
if let Some(bucket) = self.bucket(k_index + step) {
if bucket.len() + peers.len() > K {
let offset = bucket.len() + peers.len() - K;
for peer_id in bucket.0.iter().collect::<Vec<_>>()[offset..].iter().rev() {
peers.push(**peer_id);
}
break;
}
for peer_id in bucket.0.iter().rev() {
peers.push(*peer_id);
}
}
} else if k_index < step {
break;
}
step += 1;
}
Ok(peers)
}
fn bucket(&self, index: usize) -> Option<&KBucket<K>> {
self.k_index[index].map(|index| &self.k_buckets[index])
}
}
#[derive(Clone)]
pub struct KBucketTable<const K: usize> {
#[allow(unused)]
switch: Switch,
len: Arc<AtomicUsize>,
table: Arc<Mutex<RawKBucketTable<K>>>,
}
impl<const K: usize> KBucketTable<K> {
async fn insert_prv(&self, peer_id: PeerId) -> Option<PeerId> {
match self.table.lock().await.try_insert(peer_id) {
Some(pop) => {
if pop != peer_id {
self.len.fetch_sub(1, Ordering::Relaxed);
}
Some(pop)
}
None => {
self.len.fetch_add(1, Ordering::Relaxed);
None
}
}
}
}
impl<const K: usize> KBucketTable<K> {
pub async fn bind(switch: &Switch) -> Self {
assert!(K > 0, "the k must greater than zero");
let mut event_connected = EventSource::<events::HandshakeSuccess>::bind_with(
&switch,
NonZeroUsize::new(100).unwrap(),
)
.await;
let table = Self {
len: Default::default(),
table: Arc::new(Mutex::new(RawKBucketTable {
local_key: switch.local_id().into(),
k_buckets: Default::default(),
k_index: Default::default(),
})),
switch: switch.clone(),
};
let table_cloned = table.clone();
spawn_ok(async move {
while let Some((_, peer_id)) = event_connected.next().await {
table_cloned.insert(peer_id).await;
}
});
table
}
pub fn k_const(&self) -> usize {
K
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Acquire)
}
pub async fn insert(&self, peer_id: PeerId) {
if let Some(lru) = self.insert_prv(peer_id.clone()).await {
let this = self.clone();
spawn_ok(async move {
if let Err(err) = ProtocolStream::ping_with(&this.switch, &lru).await {
log::trace!("ping lru node, {}", err);
this.insert_prv(lru.clone()).await;
} else {
this.insert_prv(peer_id).await;
}
});
}
}
pub async fn closest<Q>(&self, key: Q) -> Result<Vec<PeerId>>
where
Q: Into<KBucketKey>,
{
let key: KBucketKey = key.into();
self.table.lock().await.closest(key)
}
}
#[cfg(test)]
mod tests {
use std::sync::Once;
use super::{uint::U256, *};
use quickcheck::*;
use rasi_mio::{net::register_mio_network, timer::register_mio_timer};
use xstack_tcp::TcpTransport;
impl Arbitrary for KBucketKey {
fn arbitrary(_: &mut Gen) -> KBucketKey {
KBucketKey::from(PeerId::random())
}
}
#[test]
fn distance_symmetry() {
fn prop(a: KBucketKey, b: KBucketKey) -> bool {
a.distance(&b) == b.distance(&a)
}
quickcheck(prop as fn(_, _) -> _)
}
#[test]
fn test_generic_array() {
let mut array = GenericArray::<Option<usize>, generic_array::typenum::U256>::default();
assert_eq!(array[254], None);
array[254] = Some(1);
assert_eq!(array[254], Some(1));
}
#[test]
fn k_distance_0() {
assert_eq!(KBucketDistance(U256::from(0)).k_index(), None);
assert_eq!(KBucketDistance(U256::from(1)).k_index(), Some(0));
assert_eq!(KBucketDistance(U256::from(2)).k_index(), Some(1));
assert_eq!(KBucketDistance(U256::from(3)).k_index(), Some(1));
assert_eq!(KBucketDistance(U256::from(4)).k_index(), Some(2));
assert_eq!(KBucketDistance(U256::from(5)).k_index(), Some(2));
assert_eq!(KBucketDistance(U256::from(6)).k_index(), Some(2));
assert_eq!(KBucketDistance(U256::from(7)).k_index(), Some(2));
}
#[test]
fn distance_self() {
let key = KBucketKey::from(PeerId::random());
let distance = key.distance(&key);
assert_eq!(distance.0, U256::from(0));
}
async fn init() -> Switch {
static INIT: Once = Once::new();
INIT.call_once(|| {
register_mio_network();
register_mio_timer();
});
Switch::new("kad-test")
.transport(TcpTransport::default())
.create()
.await
.unwrap()
}
#[futures_test::test]
async fn test_table() {
let switch = init().await;
let k_bucket_table = KBucketTable::<20>::bind(&switch).await;
k_bucket_table.insert(PeerId::random()).await;
assert_eq!(k_bucket_table.len(), 1);
let closest = k_bucket_table.closest(PeerId::random()).await.unwrap();
assert_eq!(closest.len(), 1);
for _ in 1..20 {
k_bucket_table.insert(PeerId::random()).await;
}
assert_eq!(k_bucket_table.len(), 20);
let closest = k_bucket_table.closest(PeerId::random()).await.unwrap();
assert_eq!(closest.len(), 20);
loop {
k_bucket_table.insert(PeerId::random()).await;
if k_bucket_table.len() > 20 {
break;
}
}
let closest = k_bucket_table.closest(PeerId::random()).await.unwrap();
assert_eq!(closest.len(), 20);
}
}