Skip to main content

fips_core/protocol/
discovery.rs

1//! Discovery messages: LookupRequest and LookupResponse.
2
3use crate::NodeAddr;
4use crate::protocol::error::ProtocolError;
5use crate::protocol::session::{decode_coords, encode_coords};
6use crate::tree::TreeCoordinate;
7use secp256k1::schnorr::Signature;
8
9/// Request to discover a node's coordinates.
10///
11/// Routed through the spanning tree via bloom-filter-guided forwarding.
12/// Each transit node forwards only to tree peers whose bloom filter
13/// contains the target. TTL limits propagation depth.
14#[derive(Clone, Debug)]
15pub struct LookupRequest {
16    /// Unique request identifier.
17    pub request_id: u64,
18    /// Node we're looking for.
19    pub target: NodeAddr,
20    /// Who's asking (for response routing).
21    pub origin: NodeAddr,
22    /// Origin's coordinates (for return path).
23    pub origin_coords: TreeCoordinate,
24    /// Remaining propagation hops.
25    pub ttl: u8,
26    /// Minimum transport MTU the origin requires for a viable route.
27    /// 0 means no requirement.
28    pub min_mtu: u16,
29}
30
31impl LookupRequest {
32    /// Create a new lookup request.
33    pub fn new(
34        request_id: u64,
35        target: NodeAddr,
36        origin: NodeAddr,
37        origin_coords: TreeCoordinate,
38        ttl: u8,
39        min_mtu: u16,
40    ) -> Self {
41        Self {
42            request_id,
43            target,
44            origin,
45            origin_coords,
46            ttl,
47            min_mtu,
48        }
49    }
50
51    /// Generate a new request with a random ID.
52    pub fn generate(
53        target: NodeAddr,
54        origin: NodeAddr,
55        origin_coords: TreeCoordinate,
56        ttl: u8,
57        min_mtu: u16,
58    ) -> Self {
59        use rand::RngExt;
60        let request_id = rand::rng().random();
61        Self::new(request_id, target, origin, origin_coords, ttl, min_mtu)
62    }
63
64    /// Decrement TTL for forwarding.
65    ///
66    /// Returns false if TTL was already 0.
67    pub fn forward(&mut self) -> bool {
68        if self.ttl == 0 {
69            return false;
70        }
71        self.ttl -= 1;
72        true
73    }
74
75    /// Check if this request can still be forwarded.
76    pub fn can_forward(&self) -> bool {
77        self.ttl > 0
78    }
79
80    /// Encode as wire format (includes msg_type byte).
81    ///
82    /// Format: `[0x30][request_id:8][target:16][origin:16][ttl:1][min_mtu:2]`
83    ///         `[origin_coords_cnt:2][origin_coords:16×n]`
84    pub fn encode(&self) -> Vec<u8> {
85        let mut buf = Vec::with_capacity(46 + self.origin_coords.depth() * 16);
86
87        buf.push(0x30); // msg_type
88        buf.extend_from_slice(&self.request_id.to_le_bytes());
89        buf.extend_from_slice(self.target.as_bytes());
90        buf.extend_from_slice(self.origin.as_bytes());
91        buf.push(self.ttl);
92        buf.extend_from_slice(&self.min_mtu.to_le_bytes());
93        encode_coords(&self.origin_coords, &mut buf);
94
95        buf
96    }
97
98    /// Decode from wire format (after msg_type byte has been consumed).
99    pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
100        // Minimum: request_id(8) + target(16) + origin(16) + ttl(1) + min_mtu(2)
101        //          + coords_count(2) = 45 bytes
102        if payload.len() < 45 {
103            return Err(ProtocolError::MessageTooShort {
104                expected: 45,
105                got: payload.len(),
106            });
107        }
108
109        let mut pos = 0;
110
111        let request_id = u64::from_le_bytes(
112            payload[pos..pos + 8]
113                .try_into()
114                .map_err(|_| ProtocolError::Malformed("bad request_id".into()))?,
115        );
116        pos += 8;
117
118        let mut target_bytes = [0u8; 16];
119        target_bytes.copy_from_slice(&payload[pos..pos + 16]);
120        let target = NodeAddr::from_bytes(target_bytes);
121        pos += 16;
122
123        let mut origin_bytes = [0u8; 16];
124        origin_bytes.copy_from_slice(&payload[pos..pos + 16]);
125        let origin = NodeAddr::from_bytes(origin_bytes);
126        pos += 16;
127
128        let ttl = payload[pos];
129        pos += 1;
130
131        let min_mtu = u16::from_le_bytes(
132            payload[pos..pos + 2]
133                .try_into()
134                .map_err(|_| ProtocolError::Malformed("bad min_mtu".into()))?,
135        );
136        pos += 2;
137
138        let (origin_coords, _consumed) = decode_coords(&payload[pos..])?;
139
140        Ok(Self {
141            request_id,
142            target,
143            origin,
144            origin_coords,
145            ttl,
146            min_mtu,
147        })
148    }
149}
150
151/// Response to a lookup request with target's coordinates.
152///
153/// Routed back to the origin using the origin_coords from the request.
154#[derive(Clone, Debug)]
155pub struct LookupResponse {
156    /// Echoed request identifier.
157    pub request_id: u64,
158    /// The target node.
159    pub target: NodeAddr,
160    /// Minimum transport MTU along the response path.
161    ///
162    /// Initialized to `u16::MAX` by the target. Each transit node applies
163    /// `path_mtu = path_mtu.min(outgoing_link_mtu)` when forwarding.
164    /// NOT included in the proof signature (transit annotation).
165    pub path_mtu: u16,
166    /// Target's coordinates in the tree.
167    pub target_coords: TreeCoordinate,
168    /// Proof that target authorized this response (signature over request).
169    pub proof: Signature,
170}
171
172impl LookupResponse {
173    /// Create a new lookup response.
174    ///
175    /// `path_mtu` is initialized to `u16::MAX` by the target; transit
176    /// nodes reduce it as they forward.
177    pub fn new(
178        request_id: u64,
179        target: NodeAddr,
180        target_coords: TreeCoordinate,
181        proof: Signature,
182    ) -> Self {
183        Self {
184            request_id,
185            target,
186            path_mtu: u16::MAX,
187            target_coords,
188            proof,
189        }
190    }
191
192    /// Get the bytes that should be signed as proof.
193    ///
194    /// Format: request_id (8) || target (16) || coords_encoding (2 + 16×n)
195    pub fn proof_bytes(
196        request_id: u64,
197        target: &NodeAddr,
198        target_coords: &TreeCoordinate,
199    ) -> Vec<u8> {
200        let coord_size = 2 + target_coords.entries().len() * 16;
201        let mut bytes = Vec::with_capacity(24 + coord_size);
202        bytes.extend_from_slice(&request_id.to_le_bytes());
203        bytes.extend_from_slice(target.as_bytes());
204        encode_coords(target_coords, &mut bytes);
205        bytes
206    }
207
208    /// Encode as wire format (includes msg_type byte).
209    ///
210    /// Format: `[0x31][request_id:8][target:16][path_mtu:2][target_coords_cnt:2][target_coords:16×n][proof:64]`
211    pub fn encode(&self) -> Vec<u8> {
212        let mut buf = Vec::with_capacity(93 + self.target_coords.depth() * 16);
213
214        buf.push(0x31); // msg_type
215        buf.extend_from_slice(&self.request_id.to_le_bytes());
216        buf.extend_from_slice(self.target.as_bytes());
217        buf.extend_from_slice(&self.path_mtu.to_le_bytes());
218        encode_coords(&self.target_coords, &mut buf);
219        buf.extend_from_slice(self.proof.as_ref());
220
221        buf
222    }
223
224    /// Decode from wire format (after msg_type byte has been consumed).
225    pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
226        // Minimum: request_id(8) + target(16) + path_mtu(2) + coords_count(2) + proof(64) = 92
227        if payload.len() < 92 {
228            return Err(ProtocolError::MessageTooShort {
229                expected: 92,
230                got: payload.len(),
231            });
232        }
233
234        let mut pos = 0;
235
236        let request_id = u64::from_le_bytes(
237            payload[pos..pos + 8]
238                .try_into()
239                .map_err(|_| ProtocolError::Malformed("bad request_id".into()))?,
240        );
241        pos += 8;
242
243        let mut target_bytes = [0u8; 16];
244        target_bytes.copy_from_slice(&payload[pos..pos + 16]);
245        let target = NodeAddr::from_bytes(target_bytes);
246        pos += 16;
247
248        let path_mtu = u16::from_le_bytes(
249            payload[pos..pos + 2]
250                .try_into()
251                .map_err(|_| ProtocolError::Malformed("bad path_mtu".into()))?,
252        );
253        pos += 2;
254
255        let (target_coords, consumed) = decode_coords(&payload[pos..])?;
256        pos += consumed;
257
258        if payload.len() < pos + 64 {
259            return Err(ProtocolError::MessageTooShort {
260                expected: pos + 64,
261                got: payload.len(),
262            });
263        }
264        let proof = Signature::from_slice(&payload[pos..pos + 64])
265            .map_err(|_| ProtocolError::Malformed("bad proof signature".into()))?;
266
267        Ok(Self {
268            request_id,
269            target,
270            path_mtu,
271            target_coords,
272            proof,
273        })
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    fn make_node_addr(val: u8) -> NodeAddr {
282        let mut bytes = [0u8; 16];
283        bytes[0] = val;
284        NodeAddr::from_bytes(bytes)
285    }
286
287    fn make_coords(ids: &[u8]) -> TreeCoordinate {
288        TreeCoordinate::from_addrs(ids.iter().map(|&v| make_node_addr(v)).collect()).unwrap()
289    }
290
291    #[test]
292    fn test_lookup_request_forward() {
293        let target = make_node_addr(1);
294        let origin = make_node_addr(2);
295        let coords = make_coords(&[2, 0]);
296
297        let mut request = LookupRequest::new(123, target, origin, coords, 5, 0);
298
299        assert!(request.can_forward());
300        assert!(request.forward());
301        assert_eq!(request.ttl, 4);
302    }
303
304    #[test]
305    fn test_lookup_request_ttl_exhausted() {
306        let target = make_node_addr(1);
307        let origin = make_node_addr(2);
308        let coords = make_coords(&[2, 0]);
309
310        let mut request = LookupRequest::new(123, target, origin, coords, 1, 0);
311
312        assert!(request.forward());
313        assert!(!request.can_forward());
314        assert!(!request.forward());
315    }
316
317    #[test]
318    fn test_lookup_request_generate() {
319        let target = make_node_addr(1);
320        let origin = make_node_addr(2);
321        let coords = make_coords(&[2, 0]);
322
323        let req1 = LookupRequest::generate(target, origin, coords.clone(), 5, 0);
324        let req2 = LookupRequest::generate(target, origin, coords, 5, 0);
325
326        // Random IDs should differ
327        assert_ne!(req1.request_id, req2.request_id);
328    }
329
330    #[test]
331    fn test_lookup_response_proof_bytes() {
332        let target = make_node_addr(42);
333        let coords = make_coords(&[42, 1, 0]);
334        let bytes = LookupResponse::proof_bytes(12345, &target, &coords);
335
336        // 8 (request_id) + 16 (target) + 2 (count) + 3*16 (coords) = 74
337        assert_eq!(bytes.len(), 74);
338        assert_eq!(&bytes[0..8], &12345u64.to_le_bytes());
339        assert_eq!(&bytes[8..24], target.as_bytes());
340
341        // Verify coordinate encoding is present
342        let count = u16::from_le_bytes([bytes[24], bytes[25]]);
343        assert_eq!(count, 3); // 3 entries in coords
344    }
345
346    #[test]
347    fn test_lookup_request_encode_decode_roundtrip() {
348        let target = make_node_addr(10);
349        let origin = make_node_addr(20);
350        let coords = make_coords(&[20, 0]);
351
352        let mut request = LookupRequest::new(12345, target, origin, coords, 8, 1386);
353        request.forward();
354
355        let encoded = request.encode();
356        assert_eq!(encoded[0], 0x30);
357
358        let decoded = LookupRequest::decode(&encoded[1..]).unwrap();
359        assert_eq!(decoded.request_id, 12345);
360        assert_eq!(decoded.target, target);
361        assert_eq!(decoded.origin, origin);
362        assert_eq!(decoded.ttl, 7); // decremented by forward()
363        assert_eq!(decoded.min_mtu, 1386);
364    }
365
366    #[test]
367    fn test_lookup_request_decode_too_short() {
368        assert!(LookupRequest::decode(&[]).is_err());
369        assert!(LookupRequest::decode(&[0u8; 42]).is_err());
370    }
371
372    #[test]
373    fn test_lookup_request_min_mtu_boundary_values() {
374        let target = make_node_addr(10);
375        let origin = make_node_addr(20);
376        let coords = make_coords(&[20, 0]);
377
378        for mtu_val in [0u16, 1386, u16::MAX] {
379            let request = LookupRequest::new(100, target, origin, coords.clone(), 5, mtu_val);
380            let encoded = request.encode();
381            let decoded = LookupRequest::decode(&encoded[1..]).unwrap();
382            assert_eq!(decoded.min_mtu, mtu_val);
383        }
384    }
385
386    #[test]
387    fn test_lookup_response_encode_decode_roundtrip() {
388        use secp256k1::Secp256k1;
389
390        let target = make_node_addr(42);
391        let coords = make_coords(&[42, 1, 0]);
392
393        // Create a dummy signature for testing
394        let secp = Secp256k1::new();
395        let mut secret_bytes = [0u8; 32];
396        rand::Rng::fill_bytes(&mut rand::rng(), &mut secret_bytes);
397        let secret_key = secp256k1::SecretKey::from_slice(&secret_bytes)
398            .expect("32 random bytes is a valid secret key");
399        let keypair = secp256k1::Keypair::from_secret_key(&secp, &secret_key);
400        let proof_data = LookupResponse::proof_bytes(999, &target, &coords);
401        use sha2::Digest;
402        let digest: [u8; 32] = sha2::Sha256::digest(&proof_data).into();
403        let sig = secp.sign_schnorr(&digest, &keypair);
404
405        let response = LookupResponse::new(999, target, coords, sig);
406
407        // Default path_mtu should be u16::MAX
408        assert_eq!(response.path_mtu, u16::MAX);
409
410        let encoded = response.encode();
411        assert_eq!(encoded[0], 0x31);
412
413        let decoded = LookupResponse::decode(&encoded[1..]).unwrap();
414        assert_eq!(decoded.request_id, 999);
415        assert_eq!(decoded.target, target);
416        assert_eq!(decoded.path_mtu, u16::MAX);
417        assert_eq!(decoded.proof, sig);
418    }
419
420    #[test]
421    fn test_lookup_response_path_mtu_roundtrip() {
422        use secp256k1::Secp256k1;
423
424        let target = make_node_addr(42);
425        let coords = make_coords(&[42, 1, 0]);
426
427        let secp = Secp256k1::new();
428        let mut secret_bytes = [0u8; 32];
429        rand::Rng::fill_bytes(&mut rand::rng(), &mut secret_bytes);
430        let secret_key = secp256k1::SecretKey::from_slice(&secret_bytes)
431            .expect("32 random bytes is a valid secret key");
432        let keypair = secp256k1::Keypair::from_secret_key(&secp, &secret_key);
433        let proof_data = LookupResponse::proof_bytes(999, &target, &coords);
434        use sha2::Digest;
435        let digest: [u8; 32] = sha2::Sha256::digest(&proof_data).into();
436        let sig = secp.sign_schnorr(&digest, &keypair);
437
438        for mtu_val in [0u16, 1280, 1386, 9000, u16::MAX] {
439            let mut response = LookupResponse::new(999, target, coords.clone(), sig);
440            response.path_mtu = mtu_val;
441
442            let encoded = response.encode();
443            let decoded = LookupResponse::decode(&encoded[1..]).unwrap();
444            assert_eq!(decoded.path_mtu, mtu_val);
445        }
446    }
447
448    #[test]
449    fn test_lookup_response_path_mtu_not_in_proof_bytes() {
450        // Verify that proof_bytes does NOT include path_mtu
451        let target = make_node_addr(42);
452        let coords = make_coords(&[42, 1, 0]);
453
454        let bytes = LookupResponse::proof_bytes(12345, &target, &coords);
455
456        // proof_bytes format: request_id(8) + target(16) + coords_encoding(2 + 3*16) = 74
457        // No path_mtu(2) in here
458        assert_eq!(bytes.len(), 74);
459    }
460
461    #[test]
462    fn test_lookup_response_decode_too_short() {
463        assert!(LookupResponse::decode(&[]).is_err());
464        assert!(LookupResponse::decode(&[0u8; 50]).is_err());
465    }
466}