ethportal_api/types/
distance.rs1use std::{fmt, ops::Deref};
2
3use alloy::primitives::U256;
4use serde::{Deserialize, Serialize};
5use ssz_derive::{Decode, Encode};
6
7#[derive(
9 Copy,
10 Clone,
11 PartialEq,
12 Eq,
13 Default,
14 PartialOrd,
15 Ord,
16 Debug,
17 Encode,
18 Decode,
19 Serialize,
20 Deserialize,
21)]
22#[serde(transparent)]
23#[ssz(struct_behaviour = "transparent")]
24pub struct Distance(U256);
25
26impl fmt::Display for Distance {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 write!(f, "{}", self.0)
29 }
30}
31
32impl Distance {
33 pub const MAX: Self = Self(U256::MAX);
35 pub const ZERO: Self = Self(U256::ZERO);
37
38 pub fn log2(&self) -> Option<usize> {
43 if self.0.is_zero() {
44 None
45 } else {
46 Some(256 - self.0.leading_zeros())
47 }
48 }
49
50 pub fn big_endian(&self) -> [u8; 32] {
52 self.0.to_be_bytes()
53 }
54
55 pub fn big_endian_u32(&self) -> u32 {
57 let mut be_bytes = [0u8; 4];
58 be_bytes.copy_from_slice(&self.big_endian()[..4]);
59 u32::from_be_bytes(be_bytes)
60 }
61}
62
63impl From<U256> for Distance {
64 fn from(value: U256) -> Self {
65 Distance(value)
66 }
67}
68
69impl Deref for Distance {
70 type Target = U256;
71
72 fn deref(&self) -> &Self::Target {
73 &self.0
74 }
75}
76
77pub trait Metric {
80 fn distance(x: &[u8; 32], y: &[u8; 32]) -> Distance;
82}
83
84#[derive(Debug)]
86pub struct XorMetric;
87
88impl Metric for XorMetric {
89 fn distance(x: &[u8; 32], y: &[u8; 32]) -> Distance {
90 let mut z: [u8; 32] = [0; 32];
91 for i in 0..32 {
92 z[i] = x[i] ^ y[i];
93 }
94 Distance(U256::from_be_slice(z.as_slice()))
95 }
96}
97
98#[cfg(test)]
99mod test {
100 use quickcheck::{quickcheck, Arbitrary, Gen, TestResult};
101 use test_log::test;
102
103 use super::*;
104
105 #[derive(Clone, Debug)]
110 struct DhtPoint([u8; 32]);
111
112 impl Arbitrary for DhtPoint {
114 fn arbitrary(g: &mut Gen) -> Self {
115 let mut value = [0; 32];
116 for byte in value.iter_mut() {
117 *byte = u8::arbitrary(g);
118 }
119 Self(value)
120 }
121 }
122
123 #[test]
124 fn distance_log2() {
125 fn prop(x: DhtPoint) -> TestResult {
126 let x = U256::from_be_slice(&x.0);
127 let distance = Distance(x);
128 let log2_distance = distance.log2();
129
130 match log2_distance {
131 Some(log2) => {
132 let x_floor = U256::from(1u8) << (log2 - 1);
133
134 if log2 == 256 {
135 TestResult::from_bool(x >= x_floor)
136 } else {
137 let x_ceil = U256::from(1u8) << log2;
138 TestResult::from_bool(x >= x_floor && x < x_ceil)
139 }
140 }
141 None => TestResult::from_bool(distance.0.is_zero()),
142 }
143 }
144 quickcheck(prop as fn(DhtPoint) -> TestResult);
145
146 let point = DhtPoint(U256::from(256).to_be_bytes());
148 assert!(!prop(point).is_failure());
149
150 let point = DhtPoint(U256::from(255).to_be_bytes());
152 assert!(!prop(point).is_failure());
153
154 let point = DhtPoint(U256::from(257).to_be_bytes());
156 assert!(!prop(point).is_failure());
157 }
158
159 #[test]
160 fn distance_big_endian() {
161 fn prop(x: DhtPoint) -> TestResult {
162 let x_be_u256 = U256::from_be_slice(&x.0);
163 let distance = Distance(x_be_u256);
164 let distance_be = distance.big_endian();
165 TestResult::from_bool(distance_be == x.0)
166 }
167 quickcheck(prop as fn(DhtPoint) -> TestResult);
168 }
169
170 #[test]
172 fn xor_identity() {
173 fn prop(x: DhtPoint) -> TestResult {
174 let distance = XorMetric::distance(&x.0, &x.0);
175 TestResult::from_bool(distance.is_zero())
176 }
177 quickcheck(prop as fn(DhtPoint) -> TestResult);
178 }
179
180 #[test]
182 fn xor_symmetry() {
183 fn prop(x: DhtPoint, y: DhtPoint) -> TestResult {
184 let distance_xy = XorMetric::distance(&x.0, &y.0);
185 let distance_yx = XorMetric::distance(&y.0, &x.0);
186 TestResult::from_bool(distance_xy == distance_yx)
187 }
188 quickcheck(prop as fn(DhtPoint, DhtPoint) -> TestResult)
189 }
190
191 #[test]
193 fn xor_triangle_inequality() {
194 fn prop(x: DhtPoint, y: DhtPoint, z: DhtPoint) -> TestResult {
195 let distance_xy = XorMetric::distance(&x.0, &y.0);
196 let distance_yz = XorMetric::distance(&y.0, &z.0);
197 let (xy_plus_yz, overflow) = distance_xy.overflowing_add(*distance_yz);
198 if overflow {
199 TestResult::discard()
200 } else {
201 let distance_xz = XorMetric::distance(&x.0, &z.0);
202 TestResult::from_bool(xy_plus_yz >= *distance_xz)
203 }
204 }
205 quickcheck(prop as fn(DhtPoint, DhtPoint, DhtPoint) -> TestResult)
206 }
207}