Skip to main content

rns_core/holepunch/
engine.rs

1//! HolePunchEngine: pure-logic state machine for NAT hole punching.
2//!
3//! Follows the asymmetric protocol from direct-link-protocol.md:
4//!
5//! Phase 1: A probes facilitator T → learns A_pub
6//! Phase 2: A sends UPGRADE_REQUEST {facilitator: T, initiator_public: A_pub} → B
7//!          B responds with UPGRADE_ACCEPT or UPGRADE_REJECT
8//! Phase 3: B probes T (from request) → learns B_pub
9//!          B sends UPGRADE_READY {responder_public: B_pub} → A
10//! Phase 4: Both punch simultaneously
11//! Phase 5: Direct link established
12//!
13//! Methods return `Vec<HolePunchAction>` instead of performing I/O.
14
15use alloc::vec::Vec;
16
17use rns_crypto::hkdf::hkdf;
18use rns_crypto::Rng;
19
20use crate::msgpack::{self, Value};
21
22use super::types::*;
23
24/// Derives a 32-byte punch token from the link's derived key and session ID.
25///
26/// `HKDF(ikm=derived_key, salt=session_id, info="rns-holepunch-v1")[:32]`
27pub fn derive_punch_token(
28    derived_key: &[u8],
29    session_id: &[u8; 16],
30) -> Result<[u8; 32], HolePunchError> {
31    let result = hkdf(32, derived_key, Some(session_id), Some(b"rns-holepunch-v1"))
32        .map_err(|_| HolePunchError::NoDerivedKey)?;
33    let mut token = [0u8; 32];
34    token.copy_from_slice(&result);
35    Ok(token)
36}
37
38/// The hole-punch state machine for a single link.
39pub struct HolePunchEngine {
40    link_id: [u8; 16],
41    session_id: [u8; 16],
42    state: HolePunchState,
43    is_initiator: bool,
44
45    /// Our discovered public endpoint.
46    our_public_endpoint: Option<Endpoint>,
47
48    /// Peer's public endpoint.
49    /// Initiator: received from UPGRADE_READY.
50    /// Responder: received from UPGRADE_REQUEST.
51    peer_public_endpoint: Option<Endpoint>,
52
53    /// Facilitator (STUN probe) address.
54    /// Initiator: our configured probe address.
55    /// Responder: received from UPGRADE_REQUEST.
56    facilitator_addr: Option<Endpoint>,
57
58    /// Punch token derived from link key + session_id.
59    punch_token: [u8; 32],
60
61    /// Probe service address for endpoint discovery (configured on this node).
62    probe_addr: Option<Endpoint>,
63
64    /// Protocol to use for endpoint discovery.
65    probe_protocol: ProbeProtocol,
66
67    /// Timestamp of state entry (for timeout tracking).
68    state_entered_at: f64,
69}
70
71impl HolePunchEngine {
72    /// Create a new engine for the given link. Does not start any session.
73    pub fn new(
74        link_id: [u8; 16],
75        probe_addr: Option<Endpoint>,
76        probe_protocol: ProbeProtocol,
77    ) -> Self {
78        HolePunchEngine {
79            link_id,
80            session_id: [0u8; 16],
81            state: HolePunchState::Idle,
82            is_initiator: false,
83            our_public_endpoint: None,
84            peer_public_endpoint: None,
85            facilitator_addr: None,
86            punch_token: [0u8; 32],
87            probe_addr,
88            probe_protocol,
89            state_entered_at: 0.0,
90        }
91    }
92
93    pub fn state(&self) -> HolePunchState {
94        self.state
95    }
96
97    pub fn session_id(&self) -> &[u8; 16] {
98        &self.session_id
99    }
100
101    pub fn is_initiator(&self) -> bool {
102        self.is_initiator
103    }
104
105    pub fn punch_token(&self) -> &[u8; 32] {
106        &self.punch_token
107    }
108
109    /// Override the facilitator address.
110    ///
111    /// Used when the orchestrator discovers that a different probe server
112    /// succeeded (failover). Must be called before `endpoints_discovered()`
113    /// so the UPGRADE_REQUEST carries the correct facilitator.
114    pub fn set_facilitator(&mut self, addr: Endpoint) {
115        self.facilitator_addr = Some(addr);
116    }
117
118    /// Peer's discovered public endpoint.
119    pub fn peer_public_endpoint(&self) -> Option<&Endpoint> {
120        self.peer_public_endpoint.as_ref()
121    }
122
123    /// Propose a direct connection (initiator side).
124    ///
125    /// Per the spec, the initiator first discovers its own public endpoint
126    /// (Phase 1) before sending the upgrade request.
127    ///
128    /// Transitions: Idle -> Discovering.
129    pub fn propose(
130        &mut self,
131        derived_key: &[u8],
132        now: f64,
133        rng: &mut dyn Rng,
134    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
135        if self.state != HolePunchState::Idle {
136            return Err(HolePunchError::InvalidState);
137        }
138
139        // Generate session ID
140        let mut session_id = [0u8; 16];
141        rng.fill_bytes(&mut session_id);
142        self.session_id = session_id;
143        self.is_initiator = true;
144
145        // Derive punch token
146        self.punch_token = derive_punch_token(derived_key, &session_id)?;
147
148        let probe_addr = self.probe_addr.clone().ok_or(HolePunchError::NoProbeAddr)?;
149        self.facilitator_addr = Some(probe_addr.clone());
150
151        // Phase 1: discover our public endpoint first
152        self.state = HolePunchState::Discovering;
153        self.state_entered_at = now;
154
155        Ok(alloc::vec![HolePunchAction::DiscoverEndpoints {
156            probe_addr,
157            protocol: self.probe_protocol
158        }])
159    }
160
161    /// Called when endpoint discovery completes.
162    ///
163    /// For initiator: Discovering -> Proposing (sends UPGRADE_REQUEST with facilitator + our addr).
164    /// For responder: Discovering -> Punching (sends UPGRADE_READY with our addr, starts punch).
165    pub fn endpoints_discovered(
166        &mut self,
167        public_endpoint: Endpoint,
168        now: f64,
169    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
170        if self.state != HolePunchState::Discovering {
171            return Err(HolePunchError::InvalidState);
172        }
173
174        self.our_public_endpoint = Some(public_endpoint.clone());
175
176        if self.is_initiator {
177            // Initiator: Phase 1 complete -> Phase 2: send UPGRADE_REQUEST
178            let facilitator = self
179                .facilitator_addr
180                .clone()
181                .ok_or(HolePunchError::NoProbeAddr)?;
182
183            let payload = encode_upgrade_request(
184                &self.session_id,
185                &facilitator,
186                &public_endpoint,
187                self.probe_protocol,
188            );
189
190            self.state = HolePunchState::Proposing;
191            self.state_entered_at = now;
192
193            Ok(alloc::vec![HolePunchAction::SendSignal {
194                link_id: self.link_id,
195                msgtype: UPGRADE_REQUEST,
196                payload,
197            }])
198        } else {
199            // Responder: Phase 3 complete -> send UPGRADE_READY, start punching
200            let payload = encode_upgrade_ready(&self.session_id, &public_endpoint);
201
202            let peer_public = self
203                .peer_public_endpoint
204                .clone()
205                .ok_or(HolePunchError::InvalidState)?;
206
207            self.state = HolePunchState::Punching;
208            self.state_entered_at = now;
209
210            Ok(alloc::vec![
211                HolePunchAction::SendSignal {
212                    link_id: self.link_id,
213                    msgtype: UPGRADE_READY,
214                    payload,
215                },
216                HolePunchAction::StartUdpPunch {
217                    peer_public,
218                    punch_token: self.punch_token,
219                    session_id: self.session_id,
220                },
221            ])
222        }
223    }
224
225    /// Handle an incoming signaling message.
226    ///
227    /// `derived_key` is needed when handling UPGRADE_REQUEST (responder side).
228    pub fn handle_signal(
229        &mut self,
230        msgtype: u16,
231        payload: &[u8],
232        derived_key: Option<&[u8]>,
233        now: f64,
234    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
235        match msgtype {
236            UPGRADE_REQUEST => self.handle_upgrade_request(payload, derived_key, now),
237            UPGRADE_ACCEPT => self.handle_upgrade_accept(payload, now),
238            UPGRADE_REJECT => self.handle_upgrade_reject(payload, now),
239            UPGRADE_READY => self.handle_upgrade_ready(payload, now),
240            UPGRADE_COMPLETE => self.handle_upgrade_complete(payload, now),
241            _ => Err(HolePunchError::InvalidPayload),
242        }
243    }
244
245    /// Called when the punch phase succeeds.
246    ///
247    /// Transitions: Punching -> Connected.
248    pub fn punch_succeeded(&mut self, now: f64) -> Result<Vec<HolePunchAction>, HolePunchError> {
249        if self.state != HolePunchState::Punching {
250            return Err(HolePunchError::InvalidState);
251        }
252
253        self.state = HolePunchState::Connected;
254        self.state_entered_at = now;
255
256        Ok(alloc::vec![HolePunchAction::Succeeded {
257            session_id: self.session_id,
258        },])
259    }
260
261    /// Called when the punch phase fails.
262    ///
263    /// Transitions: Punching -> Failed.
264    pub fn punch_failed(&mut self, now: f64) -> Result<Vec<HolePunchAction>, HolePunchError> {
265        if self.state != HolePunchState::Punching {
266            return Err(HolePunchError::InvalidState);
267        }
268
269        self.state = HolePunchState::Failed;
270        self.state_entered_at = now;
271
272        Ok(alloc::vec![HolePunchAction::Failed {
273            session_id: self.session_id,
274            reason: FAIL_TIMEOUT,
275        },])
276    }
277
278    /// Periodic tick: check timeouts.
279    pub fn tick(&mut self, now: f64) -> Vec<HolePunchAction> {
280        let elapsed = now - self.state_entered_at;
281        match self.state {
282            HolePunchState::Discovering if elapsed > DISCOVER_TIMEOUT => {
283                self.state = HolePunchState::Failed;
284                self.state_entered_at = now;
285                alloc::vec![HolePunchAction::Failed {
286                    session_id: self.session_id,
287                    reason: FAIL_PROBE,
288                }]
289            }
290            HolePunchState::Proposing if elapsed > PROPOSE_TIMEOUT => {
291                self.state = HolePunchState::Failed;
292                self.state_entered_at = now;
293                alloc::vec![HolePunchAction::Failed {
294                    session_id: self.session_id,
295                    reason: FAIL_TIMEOUT,
296                }]
297            }
298            HolePunchState::WaitingReady if elapsed > READY_TIMEOUT => {
299                self.state = HolePunchState::Failed;
300                self.state_entered_at = now;
301                alloc::vec![HolePunchAction::Failed {
302                    session_id: self.session_id,
303                    reason: FAIL_TIMEOUT,
304                }]
305            }
306            HolePunchState::Punching if elapsed > PUNCH_TIMEOUT => {
307                self.state = HolePunchState::Failed;
308                self.state_entered_at = now;
309                alloc::vec![HolePunchAction::Failed {
310                    session_id: self.session_id,
311                    reason: FAIL_TIMEOUT,
312                }]
313            }
314            _ => Vec::new(),
315        }
316    }
317
318    /// Build a reject response for a request payload without creating a full session.
319    ///
320    /// Used when the policy rejects all proposals.
321    pub fn build_reject(
322        link_id: [u8; 16],
323        request_payload: &[u8],
324        reason: u8,
325    ) -> Result<HolePunchAction, HolePunchError> {
326        let (session_id, _, _, _) = decode_upgrade_request(request_payload)?;
327        let payload = encode_upgrade_reject(&session_id, reason);
328        Ok(HolePunchAction::SendSignal {
329            link_id,
330            msgtype: UPGRADE_REJECT,
331            payload,
332        })
333    }
334
335    /// Reset engine back to Idle state for reuse.
336    pub fn reset(&mut self) {
337        self.state = HolePunchState::Idle;
338        self.session_id = [0u8; 16];
339        self.is_initiator = false;
340        self.our_public_endpoint = None;
341        self.peer_public_endpoint = None;
342        self.facilitator_addr = None;
343        self.punch_token = [0u8; 32];
344        self.probe_protocol = ProbeProtocol::Rnsp;
345        self.state_entered_at = 0.0;
346    }
347
348    // --- Private handlers ---
349
350    /// Responder receives UPGRADE_REQUEST.
351    ///
352    /// Spec Phase 2: B evaluates the request, sends UPGRADE_ACCEPT, then
353    /// begins Phase 3 (STUN discovery using facilitator from the request).
354    fn handle_upgrade_request(
355        &mut self,
356        payload: &[u8],
357        derived_key: Option<&[u8]>,
358        now: f64,
359    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
360        if self.state != HolePunchState::Idle {
361            // Already busy — reject
362            let (session_id, _, _, _) = decode_upgrade_request(payload)?;
363            let reject_payload = encode_upgrade_reject(&session_id, REJECT_BUSY);
364            return Ok(alloc::vec![HolePunchAction::SendSignal {
365                link_id: self.link_id,
366                msgtype: UPGRADE_REJECT,
367                payload: reject_payload,
368            }]);
369        }
370
371        let derived_key = derived_key.ok_or(HolePunchError::NoDerivedKey)?;
372        let (session_id, facilitator, initiator_public, protocol) =
373            decode_upgrade_request(payload)?;
374
375        self.session_id = session_id;
376        self.is_initiator = false;
377        self.probe_protocol = protocol;
378        self.punch_token = derive_punch_token(derived_key, &session_id)?;
379
380        // Store A's public address (we'll punch this later)
381        self.peer_public_endpoint = Some(initiator_public);
382
383        // Use facilitator from the request for our own STUN discovery
384        self.facilitator_addr = Some(facilitator.clone());
385
386        self.state = HolePunchState::Discovering;
387        self.state_entered_at = now;
388
389        // Send UPGRADE_ACCEPT, then discover our endpoint using facilitator from request
390        let accept_payload = encode_upgrade_accept(&session_id);
391
392        Ok(alloc::vec![
393            HolePunchAction::SendSignal {
394                link_id: self.link_id,
395                msgtype: UPGRADE_ACCEPT,
396                payload: accept_payload,
397            },
398            HolePunchAction::DiscoverEndpoints {
399                probe_addr: facilitator,
400                protocol
401            },
402        ])
403    }
404
405    /// Initiator receives UPGRADE_ACCEPT.
406    ///
407    /// Transitions: Proposing -> WaitingReady (waiting for B's UPGRADE_READY).
408    fn handle_upgrade_accept(
409        &mut self,
410        payload: &[u8],
411        now: f64,
412    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
413        if self.state != HolePunchState::Proposing || !self.is_initiator {
414            return Err(HolePunchError::InvalidState);
415        }
416
417        let session_id = decode_upgrade_accept(payload)?;
418        if session_id != self.session_id {
419            return Err(HolePunchError::SessionMismatch);
420        }
421
422        self.state = HolePunchState::WaitingReady;
423        self.state_entered_at = now;
424
425        Ok(Vec::new())
426    }
427
428    /// Initiator receives UPGRADE_REJECT.
429    fn handle_upgrade_reject(
430        &mut self,
431        payload: &[u8],
432        now: f64,
433    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
434        if self.state != HolePunchState::Proposing {
435            return Err(HolePunchError::InvalidState);
436        }
437
438        let (session_id, reason) = decode_upgrade_reject(payload)?;
439        if session_id != self.session_id {
440            return Err(HolePunchError::SessionMismatch);
441        }
442
443        self.state = HolePunchState::Failed;
444        self.state_entered_at = now;
445
446        Ok(alloc::vec![HolePunchAction::Failed {
447            session_id: self.session_id,
448            reason,
449        }])
450    }
451
452    /// Initiator receives UPGRADE_READY from responder.
453    ///
454    /// Spec Phase 3 complete: B has discovered its endpoint and sent it to A.
455    /// Both sides now start punching (Phase 4).
456    fn handle_upgrade_ready(
457        &mut self,
458        payload: &[u8],
459        now: f64,
460    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
461        if self.state != HolePunchState::WaitingReady || !self.is_initiator {
462            return Err(HolePunchError::InvalidState);
463        }
464
465        let (session_id, responder_public) = decode_upgrade_ready(payload)?;
466        if session_id != self.session_id {
467            return Err(HolePunchError::SessionMismatch);
468        }
469
470        self.peer_public_endpoint = Some(responder_public.clone());
471
472        self.state = HolePunchState::Punching;
473        self.state_entered_at = now;
474
475        Ok(alloc::vec![HolePunchAction::StartUdpPunch {
476            peer_public: responder_public,
477            punch_token: self.punch_token,
478            session_id: self.session_id,
479        }])
480    }
481
482    /// Receives UPGRADE_COMPLETE (over direct UDP channel after punch succeeds).
483    fn handle_upgrade_complete(
484        &mut self,
485        payload: &[u8],
486        now: f64,
487    ) -> Result<Vec<HolePunchAction>, HolePunchError> {
488        if self.state != HolePunchState::Punching && self.state != HolePunchState::Connected {
489            return Err(HolePunchError::InvalidState);
490        }
491
492        let session_id = decode_session_only(payload)?;
493        if session_id != self.session_id {
494            return Err(HolePunchError::SessionMismatch);
495        }
496
497        if self.state == HolePunchState::Connected {
498            // Already connected — peer is confirming
499            return Ok(Vec::new());
500        }
501
502        self.state = HolePunchState::Connected;
503        self.state_entered_at = now;
504
505        Ok(alloc::vec![HolePunchAction::Succeeded {
506            session_id: self.session_id,
507        }])
508    }
509}
510
511// --- Msgpack encode/decode helpers ---
512
513fn encode_upgrade_request(
514    session_id: &[u8; 16],
515    facilitator: &Endpoint,
516    initiator_public: &Endpoint,
517    protocol: ProbeProtocol,
518) -> Vec<u8> {
519    let mut fields = alloc::vec![
520        (
521            Value::Str(alloc::string::String::from("s")),
522            Value::Bin(session_id.to_vec())
523        ),
524        (
525            Value::Str(alloc::string::String::from("f")),
526            encode_endpoint(facilitator)
527        ),
528        (
529            Value::Str(alloc::string::String::from("a")),
530            encode_endpoint(initiator_public)
531        ),
532    ];
533    // Only include "p" when not RNSP (backward compat: old nodes don't send it)
534    if protocol != ProbeProtocol::Rnsp {
535        fields.push((
536            Value::Str(alloc::string::String::from("p")),
537            Value::UInt(protocol as u64),
538        ));
539    }
540    let val = Value::Map(fields);
541    msgpack::pack(&val)
542}
543
544fn decode_upgrade_request(
545    data: &[u8],
546) -> Result<([u8; 16], Endpoint, Endpoint, ProbeProtocol), HolePunchError> {
547    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
548    let session_id = extract_session_id(&val)?;
549    let facilitator = val
550        .map_get("f")
551        .and_then(decode_endpoint)
552        .ok_or(HolePunchError::InvalidPayload)?;
553    let initiator_public = val
554        .map_get("a")
555        .and_then(decode_endpoint)
556        .ok_or(HolePunchError::InvalidPayload)?;
557    // Fallback to Rnsp when "p" is absent (old nodes don't send it)
558    let protocol = val
559        .map_get("p")
560        .and_then(|v| v.as_uint())
561        .map(|p| match p {
562            1 => ProbeProtocol::Stun,
563            _ => ProbeProtocol::Rnsp,
564        })
565        .unwrap_or(ProbeProtocol::Rnsp);
566    Ok((session_id, facilitator, initiator_public, protocol))
567}
568
569fn encode_upgrade_accept(session_id: &[u8; 16]) -> Vec<u8> {
570    let val = Value::Map(alloc::vec![(
571        Value::Str(alloc::string::String::from("s")),
572        Value::Bin(session_id.to_vec())
573    ),]);
574    msgpack::pack(&val)
575}
576
577fn decode_upgrade_accept(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
578    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
579    extract_session_id(&val)
580}
581
582fn encode_upgrade_reject(session_id: &[u8; 16], reason: u8) -> Vec<u8> {
583    let val = Value::Map(alloc::vec![
584        (
585            Value::Str(alloc::string::String::from("s")),
586            Value::Bin(session_id.to_vec())
587        ),
588        (
589            Value::Str(alloc::string::String::from("r")),
590            Value::UInt(reason as u64)
591        ),
592    ]);
593    msgpack::pack(&val)
594}
595
596fn decode_upgrade_reject(data: &[u8]) -> Result<([u8; 16], u8), HolePunchError> {
597    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
598    let session_id = extract_session_id(&val)?;
599    let reason = val
600        .map_get("r")
601        .and_then(|v| v.as_uint())
602        .ok_or(HolePunchError::InvalidPayload)? as u8;
603    Ok((session_id, reason))
604}
605
606fn encode_upgrade_ready(session_id: &[u8; 16], responder_public: &Endpoint) -> Vec<u8> {
607    let val = Value::Map(alloc::vec![
608        (
609            Value::Str(alloc::string::String::from("s")),
610            Value::Bin(session_id.to_vec())
611        ),
612        (
613            Value::Str(alloc::string::String::from("a")),
614            encode_endpoint(responder_public)
615        ),
616    ]);
617    msgpack::pack(&val)
618}
619
620fn decode_upgrade_ready(data: &[u8]) -> Result<([u8; 16], Endpoint), HolePunchError> {
621    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
622    let session_id = extract_session_id(&val)?;
623    let responder_public = val
624        .map_get("a")
625        .and_then(decode_endpoint)
626        .ok_or(HolePunchError::InvalidPayload)?;
627    Ok((session_id, responder_public))
628}
629
630fn encode_endpoint(ep: &Endpoint) -> Value {
631    Value::Array(alloc::vec![
632        Value::Bin(ep.addr.clone()),
633        Value::UInt(ep.port as u64),
634    ])
635}
636
637fn decode_endpoint(val: &Value) -> Option<Endpoint> {
638    let arr = val.as_array()?;
639    if arr.len() < 2 {
640        return None;
641    }
642    let addr = arr[0].as_bin()?.to_vec();
643    let port = arr[1].as_uint()? as u16;
644    Some(Endpoint { addr, port })
645}
646
647#[cfg(test)]
648fn encode_session_only(session_id: &[u8; 16]) -> Vec<u8> {
649    let val = Value::Map(alloc::vec![(
650        Value::Str(alloc::string::String::from("s")),
651        Value::Bin(session_id.to_vec())
652    ),]);
653    msgpack::pack(&val)
654}
655
656fn decode_session_only(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
657    let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
658    extract_session_id(&val)
659}
660
661fn extract_session_id(val: &Value) -> Result<[u8; 16], HolePunchError> {
662    let bin = val
663        .map_get("s")
664        .and_then(|v| v.as_bin())
665        .ok_or(HolePunchError::InvalidPayload)?;
666    if bin.len() != 16 {
667        return Err(HolePunchError::InvalidPayload);
668    }
669    let mut id = [0u8; 16];
670    id.copy_from_slice(bin);
671    Ok(id)
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use rns_crypto::FixedRng;
678
679    fn make_rng(seed: u8) -> FixedRng {
680        FixedRng::new(&[seed; 128])
681    }
682
683    fn test_derived_key() -> Vec<u8> {
684        vec![0xAA; 32]
685    }
686
687    fn test_probe_addr() -> Endpoint {
688        Endpoint {
689            addr: vec![127, 0, 0, 1],
690            port: 4343,
691        }
692    }
693
694    fn test_public_addr_a() -> Endpoint {
695        Endpoint {
696            addr: vec![1, 2, 3, 4],
697            port: 41000,
698        }
699    }
700
701    fn test_public_addr_b() -> Endpoint {
702        Endpoint {
703            addr: vec![5, 6, 7, 8],
704            port: 52000,
705        }
706    }
707
708    #[test]
709    fn test_propose_initiator_discovers_first() {
710        let link_id = [0x11; 16];
711        let derived_key = test_derived_key();
712        let mut rng = make_rng(0x42);
713
714        let mut initiator =
715            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
716        let actions = initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
717
718        // Should transition to Discovering, not Proposing
719        assert_eq!(initiator.state(), HolePunchState::Discovering);
720        assert_eq!(actions.len(), 1);
721        assert!(matches!(
722            &actions[0],
723            HolePunchAction::DiscoverEndpoints { .. }
724        ));
725    }
726
727    #[test]
728    fn test_initiator_sends_request_after_discovery() {
729        let link_id = [0x11; 16];
730        let derived_key = test_derived_key();
731        let mut rng = make_rng(0x42);
732
733        let mut initiator =
734            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
735        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
736
737        // Initiator discovers its public endpoint
738        let actions = initiator
739            .endpoints_discovered(test_public_addr_a(), 101.0)
740            .unwrap();
741
742        // Should transition to Proposing and send UPGRADE_REQUEST
743        assert_eq!(initiator.state(), HolePunchState::Proposing);
744        assert_eq!(actions.len(), 1);
745        match &actions[0] {
746            HolePunchAction::SendSignal {
747                msgtype, payload, ..
748            } => {
749                assert_eq!(*msgtype, UPGRADE_REQUEST);
750                // Verify payload contains facilitator and initiator_public
751                let (sid, facilitator, init_pub, _proto) = decode_upgrade_request(payload).unwrap();
752                assert_eq!(sid, *initiator.session_id());
753                assert_eq!(facilitator, test_probe_addr());
754                assert_eq!(init_pub, test_public_addr_a());
755            }
756            _ => panic!("Expected SendSignal(UPGRADE_REQUEST)"),
757        }
758    }
759
760    #[test]
761    fn test_full_asymmetric_flow() {
762        let link_id = [0x22; 16];
763        let derived_key = test_derived_key();
764        let mut rng = make_rng(0x42);
765
766        // Phase 1: Initiator discovers its endpoint
767        let mut initiator =
768            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
769        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
770        let actions = initiator
771            .endpoints_discovered(test_public_addr_a(), 101.0)
772            .unwrap();
773
774        let request_payload = match &actions[0] {
775            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
776            _ => panic!(),
777        };
778
779        // Phase 2: Responder receives UPGRADE_REQUEST
780        let mut responder = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp); // no probe_addr needed, uses facilitator from request
781        let actions = responder
782            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
783            .unwrap();
784
785        assert_eq!(responder.state(), HolePunchState::Discovering);
786        assert_eq!(actions.len(), 2); // UPGRADE_ACCEPT + DiscoverEndpoints
787
788        let accept_payload = match &actions[0] {
789            HolePunchAction::SendSignal {
790                msgtype, payload, ..
791            } => {
792                assert_eq!(*msgtype, UPGRADE_ACCEPT);
793                payload.clone()
794            }
795            _ => panic!("Expected UPGRADE_ACCEPT"),
796        };
797
798        // B discovers using facilitator from request
799        match &actions[1] {
800            HolePunchAction::DiscoverEndpoints { probe_addr, .. } => {
801                assert_eq!(*probe_addr, test_probe_addr()); // facilitator from request
802            }
803            _ => panic!("Expected DiscoverEndpoints"),
804        }
805
806        // Initiator receives UPGRADE_ACCEPT -> WaitingReady
807        let actions = initiator
808            .handle_signal(UPGRADE_ACCEPT, &accept_payload, None, 103.0)
809            .unwrap();
810        assert_eq!(initiator.state(), HolePunchState::WaitingReady);
811        assert!(actions.is_empty()); // Just waiting
812
813        // Phase 3: Responder discovers its endpoint, sends UPGRADE_READY
814        let actions = responder
815            .endpoints_discovered(test_public_addr_b(), 104.0)
816            .unwrap();
817
818        assert_eq!(responder.state(), HolePunchState::Punching);
819        assert_eq!(actions.len(), 2); // UPGRADE_READY + StartUdpPunch
820
821        let ready_payload = match &actions[0] {
822            HolePunchAction::SendSignal {
823                msgtype, payload, ..
824            } => {
825                assert_eq!(*msgtype, UPGRADE_READY);
826                payload.clone()
827            }
828            _ => panic!("Expected UPGRADE_READY"),
829        };
830        assert!(matches!(&actions[1], HolePunchAction::StartUdpPunch { .. }));
831
832        // Phase 4: Initiator receives UPGRADE_READY -> Punching
833        let actions = initiator
834            .handle_signal(UPGRADE_READY, &ready_payload, None, 105.0)
835            .unwrap();
836
837        assert_eq!(initiator.state(), HolePunchState::Punching);
838        assert_eq!(actions.len(), 1);
839        match &actions[0] {
840            HolePunchAction::StartUdpPunch { peer_public, .. } => {
841                assert_eq!(*peer_public, test_public_addr_b());
842            }
843            _ => panic!("Expected StartUdpPunch"),
844        }
845
846        // Both derive the same punch token
847        assert_eq!(initiator.punch_token(), responder.punch_token());
848    }
849
850    #[test]
851    fn test_punch_success() {
852        let link_id = [0x33; 16];
853        let derived_key = test_derived_key();
854        let mut rng = make_rng(0x42);
855
856        let mut engine =
857            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
858        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
859        engine.state = HolePunchState::Punching;
860
861        let actions = engine.punch_succeeded(105.0).unwrap();
862        assert_eq!(engine.state(), HolePunchState::Connected);
863        assert_eq!(actions.len(), 1);
864        assert!(matches!(&actions[0], HolePunchAction::Succeeded { .. }));
865    }
866
867    #[test]
868    fn test_punch_failed() {
869        let link_id = [0x44; 16];
870        let derived_key = test_derived_key();
871        let mut rng = make_rng(0x42);
872
873        let mut engine =
874            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
875        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
876        engine.state = HolePunchState::Punching;
877
878        let actions = engine.punch_failed(120.0).unwrap();
879        assert_eq!(engine.state(), HolePunchState::Failed);
880        assert_eq!(actions.len(), 1);
881        assert!(matches!(&actions[0], HolePunchAction::Failed { .. }));
882    }
883
884    #[test]
885    fn test_reject_when_busy() {
886        let link_id = [0x55; 16];
887        let derived_key = test_derived_key();
888        let mut rng = make_rng(0x42);
889
890        // Create a request payload
891        let mut proposer =
892            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
893        proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
894        let actions = proposer
895            .endpoints_discovered(test_public_addr_a(), 101.0)
896            .unwrap();
897        let request_payload = match &actions[0] {
898            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
899            _ => panic!(),
900        };
901
902        // Responder is already busy (set to Discovering manually)
903        let mut responder =
904            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
905        responder.state = HolePunchState::Discovering;
906
907        let actions = responder
908            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
909            .unwrap();
910
911        // Should reject with REJECT_BUSY
912        assert_eq!(actions.len(), 1);
913        match &actions[0] {
914            HolePunchAction::SendSignal { msgtype, .. } => {
915                assert_eq!(*msgtype, UPGRADE_REJECT);
916            }
917            _ => panic!("Expected UPGRADE_REJECT"),
918        }
919    }
920
921    #[test]
922    fn test_initiator_receives_reject() {
923        let link_id = [0x66; 16];
924        let derived_key = test_derived_key();
925        let mut rng = make_rng(0x42);
926
927        let mut initiator =
928            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
929        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
930        initiator
931            .endpoints_discovered(test_public_addr_a(), 101.0)
932            .unwrap();
933        assert_eq!(initiator.state(), HolePunchState::Proposing);
934
935        let session_id = *initiator.session_id();
936        let reject_payload = encode_upgrade_reject(&session_id, REJECT_POLICY);
937
938        let actions = initiator
939            .handle_signal(UPGRADE_REJECT, &reject_payload, None, 102.0)
940            .unwrap();
941
942        assert_eq!(initiator.state(), HolePunchState::Failed);
943        assert_eq!(actions.len(), 1);
944        assert!(
945            matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == REJECT_POLICY)
946        );
947    }
948
949    #[test]
950    fn test_discover_timeout() {
951        let link_id = [0x77; 16];
952        let derived_key = test_derived_key();
953        let mut rng = make_rng(0x42);
954
955        let mut engine =
956            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
957        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
958        assert_eq!(engine.state(), HolePunchState::Discovering);
959
960        // Before timeout
961        let actions = engine.tick(100.0 + DISCOVER_TIMEOUT - 1.0);
962        assert!(actions.is_empty());
963
964        // After timeout
965        let actions = engine.tick(100.0 + DISCOVER_TIMEOUT + 1.0);
966        assert_eq!(engine.state(), HolePunchState::Failed);
967        assert!(
968            matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_PROBE)
969        );
970    }
971
972    #[test]
973    fn test_propose_timeout() {
974        let link_id = [0x88; 16];
975        let derived_key = test_derived_key();
976        let mut rng = make_rng(0x42);
977
978        let mut engine =
979            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
980        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
981        engine
982            .endpoints_discovered(test_public_addr_a(), 101.0)
983            .unwrap();
984        assert_eq!(engine.state(), HolePunchState::Proposing);
985
986        // After timeout
987        let actions = engine.tick(101.0 + PROPOSE_TIMEOUT + 1.0);
988        assert_eq!(engine.state(), HolePunchState::Failed);
989        assert!(
990            matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT)
991        );
992    }
993
994    #[test]
995    fn test_waiting_ready_timeout() {
996        let link_id = [0x99; 16];
997        let derived_key = test_derived_key();
998        let mut rng = make_rng(0x42);
999
1000        let mut engine =
1001            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1002        engine.propose(&derived_key, 200.0, &mut rng).unwrap();
1003        engine
1004            .endpoints_discovered(test_public_addr_a(), 201.0)
1005            .unwrap();
1006        engine.state = HolePunchState::WaitingReady;
1007        engine.state_entered_at = 202.0;
1008
1009        // After timeout
1010        let actions = engine.tick(202.0 + READY_TIMEOUT + 1.0);
1011        assert_eq!(engine.state(), HolePunchState::Failed);
1012        assert!(
1013            matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT)
1014        );
1015    }
1016
1017    #[test]
1018    fn test_punch_timeout() {
1019        let link_id = [0xAA; 16];
1020        let derived_key = test_derived_key();
1021        let mut rng = make_rng(0x42);
1022
1023        let mut engine =
1024            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1025        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
1026        engine.state = HolePunchState::Punching;
1027        engine.state_entered_at = 200.0;
1028
1029        // Before timeout
1030        let actions = engine.tick(200.0 + PUNCH_TIMEOUT - 1.0);
1031        assert!(actions.is_empty());
1032
1033        // After timeout
1034        let _actions = engine.tick(200.0 + PUNCH_TIMEOUT + 1.0);
1035        assert_eq!(engine.state(), HolePunchState::Failed);
1036    }
1037
1038    #[test]
1039    fn test_message_serialization_roundtrip() {
1040        let session_id = [0xAB; 16];
1041
1042        // UPGRADE_REQUEST (RNSP)
1043        let facilitator = test_probe_addr();
1044        let init_pub = test_public_addr_a();
1045        let data =
1046            encode_upgrade_request(&session_id, &facilitator, &init_pub, ProbeProtocol::Rnsp);
1047        let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1048        assert_eq!(sid, session_id);
1049        assert_eq!(f, facilitator);
1050        assert_eq!(a, init_pub);
1051        assert_eq!(proto, ProbeProtocol::Rnsp);
1052
1053        // UPGRADE_ACCEPT
1054        let data = encode_upgrade_accept(&session_id);
1055        let sid = decode_upgrade_accept(&data).unwrap();
1056        assert_eq!(sid, session_id);
1057
1058        // UPGRADE_REJECT
1059        let data = encode_upgrade_reject(&session_id, REJECT_POLICY);
1060        let (sid, r) = decode_upgrade_reject(&data).unwrap();
1061        assert_eq!(sid, session_id);
1062        assert_eq!(r, REJECT_POLICY);
1063
1064        // UPGRADE_READY
1065        let resp_pub = test_public_addr_b();
1066        let data = encode_upgrade_ready(&session_id, &resp_pub);
1067        let (sid, rp) = decode_upgrade_ready(&data).unwrap();
1068        assert_eq!(sid, session_id);
1069        assert_eq!(rp, resp_pub);
1070
1071        // Session only (UPGRADE_COMPLETE)
1072        let data = encode_session_only(&session_id);
1073        let sid = decode_session_only(&data).unwrap();
1074        assert_eq!(sid, session_id);
1075    }
1076
1077    #[test]
1078    fn test_punch_token_derivation_consistency() {
1079        let derived_key = vec![0xBB; 32];
1080        let session_id = [0xCC; 16];
1081
1082        let token1 = derive_punch_token(&derived_key, &session_id).unwrap();
1083        let token2 = derive_punch_token(&derived_key, &session_id).unwrap();
1084        assert_eq!(token1, token2);
1085
1086        // Different session_id -> different token
1087        let session_id2 = [0xDD; 16];
1088        let token3 = derive_punch_token(&derived_key, &session_id2).unwrap();
1089        assert_ne!(token1, token3);
1090    }
1091
1092    #[test]
1093    fn test_reset() {
1094        let link_id = [0xBB; 16];
1095        let derived_key = test_derived_key();
1096        let mut rng = make_rng(0x42);
1097
1098        let mut engine =
1099            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1100        engine.propose(&derived_key, 100.0, &mut rng).unwrap();
1101        assert_eq!(engine.state(), HolePunchState::Discovering);
1102
1103        engine.reset();
1104        assert_eq!(engine.state(), HolePunchState::Idle);
1105        assert_eq!(engine.session_id(), &[0u8; 16]);
1106    }
1107
1108    #[test]
1109    fn test_build_reject_static() {
1110        let link_id = [0xCC; 16];
1111        let derived_key = test_derived_key();
1112        let mut rng = make_rng(0x42);
1113
1114        let mut proposer =
1115            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1116        proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
1117        let actions = proposer
1118            .endpoints_discovered(test_public_addr_a(), 101.0)
1119            .unwrap();
1120        let request_payload = match &actions[0] {
1121            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1122            _ => panic!(),
1123        };
1124
1125        let action =
1126            HolePunchEngine::build_reject(link_id, &request_payload, REJECT_POLICY).unwrap();
1127        match action {
1128            HolePunchAction::SendSignal { msgtype, .. } => {
1129                assert_eq!(msgtype, UPGRADE_REJECT);
1130            }
1131            _ => panic!("Expected SendSignal(UPGRADE_REJECT)"),
1132        }
1133    }
1134
1135    #[test]
1136    fn test_responder_needs_no_probe_addr() {
1137        // Responder uses facilitator from UPGRADE_REQUEST, doesn't need its own
1138        let link_id = [0xDD; 16];
1139        let derived_key = test_derived_key();
1140        let mut rng = make_rng(0x42);
1141
1142        // Build a request
1143        let mut initiator =
1144            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1145        initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
1146        let actions = initiator
1147            .endpoints_discovered(test_public_addr_a(), 101.0)
1148            .unwrap();
1149        let request_payload = match &actions[0] {
1150            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1151            _ => panic!(),
1152        };
1153
1154        // Responder has NO probe_addr configured
1155        let mut responder = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp);
1156        let actions = responder
1157            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
1158            .unwrap();
1159
1160        // Should still work — uses facilitator from the request
1161        assert_eq!(responder.state(), HolePunchState::Discovering);
1162        assert_eq!(actions.len(), 2);
1163        assert!(
1164            matches!(&actions[0], HolePunchAction::SendSignal { msgtype, .. } if *msgtype == UPGRADE_ACCEPT)
1165        );
1166        assert!(matches!(
1167            &actions[1],
1168            HolePunchAction::DiscoverEndpoints { .. }
1169        ));
1170    }
1171
1172    #[test]
1173    fn test_stun_protocol_in_upgrade_request_roundtrip() {
1174        let session_id = [0xAB; 16];
1175        let facilitator = test_probe_addr();
1176        let init_pub = test_public_addr_a();
1177
1178        // Encode with STUN protocol
1179        let data =
1180            encode_upgrade_request(&session_id, &facilitator, &init_pub, ProbeProtocol::Stun);
1181        let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1182        assert_eq!(sid, session_id);
1183        assert_eq!(f, facilitator);
1184        assert_eq!(a, init_pub);
1185        assert_eq!(proto, ProbeProtocol::Stun);
1186    }
1187
1188    #[test]
1189    fn test_rnsp_protocol_omits_p_field() {
1190        let session_id = [0xAB; 16];
1191        let facilitator = test_probe_addr();
1192        let init_pub = test_public_addr_a();
1193
1194        // Encode with RNSP (default) — should NOT include "p" field
1195        let data =
1196            encode_upgrade_request(&session_id, &facilitator, &init_pub, ProbeProtocol::Rnsp);
1197        let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1198        assert_eq!(sid, session_id);
1199        assert_eq!(f, facilitator);
1200        assert_eq!(a, init_pub);
1201        assert_eq!(proto, ProbeProtocol::Rnsp);
1202    }
1203
1204    #[test]
1205    fn test_backward_compat_decode_without_p_field() {
1206        // Simulate old node payload that has no "p" field
1207        let session_id = [0xAB; 16];
1208        let facilitator = test_probe_addr();
1209        let init_pub = test_public_addr_a();
1210
1211        // Manually encode without "p" field (old format)
1212        let val = Value::Map(alloc::vec![
1213            (
1214                Value::Str(alloc::string::String::from("s")),
1215                Value::Bin(session_id.to_vec())
1216            ),
1217            (
1218                Value::Str(alloc::string::String::from("f")),
1219                encode_endpoint(&facilitator)
1220            ),
1221            (
1222                Value::Str(alloc::string::String::from("a")),
1223                encode_endpoint(&init_pub)
1224            ),
1225        ]);
1226        let data = msgpack::pack(&val);
1227
1228        let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1229        assert_eq!(sid, session_id);
1230        assert_eq!(f, facilitator);
1231        assert_eq!(a, init_pub);
1232        assert_eq!(proto, ProbeProtocol::Rnsp); // Default fallback
1233    }
1234
1235    #[test]
1236    fn test_stun_initiator_responder_gets_stun_protocol() {
1237        let link_id = [0xEE; 16];
1238        let derived_key = test_derived_key();
1239        let mut rng = make_rng(0x42);
1240
1241        // Initiator uses STUN
1242        let mut initiator =
1243            HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Stun);
1244        let actions = initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
1245
1246        // DiscoverEndpoints should carry Stun protocol
1247        match &actions[0] {
1248            HolePunchAction::DiscoverEndpoints { protocol, .. } => {
1249                assert_eq!(*protocol, ProbeProtocol::Stun);
1250            }
1251            _ => panic!("Expected DiscoverEndpoints"),
1252        }
1253
1254        let actions = initiator
1255            .endpoints_discovered(test_public_addr_a(), 101.0)
1256            .unwrap();
1257        let request_payload = match &actions[0] {
1258            HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1259            _ => panic!(),
1260        };
1261
1262        // Responder decodes and gets Stun protocol
1263        let mut responder = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp);
1264        let actions = responder
1265            .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
1266            .unwrap();
1267
1268        // Responder's DiscoverEndpoints should carry Stun protocol (from request)
1269        match &actions[1] {
1270            HolePunchAction::DiscoverEndpoints { protocol, .. } => {
1271                assert_eq!(*protocol, ProbeProtocol::Stun);
1272            }
1273            _ => panic!("Expected DiscoverEndpoints"),
1274        }
1275    }
1276}