1use crc::{CRC_32_ISCSI, Crc};
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use std::convert::TryInto;
6use std::{
7 fmt::{self, Debug, Display, Formatter},
8 net::{IpAddr, Ipv4Addr, SocketAddr},
9 str::FromStr,
10};
11
12pub const ID_SIZE: usize = 20;
14pub const MAX_DISTANCE: u8 = ID_SIZE as u8 * 8;
15
16const IPV4_MASK: u32 = 0x030f3fff;
17const CASTAGNOLI: Crc<u32> = Crc::<u32>::new(&CRC_32_ISCSI);
18
19#[derive(Clone, Copy, PartialEq, Ord, PartialOrd, Eq, Hash, Serialize, Deserialize)]
20pub struct Id([u8; ID_SIZE]);
22
23impl Id {
24 pub fn random() -> Id {
26 let mut bytes: [u8; 20] = [0; 20];
27 rand::rng().fill_bytes(&mut bytes);
28
29 Id(bytes)
30 }
31
32 pub fn from_bytes<T: AsRef<[u8]>>(bytes: T) -> Result<Id, InvalidIdSize> {
34 let bytes = bytes.as_ref();
35 if bytes.len() != ID_SIZE {
36 return Err(InvalidIdSize(bytes.len()));
37 }
38
39 let mut tmp: [u8; ID_SIZE] = [0; ID_SIZE];
40 tmp[..ID_SIZE].clone_from_slice(&bytes[..ID_SIZE]);
41
42 Ok(Id(tmp))
43 }
44
45 pub fn distance(&self, other: &Id) -> u8 {
53 MAX_DISTANCE - self.xor(other).leading_zeros()
54 }
55
56 pub fn leading_zeros(&self) -> u8 {
58 for (i, byte) in self.0.iter().enumerate() {
59 if *byte != 0 {
60 return (i as u32 * 8 + byte.leading_zeros()) as u8;
62 }
63 }
64
65 160
66 }
67
68 pub fn xor(&self, other: &Id) -> Id {
70 let mut result = [0_u8; 20];
71
72 for (i, (a, b)) in self.0.iter().zip(other.0).enumerate() {
73 result[i] = a ^ b;
74 }
75
76 result.into()
77 }
78
79 pub fn as_bytes(&self) -> &[u8; 20] {
81 &self.0
82 }
83
84 pub fn from_addr(addr: &SocketAddr) -> Id {
86 let ip = addr.ip();
87
88 Id::from_ip(ip)
89 }
90
91 pub fn from_ip(ip: IpAddr) -> Id {
93 match ip {
94 IpAddr::V4(addr) => Id::from_ipv4(addr),
95 IpAddr::V6(_addr) => unimplemented!("Ipv6 is not supported"),
96 }
97 }
98
99 pub fn from_ipv4(ipv4: Ipv4Addr) -> Id {
101 let mut bytes = [0_u8; 21];
102 rand::rng().fill_bytes(&mut bytes);
103
104 from_ipv4_and_r(bytes[1..].try_into().expect("infallible"), ipv4, bytes[0])
105 }
106
107 pub fn is_valid_for_ip(&self, ipv4: Ipv4Addr) -> bool {
109 if ipv4.is_private() || ipv4.is_link_local() || ipv4.is_loopback() {
110 return true;
111 }
112
113 let expected = first_21_bits(&id_prefix_ipv4(ipv4, self.0[ID_SIZE - 1]));
114
115 self.first_21_bits() == expected
116 }
117
118 pub(crate) fn first_21_bits(&self) -> [u8; 3] {
119 first_21_bits(&self.0)
120 }
121}
122
123fn first_21_bits(bytes: &[u8]) -> [u8; 3] {
124 [bytes[0], bytes[1], bytes[2] & 0xf8]
125}
126
127fn from_ipv4_and_r(bytes: [u8; 20], ip: Ipv4Addr, r: u8) -> Id {
128 let mut bytes = bytes;
129 let prefix = id_prefix_ipv4(ip, r);
130
131 bytes[0] = prefix[0];
133 bytes[1] = prefix[1];
134 bytes[2] = (prefix[2] & 0xf8) | (bytes[2] & 0x7);
136
137 bytes[ID_SIZE - 1] = r;
139
140 Id(bytes)
141}
142
143fn id_prefix_ipv4(ip: Ipv4Addr, r: u8) -> [u8; 3] {
144 let r32: u32 = r.into();
145 let ip_int: u32 = u32::from_be_bytes(ip.octets());
146 let masked_ip: u32 = (ip_int & IPV4_MASK) | (r32 << 29);
147
148 let mut digest = CASTAGNOLI.digest();
149 digest.update(&masked_ip.to_be_bytes());
150
151 let crc = digest.finalize();
152
153 crc.to_be_bytes()[..3]
154 .try_into()
155 .expect("Failed to convert bytes 0-2 of the crc into a 3-byte array")
156}
157
158impl Display for Id {
159 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
160 #[allow(clippy::format_collect)]
161 let hex_chars: String = self.0.iter().map(|byte| format!("{byte:02x}")).collect();
162
163 write!(f, "{hex_chars}")
164 }
165}
166
167impl From<[u8; ID_SIZE]> for Id {
168 fn from(bytes: [u8; ID_SIZE]) -> Id {
169 Id(bytes)
170 }
171}
172
173impl From<&[u8; ID_SIZE]> for Id {
174 fn from(bytes: &[u8; ID_SIZE]) -> Id {
175 Id(*bytes)
176 }
177}
178
179impl From<Id> for [u8; ID_SIZE] {
180 fn from(value: Id) -> Self {
181 value.0
182 }
183}
184
185impl FromStr for Id {
186 type Err = DecodeIdError;
187
188 fn from_str(s: &str) -> Result<Id, DecodeIdError> {
189 if !s.len().is_multiple_of(2) {
190 return Err(DecodeIdError::OddNumberOfCharacters);
191 }
192
193 let mut bytes = Vec::with_capacity(s.len() / 2);
194
195 for i in 0..s.len() / 2 {
196 let byte_str = &s[i * 2..(i * 2) + 2];
197 if let Ok(byte) = u8::from_str_radix(byte_str, 16) {
198 bytes.push(byte);
199 } else {
200 return Err(DecodeIdError::InvalidHexCharacter(byte_str.into()));
201 }
202 }
203
204 Ok(Id::from_bytes(bytes)?)
205 }
206}
207
208impl Debug for Id {
209 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
210 write!(f, "Id({self})")
211 }
212}
213
214#[derive(Debug)]
215pub struct InvalidIdSize(usize);
216
217impl std::error::Error for InvalidIdSize {}
218
219impl std::fmt::Display for InvalidIdSize {
220 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
221 write!(f, "Invalid Id size, expected 20, got {0}", self.0)
222 }
223}
224
225#[n0_error::stack_error(derive, from_sources, std_sources)]
226pub enum DecodeIdError {
228 #[error(transparent)]
230 InvalidIdSize(InvalidIdSize),
231
232 #[error("Hex encoding should contain an even number of hex characters")]
233 OddNumberOfCharacters,
235
236 #[error("Invalid Id encoding")]
238 InvalidHexCharacter(String),
239}
240
241#[cfg(test)]
242mod test {
243 use super::*;
244
245 #[test]
246 fn distance_to_self() {
247 let id = Id::random();
248 let distance = id.distance(&id);
249 assert_eq!(distance, 0)
250 }
251
252 #[test]
253 fn distance_to_id() {
254 let id = Id::from_str("0639A1E24FBB8AB277DF033476AB0DE10FAB3BDC").unwrap();
255
256 let target = Id::from_str("035b1aeb9737ade1a80933594f405d3f772aa08e").unwrap();
257
258 let distance = id.distance(&target);
259
260 assert_eq!(distance, 155)
261 }
262
263 #[test]
264 fn distance_to_random_id() {
265 let id = Id::random();
266 let target = Id::random();
267
268 let distance = id.distance(&target);
269
270 assert_ne!(distance, 0)
271 }
272
273 #[test]
274 fn distance_to_furthest() {
275 let id = Id::random();
276
277 let mut opposite = [0_u8; 20];
278 for (i, &value) in id.as_bytes().iter().enumerate() {
279 opposite[i] = value ^ 0xff;
280 }
281 let target = Id::from_bytes(opposite).unwrap();
282
283 let distance = id.distance(&target);
284
285 assert_eq!(distance, MAX_DISTANCE)
286 }
287
288 #[test]
289 fn from_u8_20() {
290 let bytes = [8; 20];
291
292 let id: Id = bytes.into();
293
294 assert_eq!(*id.as_bytes(), bytes);
295 }
296
297 #[test]
298 fn from_ipv4() {
299 let vectors = vec![
300 (Ipv4Addr::new(124, 31, 75, 21), 1, [0x5f, 0xbf, 0xbf]),
301 (Ipv4Addr::new(21, 75, 31, 124), 86, [0x5a, 0x3c, 0xe9]),
302 (Ipv4Addr::new(65, 23, 51, 170), 22, [0xa5, 0xd4, 0x32]),
303 (Ipv4Addr::new(84, 124, 73, 14), 65, [0x1b, 0x03, 0x21]),
304 (Ipv4Addr::new(43, 213, 53, 83), 90, [0xe5, 0x6f, 0x6c]),
305 ];
306
307 for vector in vectors {
308 test(vector.0, vector.1, vector.2);
309 }
310
311 fn test(ip: Ipv4Addr, r: u8, expected_prefix: [u8; 3]) {
312 let id = Id::random();
313 let result = from_ipv4_and_r(*id.as_bytes(), ip, r);
314 let prefix = first_21_bits(result.as_bytes());
315
316 assert_eq!(prefix, first_21_bits(&expected_prefix));
317 assert_eq!(result.as_bytes()[ID_SIZE - 1], r);
318 }
319 }
320
321 #[test]
322 fn is_valid_for_ipv4() {
323 let valid_vectors = vec![
324 (
325 Ipv4Addr::new(124, 31, 75, 21),
326 "5fbfbff10c5d6a4ec8a88e4c6ab4c28b95eee401",
327 ),
328 (
329 Ipv4Addr::new(21, 75, 31, 124),
330 "5a3ce9c14e7a08645677bbd1cfe7d8f956d53256",
331 ),
332 (
333 Ipv4Addr::new(65, 23, 51, 170),
334 "a5d43220bc8f112a3d426c84764f8c2a1150e616",
335 ),
336 (
337 Ipv4Addr::new(84, 124, 73, 14),
338 "1b0321dd1bb1fe518101ceef99462b947a01ff41",
339 ),
340 (
341 Ipv4Addr::new(43, 213, 53, 83),
342 "e56f6cbf5b7c4be0237986d5243b87aa6d51305a",
343 ),
344 ];
345
346 for vector in valid_vectors {
347 test(vector.0, vector.1);
348 }
349
350 fn test(ip: Ipv4Addr, hex: &str) {
351 let id = Id::from_str(hex).unwrap();
352
353 assert!(id.is_valid_for_ip(ip));
354 }
355 }
356}