1use alloc::vec::Vec;
16
17use rns_crypto::hkdf::hkdf;
18use rns_crypto::Rng;
19
20use crate::msgpack::{self, Value};
21
22use super::types::*;
23
24pub fn derive_punch_token(
28 derived_key: &[u8],
29 session_id: &[u8; 16],
30) -> Result<[u8; 32], HolePunchError> {
31 let result = hkdf(32, derived_key, Some(session_id), Some(b"rns-holepunch-v1"))
32 .map_err(|_| HolePunchError::NoDerivedKey)?;
33 let mut token = [0u8; 32];
34 token.copy_from_slice(&result);
35 Ok(token)
36}
37
38pub struct HolePunchEngine {
40 link_id: [u8; 16],
41 session_id: [u8; 16],
42 state: HolePunchState,
43 is_initiator: bool,
44
45 our_public_endpoint: Option<Endpoint>,
47
48 peer_public_endpoint: Option<Endpoint>,
52
53 facilitator_addr: Option<Endpoint>,
57
58 punch_token: [u8; 32],
60
61 probe_addr: Option<Endpoint>,
63
64 probe_protocol: ProbeProtocol,
66
67 state_entered_at: f64,
69}
70
71impl HolePunchEngine {
72 pub fn new(
74 link_id: [u8; 16],
75 probe_addr: Option<Endpoint>,
76 probe_protocol: ProbeProtocol,
77 ) -> Self {
78 HolePunchEngine {
79 link_id,
80 session_id: [0u8; 16],
81 state: HolePunchState::Idle,
82 is_initiator: false,
83 our_public_endpoint: None,
84 peer_public_endpoint: None,
85 facilitator_addr: None,
86 punch_token: [0u8; 32],
87 probe_addr,
88 probe_protocol,
89 state_entered_at: 0.0,
90 }
91 }
92
93 pub fn state(&self) -> HolePunchState {
94 self.state
95 }
96
97 pub fn session_id(&self) -> &[u8; 16] {
98 &self.session_id
99 }
100
101 pub fn is_initiator(&self) -> bool {
102 self.is_initiator
103 }
104
105 pub fn punch_token(&self) -> &[u8; 32] {
106 &self.punch_token
107 }
108
109 pub fn set_facilitator(&mut self, addr: Endpoint) {
115 self.facilitator_addr = Some(addr);
116 }
117
118 pub fn peer_public_endpoint(&self) -> Option<&Endpoint> {
120 self.peer_public_endpoint.as_ref()
121 }
122
123 pub fn propose(
130 &mut self,
131 derived_key: &[u8],
132 now: f64,
133 rng: &mut dyn Rng,
134 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
135 if self.state != HolePunchState::Idle {
136 return Err(HolePunchError::InvalidState);
137 }
138
139 let mut session_id = [0u8; 16];
141 rng.fill_bytes(&mut session_id);
142 self.session_id = session_id;
143 self.is_initiator = true;
144
145 self.punch_token = derive_punch_token(derived_key, &session_id)?;
147
148 let probe_addr = self.probe_addr.clone().ok_or(HolePunchError::NoProbeAddr)?;
149 self.facilitator_addr = Some(probe_addr.clone());
150
151 self.state = HolePunchState::Discovering;
153 self.state_entered_at = now;
154
155 Ok(alloc::vec![HolePunchAction::DiscoverEndpoints {
156 probe_addr,
157 protocol: self.probe_protocol
158 }])
159 }
160
161 pub fn endpoints_discovered(
166 &mut self,
167 public_endpoint: Endpoint,
168 now: f64,
169 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
170 if self.state != HolePunchState::Discovering {
171 return Err(HolePunchError::InvalidState);
172 }
173
174 self.our_public_endpoint = Some(public_endpoint.clone());
175
176 if self.is_initiator {
177 let facilitator = self
179 .facilitator_addr
180 .clone()
181 .ok_or(HolePunchError::NoProbeAddr)?;
182
183 let payload = encode_upgrade_request(
184 &self.session_id,
185 &facilitator,
186 &public_endpoint,
187 self.probe_protocol,
188 );
189
190 self.state = HolePunchState::Proposing;
191 self.state_entered_at = now;
192
193 Ok(alloc::vec![HolePunchAction::SendSignal {
194 link_id: self.link_id,
195 msgtype: UPGRADE_REQUEST,
196 payload,
197 }])
198 } else {
199 let payload = encode_upgrade_ready(&self.session_id, &public_endpoint);
201
202 let peer_public = self
203 .peer_public_endpoint
204 .clone()
205 .ok_or(HolePunchError::InvalidState)?;
206
207 self.state = HolePunchState::Punching;
208 self.state_entered_at = now;
209
210 Ok(alloc::vec![
211 HolePunchAction::SendSignal {
212 link_id: self.link_id,
213 msgtype: UPGRADE_READY,
214 payload,
215 },
216 HolePunchAction::StartUdpPunch {
217 peer_public,
218 punch_token: self.punch_token,
219 session_id: self.session_id,
220 },
221 ])
222 }
223 }
224
225 pub fn handle_signal(
229 &mut self,
230 msgtype: u16,
231 payload: &[u8],
232 derived_key: Option<&[u8]>,
233 now: f64,
234 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
235 match msgtype {
236 UPGRADE_REQUEST => self.handle_upgrade_request(payload, derived_key, now),
237 UPGRADE_ACCEPT => self.handle_upgrade_accept(payload, now),
238 UPGRADE_REJECT => self.handle_upgrade_reject(payload, now),
239 UPGRADE_READY => self.handle_upgrade_ready(payload, now),
240 UPGRADE_COMPLETE => self.handle_upgrade_complete(payload, now),
241 _ => Err(HolePunchError::InvalidPayload),
242 }
243 }
244
245 pub fn punch_succeeded(&mut self, now: f64) -> Result<Vec<HolePunchAction>, HolePunchError> {
249 if self.state != HolePunchState::Punching {
250 return Err(HolePunchError::InvalidState);
251 }
252
253 self.state = HolePunchState::Connected;
254 self.state_entered_at = now;
255
256 Ok(alloc::vec![HolePunchAction::Succeeded {
257 session_id: self.session_id,
258 },])
259 }
260
261 pub fn punch_failed(&mut self, now: f64) -> Result<Vec<HolePunchAction>, HolePunchError> {
265 if self.state != HolePunchState::Punching {
266 return Err(HolePunchError::InvalidState);
267 }
268
269 self.state = HolePunchState::Failed;
270 self.state_entered_at = now;
271
272 Ok(alloc::vec![HolePunchAction::Failed {
273 session_id: self.session_id,
274 reason: FAIL_TIMEOUT,
275 },])
276 }
277
278 pub fn tick(&mut self, now: f64) -> Vec<HolePunchAction> {
280 let elapsed = now - self.state_entered_at;
281 match self.state {
282 HolePunchState::Discovering if elapsed > DISCOVER_TIMEOUT => {
283 self.state = HolePunchState::Failed;
284 self.state_entered_at = now;
285 alloc::vec![HolePunchAction::Failed {
286 session_id: self.session_id,
287 reason: FAIL_PROBE,
288 }]
289 }
290 HolePunchState::Proposing if elapsed > PROPOSE_TIMEOUT => {
291 self.state = HolePunchState::Failed;
292 self.state_entered_at = now;
293 alloc::vec![HolePunchAction::Failed {
294 session_id: self.session_id,
295 reason: FAIL_TIMEOUT,
296 }]
297 }
298 HolePunchState::WaitingReady if elapsed > READY_TIMEOUT => {
299 self.state = HolePunchState::Failed;
300 self.state_entered_at = now;
301 alloc::vec![HolePunchAction::Failed {
302 session_id: self.session_id,
303 reason: FAIL_TIMEOUT,
304 }]
305 }
306 HolePunchState::Punching if elapsed > PUNCH_TIMEOUT => {
307 self.state = HolePunchState::Failed;
308 self.state_entered_at = now;
309 alloc::vec![HolePunchAction::Failed {
310 session_id: self.session_id,
311 reason: FAIL_TIMEOUT,
312 }]
313 }
314 _ => Vec::new(),
315 }
316 }
317
318 pub fn build_reject(
322 link_id: [u8; 16],
323 request_payload: &[u8],
324 reason: u8,
325 ) -> Result<HolePunchAction, HolePunchError> {
326 let (session_id, _, _, _) = decode_upgrade_request(request_payload)?;
327 let payload = encode_upgrade_reject(&session_id, reason);
328 Ok(HolePunchAction::SendSignal {
329 link_id,
330 msgtype: UPGRADE_REJECT,
331 payload,
332 })
333 }
334
335 pub fn reset(&mut self) {
337 self.state = HolePunchState::Idle;
338 self.session_id = [0u8; 16];
339 self.is_initiator = false;
340 self.our_public_endpoint = None;
341 self.peer_public_endpoint = None;
342 self.facilitator_addr = None;
343 self.punch_token = [0u8; 32];
344 self.probe_protocol = ProbeProtocol::Rnsp;
345 self.state_entered_at = 0.0;
346 }
347
348 fn handle_upgrade_request(
355 &mut self,
356 payload: &[u8],
357 derived_key: Option<&[u8]>,
358 now: f64,
359 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
360 if self.state != HolePunchState::Idle {
361 let (session_id, _, _, _) = decode_upgrade_request(payload)?;
363 let reject_payload = encode_upgrade_reject(&session_id, REJECT_BUSY);
364 return Ok(alloc::vec![HolePunchAction::SendSignal {
365 link_id: self.link_id,
366 msgtype: UPGRADE_REJECT,
367 payload: reject_payload,
368 }]);
369 }
370
371 let derived_key = derived_key.ok_or(HolePunchError::NoDerivedKey)?;
372 let (session_id, facilitator, initiator_public, protocol) =
373 decode_upgrade_request(payload)?;
374
375 self.session_id = session_id;
376 self.is_initiator = false;
377 self.probe_protocol = protocol;
378 self.punch_token = derive_punch_token(derived_key, &session_id)?;
379
380 self.peer_public_endpoint = Some(initiator_public);
382
383 self.facilitator_addr = Some(facilitator.clone());
385
386 self.state = HolePunchState::Discovering;
387 self.state_entered_at = now;
388
389 let accept_payload = encode_upgrade_accept(&session_id);
391
392 Ok(alloc::vec![
393 HolePunchAction::SendSignal {
394 link_id: self.link_id,
395 msgtype: UPGRADE_ACCEPT,
396 payload: accept_payload,
397 },
398 HolePunchAction::DiscoverEndpoints {
399 probe_addr: facilitator,
400 protocol
401 },
402 ])
403 }
404
405 fn handle_upgrade_accept(
409 &mut self,
410 payload: &[u8],
411 now: f64,
412 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
413 if self.state != HolePunchState::Proposing || !self.is_initiator {
414 return Err(HolePunchError::InvalidState);
415 }
416
417 let session_id = decode_upgrade_accept(payload)?;
418 if session_id != self.session_id {
419 return Err(HolePunchError::SessionMismatch);
420 }
421
422 self.state = HolePunchState::WaitingReady;
423 self.state_entered_at = now;
424
425 Ok(Vec::new())
426 }
427
428 fn handle_upgrade_reject(
430 &mut self,
431 payload: &[u8],
432 now: f64,
433 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
434 if self.state != HolePunchState::Proposing {
435 return Err(HolePunchError::InvalidState);
436 }
437
438 let (session_id, reason) = decode_upgrade_reject(payload)?;
439 if session_id != self.session_id {
440 return Err(HolePunchError::SessionMismatch);
441 }
442
443 self.state = HolePunchState::Failed;
444 self.state_entered_at = now;
445
446 Ok(alloc::vec![HolePunchAction::Failed {
447 session_id: self.session_id,
448 reason,
449 }])
450 }
451
452 fn handle_upgrade_ready(
457 &mut self,
458 payload: &[u8],
459 now: f64,
460 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
461 if self.state != HolePunchState::WaitingReady || !self.is_initiator {
462 return Err(HolePunchError::InvalidState);
463 }
464
465 let (session_id, responder_public) = decode_upgrade_ready(payload)?;
466 if session_id != self.session_id {
467 return Err(HolePunchError::SessionMismatch);
468 }
469
470 self.peer_public_endpoint = Some(responder_public.clone());
471
472 self.state = HolePunchState::Punching;
473 self.state_entered_at = now;
474
475 Ok(alloc::vec![HolePunchAction::StartUdpPunch {
476 peer_public: responder_public,
477 punch_token: self.punch_token,
478 session_id: self.session_id,
479 }])
480 }
481
482 fn handle_upgrade_complete(
484 &mut self,
485 payload: &[u8],
486 now: f64,
487 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
488 if self.state != HolePunchState::Punching && self.state != HolePunchState::Connected {
489 return Err(HolePunchError::InvalidState);
490 }
491
492 let session_id = decode_session_only(payload)?;
493 if session_id != self.session_id {
494 return Err(HolePunchError::SessionMismatch);
495 }
496
497 if self.state == HolePunchState::Connected {
498 return Ok(Vec::new());
500 }
501
502 self.state = HolePunchState::Connected;
503 self.state_entered_at = now;
504
505 Ok(alloc::vec![HolePunchAction::Succeeded {
506 session_id: self.session_id,
507 }])
508 }
509}
510
511fn encode_upgrade_request(
514 session_id: &[u8; 16],
515 facilitator: &Endpoint,
516 initiator_public: &Endpoint,
517 protocol: ProbeProtocol,
518) -> Vec<u8> {
519 let mut fields = alloc::vec![
520 (
521 Value::Str(alloc::string::String::from("s")),
522 Value::Bin(session_id.to_vec())
523 ),
524 (
525 Value::Str(alloc::string::String::from("f")),
526 encode_endpoint(facilitator)
527 ),
528 (
529 Value::Str(alloc::string::String::from("a")),
530 encode_endpoint(initiator_public)
531 ),
532 ];
533 if protocol != ProbeProtocol::Rnsp {
535 fields.push((
536 Value::Str(alloc::string::String::from("p")),
537 Value::UInt(protocol as u64),
538 ));
539 }
540 let val = Value::Map(fields);
541 msgpack::pack(&val)
542}
543
544fn decode_upgrade_request(
545 data: &[u8],
546) -> Result<([u8; 16], Endpoint, Endpoint, ProbeProtocol), HolePunchError> {
547 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
548 let session_id = extract_session_id(&val)?;
549 let facilitator = val
550 .map_get("f")
551 .and_then(decode_endpoint)
552 .ok_or(HolePunchError::InvalidPayload)?;
553 let initiator_public = val
554 .map_get("a")
555 .and_then(decode_endpoint)
556 .ok_or(HolePunchError::InvalidPayload)?;
557 let protocol = val
559 .map_get("p")
560 .and_then(|v| v.as_uint())
561 .map(|p| match p {
562 1 => ProbeProtocol::Stun,
563 _ => ProbeProtocol::Rnsp,
564 })
565 .unwrap_or(ProbeProtocol::Rnsp);
566 Ok((session_id, facilitator, initiator_public, protocol))
567}
568
569fn encode_upgrade_accept(session_id: &[u8; 16]) -> Vec<u8> {
570 let val = Value::Map(alloc::vec![(
571 Value::Str(alloc::string::String::from("s")),
572 Value::Bin(session_id.to_vec())
573 ),]);
574 msgpack::pack(&val)
575}
576
577fn decode_upgrade_accept(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
578 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
579 extract_session_id(&val)
580}
581
582fn encode_upgrade_reject(session_id: &[u8; 16], reason: u8) -> Vec<u8> {
583 let val = Value::Map(alloc::vec![
584 (
585 Value::Str(alloc::string::String::from("s")),
586 Value::Bin(session_id.to_vec())
587 ),
588 (
589 Value::Str(alloc::string::String::from("r")),
590 Value::UInt(reason as u64)
591 ),
592 ]);
593 msgpack::pack(&val)
594}
595
596fn decode_upgrade_reject(data: &[u8]) -> Result<([u8; 16], u8), HolePunchError> {
597 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
598 let session_id = extract_session_id(&val)?;
599 let reason = val
600 .map_get("r")
601 .and_then(|v| v.as_uint())
602 .ok_or(HolePunchError::InvalidPayload)? as u8;
603 Ok((session_id, reason))
604}
605
606fn encode_upgrade_ready(session_id: &[u8; 16], responder_public: &Endpoint) -> Vec<u8> {
607 let val = Value::Map(alloc::vec![
608 (
609 Value::Str(alloc::string::String::from("s")),
610 Value::Bin(session_id.to_vec())
611 ),
612 (
613 Value::Str(alloc::string::String::from("a")),
614 encode_endpoint(responder_public)
615 ),
616 ]);
617 msgpack::pack(&val)
618}
619
620fn decode_upgrade_ready(data: &[u8]) -> Result<([u8; 16], Endpoint), HolePunchError> {
621 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
622 let session_id = extract_session_id(&val)?;
623 let responder_public = val
624 .map_get("a")
625 .and_then(decode_endpoint)
626 .ok_or(HolePunchError::InvalidPayload)?;
627 Ok((session_id, responder_public))
628}
629
630fn encode_endpoint(ep: &Endpoint) -> Value {
631 Value::Array(alloc::vec![
632 Value::Bin(ep.addr.clone()),
633 Value::UInt(ep.port as u64),
634 ])
635}
636
637fn decode_endpoint(val: &Value) -> Option<Endpoint> {
638 let arr = val.as_array()?;
639 if arr.len() < 2 {
640 return None;
641 }
642 let addr = arr[0].as_bin()?.to_vec();
643 let port = arr[1].as_uint()? as u16;
644 Some(Endpoint { addr, port })
645}
646
647#[cfg(test)]
648fn encode_session_only(session_id: &[u8; 16]) -> Vec<u8> {
649 let val = Value::Map(alloc::vec![(
650 Value::Str(alloc::string::String::from("s")),
651 Value::Bin(session_id.to_vec())
652 ),]);
653 msgpack::pack(&val)
654}
655
656fn decode_session_only(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
657 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
658 extract_session_id(&val)
659}
660
661fn extract_session_id(val: &Value) -> Result<[u8; 16], HolePunchError> {
662 let bin = val
663 .map_get("s")
664 .and_then(|v| v.as_bin())
665 .ok_or(HolePunchError::InvalidPayload)?;
666 if bin.len() != 16 {
667 return Err(HolePunchError::InvalidPayload);
668 }
669 let mut id = [0u8; 16];
670 id.copy_from_slice(bin);
671 Ok(id)
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677 use rns_crypto::FixedRng;
678
679 fn make_rng(seed: u8) -> FixedRng {
680 FixedRng::new(&[seed; 128])
681 }
682
683 fn test_derived_key() -> Vec<u8> {
684 vec![0xAA; 32]
685 }
686
687 fn test_probe_addr() -> Endpoint {
688 Endpoint {
689 addr: vec![127, 0, 0, 1],
690 port: 4343,
691 }
692 }
693
694 fn test_public_addr_a() -> Endpoint {
695 Endpoint {
696 addr: vec![1, 2, 3, 4],
697 port: 41000,
698 }
699 }
700
701 fn test_public_addr_b() -> Endpoint {
702 Endpoint {
703 addr: vec![5, 6, 7, 8],
704 port: 52000,
705 }
706 }
707
708 #[test]
709 fn test_propose_initiator_discovers_first() {
710 let link_id = [0x11; 16];
711 let derived_key = test_derived_key();
712 let mut rng = make_rng(0x42);
713
714 let mut initiator =
715 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
716 let actions = initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
717
718 assert_eq!(initiator.state(), HolePunchState::Discovering);
720 assert_eq!(actions.len(), 1);
721 assert!(matches!(
722 &actions[0],
723 HolePunchAction::DiscoverEndpoints { .. }
724 ));
725 }
726
727 #[test]
728 fn test_initiator_sends_request_after_discovery() {
729 let link_id = [0x11; 16];
730 let derived_key = test_derived_key();
731 let mut rng = make_rng(0x42);
732
733 let mut initiator =
734 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
735 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
736
737 let actions = initiator
739 .endpoints_discovered(test_public_addr_a(), 101.0)
740 .unwrap();
741
742 assert_eq!(initiator.state(), HolePunchState::Proposing);
744 assert_eq!(actions.len(), 1);
745 match &actions[0] {
746 HolePunchAction::SendSignal {
747 msgtype, payload, ..
748 } => {
749 assert_eq!(*msgtype, UPGRADE_REQUEST);
750 let (sid, facilitator, init_pub, _proto) = decode_upgrade_request(payload).unwrap();
752 assert_eq!(sid, *initiator.session_id());
753 assert_eq!(facilitator, test_probe_addr());
754 assert_eq!(init_pub, test_public_addr_a());
755 }
756 _ => panic!("Expected SendSignal(UPGRADE_REQUEST)"),
757 }
758 }
759
760 #[test]
761 fn test_full_asymmetric_flow() {
762 let link_id = [0x22; 16];
763 let derived_key = test_derived_key();
764 let mut rng = make_rng(0x42);
765
766 let mut initiator =
768 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
769 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
770 let actions = initiator
771 .endpoints_discovered(test_public_addr_a(), 101.0)
772 .unwrap();
773
774 let request_payload = match &actions[0] {
775 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
776 _ => panic!(),
777 };
778
779 let mut responder = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp); let actions = responder
782 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
783 .unwrap();
784
785 assert_eq!(responder.state(), HolePunchState::Discovering);
786 assert_eq!(actions.len(), 2); let accept_payload = match &actions[0] {
789 HolePunchAction::SendSignal {
790 msgtype, payload, ..
791 } => {
792 assert_eq!(*msgtype, UPGRADE_ACCEPT);
793 payload.clone()
794 }
795 _ => panic!("Expected UPGRADE_ACCEPT"),
796 };
797
798 match &actions[1] {
800 HolePunchAction::DiscoverEndpoints { probe_addr, .. } => {
801 assert_eq!(*probe_addr, test_probe_addr()); }
803 _ => panic!("Expected DiscoverEndpoints"),
804 }
805
806 let actions = initiator
808 .handle_signal(UPGRADE_ACCEPT, &accept_payload, None, 103.0)
809 .unwrap();
810 assert_eq!(initiator.state(), HolePunchState::WaitingReady);
811 assert!(actions.is_empty()); let actions = responder
815 .endpoints_discovered(test_public_addr_b(), 104.0)
816 .unwrap();
817
818 assert_eq!(responder.state(), HolePunchState::Punching);
819 assert_eq!(actions.len(), 2); let ready_payload = match &actions[0] {
822 HolePunchAction::SendSignal {
823 msgtype, payload, ..
824 } => {
825 assert_eq!(*msgtype, UPGRADE_READY);
826 payload.clone()
827 }
828 _ => panic!("Expected UPGRADE_READY"),
829 };
830 assert!(matches!(&actions[1], HolePunchAction::StartUdpPunch { .. }));
831
832 let actions = initiator
834 .handle_signal(UPGRADE_READY, &ready_payload, None, 105.0)
835 .unwrap();
836
837 assert_eq!(initiator.state(), HolePunchState::Punching);
838 assert_eq!(actions.len(), 1);
839 match &actions[0] {
840 HolePunchAction::StartUdpPunch { peer_public, .. } => {
841 assert_eq!(*peer_public, test_public_addr_b());
842 }
843 _ => panic!("Expected StartUdpPunch"),
844 }
845
846 assert_eq!(initiator.punch_token(), responder.punch_token());
848 }
849
850 #[test]
851 fn test_punch_success() {
852 let link_id = [0x33; 16];
853 let derived_key = test_derived_key();
854 let mut rng = make_rng(0x42);
855
856 let mut engine =
857 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
858 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
859 engine.state = HolePunchState::Punching;
860
861 let actions = engine.punch_succeeded(105.0).unwrap();
862 assert_eq!(engine.state(), HolePunchState::Connected);
863 assert_eq!(actions.len(), 1);
864 assert!(matches!(&actions[0], HolePunchAction::Succeeded { .. }));
865 }
866
867 #[test]
868 fn test_punch_failed() {
869 let link_id = [0x44; 16];
870 let derived_key = test_derived_key();
871 let mut rng = make_rng(0x42);
872
873 let mut engine =
874 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
875 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
876 engine.state = HolePunchState::Punching;
877
878 let actions = engine.punch_failed(120.0).unwrap();
879 assert_eq!(engine.state(), HolePunchState::Failed);
880 assert_eq!(actions.len(), 1);
881 assert!(matches!(&actions[0], HolePunchAction::Failed { .. }));
882 }
883
884 #[test]
885 fn test_reject_when_busy() {
886 let link_id = [0x55; 16];
887 let derived_key = test_derived_key();
888 let mut rng = make_rng(0x42);
889
890 let mut proposer =
892 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
893 proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
894 let actions = proposer
895 .endpoints_discovered(test_public_addr_a(), 101.0)
896 .unwrap();
897 let request_payload = match &actions[0] {
898 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
899 _ => panic!(),
900 };
901
902 let mut responder =
904 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
905 responder.state = HolePunchState::Discovering;
906
907 let actions = responder
908 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
909 .unwrap();
910
911 assert_eq!(actions.len(), 1);
913 match &actions[0] {
914 HolePunchAction::SendSignal { msgtype, .. } => {
915 assert_eq!(*msgtype, UPGRADE_REJECT);
916 }
917 _ => panic!("Expected UPGRADE_REJECT"),
918 }
919 }
920
921 #[test]
922 fn test_initiator_receives_reject() {
923 let link_id = [0x66; 16];
924 let derived_key = test_derived_key();
925 let mut rng = make_rng(0x42);
926
927 let mut initiator =
928 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
929 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
930 initiator
931 .endpoints_discovered(test_public_addr_a(), 101.0)
932 .unwrap();
933 assert_eq!(initiator.state(), HolePunchState::Proposing);
934
935 let session_id = *initiator.session_id();
936 let reject_payload = encode_upgrade_reject(&session_id, REJECT_POLICY);
937
938 let actions = initiator
939 .handle_signal(UPGRADE_REJECT, &reject_payload, None, 102.0)
940 .unwrap();
941
942 assert_eq!(initiator.state(), HolePunchState::Failed);
943 assert_eq!(actions.len(), 1);
944 assert!(
945 matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == REJECT_POLICY)
946 );
947 }
948
949 #[test]
950 fn test_discover_timeout() {
951 let link_id = [0x77; 16];
952 let derived_key = test_derived_key();
953 let mut rng = make_rng(0x42);
954
955 let mut engine =
956 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
957 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
958 assert_eq!(engine.state(), HolePunchState::Discovering);
959
960 let actions = engine.tick(100.0 + DISCOVER_TIMEOUT - 1.0);
962 assert!(actions.is_empty());
963
964 let actions = engine.tick(100.0 + DISCOVER_TIMEOUT + 1.0);
966 assert_eq!(engine.state(), HolePunchState::Failed);
967 assert!(
968 matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_PROBE)
969 );
970 }
971
972 #[test]
973 fn test_propose_timeout() {
974 let link_id = [0x88; 16];
975 let derived_key = test_derived_key();
976 let mut rng = make_rng(0x42);
977
978 let mut engine =
979 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
980 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
981 engine
982 .endpoints_discovered(test_public_addr_a(), 101.0)
983 .unwrap();
984 assert_eq!(engine.state(), HolePunchState::Proposing);
985
986 let actions = engine.tick(101.0 + PROPOSE_TIMEOUT + 1.0);
988 assert_eq!(engine.state(), HolePunchState::Failed);
989 assert!(
990 matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT)
991 );
992 }
993
994 #[test]
995 fn test_waiting_ready_timeout() {
996 let link_id = [0x99; 16];
997 let derived_key = test_derived_key();
998 let mut rng = make_rng(0x42);
999
1000 let mut engine =
1001 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1002 engine.propose(&derived_key, 200.0, &mut rng).unwrap();
1003 engine
1004 .endpoints_discovered(test_public_addr_a(), 201.0)
1005 .unwrap();
1006 engine.state = HolePunchState::WaitingReady;
1007 engine.state_entered_at = 202.0;
1008
1009 let actions = engine.tick(202.0 + READY_TIMEOUT + 1.0);
1011 assert_eq!(engine.state(), HolePunchState::Failed);
1012 assert!(
1013 matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT)
1014 );
1015 }
1016
1017 #[test]
1018 fn test_punch_timeout() {
1019 let link_id = [0xAA; 16];
1020 let derived_key = test_derived_key();
1021 let mut rng = make_rng(0x42);
1022
1023 let mut engine =
1024 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1025 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
1026 engine.state = HolePunchState::Punching;
1027 engine.state_entered_at = 200.0;
1028
1029 let actions = engine.tick(200.0 + PUNCH_TIMEOUT - 1.0);
1031 assert!(actions.is_empty());
1032
1033 let _actions = engine.tick(200.0 + PUNCH_TIMEOUT + 1.0);
1035 assert_eq!(engine.state(), HolePunchState::Failed);
1036 }
1037
1038 #[test]
1039 fn test_message_serialization_roundtrip() {
1040 let session_id = [0xAB; 16];
1041
1042 let facilitator = test_probe_addr();
1044 let init_pub = test_public_addr_a();
1045 let data =
1046 encode_upgrade_request(&session_id, &facilitator, &init_pub, ProbeProtocol::Rnsp);
1047 let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1048 assert_eq!(sid, session_id);
1049 assert_eq!(f, facilitator);
1050 assert_eq!(a, init_pub);
1051 assert_eq!(proto, ProbeProtocol::Rnsp);
1052
1053 let data = encode_upgrade_accept(&session_id);
1055 let sid = decode_upgrade_accept(&data).unwrap();
1056 assert_eq!(sid, session_id);
1057
1058 let data = encode_upgrade_reject(&session_id, REJECT_POLICY);
1060 let (sid, r) = decode_upgrade_reject(&data).unwrap();
1061 assert_eq!(sid, session_id);
1062 assert_eq!(r, REJECT_POLICY);
1063
1064 let resp_pub = test_public_addr_b();
1066 let data = encode_upgrade_ready(&session_id, &resp_pub);
1067 let (sid, rp) = decode_upgrade_ready(&data).unwrap();
1068 assert_eq!(sid, session_id);
1069 assert_eq!(rp, resp_pub);
1070
1071 let data = encode_session_only(&session_id);
1073 let sid = decode_session_only(&data).unwrap();
1074 assert_eq!(sid, session_id);
1075 }
1076
1077 #[test]
1078 fn test_punch_token_derivation_consistency() {
1079 let derived_key = vec![0xBB; 32];
1080 let session_id = [0xCC; 16];
1081
1082 let token1 = derive_punch_token(&derived_key, &session_id).unwrap();
1083 let token2 = derive_punch_token(&derived_key, &session_id).unwrap();
1084 assert_eq!(token1, token2);
1085
1086 let session_id2 = [0xDD; 16];
1088 let token3 = derive_punch_token(&derived_key, &session_id2).unwrap();
1089 assert_ne!(token1, token3);
1090 }
1091
1092 #[test]
1093 fn test_reset() {
1094 let link_id = [0xBB; 16];
1095 let derived_key = test_derived_key();
1096 let mut rng = make_rng(0x42);
1097
1098 let mut engine =
1099 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1100 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
1101 assert_eq!(engine.state(), HolePunchState::Discovering);
1102
1103 engine.reset();
1104 assert_eq!(engine.state(), HolePunchState::Idle);
1105 assert_eq!(engine.session_id(), &[0u8; 16]);
1106 }
1107
1108 #[test]
1109 fn test_build_reject_static() {
1110 let link_id = [0xCC; 16];
1111 let derived_key = test_derived_key();
1112 let mut rng = make_rng(0x42);
1113
1114 let mut proposer =
1115 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1116 proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
1117 let actions = proposer
1118 .endpoints_discovered(test_public_addr_a(), 101.0)
1119 .unwrap();
1120 let request_payload = match &actions[0] {
1121 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1122 _ => panic!(),
1123 };
1124
1125 let action =
1126 HolePunchEngine::build_reject(link_id, &request_payload, REJECT_POLICY).unwrap();
1127 match action {
1128 HolePunchAction::SendSignal { msgtype, .. } => {
1129 assert_eq!(msgtype, UPGRADE_REJECT);
1130 }
1131 _ => panic!("Expected SendSignal(UPGRADE_REJECT)"),
1132 }
1133 }
1134
1135 #[test]
1136 fn test_responder_needs_no_probe_addr() {
1137 let link_id = [0xDD; 16];
1139 let derived_key = test_derived_key();
1140 let mut rng = make_rng(0x42);
1141
1142 let mut initiator =
1144 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Rnsp);
1145 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
1146 let actions = initiator
1147 .endpoints_discovered(test_public_addr_a(), 101.0)
1148 .unwrap();
1149 let request_payload = match &actions[0] {
1150 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1151 _ => panic!(),
1152 };
1153
1154 let mut responder = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp);
1156 let actions = responder
1157 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
1158 .unwrap();
1159
1160 assert_eq!(responder.state(), HolePunchState::Discovering);
1162 assert_eq!(actions.len(), 2);
1163 assert!(
1164 matches!(&actions[0], HolePunchAction::SendSignal { msgtype, .. } if *msgtype == UPGRADE_ACCEPT)
1165 );
1166 assert!(matches!(
1167 &actions[1],
1168 HolePunchAction::DiscoverEndpoints { .. }
1169 ));
1170 }
1171
1172 #[test]
1173 fn test_stun_protocol_in_upgrade_request_roundtrip() {
1174 let session_id = [0xAB; 16];
1175 let facilitator = test_probe_addr();
1176 let init_pub = test_public_addr_a();
1177
1178 let data =
1180 encode_upgrade_request(&session_id, &facilitator, &init_pub, ProbeProtocol::Stun);
1181 let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1182 assert_eq!(sid, session_id);
1183 assert_eq!(f, facilitator);
1184 assert_eq!(a, init_pub);
1185 assert_eq!(proto, ProbeProtocol::Stun);
1186 }
1187
1188 #[test]
1189 fn test_rnsp_protocol_omits_p_field() {
1190 let session_id = [0xAB; 16];
1191 let facilitator = test_probe_addr();
1192 let init_pub = test_public_addr_a();
1193
1194 let data =
1196 encode_upgrade_request(&session_id, &facilitator, &init_pub, ProbeProtocol::Rnsp);
1197 let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1198 assert_eq!(sid, session_id);
1199 assert_eq!(f, facilitator);
1200 assert_eq!(a, init_pub);
1201 assert_eq!(proto, ProbeProtocol::Rnsp);
1202 }
1203
1204 #[test]
1205 fn test_backward_compat_decode_without_p_field() {
1206 let session_id = [0xAB; 16];
1208 let facilitator = test_probe_addr();
1209 let init_pub = test_public_addr_a();
1210
1211 let val = Value::Map(alloc::vec![
1213 (
1214 Value::Str(alloc::string::String::from("s")),
1215 Value::Bin(session_id.to_vec())
1216 ),
1217 (
1218 Value::Str(alloc::string::String::from("f")),
1219 encode_endpoint(&facilitator)
1220 ),
1221 (
1222 Value::Str(alloc::string::String::from("a")),
1223 encode_endpoint(&init_pub)
1224 ),
1225 ]);
1226 let data = msgpack::pack(&val);
1227
1228 let (sid, f, a, proto) = decode_upgrade_request(&data).unwrap();
1229 assert_eq!(sid, session_id);
1230 assert_eq!(f, facilitator);
1231 assert_eq!(a, init_pub);
1232 assert_eq!(proto, ProbeProtocol::Rnsp); }
1234
1235 #[test]
1236 fn test_stun_initiator_responder_gets_stun_protocol() {
1237 let link_id = [0xEE; 16];
1238 let derived_key = test_derived_key();
1239 let mut rng = make_rng(0x42);
1240
1241 let mut initiator =
1243 HolePunchEngine::new(link_id, Some(test_probe_addr()), ProbeProtocol::Stun);
1244 let actions = initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
1245
1246 match &actions[0] {
1248 HolePunchAction::DiscoverEndpoints { protocol, .. } => {
1249 assert_eq!(*protocol, ProbeProtocol::Stun);
1250 }
1251 _ => panic!("Expected DiscoverEndpoints"),
1252 }
1253
1254 let actions = initiator
1255 .endpoints_discovered(test_public_addr_a(), 101.0)
1256 .unwrap();
1257 let request_payload = match &actions[0] {
1258 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1259 _ => panic!(),
1260 };
1261
1262 let mut responder = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp);
1264 let actions = responder
1265 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
1266 .unwrap();
1267
1268 match &actions[1] {
1270 HolePunchAction::DiscoverEndpoints { protocol, .. } => {
1271 assert_eq!(*protocol, ProbeProtocol::Stun);
1272 }
1273 _ => panic!("Expected DiscoverEndpoints"),
1274 }
1275 }
1276}