1use crate::NodeAddr;
4use crate::protocol::error::ProtocolError;
5use crate::protocol::session::{decode_coords, encode_coords};
6use crate::tree::TreeCoordinate;
7use secp256k1::schnorr::Signature;
8
9#[derive(Clone, Debug)]
15pub struct LookupRequest {
16 pub request_id: u64,
18 pub target: NodeAddr,
20 pub origin: NodeAddr,
22 pub origin_coords: TreeCoordinate,
24 pub ttl: u8,
26 pub min_mtu: u16,
29}
30
31impl LookupRequest {
32 pub fn new(
34 request_id: u64,
35 target: NodeAddr,
36 origin: NodeAddr,
37 origin_coords: TreeCoordinate,
38 ttl: u8,
39 min_mtu: u16,
40 ) -> Self {
41 Self {
42 request_id,
43 target,
44 origin,
45 origin_coords,
46 ttl,
47 min_mtu,
48 }
49 }
50
51 pub fn generate(
53 target: NodeAddr,
54 origin: NodeAddr,
55 origin_coords: TreeCoordinate,
56 ttl: u8,
57 min_mtu: u16,
58 ) -> Self {
59 use rand::RngExt;
60 let request_id = rand::rng().random();
61 Self::new(request_id, target, origin, origin_coords, ttl, min_mtu)
62 }
63
64 pub fn forward(&mut self) -> bool {
68 if self.ttl == 0 {
69 return false;
70 }
71 self.ttl -= 1;
72 true
73 }
74
75 pub fn can_forward(&self) -> bool {
77 self.ttl > 0
78 }
79
80 pub fn encode(&self) -> Vec<u8> {
85 let mut buf = Vec::with_capacity(46 + self.origin_coords.depth() * 16);
86
87 buf.push(0x30); buf.extend_from_slice(&self.request_id.to_le_bytes());
89 buf.extend_from_slice(self.target.as_bytes());
90 buf.extend_from_slice(self.origin.as_bytes());
91 buf.push(self.ttl);
92 buf.extend_from_slice(&self.min_mtu.to_le_bytes());
93 encode_coords(&self.origin_coords, &mut buf);
94
95 buf
96 }
97
98 pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
100 if payload.len() < 45 {
103 return Err(ProtocolError::MessageTooShort {
104 expected: 45,
105 got: payload.len(),
106 });
107 }
108
109 let mut pos = 0;
110
111 let request_id = u64::from_le_bytes(
112 payload[pos..pos + 8]
113 .try_into()
114 .map_err(|_| ProtocolError::Malformed("bad request_id".into()))?,
115 );
116 pos += 8;
117
118 let mut target_bytes = [0u8; 16];
119 target_bytes.copy_from_slice(&payload[pos..pos + 16]);
120 let target = NodeAddr::from_bytes(target_bytes);
121 pos += 16;
122
123 let mut origin_bytes = [0u8; 16];
124 origin_bytes.copy_from_slice(&payload[pos..pos + 16]);
125 let origin = NodeAddr::from_bytes(origin_bytes);
126 pos += 16;
127
128 let ttl = payload[pos];
129 pos += 1;
130
131 let min_mtu = u16::from_le_bytes(
132 payload[pos..pos + 2]
133 .try_into()
134 .map_err(|_| ProtocolError::Malformed("bad min_mtu".into()))?,
135 );
136 pos += 2;
137
138 let (origin_coords, _consumed) = decode_coords(&payload[pos..])?;
139
140 Ok(Self {
141 request_id,
142 target,
143 origin,
144 origin_coords,
145 ttl,
146 min_mtu,
147 })
148 }
149}
150
151#[derive(Clone, Debug)]
155pub struct LookupResponse {
156 pub request_id: u64,
158 pub target: NodeAddr,
160 pub path_mtu: u16,
166 pub target_coords: TreeCoordinate,
168 pub proof: Signature,
170}
171
172impl LookupResponse {
173 pub fn new(
178 request_id: u64,
179 target: NodeAddr,
180 target_coords: TreeCoordinate,
181 proof: Signature,
182 ) -> Self {
183 Self {
184 request_id,
185 target,
186 path_mtu: u16::MAX,
187 target_coords,
188 proof,
189 }
190 }
191
192 pub fn proof_bytes(
196 request_id: u64,
197 target: &NodeAddr,
198 target_coords: &TreeCoordinate,
199 ) -> Vec<u8> {
200 let coord_size = 2 + target_coords.entries().len() * 16;
201 let mut bytes = Vec::with_capacity(24 + coord_size);
202 bytes.extend_from_slice(&request_id.to_le_bytes());
203 bytes.extend_from_slice(target.as_bytes());
204 encode_coords(target_coords, &mut bytes);
205 bytes
206 }
207
208 pub fn encode(&self) -> Vec<u8> {
212 let mut buf = Vec::with_capacity(93 + self.target_coords.depth() * 16);
213
214 buf.push(0x31); buf.extend_from_slice(&self.request_id.to_le_bytes());
216 buf.extend_from_slice(self.target.as_bytes());
217 buf.extend_from_slice(&self.path_mtu.to_le_bytes());
218 encode_coords(&self.target_coords, &mut buf);
219 buf.extend_from_slice(self.proof.as_ref());
220
221 buf
222 }
223
224 pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
226 if payload.len() < 92 {
228 return Err(ProtocolError::MessageTooShort {
229 expected: 92,
230 got: payload.len(),
231 });
232 }
233
234 let mut pos = 0;
235
236 let request_id = u64::from_le_bytes(
237 payload[pos..pos + 8]
238 .try_into()
239 .map_err(|_| ProtocolError::Malformed("bad request_id".into()))?,
240 );
241 pos += 8;
242
243 let mut target_bytes = [0u8; 16];
244 target_bytes.copy_from_slice(&payload[pos..pos + 16]);
245 let target = NodeAddr::from_bytes(target_bytes);
246 pos += 16;
247
248 let path_mtu = u16::from_le_bytes(
249 payload[pos..pos + 2]
250 .try_into()
251 .map_err(|_| ProtocolError::Malformed("bad path_mtu".into()))?,
252 );
253 pos += 2;
254
255 let (target_coords, consumed) = decode_coords(&payload[pos..])?;
256 pos += consumed;
257
258 if payload.len() < pos + 64 {
259 return Err(ProtocolError::MessageTooShort {
260 expected: pos + 64,
261 got: payload.len(),
262 });
263 }
264 let proof = Signature::from_slice(&payload[pos..pos + 64])
265 .map_err(|_| ProtocolError::Malformed("bad proof signature".into()))?;
266
267 Ok(Self {
268 request_id,
269 target,
270 path_mtu,
271 target_coords,
272 proof,
273 })
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 fn make_node_addr(val: u8) -> NodeAddr {
282 let mut bytes = [0u8; 16];
283 bytes[0] = val;
284 NodeAddr::from_bytes(bytes)
285 }
286
287 fn make_coords(ids: &[u8]) -> TreeCoordinate {
288 TreeCoordinate::from_addrs(ids.iter().map(|&v| make_node_addr(v)).collect()).unwrap()
289 }
290
291 #[test]
292 fn test_lookup_request_forward() {
293 let target = make_node_addr(1);
294 let origin = make_node_addr(2);
295 let coords = make_coords(&[2, 0]);
296
297 let mut request = LookupRequest::new(123, target, origin, coords, 5, 0);
298
299 assert!(request.can_forward());
300 assert!(request.forward());
301 assert_eq!(request.ttl, 4);
302 }
303
304 #[test]
305 fn test_lookup_request_ttl_exhausted() {
306 let target = make_node_addr(1);
307 let origin = make_node_addr(2);
308 let coords = make_coords(&[2, 0]);
309
310 let mut request = LookupRequest::new(123, target, origin, coords, 1, 0);
311
312 assert!(request.forward());
313 assert!(!request.can_forward());
314 assert!(!request.forward());
315 }
316
317 #[test]
318 fn test_lookup_request_generate() {
319 let target = make_node_addr(1);
320 let origin = make_node_addr(2);
321 let coords = make_coords(&[2, 0]);
322
323 let req1 = LookupRequest::generate(target, origin, coords.clone(), 5, 0);
324 let req2 = LookupRequest::generate(target, origin, coords, 5, 0);
325
326 assert_ne!(req1.request_id, req2.request_id);
328 }
329
330 #[test]
331 fn test_lookup_response_proof_bytes() {
332 let target = make_node_addr(42);
333 let coords = make_coords(&[42, 1, 0]);
334 let bytes = LookupResponse::proof_bytes(12345, &target, &coords);
335
336 assert_eq!(bytes.len(), 74);
338 assert_eq!(&bytes[0..8], &12345u64.to_le_bytes());
339 assert_eq!(&bytes[8..24], target.as_bytes());
340
341 let count = u16::from_le_bytes([bytes[24], bytes[25]]);
343 assert_eq!(count, 3); }
345
346 #[test]
347 fn test_lookup_request_encode_decode_roundtrip() {
348 let target = make_node_addr(10);
349 let origin = make_node_addr(20);
350 let coords = make_coords(&[20, 0]);
351
352 let mut request = LookupRequest::new(12345, target, origin, coords, 8, 1386);
353 request.forward();
354
355 let encoded = request.encode();
356 assert_eq!(encoded[0], 0x30);
357
358 let decoded = LookupRequest::decode(&encoded[1..]).unwrap();
359 assert_eq!(decoded.request_id, 12345);
360 assert_eq!(decoded.target, target);
361 assert_eq!(decoded.origin, origin);
362 assert_eq!(decoded.ttl, 7); assert_eq!(decoded.min_mtu, 1386);
364 }
365
366 #[test]
367 fn test_lookup_request_decode_too_short() {
368 assert!(LookupRequest::decode(&[]).is_err());
369 assert!(LookupRequest::decode(&[0u8; 42]).is_err());
370 }
371
372 #[test]
373 fn test_lookup_request_min_mtu_boundary_values() {
374 let target = make_node_addr(10);
375 let origin = make_node_addr(20);
376 let coords = make_coords(&[20, 0]);
377
378 for mtu_val in [0u16, 1386, u16::MAX] {
379 let request = LookupRequest::new(100, target, origin, coords.clone(), 5, mtu_val);
380 let encoded = request.encode();
381 let decoded = LookupRequest::decode(&encoded[1..]).unwrap();
382 assert_eq!(decoded.min_mtu, mtu_val);
383 }
384 }
385
386 #[test]
387 fn test_lookup_response_encode_decode_roundtrip() {
388 use secp256k1::Secp256k1;
389
390 let target = make_node_addr(42);
391 let coords = make_coords(&[42, 1, 0]);
392
393 let secp = Secp256k1::new();
395 let mut secret_bytes = [0u8; 32];
396 rand::Rng::fill_bytes(&mut rand::rng(), &mut secret_bytes);
397 let secret_key = secp256k1::SecretKey::from_slice(&secret_bytes)
398 .expect("32 random bytes is a valid secret key");
399 let keypair = secp256k1::Keypair::from_secret_key(&secp, &secret_key);
400 let proof_data = LookupResponse::proof_bytes(999, &target, &coords);
401 use sha2::Digest;
402 let digest: [u8; 32] = sha2::Sha256::digest(&proof_data).into();
403 let sig = secp.sign_schnorr(&digest, &keypair);
404
405 let response = LookupResponse::new(999, target, coords, sig);
406
407 assert_eq!(response.path_mtu, u16::MAX);
409
410 let encoded = response.encode();
411 assert_eq!(encoded[0], 0x31);
412
413 let decoded = LookupResponse::decode(&encoded[1..]).unwrap();
414 assert_eq!(decoded.request_id, 999);
415 assert_eq!(decoded.target, target);
416 assert_eq!(decoded.path_mtu, u16::MAX);
417 assert_eq!(decoded.proof, sig);
418 }
419
420 #[test]
421 fn test_lookup_response_path_mtu_roundtrip() {
422 use secp256k1::Secp256k1;
423
424 let target = make_node_addr(42);
425 let coords = make_coords(&[42, 1, 0]);
426
427 let secp = Secp256k1::new();
428 let mut secret_bytes = [0u8; 32];
429 rand::Rng::fill_bytes(&mut rand::rng(), &mut secret_bytes);
430 let secret_key = secp256k1::SecretKey::from_slice(&secret_bytes)
431 .expect("32 random bytes is a valid secret key");
432 let keypair = secp256k1::Keypair::from_secret_key(&secp, &secret_key);
433 let proof_data = LookupResponse::proof_bytes(999, &target, &coords);
434 use sha2::Digest;
435 let digest: [u8; 32] = sha2::Sha256::digest(&proof_data).into();
436 let sig = secp.sign_schnorr(&digest, &keypair);
437
438 for mtu_val in [0u16, 1280, 1386, 9000, u16::MAX] {
439 let mut response = LookupResponse::new(999, target, coords.clone(), sig);
440 response.path_mtu = mtu_val;
441
442 let encoded = response.encode();
443 let decoded = LookupResponse::decode(&encoded[1..]).unwrap();
444 assert_eq!(decoded.path_mtu, mtu_val);
445 }
446 }
447
448 #[test]
449 fn test_lookup_response_path_mtu_not_in_proof_bytes() {
450 let target = make_node_addr(42);
452 let coords = make_coords(&[42, 1, 0]);
453
454 let bytes = LookupResponse::proof_bytes(12345, &target, &coords);
455
456 assert_eq!(bytes.len(), 74);
459 }
460
461 #[test]
462 fn test_lookup_response_decode_too_short() {
463 assert!(LookupResponse::decode(&[]).is_err());
464 assert!(LookupResponse::decode(&[0u8; 50]).is_err());
465 }
466}