ethportal_api/types/
distance.rs

1use std::{fmt, ops::Deref};
2
3use alloy::primitives::U256;
4use serde::{Deserialize, Serialize};
5use ssz_derive::{Decode, Encode};
6
7/// Represents a distance between two keys in the DHT key space.
8#[derive(
9    Copy,
10    Clone,
11    PartialEq,
12    Eq,
13    Default,
14    PartialOrd,
15    Ord,
16    Debug,
17    Encode,
18    Decode,
19    Serialize,
20    Deserialize,
21)]
22#[serde(transparent)]
23#[ssz(struct_behaviour = "transparent")]
24pub struct Distance(U256);
25
26impl fmt::Display for Distance {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        write!(f, "{}", self.0)
29    }
30}
31
32impl Distance {
33    /// The maximum value.
34    pub const MAX: Self = Self(U256::MAX);
35    /// The minimum value.
36    pub const ZERO: Self = Self(U256::ZERO);
37
38    /// Returns the integer base-2 logarithm of `self`.
39    ///
40    /// Returns `None` is `self` is zero, because the logarithm of zero is undefined. Otherwise,
41    /// returns `Some(log2)` where `log2` is in the range [1, 256].
42    pub fn log2(&self) -> Option<usize> {
43        if self.0.is_zero() {
44            None
45        } else {
46            Some(256 - self.0.leading_zeros())
47        }
48    }
49
50    /// Returns the big-endian representation of `self`.
51    pub fn big_endian(&self) -> [u8; 32] {
52        self.0.to_be_bytes()
53    }
54
55    /// Returns the top 4 bytes representation of `self`.
56    pub fn big_endian_u32(&self) -> u32 {
57        let mut be_bytes = [0u8; 4];
58        be_bytes.copy_from_slice(&self.big_endian()[..4]);
59        u32::from_be_bytes(be_bytes)
60    }
61}
62
63impl From<U256> for Distance {
64    fn from(value: U256) -> Self {
65        Distance(value)
66    }
67}
68
69impl Deref for Distance {
70    type Target = U256;
71
72    fn deref(&self) -> &Self::Target {
73        &self.0
74    }
75}
76
77/// Types whose values represent a metric (distance function) that defines a notion of distance
78/// between two elements in the DHT key space.
79pub trait Metric {
80    /// Returns the distance between two elements in the DHT key space.
81    fn distance(x: &[u8; 32], y: &[u8; 32]) -> Distance;
82}
83
84/// The XOR metric defined in the Kademlia paper.
85#[derive(Debug)]
86pub struct XorMetric;
87
88impl Metric for XorMetric {
89    fn distance(x: &[u8; 32], y: &[u8; 32]) -> Distance {
90        let mut z: [u8; 32] = [0; 32];
91        for i in 0..32 {
92            z[i] = x[i] ^ y[i];
93        }
94        Distance(U256::from_be_slice(z.as_slice()))
95    }
96}
97
98#[cfg(test)]
99mod test {
100    use quickcheck::{quickcheck, Arbitrary, Gen, TestResult};
101    use test_log::test;
102
103    use super::*;
104
105    /// Wrapper type around a 256-bit identifier in the DHT key space.
106    ///
107    /// Wraps a `[u8; 32]` because quickcheck does not provide an implementation of Arbitrary for
108    /// that type.
109    #[derive(Clone, Debug)]
110    struct DhtPoint([u8; 32]);
111
112    // TODO: Eliminate loop from trait implementation.
113    impl Arbitrary for DhtPoint {
114        fn arbitrary(g: &mut Gen) -> Self {
115            let mut value = [0; 32];
116            for byte in value.iter_mut() {
117                *byte = u8::arbitrary(g);
118            }
119            Self(value)
120        }
121    }
122
123    #[test]
124    fn distance_log2() {
125        fn prop(x: DhtPoint) -> TestResult {
126            let x = U256::from_be_slice(&x.0);
127            let distance = Distance(x);
128            let log2_distance = distance.log2();
129
130            match log2_distance {
131                Some(log2) => {
132                    let x_floor = U256::from(1u8) << (log2 - 1);
133
134                    if log2 == 256 {
135                        TestResult::from_bool(x >= x_floor)
136                    } else {
137                        let x_ceil = U256::from(1u8) << log2;
138                        TestResult::from_bool(x >= x_floor && x < x_ceil)
139                    }
140                }
141                None => TestResult::from_bool(distance.0.is_zero()),
142            }
143        }
144        quickcheck(prop as fn(DhtPoint) -> TestResult);
145
146        // 256 (2^8).
147        let point = DhtPoint(U256::from(256).to_be_bytes());
148        assert!(!prop(point).is_failure());
149
150        // 255 (2^8 - 1).
151        let point = DhtPoint(U256::from(255).to_be_bytes());
152        assert!(!prop(point).is_failure());
153
154        // 257 (2^8 + 1).
155        let point = DhtPoint(U256::from(257).to_be_bytes());
156        assert!(!prop(point).is_failure());
157    }
158
159    #[test]
160    fn distance_big_endian() {
161        fn prop(x: DhtPoint) -> TestResult {
162            let x_be_u256 = U256::from_be_slice(&x.0);
163            let distance = Distance(x_be_u256);
164            let distance_be = distance.big_endian();
165            TestResult::from_bool(distance_be == x.0)
166        }
167        quickcheck(prop as fn(DhtPoint) -> TestResult);
168    }
169
170    // For all x, distance(x, x) = 0.
171    #[test]
172    fn xor_identity() {
173        fn prop(x: DhtPoint) -> TestResult {
174            let distance = XorMetric::distance(&x.0, &x.0);
175            TestResult::from_bool(distance.is_zero())
176        }
177        quickcheck(prop as fn(DhtPoint) -> TestResult);
178    }
179
180    // For all x, y, distance(x, y) = distance(y, x).
181    #[test]
182    fn xor_symmetry() {
183        fn prop(x: DhtPoint, y: DhtPoint) -> TestResult {
184            let distance_xy = XorMetric::distance(&x.0, &y.0);
185            let distance_yx = XorMetric::distance(&y.0, &x.0);
186            TestResult::from_bool(distance_xy == distance_yx)
187        }
188        quickcheck(prop as fn(DhtPoint, DhtPoint) -> TestResult)
189    }
190
191    // For all x, y, z, distance(x, y) + distance(y, z) >= distance(x, z).
192    #[test]
193    fn xor_triangle_inequality() {
194        fn prop(x: DhtPoint, y: DhtPoint, z: DhtPoint) -> TestResult {
195            let distance_xy = XorMetric::distance(&x.0, &y.0);
196            let distance_yz = XorMetric::distance(&y.0, &z.0);
197            let (xy_plus_yz, overflow) = distance_xy.overflowing_add(*distance_yz);
198            if overflow {
199                TestResult::discard()
200            } else {
201                let distance_xz = XorMetric::distance(&x.0, &z.0);
202                TestResult::from_bool(xy_plus_yz >= *distance_xz)
203            }
204        }
205        quickcheck(prop as fn(DhtPoint, DhtPoint, DhtPoint) -> TestResult)
206    }
207}