Skip to main content

rns_core/holepunch/
engine.rs

1//! HolePunchEngine: pure-logic state machine for NAT hole punching.
2//!
3//! Follows the asymmetric protocol from direct-link-protocol.md:
4//!
5//! Phase 1: A probes facilitator T → learns A_pub
6//! Phase 2: A sends UPGRADE_REQUEST {facilitator: T, initiator_public: A_pub} → B
7//!          B responds with UPGRADE_ACCEPT or UPGRADE_REJECT
8//! Phase 3: B probes T (from request) → learns B_pub
9//!          B sends UPGRADE_READY {responder_public: B_pub} → A
10//! Phase 4: Both punch simultaneously
11//! Phase 5: Direct link established
12//!
13//! Methods return `Vec<HolePunchAction>` instead of performing I/O.
14
15use alloc::vec::Vec;
16
17use rns_crypto::hkdf::hkdf;
18use rns_crypto::Rng;
19
20use crate::msgpack::{self, Value};
21
22use super::types::*;
23
24/// Derives a 32-byte punch token from the link's derived key and session ID.
25///
26/// `HKDF(ikm=derived_key, salt=session_id, info="rns-holepunch-v1")[:32]`
27pub fn derive_punch_token(
28    derived_key: &[u8],
29    session_id: &[u8; 16],
30) -> Result<[u8; 32], HolePunchError> {
31    let result = hkdf(32, derived_key, Some(session_id), Some(b"rns-holepunch-v1"))
32        .map_err(|_| HolePunchError::NoDerivedKey)?;
33    let mut token = [0u8; 32];
34    token.copy_from_slice(&result);
35    Ok(token)
36}
37
38/// The hole-punch state machine for a single link.
39pub struct HolePunchEngine {
40    link_id: [u8; 16],
41    session_id: [u8; 16],
42    state: HolePunchState,
43    is_initiator: bool,
44
45    /// Our discovered public endpoint.
46    our_public_endpoint: Option<Endpoint>,
47
48    /// Peer's public endpoint.
49    /// Initiator: received from UPGRADE_READY.
50    /// Responder: received from UPGRADE_REQUEST.
51    peer_public_endpoint: Option<Endpoint>,
52
53    /// Facilitator (STUN probe) address.
54    /// Initiator: our configured probe address.
55    /// Responder: received from UPGRADE_REQUEST.
56    facilitator_addr: Option<Endpoint>,
57
58    /// Punch token derived from link key + session_id.
59    punch_token: [u8; 32],
60
61    /// Probe service address for endpoint discovery (configured on this node).
62    probe_addr: Option<Endpoint>,
63
64    /// Timestamp of state entry (for timeout tracking).
65    state_entered_at: f64,
66}
67
68impl HolePunchEngine {
69    /// Create a new engine for the given link. Does not start any session.
70    pub fn new(
71        link_id: [u8; 16],
72        probe_addr: Option<Endpoint>,
73    ) -> Self {
74        HolePunchEngine {
75            link_id,
76            session_id: [0u8; 16],
77            state: HolePunchState::Idle,
78            is_initiator: false,
79            our_public_endpoint: None,
80            peer_public_endpoint: None,
81            facilitator_addr: None,
82            punch_token: [0u8; 32],
83            probe_addr,
84            state_entered_at: 0.0,
85        }
86    }
87
88    pub fn state(&self) -> HolePunchState {
89        self.state
90    }
91
92    pub fn session_id(&self) -> &[u8; 16] {
93        &self.session_id
94    }
95
96    pub fn is_initiator(&self) -> bool {
97        self.is_initiator
98    }
99
100    pub fn punch_token(&self) -> &[u8; 32] {
101        &self.punch_token
102    }
103
104    /// Peer's discovered public endpoint.
105    pub fn peer_public_endpoint(&self) -> Option<&Endpoint> {
106        self.peer_public_endpoint.as_ref()
107    }
108
109    /// Propose a direct connection (initiator side).
110    ///
111    /// Per the spec, the initiator first discovers its own public endpoint
112    /// (Phase 1) before sending the upgrade request.
113    ///
114    /// Transitions: Idle -> Discovering.
115    pub fn propose(
116        &mut self,
117        derived_key: &[u8],
118        now: f64,
119        rng: &mut dyn Rng,
120    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
121        if self.state != HolePunchState::Idle {
122            return Err(HolePunchError::InvalidState);
123        }
124
125        // Generate session ID
126        let mut session_id = [0u8; 16];
127        rng.fill_bytes(&mut session_id);
128        self.session_id = session_id;
129        self.is_initiator = true;
130
131        // Derive punch token
132        self.punch_token = derive_punch_token(derived_key, &session_id)?;
133
134        let probe_addr = self.probe_addr.clone().ok_or(HolePunchError::NoProbeAddr)?;
135        self.facilitator_addr = Some(probe_addr.clone());
136
137        // Phase 1: discover our public endpoint first
138        self.state = HolePunchState::Discovering;
139        self.state_entered_at = now;
140
141        Ok(alloc::vec![HolePunchAction::DiscoverEndpoints { probe_addr }])
142    }
143
144    /// Called when endpoint discovery completes.
145    ///
146    /// For initiator: Discovering -> Proposing (sends UPGRADE_REQUEST with facilitator + our addr).
147    /// For responder: Discovering -> Punching (sends UPGRADE_READY with our addr, starts punch).
148    pub fn endpoints_discovered(
149        &mut self,
150        public_endpoint: Endpoint,
151        now: f64,
152    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
153        if self.state != HolePunchState::Discovering {
154            return Err(HolePunchError::InvalidState);
155        }
156
157        self.our_public_endpoint = Some(public_endpoint.clone());
158
159        if self.is_initiator {
160            // Initiator: Phase 1 complete -> Phase 2: send UPGRADE_REQUEST
161            let facilitator = self.facilitator_addr.clone()
162                .ok_or(HolePunchError::NoProbeAddr)?;
163
164            let payload = encode_upgrade_request(
165                &self.session_id,
166                &facilitator,
167                &public_endpoint,
168            );
169
170            self.state = HolePunchState::Proposing;
171            self.state_entered_at = now;
172
173            Ok(alloc::vec![HolePunchAction::SendSignal {
174                link_id: self.link_id,
175                msgtype: UPGRADE_REQUEST,
176                payload,
177            }])
178        } else {
179            // Responder: Phase 3 complete -> send UPGRADE_READY, start punching
180            let payload = encode_upgrade_ready(&self.session_id, &public_endpoint);
181
182            let peer_public = self.peer_public_endpoint.clone()
183                .ok_or(HolePunchError::InvalidState)?;
184
185            self.state = HolePunchState::Punching;
186            self.state_entered_at = now;
187
188            Ok(alloc::vec![
189                HolePunchAction::SendSignal {
190                    link_id: self.link_id,
191                    msgtype: UPGRADE_READY,
192                    payload,
193                },
194                HolePunchAction::StartUdpPunch {
195                    peer_public,
196                    punch_token: self.punch_token,
197                    session_id: self.session_id,
198                },
199            ])
200        }
201    }
202
203    /// Handle an incoming signaling message.
204    ///
205    /// `derived_key` is needed when handling UPGRADE_REQUEST (responder side).
206    pub fn handle_signal(
207        &mut self,
208        msgtype: u16,
209        payload: &[u8],
210        derived_key: Option<&[u8]>,
211        now: f64,
212    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
213        match msgtype {
214            UPGRADE_REQUEST => self.handle_upgrade_request(payload, derived_key, now),
215            UPGRADE_ACCEPT => self.handle_upgrade_accept(payload, now),
216            UPGRADE_REJECT => self.handle_upgrade_reject(payload, now),
217            UPGRADE_READY => self.handle_upgrade_ready(payload, now),
218            UPGRADE_COMPLETE => self.handle_upgrade_complete(payload, now),
219            _ => Err(HolePunchError::InvalidPayload),
220        }
221    }
222
223    /// Called when the punch phase succeeds.
224    ///
225    /// Transitions: Punching -> Connected.
226    pub fn punch_succeeded(
227        &mut self,
228        now: f64,
229    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
230        if self.state != HolePunchState::Punching {
231            return Err(HolePunchError::InvalidState);
232        }
233
234        self.state = HolePunchState::Connected;
235        self.state_entered_at = now;
236
237        Ok(alloc::vec![
238            HolePunchAction::Succeeded {
239                session_id: self.session_id,
240            },
241        ])
242    }
243
244    /// Called when the punch phase fails.
245    ///
246    /// Transitions: Punching -> Failed.
247    pub fn punch_failed(
248        &mut self,
249        now: f64,
250    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
251        if self.state != HolePunchState::Punching {
252            return Err(HolePunchError::InvalidState);
253        }
254
255        self.state = HolePunchState::Failed;
256        self.state_entered_at = now;
257
258        Ok(alloc::vec![
259            HolePunchAction::Failed {
260                session_id: self.session_id,
261                reason: FAIL_TIMEOUT,
262            },
263        ])
264    }
265
266    /// Periodic tick: check timeouts.
267    pub fn tick(&mut self, now: f64) -> Vec<HolePunchAction> {
268        let elapsed = now - self.state_entered_at;
269        match self.state {
270            HolePunchState::Discovering if elapsed > DISCOVER_TIMEOUT => {
271                self.state = HolePunchState::Failed;
272                self.state_entered_at = now;
273                alloc::vec![HolePunchAction::Failed {
274                    session_id: self.session_id,
275                    reason: FAIL_PROBE,
276                }]
277            }
278            HolePunchState::Proposing if elapsed > PROPOSE_TIMEOUT => {
279                self.state = HolePunchState::Failed;
280                self.state_entered_at = now;
281                alloc::vec![HolePunchAction::Failed {
282                    session_id: self.session_id,
283                    reason: FAIL_TIMEOUT,
284                }]
285            }
286            HolePunchState::WaitingReady if elapsed > READY_TIMEOUT => {
287                self.state = HolePunchState::Failed;
288                self.state_entered_at = now;
289                alloc::vec![HolePunchAction::Failed {
290                    session_id: self.session_id,
291                    reason: FAIL_TIMEOUT,
292                }]
293            }
294            HolePunchState::Punching if elapsed > PUNCH_TIMEOUT => {
295                self.state = HolePunchState::Failed;
296                self.state_entered_at = now;
297                alloc::vec![HolePunchAction::Failed {
298                    session_id: self.session_id,
299                    reason: FAIL_TIMEOUT,
300                }]
301            }
302            _ => Vec::new(),
303        }
304    }
305
306    /// Build a reject response for a request payload without creating a full session.
307    ///
308    /// Used when the policy rejects all proposals.
309    pub fn build_reject(
310        link_id: [u8; 16],
311        request_payload: &[u8],
312        reason: u8,
313    ) -> Result<HolePunchAction, HolePunchError> {
314        let (session_id, _, _) = decode_upgrade_request(request_payload)?;
315        let payload = encode_upgrade_reject(&session_id, reason);
316        Ok(HolePunchAction::SendSignal {
317            link_id,
318            msgtype: UPGRADE_REJECT,
319            payload,
320        })
321    }
322
323    /// Reset engine back to Idle state for reuse.
324    pub fn reset(&mut self) {
325        self.state = HolePunchState::Idle;
326        self.session_id = [0u8; 16];
327        self.is_initiator = false;
328        self.our_public_endpoint = None;
329        self.peer_public_endpoint = None;
330        self.facilitator_addr = None;
331        self.punch_token = [0u8; 32];
332        self.state_entered_at = 0.0;
333    }
334
335    // --- Private handlers ---
336
337    /// Responder receives UPGRADE_REQUEST.
338    ///
339    /// Spec Phase 2: B evaluates the request, sends UPGRADE_ACCEPT, then
340    /// begins Phase 3 (STUN discovery using facilitator from the request).
341    fn handle_upgrade_request(
342        &mut self,
343        payload: &[u8],
344        derived_key: Option<&[u8]>,
345        now: f64,
346    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
347        if self.state != HolePunchState::Idle {
348            // Already busy — reject
349            let (session_id, _, _) = decode_upgrade_request(payload)?;
350            let reject_payload = encode_upgrade_reject(&session_id, REJECT_BUSY);
351            return Ok(alloc::vec![HolePunchAction::SendSignal {
352                link_id: self.link_id,
353                msgtype: UPGRADE_REJECT,
354                payload: reject_payload,
355            }]);
356        }
357
358        let derived_key = derived_key.ok_or(HolePunchError::NoDerivedKey)?;
359        let (session_id, facilitator, initiator_public) = decode_upgrade_request(payload)?;
360
361        self.session_id = session_id;
362        self.is_initiator = false;
363        self.punch_token = derive_punch_token(derived_key, &session_id)?;
364
365        // Store A's public address (we'll punch this later)
366        self.peer_public_endpoint = Some(initiator_public);
367
368        // Use facilitator from the request for our own STUN discovery
369        self.facilitator_addr = Some(facilitator.clone());
370
371        self.state = HolePunchState::Discovering;
372        self.state_entered_at = now;
373
374        // Send UPGRADE_ACCEPT, then discover our endpoint using facilitator from request
375        let accept_payload = encode_upgrade_accept(&session_id);
376
377        Ok(alloc::vec![
378            HolePunchAction::SendSignal {
379                link_id: self.link_id,
380                msgtype: UPGRADE_ACCEPT,
381                payload: accept_payload,
382            },
383            HolePunchAction::DiscoverEndpoints { probe_addr: facilitator },
384        ])
385    }
386
387    /// Initiator receives UPGRADE_ACCEPT.
388    ///
389    /// Transitions: Proposing -> WaitingReady (waiting for B's UPGRADE_READY).
390    fn handle_upgrade_accept(
391        &mut self,
392        payload: &[u8],
393        now: f64,
394    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
395        if self.state != HolePunchState::Proposing || !self.is_initiator {
396            return Err(HolePunchError::InvalidState);
397        }
398
399        let session_id = decode_upgrade_accept(payload)?;
400        if session_id != self.session_id {
401            return Err(HolePunchError::SessionMismatch);
402        }
403
404        self.state = HolePunchState::WaitingReady;
405        self.state_entered_at = now;
406
407        Ok(Vec::new())
408    }
409
410    /// Initiator receives UPGRADE_REJECT.
411    fn handle_upgrade_reject(
412        &mut self,
413        payload: &[u8],
414        now: f64,
415    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
416        if self.state != HolePunchState::Proposing {
417            return Err(HolePunchError::InvalidState);
418        }
419
420        let (session_id, reason) = decode_upgrade_reject(payload)?;
421        if session_id != self.session_id {
422            return Err(HolePunchError::SessionMismatch);
423        }
424
425        self.state = HolePunchState::Failed;
426        self.state_entered_at = now;
427
428        Ok(alloc::vec![HolePunchAction::Failed {
429            session_id: self.session_id,
430            reason,
431        }])
432    }
433
434    /// Initiator receives UPGRADE_READY from responder.
435    ///
436    /// Spec Phase 3 complete: B has discovered its endpoint and sent it to A.
437    /// Both sides now start punching (Phase 4).
438    fn handle_upgrade_ready(
439        &mut self,
440        payload: &[u8],
441        now: f64,
442    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
443        if self.state != HolePunchState::WaitingReady || !self.is_initiator {
444            return Err(HolePunchError::InvalidState);
445        }
446
447        let (session_id, responder_public) = decode_upgrade_ready(payload)?;
448        if session_id != self.session_id {
449            return Err(HolePunchError::SessionMismatch);
450        }
451
452        self.peer_public_endpoint = Some(responder_public.clone());
453
454        self.state = HolePunchState::Punching;
455        self.state_entered_at = now;
456
457        Ok(alloc::vec![HolePunchAction::StartUdpPunch {
458            peer_public: responder_public,
459            punch_token: self.punch_token,
460            session_id: self.session_id,
461        }])
462    }
463
464    /// Receives UPGRADE_COMPLETE (over direct UDP channel after punch succeeds).
465    fn handle_upgrade_complete(
466        &mut self,
467        payload: &[u8],
468        now: f64,
469    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
470        if self.state != HolePunchState::Punching && self.state != HolePunchState::Connected {
471            return Err(HolePunchError::InvalidState);
472        }
473
474        let session_id = decode_session_only(payload)?;
475        if session_id != self.session_id {
476            return Err(HolePunchError::SessionMismatch);
477        }
478
479        if self.state == HolePunchState::Connected {
480            // Already connected — peer is confirming
481            return Ok(Vec::new());
482        }
483
484        self.state = HolePunchState::Connected;
485        self.state_entered_at = now;
486
487        Ok(alloc::vec![HolePunchAction::Succeeded {
488            session_id: self.session_id,
489        }])
490    }
491}
492
493// --- Msgpack encode/decode helpers ---
494
495fn encode_upgrade_request(
496    session_id: &[u8; 16],
497    facilitator: &Endpoint,
498    initiator_public: &Endpoint,
499) -> Vec<u8> {
500    let val = Value::Map(alloc::vec![
501        (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
502        (Value::Str(alloc::string::String::from("f")), encode_endpoint(facilitator)),
503        (Value::Str(alloc::string::String::from("a")), encode_endpoint(initiator_public)),
504    ]);
505    msgpack::pack(&val)
506}
507
508fn decode_upgrade_request(data: &[u8]) -> Result<([u8; 16], Endpoint, Endpoint), HolePunchError> {
509    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
510    let session_id = extract_session_id(&val)?;
511    let facilitator = val
512        .map_get("f")
513        .and_then(decode_endpoint)
514        .ok_or(HolePunchError::InvalidPayload)?;
515    let initiator_public = val
516        .map_get("a")
517        .and_then(decode_endpoint)
518        .ok_or(HolePunchError::InvalidPayload)?;
519    Ok((session_id, facilitator, initiator_public))
520}
521
522fn encode_upgrade_accept(session_id: &[u8; 16]) -> Vec<u8> {
523    let val = Value::Map(alloc::vec![
524        (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
525    ]);
526    msgpack::pack(&val)
527}
528
529fn decode_upgrade_accept(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
530    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
531    extract_session_id(&val)
532}
533
534fn encode_upgrade_reject(session_id: &[u8; 16], reason: u8) -> Vec<u8> {
535    let val = Value::Map(alloc::vec![
536        (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
537        (Value::Str(alloc::string::String::from("r")), Value::UInt(reason as u64)),
538    ]);
539    msgpack::pack(&val)
540}
541
542fn decode_upgrade_reject(data: &[u8]) -> Result<([u8; 16], u8), HolePunchError> {
543    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
544    let session_id = extract_session_id(&val)?;
545    let reason = val
546        .map_get("r")
547        .and_then(|v| v.as_uint())
548        .ok_or(HolePunchError::InvalidPayload)? as u8;
549    Ok((session_id, reason))
550}
551
552fn encode_upgrade_ready(session_id: &[u8; 16], responder_public: &Endpoint) -> Vec<u8> {
553    let val = Value::Map(alloc::vec![
554        (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
555        (Value::Str(alloc::string::String::from("a")), encode_endpoint(responder_public)),
556    ]);
557    msgpack::pack(&val)
558}
559
560fn decode_upgrade_ready(data: &[u8]) -> Result<([u8; 16], Endpoint), HolePunchError> {
561    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
562    let session_id = extract_session_id(&val)?;
563    let responder_public = val
564        .map_get("a")
565        .and_then(decode_endpoint)
566        .ok_or(HolePunchError::InvalidPayload)?;
567    Ok((session_id, responder_public))
568}
569
570fn encode_endpoint(ep: &Endpoint) -> Value {
571    Value::Array(alloc::vec![
572        Value::Bin(ep.addr.clone()),
573        Value::UInt(ep.port as u64),
574    ])
575}
576
577fn decode_endpoint(val: &Value) -> Option<Endpoint> {
578    let arr = val.as_array()?;
579    if arr.len() < 2 {
580        return None;
581    }
582    let addr = arr[0].as_bin()?.to_vec();
583    let port = arr[1].as_uint()? as u16;
584    Some(Endpoint { addr, port })
585}
586
587#[cfg(test)]
588fn encode_session_only(session_id: &[u8; 16]) -> Vec<u8> {
589    let val = Value::Map(alloc::vec![
590        (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
591    ]);
592    msgpack::pack(&val)
593}
594
595fn decode_session_only(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
596    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
597    extract_session_id(&val)
598}
599
600fn extract_session_id(val: &Value) -> Result<[u8; 16], HolePunchError> {
601    let bin = val
602        .map_get("s")
603        .and_then(|v| v.as_bin())
604        .ok_or(HolePunchError::InvalidPayload)?;
605    if bin.len() != 16 {
606        return Err(HolePunchError::InvalidPayload);
607    }
608    let mut id = [0u8; 16];
609    id.copy_from_slice(bin);
610    Ok(id)
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616    use rns_crypto::FixedRng;
617
618    fn make_rng(seed: u8) -> FixedRng {
619        FixedRng::new(&[seed; 128])
620    }
621
622    fn test_derived_key() -> Vec<u8> {
623        vec![0xAA; 32]
624    }
625
626    fn test_probe_addr() -> Endpoint {
627        Endpoint {
628            addr: vec![127, 0, 0, 1],
629            port: 4343,
630        }
631    }
632
633    fn test_public_addr_a() -> Endpoint {
634        Endpoint {
635            addr: vec![1, 2, 3, 4],
636            port: 41000,
637        }
638    }
639
640    fn test_public_addr_b() -> Endpoint {
641        Endpoint {
642            addr: vec![5, 6, 7, 8],
643            port: 52000,
644        }
645    }
646
647    #[test]
648    fn test_propose_initiator_discovers_first() {
649        let link_id = [0x11; 16];
650        let derived_key = test_derived_key();
651        let mut rng = make_rng(0x42);
652
653        let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
654        let actions = initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
655
656        // Should transition to Discovering, not Proposing
657        assert_eq!(initiator.state(), HolePunchState::Discovering);
658        assert_eq!(actions.len(), 1);
659        assert!(matches!(&actions[0], HolePunchAction::DiscoverEndpoints { .. }));
660    }
661
662    #[test]
663    fn test_initiator_sends_request_after_discovery() {
664        let link_id = [0x11; 16];
665        let derived_key = test_derived_key();
666        let mut rng = make_rng(0x42);
667
668        let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
669        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
670
671        // Initiator discovers its public endpoint
672        let actions = initiator
673            .endpoints_discovered(test_public_addr_a(), 101.0)
674            .unwrap();
675
676        // Should transition to Proposing and send UPGRADE_REQUEST
677        assert_eq!(initiator.state(), HolePunchState::Proposing);
678        assert_eq!(actions.len(), 1);
679        match &actions[0] {
680            HolePunchAction::SendSignal { msgtype, payload, .. } => {
681                assert_eq!(*msgtype, UPGRADE_REQUEST);
682                // Verify payload contains facilitator and initiator_public
683                let (sid, facilitator, init_pub) = decode_upgrade_request(payload).unwrap();
684                assert_eq!(sid, *initiator.session_id());
685                assert_eq!(facilitator, test_probe_addr());
686                assert_eq!(init_pub, test_public_addr_a());
687            }
688            _ => panic!("Expected SendSignal(UPGRADE_REQUEST)"),
689        }
690    }
691
692    #[test]
693    fn test_full_asymmetric_flow() {
694        let link_id = [0x22; 16];
695        let derived_key = test_derived_key();
696        let mut rng = make_rng(0x42);
697
698        // Phase 1: Initiator discovers its endpoint
699        let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
700        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
701        let actions = initiator
702            .endpoints_discovered(test_public_addr_a(), 101.0)
703            .unwrap();
704
705        let request_payload = match &actions[0] {
706            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
707            _ => panic!(),
708        };
709
710        // Phase 2: Responder receives UPGRADE_REQUEST
711        let mut responder = HolePunchEngine::new(link_id, None); // no probe_addr needed, uses facilitator from request
712        let actions = responder
713            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
714            .unwrap();
715
716        assert_eq!(responder.state(), HolePunchState::Discovering);
717        assert_eq!(actions.len(), 2); // UPGRADE_ACCEPT + DiscoverEndpoints
718
719        let accept_payload = match &actions[0] {
720            HolePunchAction::SendSignal { msgtype, payload, .. } => {
721                assert_eq!(*msgtype, UPGRADE_ACCEPT);
722                payload.clone()
723            }
724            _ => panic!("Expected UPGRADE_ACCEPT"),
725        };
726
727        // B discovers using facilitator from request
728        match &actions[1] {
729            HolePunchAction::DiscoverEndpoints { probe_addr } => {
730                assert_eq!(*probe_addr, test_probe_addr()); // facilitator from request
731            }
732            _ => panic!("Expected DiscoverEndpoints"),
733        }
734
735        // Initiator receives UPGRADE_ACCEPT -> WaitingReady
736        let actions = initiator
737            .handle_signal(UPGRADE_ACCEPT, &accept_payload, None, 103.0)
738            .unwrap();
739        assert_eq!(initiator.state(), HolePunchState::WaitingReady);
740        assert!(actions.is_empty()); // Just waiting
741
742        // Phase 3: Responder discovers its endpoint, sends UPGRADE_READY
743        let actions = responder
744            .endpoints_discovered(test_public_addr_b(), 104.0)
745            .unwrap();
746
747        assert_eq!(responder.state(), HolePunchState::Punching);
748        assert_eq!(actions.len(), 2); // UPGRADE_READY + StartUdpPunch
749
750        let ready_payload = match &actions[0] {
751            HolePunchAction::SendSignal { msgtype, payload, .. } => {
752                assert_eq!(*msgtype, UPGRADE_READY);
753                payload.clone()
754            }
755            _ => panic!("Expected UPGRADE_READY"),
756        };
757        assert!(matches!(&actions[1], HolePunchAction::StartUdpPunch { .. }));
758
759        // Phase 4: Initiator receives UPGRADE_READY -> Punching
760        let actions = initiator
761            .handle_signal(UPGRADE_READY, &ready_payload, None, 105.0)
762            .unwrap();
763
764        assert_eq!(initiator.state(), HolePunchState::Punching);
765        assert_eq!(actions.len(), 1);
766        match &actions[0] {
767            HolePunchAction::StartUdpPunch { peer_public, .. } => {
768                assert_eq!(*peer_public, test_public_addr_b());
769            }
770            _ => panic!("Expected StartUdpPunch"),
771        }
772
773        // Both derive the same punch token
774        assert_eq!(initiator.punch_token(), responder.punch_token());
775    }
776
777    #[test]
778    fn test_punch_success() {
779        let link_id = [0x33; 16];
780        let derived_key = test_derived_key();
781        let mut rng = make_rng(0x42);
782
783        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
784        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
785        engine.state = HolePunchState::Punching;
786
787        let actions = engine.punch_succeeded(105.0).unwrap();
788        assert_eq!(engine.state(), HolePunchState::Connected);
789        assert_eq!(actions.len(), 1);
790        assert!(matches!(&actions[0], HolePunchAction::Succeeded { .. }));
791    }
792
793    #[test]
794    fn test_punch_failed() {
795        let link_id = [0x44; 16];
796        let derived_key = test_derived_key();
797        let mut rng = make_rng(0x42);
798
799        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
800        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
801        engine.state = HolePunchState::Punching;
802
803        let actions = engine.punch_failed(120.0).unwrap();
804        assert_eq!(engine.state(), HolePunchState::Failed);
805        assert_eq!(actions.len(), 1);
806        assert!(matches!(&actions[0], HolePunchAction::Failed { .. }));
807    }
808
809    #[test]
810    fn test_reject_when_busy() {
811        let link_id = [0x55; 16];
812        let derived_key = test_derived_key();
813        let mut rng = make_rng(0x42);
814
815        // Create a request payload
816        let mut proposer = HolePunchEngine::new(link_id, Some(test_probe_addr()));
817        proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
818        let actions = proposer.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
819        let request_payload = match &actions[0] {
820            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
821            _ => panic!(),
822        };
823
824        // Responder is already busy (set to Discovering manually)
825        let mut responder = HolePunchEngine::new(link_id, Some(test_probe_addr()));
826        responder.state = HolePunchState::Discovering;
827
828        let actions = responder
829            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
830            .unwrap();
831
832        // Should reject with REJECT_BUSY
833        assert_eq!(actions.len(), 1);
834        match &actions[0] {
835            HolePunchAction::SendSignal { msgtype, .. } => {
836                assert_eq!(*msgtype, UPGRADE_REJECT);
837            }
838            _ => panic!("Expected UPGRADE_REJECT"),
839        }
840    }
841
842    #[test]
843    fn test_initiator_receives_reject() {
844        let link_id = [0x66; 16];
845        let derived_key = test_derived_key();
846        let mut rng = make_rng(0x42);
847
848        let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
849        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
850        initiator.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
851        assert_eq!(initiator.state(), HolePunchState::Proposing);
852
853        let session_id = *initiator.session_id();
854        let reject_payload = encode_upgrade_reject(&session_id, REJECT_POLICY);
855
856        let actions = initiator
857            .handle_signal(UPGRADE_REJECT, &reject_payload, None, 102.0)
858            .unwrap();
859
860        assert_eq!(initiator.state(), HolePunchState::Failed);
861        assert_eq!(actions.len(), 1);
862        assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == REJECT_POLICY));
863    }
864
865    #[test]
866    fn test_discover_timeout() {
867        let link_id = [0x77; 16];
868        let derived_key = test_derived_key();
869        let mut rng = make_rng(0x42);
870
871        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
872        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
873        assert_eq!(engine.state(), HolePunchState::Discovering);
874
875        // Before timeout
876        let actions = engine.tick(100.0 + DISCOVER_TIMEOUT - 1.0);
877        assert!(actions.is_empty());
878
879        // After timeout
880        let actions = engine.tick(100.0 + DISCOVER_TIMEOUT + 1.0);
881        assert_eq!(engine.state(), HolePunchState::Failed);
882        assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_PROBE));
883    }
884
885    #[test]
886    fn test_propose_timeout() {
887        let link_id = [0x88; 16];
888        let derived_key = test_derived_key();
889        let mut rng = make_rng(0x42);
890
891        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
892        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
893        engine.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
894        assert_eq!(engine.state(), HolePunchState::Proposing);
895
896        // After timeout
897        let actions = engine.tick(101.0 + PROPOSE_TIMEOUT + 1.0);
898        assert_eq!(engine.state(), HolePunchState::Failed);
899        assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT));
900    }
901
902    #[test]
903    fn test_waiting_ready_timeout() {
904        let link_id = [0x99; 16];
905        let derived_key = test_derived_key();
906        let mut rng = make_rng(0x42);
907
908        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
909        engine.propose(&derived_key, 200.0, &mut rng).unwrap();
910        engine.endpoints_discovered(test_public_addr_a(), 201.0).unwrap();
911        engine.state = HolePunchState::WaitingReady;
912        engine.state_entered_at = 202.0;
913
914        // After timeout
915        let actions = engine.tick(202.0 + READY_TIMEOUT + 1.0);
916        assert_eq!(engine.state(), HolePunchState::Failed);
917        assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT));
918    }
919
920    #[test]
921    fn test_punch_timeout() {
922        let link_id = [0xAA; 16];
923        let derived_key = test_derived_key();
924        let mut rng = make_rng(0x42);
925
926        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
927        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
928        engine.state = HolePunchState::Punching;
929        engine.state_entered_at = 200.0;
930
931        // Before timeout
932        let actions = engine.tick(200.0 + PUNCH_TIMEOUT - 1.0);
933        assert!(actions.is_empty());
934
935        // After timeout
936        let actions = engine.tick(200.0 + PUNCH_TIMEOUT + 1.0);
937        assert_eq!(engine.state(), HolePunchState::Failed);
938    }
939
940    #[test]
941    fn test_message_serialization_roundtrip() {
942        let session_id = [0xAB; 16];
943
944        // UPGRADE_REQUEST
945        let facilitator = test_probe_addr();
946        let init_pub = test_public_addr_a();
947        let data = encode_upgrade_request(&session_id, &facilitator, &init_pub);
948        let (sid, f, a) = decode_upgrade_request(&data).unwrap();
949        assert_eq!(sid, session_id);
950        assert_eq!(f, facilitator);
951        assert_eq!(a, init_pub);
952
953        // UPGRADE_ACCEPT
954        let data = encode_upgrade_accept(&session_id);
955        let sid = decode_upgrade_accept(&data).unwrap();
956        assert_eq!(sid, session_id);
957
958        // UPGRADE_REJECT
959        let data = encode_upgrade_reject(&session_id, REJECT_POLICY);
960        let (sid, r) = decode_upgrade_reject(&data).unwrap();
961        assert_eq!(sid, session_id);
962        assert_eq!(r, REJECT_POLICY);
963
964        // UPGRADE_READY
965        let resp_pub = test_public_addr_b();
966        let data = encode_upgrade_ready(&session_id, &resp_pub);
967        let (sid, rp) = decode_upgrade_ready(&data).unwrap();
968        assert_eq!(sid, session_id);
969        assert_eq!(rp, resp_pub);
970
971        // Session only (UPGRADE_COMPLETE)
972        let data = encode_session_only(&session_id);
973        let sid = decode_session_only(&data).unwrap();
974        assert_eq!(sid, session_id);
975    }
976
977    #[test]
978    fn test_punch_token_derivation_consistency() {
979        let derived_key = vec![0xBB; 32];
980        let session_id = [0xCC; 16];
981
982        let token1 = derive_punch_token(&derived_key, &session_id).unwrap();
983        let token2 = derive_punch_token(&derived_key, &session_id).unwrap();
984        assert_eq!(token1, token2);
985
986        // Different session_id -> different token
987        let session_id2 = [0xDD; 16];
988        let token3 = derive_punch_token(&derived_key, &session_id2).unwrap();
989        assert_ne!(token1, token3);
990    }
991
992    #[test]
993    fn test_reset() {
994        let link_id = [0xBB; 16];
995        let derived_key = test_derived_key();
996        let mut rng = make_rng(0x42);
997
998        let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
999        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
1000        assert_eq!(engine.state(), HolePunchState::Discovering);
1001
1002        engine.reset();
1003        assert_eq!(engine.state(), HolePunchState::Idle);
1004        assert_eq!(engine.session_id(), &[0u8; 16]);
1005    }
1006
1007    #[test]
1008    fn test_build_reject_static() {
1009        let link_id = [0xCC; 16];
1010        let derived_key = test_derived_key();
1011        let mut rng = make_rng(0x42);
1012
1013        let mut proposer = HolePunchEngine::new(link_id, Some(test_probe_addr()));
1014        proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
1015        let actions = proposer.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
1016        let request_payload = match &actions[0] {
1017            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1018            _ => panic!(),
1019        };
1020
1021        let action = HolePunchEngine::build_reject(link_id, &request_payload, REJECT_POLICY).unwrap();
1022        match action {
1023            HolePunchAction::SendSignal { msgtype, .. } => {
1024                assert_eq!(msgtype, UPGRADE_REJECT);
1025            }
1026            _ => panic!("Expected SendSignal(UPGRADE_REJECT)"),
1027        }
1028    }
1029
1030    #[test]
1031    fn test_responder_needs_no_probe_addr() {
1032        // Responder uses facilitator from UPGRADE_REQUEST, doesn't need its own
1033        let link_id = [0xDD; 16];
1034        let derived_key = test_derived_key();
1035        let mut rng = make_rng(0x42);
1036
1037        // Build a request
1038        let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
1039        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
1040        let actions = initiator.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
1041        let request_payload = match &actions[0] {
1042            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1043            _ => panic!(),
1044        };
1045
1046        // Responder has NO probe_addr configured
1047        let mut responder = HolePunchEngine::new(link_id, None);
1048        let actions = responder
1049            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
1050            .unwrap();
1051
1052        // Should still work — uses facilitator from the request
1053        assert_eq!(responder.state(), HolePunchState::Discovering);
1054        assert_eq!(actions.len(), 2);
1055        assert!(matches!(&actions[0], HolePunchAction::SendSignal { msgtype, .. } if *msgtype == UPGRADE_ACCEPT));
1056        assert!(matches!(&actions[1], HolePunchAction::DiscoverEndpoints { .. }));
1057    }
1058}