1mod bucket;
70mod entry;
71#[allow(clippy::ptr_offset_with_cast)]
72#[allow(clippy::assign_op_pattern)]
73mod key;
74
75use std::{collections::VecDeque, num::NonZeroUsize, time::Duration};
76
77use bucket::KBucket;
78pub use bucket::NodeStatus;
79pub use entry::*;
80use web_time::Instant;
81
82const NUM_BUCKETS: usize = 256;
84
85#[derive(Debug, Clone, Copy)]
87pub(crate) struct KBucketConfig {
88 bucket_size: usize,
90 pending_timeout: Duration,
94}
95
96impl Default for KBucketConfig {
97 fn default() -> Self {
98 KBucketConfig {
99 bucket_size: K_VALUE.get(),
100 pending_timeout: Duration::from_secs(60),
101 }
102 }
103}
104
105impl KBucketConfig {
106 pub(crate) fn set_bucket_size(&mut self, bucket_size: NonZeroUsize) {
108 self.bucket_size = bucket_size.get();
109 }
110
111 pub(crate) fn set_pending_timeout(&mut self, pending_timeout: Duration) {
115 self.pending_timeout = pending_timeout;
116 }
117}
118
119#[derive(Debug, Clone)]
121pub(crate) struct KBucketsTable<TKey, TVal> {
122 local_key: TKey,
124 buckets: Vec<KBucket<TKey, TVal>>,
126 bucket_size: usize,
128 applied_pending: VecDeque<AppliedPending<TKey, TVal>>,
131}
132
133#[derive(Debug, Copy, Clone, PartialEq, Eq)]
136struct BucketIndex(usize);
137
138impl BucketIndex {
139 fn new(d: &Distance) -> Option<BucketIndex> {
147 d.ilog2().map(|i| BucketIndex(i as usize))
148 }
149
150 fn get(&self) -> usize {
152 self.0
153 }
154
155 fn range(&self) -> (Distance, Distance) {
158 let min = Distance(U256::pow(U256::from(2), U256::from(self.0)));
159 if self.0 == usize::from(u8::MAX) {
160 (min, Distance(U256::MAX))
161 } else {
162 let max = Distance(U256::pow(U256::from(2), U256::from(self.0 + 1)) - 1);
163 (min, max)
164 }
165 }
166
167 fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance {
169 let mut bytes = [0u8; 32];
170 let quot = self.0 / 8;
171 for i in 0..quot {
172 bytes[31 - i] = rng.gen();
173 }
174 let rem = (self.0 % 8) as u32;
175 let lower = usize::pow(2, rem);
176 let upper = usize::pow(2, rem + 1);
177 bytes[31 - quot] = rng.gen_range(lower..upper) as u8;
178 Distance(U256::from(bytes))
179 }
180}
181
182impl<TKey, TVal> KBucketsTable<TKey, TVal>
183where
184 TKey: Clone + AsRef<KeyBytes>,
185 TVal: Clone,
186{
187 pub(crate) fn new(local_key: TKey, config: KBucketConfig) -> Self {
190 KBucketsTable {
191 local_key,
192 buckets: (0..NUM_BUCKETS).map(|_| KBucket::new(config)).collect(),
193 bucket_size: config.bucket_size,
194 applied_pending: VecDeque::new(),
195 }
196 }
197
198 pub(crate) fn local_key(&self) -> &TKey {
200 &self.local_key
201 }
202
203 pub(crate) fn entry<'a>(&'a mut self, key: &'a TKey) -> Option<Entry<'a, TKey, TVal>> {
208 let index = BucketIndex::new(&self.local_key.as_ref().distance(key))?;
209
210 let bucket = &mut self.buckets[index.get()];
211 if let Some(applied) = bucket.apply_pending() {
212 self.applied_pending.push_back(applied)
213 }
214 Some(Entry::new(bucket, key))
215 }
216
217 pub(crate) fn iter(&mut self) -> impl Iterator<Item = KBucketRef<'_, TKey, TVal>> + '_ {
222 let applied_pending = &mut self.applied_pending;
223 self.buckets.iter_mut().enumerate().map(move |(i, b)| {
224 if let Some(applied) = b.apply_pending() {
225 applied_pending.push_back(applied)
226 }
227 KBucketRef {
228 index: BucketIndex(i),
229 bucket: b,
230 }
231 })
232 }
233
234 pub(crate) fn bucket<K>(&mut self, key: &K) -> Option<KBucketRef<'_, TKey, TVal>>
238 where
239 K: AsRef<KeyBytes>,
240 {
241 let d = self.local_key.as_ref().distance(key);
242 if let Some(index) = BucketIndex::new(&d) {
243 let bucket = &mut self.buckets[index.0];
244 if let Some(applied) = bucket.apply_pending() {
245 self.applied_pending.push_back(applied)
246 }
247 Some(KBucketRef { bucket, index })
248 } else {
249 None
250 }
251 }
252
253 pub(crate) fn take_applied_pending(&mut self) -> Option<AppliedPending<TKey, TVal>> {
266 self.applied_pending.pop_front()
267 }
268
269 pub(crate) fn closest_keys<'a, T>(
272 &'a mut self,
273 target: &'a T,
274 ) -> impl Iterator<Item = TKey> + 'a
275 where
276 T: AsRef<KeyBytes>,
277 {
278 let distance = self.local_key.as_ref().distance(target);
279 let bucket_size = self.bucket_size;
280 ClosestIter {
281 target,
282 iter: None,
283 table: self,
284 buckets_iter: ClosestBucketsIter::new(distance),
285 fmap: move |b: &KBucket<TKey, _>| -> Vec<_> {
286 let mut vec = Vec::with_capacity(bucket_size);
287 vec.extend(b.iter().map(|(n, _)| n.key.clone()));
288 vec
289 },
290 }
291 }
292
293 pub(crate) fn closest<'a, T>(
296 &'a mut self,
297 target: &'a T,
298 ) -> impl Iterator<Item = EntryView<TKey, TVal>> + 'a
299 where
300 T: Clone + AsRef<KeyBytes>,
301 TVal: Clone,
302 {
303 let distance = self.local_key.as_ref().distance(target);
304 let bucket_size = self.bucket_size;
305 ClosestIter {
306 target,
307 iter: None,
308 table: self,
309 buckets_iter: ClosestBucketsIter::new(distance),
310 fmap: move |b: &KBucket<_, TVal>| -> Vec<_> {
311 b.iter()
312 .take(bucket_size)
313 .map(|(n, status)| EntryView {
314 node: n.clone(),
315 status,
316 })
317 .collect()
318 },
319 }
320 }
321
322 pub(crate) fn count_nodes_between<T>(&mut self, target: &T) -> usize
328 where
329 T: AsRef<KeyBytes>,
330 {
331 let local_key = self.local_key.clone();
332 let distance = target.as_ref().distance(&local_key);
333 let mut iter = ClosestBucketsIter::new(distance).take_while(|i| i.get() != 0);
334 if let Some(i) = iter.next() {
335 let num_first = self.buckets[i.get()]
336 .iter()
337 .filter(|(n, _)| n.key.as_ref().distance(&local_key) <= distance)
338 .count();
339 let num_rest: usize = iter.map(|i| self.buckets[i.get()].num_entries()).sum();
340 num_first + num_rest
341 } else {
342 0
343 }
344 }
345}
346
347struct ClosestIter<'a, TTarget, TKey, TVal, TMap, TOut> {
350 target: &'a TTarget,
355 table: &'a mut KBucketsTable<TKey, TVal>,
357 buckets_iter: ClosestBucketsIter,
360 iter: Option<std::vec::IntoIter<TOut>>,
362 fmap: TMap,
365}
366
367struct ClosestBucketsIter {
371 distance: Distance,
373 state: ClosestBucketsIterState,
375}
376
377enum ClosestBucketsIterState {
379 Start(BucketIndex),
382 ZoomIn(BucketIndex),
388 ZoomOut(BucketIndex),
394 Done,
396}
397
398impl ClosestBucketsIter {
399 fn new(distance: Distance) -> Self {
400 let state = match BucketIndex::new(&distance) {
401 Some(i) => ClosestBucketsIterState::Start(i),
402 None => ClosestBucketsIterState::Start(BucketIndex(0)),
403 };
404 Self { distance, state }
405 }
406
407 fn next_in(&self, i: BucketIndex) -> Option<BucketIndex> {
408 (0..i.get()).rev().find_map(|i| {
409 if self.distance.0.bit(i) {
410 Some(BucketIndex(i))
411 } else {
412 None
413 }
414 })
415 }
416
417 fn next_out(&self, i: BucketIndex) -> Option<BucketIndex> {
418 (i.get() + 1..NUM_BUCKETS).find_map(|i| {
419 if !self.distance.0.bit(i) {
420 Some(BucketIndex(i))
421 } else {
422 None
423 }
424 })
425 }
426}
427
428impl Iterator for ClosestBucketsIter {
429 type Item = BucketIndex;
430
431 fn next(&mut self) -> Option<Self::Item> {
432 match self.state {
433 ClosestBucketsIterState::Start(i) => {
434 self.state = ClosestBucketsIterState::ZoomIn(i);
435 Some(i)
436 }
437 ClosestBucketsIterState::ZoomIn(i) => {
438 if let Some(i) = self.next_in(i) {
439 self.state = ClosestBucketsIterState::ZoomIn(i);
440 Some(i)
441 } else {
442 let i = BucketIndex(0);
443 self.state = ClosestBucketsIterState::ZoomOut(i);
444 Some(i)
445 }
446 }
447 ClosestBucketsIterState::ZoomOut(i) => {
448 if let Some(i) = self.next_out(i) {
449 self.state = ClosestBucketsIterState::ZoomOut(i);
450 Some(i)
451 } else {
452 self.state = ClosestBucketsIterState::Done;
453 None
454 }
455 }
456 ClosestBucketsIterState::Done => None,
457 }
458 }
459}
460
461impl<TTarget, TKey, TVal, TMap, TOut> Iterator for ClosestIter<'_, TTarget, TKey, TVal, TMap, TOut>
462where
463 TTarget: AsRef<KeyBytes>,
464 TKey: Clone + AsRef<KeyBytes>,
465 TVal: Clone,
466 TMap: Fn(&KBucket<TKey, TVal>) -> Vec<TOut>,
467 TOut: AsRef<KeyBytes>,
468{
469 type Item = TOut;
470
471 fn next(&mut self) -> Option<Self::Item> {
472 loop {
473 match &mut self.iter {
474 Some(iter) => match iter.next() {
475 Some(k) => return Some(k),
476 None => self.iter = None,
477 },
478 None => {
479 if let Some(i) = self.buckets_iter.next() {
480 let bucket = &mut self.table.buckets[i.get()];
481 if let Some(applied) = bucket.apply_pending() {
482 self.table.applied_pending.push_back(applied)
483 }
484 let mut v = (self.fmap)(bucket);
485 v.sort_by(|a, b| {
486 self.target
487 .as_ref()
488 .distance(a.as_ref())
489 .cmp(&self.target.as_ref().distance(b.as_ref()))
490 });
491 self.iter = Some(v.into_iter());
492 } else {
493 return None;
494 }
495 }
496 }
497 }
498 }
499}
500
501pub struct KBucketRef<'a, TKey, TVal> {
503 index: BucketIndex,
504 bucket: &'a mut KBucket<TKey, TVal>,
505}
506
507impl<'a, TKey, TVal> KBucketRef<'a, TKey, TVal>
508where
509 TKey: Clone + AsRef<KeyBytes>,
510 TVal: Clone,
511{
512 pub fn range(&self) -> (Distance, Distance) {
515 self.index.range()
516 }
517
518 pub fn is_empty(&self) -> bool {
520 self.num_entries() == 0
521 }
522
523 pub fn num_entries(&self) -> usize {
525 self.bucket.num_entries()
526 }
527
528 pub fn has_pending(&self) -> bool {
530 self.bucket.pending().is_some_and(|n| !n.is_ready())
531 }
532
533 pub fn contains(&self, d: &Distance) -> bool {
535 BucketIndex::new(d).is_some_and(|i| i == self.index)
536 }
537
538 pub fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance {
545 self.index.rand_distance(rng)
546 }
547
548 pub fn iter(&'a self) -> impl Iterator<Item = EntryRefView<'a, TKey, TVal>> {
550 self.bucket.iter().map(move |(n, status)| EntryRefView {
551 node: NodeRefView {
552 key: &n.key,
553 value: &n.value,
554 },
555 status,
556 })
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use libp2p_identity::PeerId;
563 use quickcheck::*;
564
565 use super::*;
566
567 type TestTable = KBucketsTable<KeyBytes, ()>;
568
569 impl Arbitrary for TestTable {
570 fn arbitrary(g: &mut Gen) -> TestTable {
571 let local_key = Key::from(PeerId::random());
572 let timeout = Duration::from_secs(g.gen_range(1..360));
573 let mut config = KBucketConfig::default();
574 config.set_pending_timeout(timeout);
575 let bucket_size = config.bucket_size;
576 let mut table = TestTable::new(local_key.into(), config);
577 let mut num_total = g.gen_range(0..100);
578 for (i, b) in &mut table.buckets.iter_mut().enumerate().rev() {
579 let ix = BucketIndex(i);
580 let num = g.gen_range(0..usize::min(bucket_size, num_total) + 1);
581 num_total -= num;
582 for _ in 0..num {
583 let distance = ix.rand_distance(&mut rand::thread_rng());
584 let key = local_key.for_distance(distance);
585 let node = Node { key, value: () };
586 let status = NodeStatus::arbitrary(g);
587 match b.insert(node, status) {
588 InsertResult::Inserted => {}
589 _ => panic!(),
590 }
591 }
592 }
593 table
594 }
595 }
596
597 #[test]
598 fn buckets_are_non_overlapping_and_exhaustive() {
599 let local_key = Key::from(PeerId::random());
600 let timeout = Duration::from_secs(0);
601 let mut config = KBucketConfig::default();
602 config.set_pending_timeout(timeout);
603 let mut table = KBucketsTable::<KeyBytes, ()>::new(local_key.into(), config);
604
605 let mut prev_max = U256::from(0);
606
607 for bucket in table.iter() {
608 let (min, max) = bucket.range();
609 assert_eq!(Distance(prev_max + U256::from(1)), min);
610 prev_max = max.0;
611 }
612
613 assert_eq!(U256::MAX, prev_max);
614 }
615
616 #[test]
617 fn bucket_contains_range() {
618 fn prop(ix: u8) {
619 let index = BucketIndex(ix as usize);
620 let mut config = KBucketConfig::default();
621 config.set_pending_timeout(Duration::from_secs(0));
622 let mut bucket = KBucket::<Key<PeerId>, ()>::new(config);
623 let bucket_ref = KBucketRef {
624 index,
625 bucket: &mut bucket,
626 };
627
628 let (min, max) = bucket_ref.range();
629
630 assert!(min <= max);
631
632 assert!(bucket_ref.contains(&min));
633 assert!(bucket_ref.contains(&max));
634
635 if min != Distance(0.into()) {
636 assert!(!bucket_ref.contains(&Distance(min.0 - 1)));
638 }
639
640 if max != Distance(U256::MAX) {
641 assert!(!bucket_ref.contains(&Distance(max.0 + 1)));
643 }
644 }
645
646 quickcheck(prop as fn(_));
647 }
648
649 #[test]
650 fn rand_distance() {
651 fn prop(ix: u8) -> bool {
652 let d = BucketIndex(ix as usize).rand_distance(&mut rand::thread_rng());
653 let n = U256::from(<[u8; 32]>::from(d.0));
654 let b = U256::from(2);
655 let e = U256::from(ix);
656 let lower = b.pow(e);
657 let upper = b.checked_pow(e + U256::from(1)).unwrap_or(U256::MAX) - U256::from(1);
658 lower <= n && n <= upper
659 }
660 quickcheck(prop as fn(_) -> _);
661 }
662
663 #[test]
664 fn entry_inserted() {
665 let local_key = Key::from(PeerId::random());
666 let other_id = Key::from(PeerId::random());
667
668 let mut table = KBucketsTable::<_, ()>::new(local_key, KBucketConfig::default());
669 if let Some(Entry::Absent(entry)) = table.entry(&other_id) {
670 match entry.insert((), NodeStatus::Connected) {
671 InsertResult::Inserted => (),
672 _ => panic!(),
673 }
674 } else {
675 panic!()
676 }
677
678 let res = table.closest_keys(&other_id).collect::<Vec<_>>();
679 assert_eq!(res.len(), 1);
680 assert_eq!(res[0], other_id);
681 }
682
683 #[test]
684 fn entry_self() {
685 let local_key = Key::from(PeerId::random());
686 let mut table = KBucketsTable::<_, ()>::new(local_key, KBucketConfig::default());
687
688 assert!(table.entry(&local_key).is_none())
689 }
690
691 #[test]
692 fn closest() {
693 let local_key = Key::from(PeerId::random());
694 let mut table = KBucketsTable::<_, ()>::new(local_key, KBucketConfig::default());
695 let mut count = 0;
696 loop {
697 if count == 100 {
698 break;
699 }
700 let key = Key::from(PeerId::random());
701 if let Some(Entry::Absent(e)) = table.entry(&key) {
702 match e.insert((), NodeStatus::Connected) {
703 InsertResult::Inserted => count += 1,
704 _ => continue,
705 }
706 } else {
707 panic!("entry exists")
708 }
709 }
710
711 let mut expected_keys: Vec<_> = table
712 .buckets
713 .iter()
714 .flat_map(|t| t.iter().map(|(n, _)| n.key))
715 .collect();
716
717 for _ in 0..10 {
718 let target_key = Key::from(PeerId::random());
719 let keys = table.closest_keys(&target_key).collect::<Vec<_>>();
720 expected_keys.sort_by_key(|k| k.distance(&target_key));
722 assert_eq!(keys, expected_keys);
723 }
724 }
725
726 #[test]
727 fn applied_pending() {
728 let local_key = Key::from(PeerId::random());
729 let mut config = KBucketConfig::default();
730 config.set_pending_timeout(Duration::from_millis(1));
731 let mut table = KBucketsTable::<_, ()>::new(local_key, config);
732 let expected_applied;
733 let full_bucket_index;
734 loop {
735 let key = Key::from(PeerId::random());
736 if let Some(Entry::Absent(e)) = table.entry(&key) {
737 match e.insert((), NodeStatus::Disconnected) {
738 InsertResult::Full => {
739 if let Some(Entry::Absent(e)) = table.entry(&key) {
740 match e.insert((), NodeStatus::Connected) {
741 InsertResult::Pending { disconnected } => {
742 expected_applied = AppliedPending {
743 inserted: Node { key, value: () },
744 evicted: Some(Node {
745 key: disconnected,
746 value: (),
747 }),
748 };
749 full_bucket_index = BucketIndex::new(&key.distance(&local_key));
750 break;
751 }
752 _ => panic!(),
753 }
754 } else {
755 panic!()
756 }
757 }
758 _ => continue,
759 }
760 } else {
761 panic!("entry exists")
762 }
763 }
764
765 let full_bucket = &mut table.buckets[full_bucket_index.unwrap().get()];
767 let elapsed = Instant::now().checked_sub(Duration::from_secs(1)).unwrap();
768 full_bucket.pending_mut().unwrap().set_ready_at(elapsed);
769
770 match table.entry(&expected_applied.inserted.key) {
771 Some(Entry::Present(_, NodeStatus::Connected)) => {}
772 x => panic!("Unexpected entry: {x:?}"),
773 }
774
775 match table.entry(&expected_applied.evicted.as_ref().unwrap().key) {
776 Some(Entry::Absent(_)) => {}
777 x => panic!("Unexpected entry: {x:?}"),
778 }
779
780 assert_eq!(Some(expected_applied), table.take_applied_pending());
781 assert_eq!(None, table.take_applied_pending());
782 }
783
784 #[test]
785 fn count_nodes_between() {
786 fn prop(mut table: TestTable, target: Key<PeerId>) -> bool {
787 let num_to_target = table.count_nodes_between(&target);
788 let distance = table.local_key.distance(&target);
789 let base2 = U256::from(2);
790 let mut iter = ClosestBucketsIter::new(distance);
791 iter.all(|i| {
792 let d = Distance(distance.0 ^ (base2.pow(U256::from(i.get()))));
794 let k = table.local_key.for_distance(d);
795 if distance.0.bit(i.get()) {
796 d < distance && table.count_nodes_between(&k) <= num_to_target
798 } else {
799 d > distance && table.count_nodes_between(&k) >= num_to_target
801 }
802 })
803 }
804
805 QuickCheck::new()
806 .tests(10)
807 .quickcheck(prop as fn(_, _) -> _)
808 }
809}