1use alloc::vec::Vec;
16
17use rns_crypto::hkdf::hkdf;
18use rns_crypto::Rng;
19
20use crate::msgpack::{self, Value};
21
22use super::types::*;
23
24pub fn derive_punch_token(
28 derived_key: &[u8],
29 session_id: &[u8; 16],
30) -> Result<[u8; 32], HolePunchError> {
31 let result = hkdf(32, derived_key, Some(session_id), Some(b"rns-holepunch-v1"))
32 .map_err(|_| HolePunchError::NoDerivedKey)?;
33 let mut token = [0u8; 32];
34 token.copy_from_slice(&result);
35 Ok(token)
36}
37
38pub struct HolePunchEngine {
40 link_id: [u8; 16],
41 session_id: [u8; 16],
42 state: HolePunchState,
43 is_initiator: bool,
44
45 our_public_endpoint: Option<Endpoint>,
47
48 peer_public_endpoint: Option<Endpoint>,
52
53 facilitator_addr: Option<Endpoint>,
57
58 punch_token: [u8; 32],
60
61 probe_addr: Option<Endpoint>,
63
64 state_entered_at: f64,
66}
67
68impl HolePunchEngine {
69 pub fn new(
71 link_id: [u8; 16],
72 probe_addr: Option<Endpoint>,
73 ) -> Self {
74 HolePunchEngine {
75 link_id,
76 session_id: [0u8; 16],
77 state: HolePunchState::Idle,
78 is_initiator: false,
79 our_public_endpoint: None,
80 peer_public_endpoint: None,
81 facilitator_addr: None,
82 punch_token: [0u8; 32],
83 probe_addr,
84 state_entered_at: 0.0,
85 }
86 }
87
88 pub fn state(&self) -> HolePunchState {
89 self.state
90 }
91
92 pub fn session_id(&self) -> &[u8; 16] {
93 &self.session_id
94 }
95
96 pub fn is_initiator(&self) -> bool {
97 self.is_initiator
98 }
99
100 pub fn punch_token(&self) -> &[u8; 32] {
101 &self.punch_token
102 }
103
104 pub fn peer_public_endpoint(&self) -> Option<&Endpoint> {
106 self.peer_public_endpoint.as_ref()
107 }
108
109 pub fn propose(
116 &mut self,
117 derived_key: &[u8],
118 now: f64,
119 rng: &mut dyn Rng,
120 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
121 if self.state != HolePunchState::Idle {
122 return Err(HolePunchError::InvalidState);
123 }
124
125 let mut session_id = [0u8; 16];
127 rng.fill_bytes(&mut session_id);
128 self.session_id = session_id;
129 self.is_initiator = true;
130
131 self.punch_token = derive_punch_token(derived_key, &session_id)?;
133
134 let probe_addr = self.probe_addr.clone().ok_or(HolePunchError::NoProbeAddr)?;
135 self.facilitator_addr = Some(probe_addr.clone());
136
137 self.state = HolePunchState::Discovering;
139 self.state_entered_at = now;
140
141 Ok(alloc::vec![HolePunchAction::DiscoverEndpoints { probe_addr }])
142 }
143
144 pub fn endpoints_discovered(
149 &mut self,
150 public_endpoint: Endpoint,
151 now: f64,
152 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
153 if self.state != HolePunchState::Discovering {
154 return Err(HolePunchError::InvalidState);
155 }
156
157 self.our_public_endpoint = Some(public_endpoint.clone());
158
159 if self.is_initiator {
160 let facilitator = self.facilitator_addr.clone()
162 .ok_or(HolePunchError::NoProbeAddr)?;
163
164 let payload = encode_upgrade_request(
165 &self.session_id,
166 &facilitator,
167 &public_endpoint,
168 );
169
170 self.state = HolePunchState::Proposing;
171 self.state_entered_at = now;
172
173 Ok(alloc::vec![HolePunchAction::SendSignal {
174 link_id: self.link_id,
175 msgtype: UPGRADE_REQUEST,
176 payload,
177 }])
178 } else {
179 let payload = encode_upgrade_ready(&self.session_id, &public_endpoint);
181
182 let peer_public = self.peer_public_endpoint.clone()
183 .ok_or(HolePunchError::InvalidState)?;
184
185 self.state = HolePunchState::Punching;
186 self.state_entered_at = now;
187
188 Ok(alloc::vec![
189 HolePunchAction::SendSignal {
190 link_id: self.link_id,
191 msgtype: UPGRADE_READY,
192 payload,
193 },
194 HolePunchAction::StartUdpPunch {
195 peer_public,
196 punch_token: self.punch_token,
197 session_id: self.session_id,
198 },
199 ])
200 }
201 }
202
203 pub fn handle_signal(
207 &mut self,
208 msgtype: u16,
209 payload: &[u8],
210 derived_key: Option<&[u8]>,
211 now: f64,
212 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
213 match msgtype {
214 UPGRADE_REQUEST => self.handle_upgrade_request(payload, derived_key, now),
215 UPGRADE_ACCEPT => self.handle_upgrade_accept(payload, now),
216 UPGRADE_REJECT => self.handle_upgrade_reject(payload, now),
217 UPGRADE_READY => self.handle_upgrade_ready(payload, now),
218 UPGRADE_COMPLETE => self.handle_upgrade_complete(payload, now),
219 _ => Err(HolePunchError::InvalidPayload),
220 }
221 }
222
223 pub fn punch_succeeded(
227 &mut self,
228 now: f64,
229 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
230 if self.state != HolePunchState::Punching {
231 return Err(HolePunchError::InvalidState);
232 }
233
234 self.state = HolePunchState::Connected;
235 self.state_entered_at = now;
236
237 Ok(alloc::vec![
238 HolePunchAction::Succeeded {
239 session_id: self.session_id,
240 },
241 ])
242 }
243
244 pub fn punch_failed(
248 &mut self,
249 now: f64,
250 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
251 if self.state != HolePunchState::Punching {
252 return Err(HolePunchError::InvalidState);
253 }
254
255 self.state = HolePunchState::Failed;
256 self.state_entered_at = now;
257
258 Ok(alloc::vec![
259 HolePunchAction::Failed {
260 session_id: self.session_id,
261 reason: FAIL_TIMEOUT,
262 },
263 ])
264 }
265
266 pub fn tick(&mut self, now: f64) -> Vec<HolePunchAction> {
268 let elapsed = now - self.state_entered_at;
269 match self.state {
270 HolePunchState::Discovering if elapsed > DISCOVER_TIMEOUT => {
271 self.state = HolePunchState::Failed;
272 self.state_entered_at = now;
273 alloc::vec![HolePunchAction::Failed {
274 session_id: self.session_id,
275 reason: FAIL_PROBE,
276 }]
277 }
278 HolePunchState::Proposing if elapsed > PROPOSE_TIMEOUT => {
279 self.state = HolePunchState::Failed;
280 self.state_entered_at = now;
281 alloc::vec![HolePunchAction::Failed {
282 session_id: self.session_id,
283 reason: FAIL_TIMEOUT,
284 }]
285 }
286 HolePunchState::WaitingReady if elapsed > READY_TIMEOUT => {
287 self.state = HolePunchState::Failed;
288 self.state_entered_at = now;
289 alloc::vec![HolePunchAction::Failed {
290 session_id: self.session_id,
291 reason: FAIL_TIMEOUT,
292 }]
293 }
294 HolePunchState::Punching if elapsed > PUNCH_TIMEOUT => {
295 self.state = HolePunchState::Failed;
296 self.state_entered_at = now;
297 alloc::vec![HolePunchAction::Failed {
298 session_id: self.session_id,
299 reason: FAIL_TIMEOUT,
300 }]
301 }
302 _ => Vec::new(),
303 }
304 }
305
306 pub fn build_reject(
310 link_id: [u8; 16],
311 request_payload: &[u8],
312 reason: u8,
313 ) -> Result<HolePunchAction, HolePunchError> {
314 let (session_id, _, _) = decode_upgrade_request(request_payload)?;
315 let payload = encode_upgrade_reject(&session_id, reason);
316 Ok(HolePunchAction::SendSignal {
317 link_id,
318 msgtype: UPGRADE_REJECT,
319 payload,
320 })
321 }
322
323 pub fn reset(&mut self) {
325 self.state = HolePunchState::Idle;
326 self.session_id = [0u8; 16];
327 self.is_initiator = false;
328 self.our_public_endpoint = None;
329 self.peer_public_endpoint = None;
330 self.facilitator_addr = None;
331 self.punch_token = [0u8; 32];
332 self.state_entered_at = 0.0;
333 }
334
335 fn handle_upgrade_request(
342 &mut self,
343 payload: &[u8],
344 derived_key: Option<&[u8]>,
345 now: f64,
346 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
347 if self.state != HolePunchState::Idle {
348 let (session_id, _, _) = decode_upgrade_request(payload)?;
350 let reject_payload = encode_upgrade_reject(&session_id, REJECT_BUSY);
351 return Ok(alloc::vec![HolePunchAction::SendSignal {
352 link_id: self.link_id,
353 msgtype: UPGRADE_REJECT,
354 payload: reject_payload,
355 }]);
356 }
357
358 let derived_key = derived_key.ok_or(HolePunchError::NoDerivedKey)?;
359 let (session_id, facilitator, initiator_public) = decode_upgrade_request(payload)?;
360
361 self.session_id = session_id;
362 self.is_initiator = false;
363 self.punch_token = derive_punch_token(derived_key, &session_id)?;
364
365 self.peer_public_endpoint = Some(initiator_public);
367
368 self.facilitator_addr = Some(facilitator.clone());
370
371 self.state = HolePunchState::Discovering;
372 self.state_entered_at = now;
373
374 let accept_payload = encode_upgrade_accept(&session_id);
376
377 Ok(alloc::vec![
378 HolePunchAction::SendSignal {
379 link_id: self.link_id,
380 msgtype: UPGRADE_ACCEPT,
381 payload: accept_payload,
382 },
383 HolePunchAction::DiscoverEndpoints { probe_addr: facilitator },
384 ])
385 }
386
387 fn handle_upgrade_accept(
391 &mut self,
392 payload: &[u8],
393 now: f64,
394 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
395 if self.state != HolePunchState::Proposing || !self.is_initiator {
396 return Err(HolePunchError::InvalidState);
397 }
398
399 let session_id = decode_upgrade_accept(payload)?;
400 if session_id != self.session_id {
401 return Err(HolePunchError::SessionMismatch);
402 }
403
404 self.state = HolePunchState::WaitingReady;
405 self.state_entered_at = now;
406
407 Ok(Vec::new())
408 }
409
410 fn handle_upgrade_reject(
412 &mut self,
413 payload: &[u8],
414 now: f64,
415 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
416 if self.state != HolePunchState::Proposing {
417 return Err(HolePunchError::InvalidState);
418 }
419
420 let (session_id, reason) = decode_upgrade_reject(payload)?;
421 if session_id != self.session_id {
422 return Err(HolePunchError::SessionMismatch);
423 }
424
425 self.state = HolePunchState::Failed;
426 self.state_entered_at = now;
427
428 Ok(alloc::vec![HolePunchAction::Failed {
429 session_id: self.session_id,
430 reason,
431 }])
432 }
433
434 fn handle_upgrade_ready(
439 &mut self,
440 payload: &[u8],
441 now: f64,
442 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
443 if self.state != HolePunchState::WaitingReady || !self.is_initiator {
444 return Err(HolePunchError::InvalidState);
445 }
446
447 let (session_id, responder_public) = decode_upgrade_ready(payload)?;
448 if session_id != self.session_id {
449 return Err(HolePunchError::SessionMismatch);
450 }
451
452 self.peer_public_endpoint = Some(responder_public.clone());
453
454 self.state = HolePunchState::Punching;
455 self.state_entered_at = now;
456
457 Ok(alloc::vec![HolePunchAction::StartUdpPunch {
458 peer_public: responder_public,
459 punch_token: self.punch_token,
460 session_id: self.session_id,
461 }])
462 }
463
464 fn handle_upgrade_complete(
466 &mut self,
467 payload: &[u8],
468 now: f64,
469 ) -> Result<Vec<HolePunchAction>, HolePunchError> {
470 if self.state != HolePunchState::Punching && self.state != HolePunchState::Connected {
471 return Err(HolePunchError::InvalidState);
472 }
473
474 let session_id = decode_session_only(payload)?;
475 if session_id != self.session_id {
476 return Err(HolePunchError::SessionMismatch);
477 }
478
479 if self.state == HolePunchState::Connected {
480 return Ok(Vec::new());
482 }
483
484 self.state = HolePunchState::Connected;
485 self.state_entered_at = now;
486
487 Ok(alloc::vec![HolePunchAction::Succeeded {
488 session_id: self.session_id,
489 }])
490 }
491}
492
493fn encode_upgrade_request(
496 session_id: &[u8; 16],
497 facilitator: &Endpoint,
498 initiator_public: &Endpoint,
499) -> Vec<u8> {
500 let val = Value::Map(alloc::vec![
501 (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
502 (Value::Str(alloc::string::String::from("f")), encode_endpoint(facilitator)),
503 (Value::Str(alloc::string::String::from("a")), encode_endpoint(initiator_public)),
504 ]);
505 msgpack::pack(&val)
506}
507
508fn decode_upgrade_request(data: &[u8]) -> Result<([u8; 16], Endpoint, Endpoint), HolePunchError> {
509 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
510 let session_id = extract_session_id(&val)?;
511 let facilitator = val
512 .map_get("f")
513 .and_then(decode_endpoint)
514 .ok_or(HolePunchError::InvalidPayload)?;
515 let initiator_public = val
516 .map_get("a")
517 .and_then(decode_endpoint)
518 .ok_or(HolePunchError::InvalidPayload)?;
519 Ok((session_id, facilitator, initiator_public))
520}
521
522fn encode_upgrade_accept(session_id: &[u8; 16]) -> Vec<u8> {
523 let val = Value::Map(alloc::vec![
524 (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
525 ]);
526 msgpack::pack(&val)
527}
528
529fn decode_upgrade_accept(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
530 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
531 extract_session_id(&val)
532}
533
534fn encode_upgrade_reject(session_id: &[u8; 16], reason: u8) -> Vec<u8> {
535 let val = Value::Map(alloc::vec![
536 (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
537 (Value::Str(alloc::string::String::from("r")), Value::UInt(reason as u64)),
538 ]);
539 msgpack::pack(&val)
540}
541
542fn decode_upgrade_reject(data: &[u8]) -> Result<([u8; 16], u8), HolePunchError> {
543 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
544 let session_id = extract_session_id(&val)?;
545 let reason = val
546 .map_get("r")
547 .and_then(|v| v.as_uint())
548 .ok_or(HolePunchError::InvalidPayload)? as u8;
549 Ok((session_id, reason))
550}
551
552fn encode_upgrade_ready(session_id: &[u8; 16], responder_public: &Endpoint) -> Vec<u8> {
553 let val = Value::Map(alloc::vec![
554 (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
555 (Value::Str(alloc::string::String::from("a")), encode_endpoint(responder_public)),
556 ]);
557 msgpack::pack(&val)
558}
559
560fn decode_upgrade_ready(data: &[u8]) -> Result<([u8; 16], Endpoint), HolePunchError> {
561 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
562 let session_id = extract_session_id(&val)?;
563 let responder_public = val
564 .map_get("a")
565 .and_then(decode_endpoint)
566 .ok_or(HolePunchError::InvalidPayload)?;
567 Ok((session_id, responder_public))
568}
569
570fn encode_endpoint(ep: &Endpoint) -> Value {
571 Value::Array(alloc::vec![
572 Value::Bin(ep.addr.clone()),
573 Value::UInt(ep.port as u64),
574 ])
575}
576
577fn decode_endpoint(val: &Value) -> Option<Endpoint> {
578 let arr = val.as_array()?;
579 if arr.len() < 2 {
580 return None;
581 }
582 let addr = arr[0].as_bin()?.to_vec();
583 let port = arr[1].as_uint()? as u16;
584 Some(Endpoint { addr, port })
585}
586
587#[cfg(test)]
588fn encode_session_only(session_id: &[u8; 16]) -> Vec<u8> {
589 let val = Value::Map(alloc::vec![
590 (Value::Str(alloc::string::String::from("s")), Value::Bin(session_id.to_vec())),
591 ]);
592 msgpack::pack(&val)
593}
594
595fn decode_session_only(data: &[u8]) -> Result<[u8; 16], HolePunchError> {
596 let (val, _) = msgpack::unpack(data).map_err(|_| HolePunchError::InvalidPayload)?;
597 extract_session_id(&val)
598}
599
600fn extract_session_id(val: &Value) -> Result<[u8; 16], HolePunchError> {
601 let bin = val
602 .map_get("s")
603 .and_then(|v| v.as_bin())
604 .ok_or(HolePunchError::InvalidPayload)?;
605 if bin.len() != 16 {
606 return Err(HolePunchError::InvalidPayload);
607 }
608 let mut id = [0u8; 16];
609 id.copy_from_slice(bin);
610 Ok(id)
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616 use rns_crypto::FixedRng;
617
618 fn make_rng(seed: u8) -> FixedRng {
619 FixedRng::new(&[seed; 128])
620 }
621
622 fn test_derived_key() -> Vec<u8> {
623 vec![0xAA; 32]
624 }
625
626 fn test_probe_addr() -> Endpoint {
627 Endpoint {
628 addr: vec![127, 0, 0, 1],
629 port: 4343,
630 }
631 }
632
633 fn test_public_addr_a() -> Endpoint {
634 Endpoint {
635 addr: vec![1, 2, 3, 4],
636 port: 41000,
637 }
638 }
639
640 fn test_public_addr_b() -> Endpoint {
641 Endpoint {
642 addr: vec![5, 6, 7, 8],
643 port: 52000,
644 }
645 }
646
647 #[test]
648 fn test_propose_initiator_discovers_first() {
649 let link_id = [0x11; 16];
650 let derived_key = test_derived_key();
651 let mut rng = make_rng(0x42);
652
653 let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
654 let actions = initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
655
656 assert_eq!(initiator.state(), HolePunchState::Discovering);
658 assert_eq!(actions.len(), 1);
659 assert!(matches!(&actions[0], HolePunchAction::DiscoverEndpoints { .. }));
660 }
661
662 #[test]
663 fn test_initiator_sends_request_after_discovery() {
664 let link_id = [0x11; 16];
665 let derived_key = test_derived_key();
666 let mut rng = make_rng(0x42);
667
668 let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
669 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
670
671 let actions = initiator
673 .endpoints_discovered(test_public_addr_a(), 101.0)
674 .unwrap();
675
676 assert_eq!(initiator.state(), HolePunchState::Proposing);
678 assert_eq!(actions.len(), 1);
679 match &actions[0] {
680 HolePunchAction::SendSignal { msgtype, payload, .. } => {
681 assert_eq!(*msgtype, UPGRADE_REQUEST);
682 let (sid, facilitator, init_pub) = decode_upgrade_request(payload).unwrap();
684 assert_eq!(sid, *initiator.session_id());
685 assert_eq!(facilitator, test_probe_addr());
686 assert_eq!(init_pub, test_public_addr_a());
687 }
688 _ => panic!("Expected SendSignal(UPGRADE_REQUEST)"),
689 }
690 }
691
692 #[test]
693 fn test_full_asymmetric_flow() {
694 let link_id = [0x22; 16];
695 let derived_key = test_derived_key();
696 let mut rng = make_rng(0x42);
697
698 let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
700 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
701 let actions = initiator
702 .endpoints_discovered(test_public_addr_a(), 101.0)
703 .unwrap();
704
705 let request_payload = match &actions[0] {
706 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
707 _ => panic!(),
708 };
709
710 let mut responder = HolePunchEngine::new(link_id, None); let actions = responder
713 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
714 .unwrap();
715
716 assert_eq!(responder.state(), HolePunchState::Discovering);
717 assert_eq!(actions.len(), 2); let accept_payload = match &actions[0] {
720 HolePunchAction::SendSignal { msgtype, payload, .. } => {
721 assert_eq!(*msgtype, UPGRADE_ACCEPT);
722 payload.clone()
723 }
724 _ => panic!("Expected UPGRADE_ACCEPT"),
725 };
726
727 match &actions[1] {
729 HolePunchAction::DiscoverEndpoints { probe_addr } => {
730 assert_eq!(*probe_addr, test_probe_addr()); }
732 _ => panic!("Expected DiscoverEndpoints"),
733 }
734
735 let actions = initiator
737 .handle_signal(UPGRADE_ACCEPT, &accept_payload, None, 103.0)
738 .unwrap();
739 assert_eq!(initiator.state(), HolePunchState::WaitingReady);
740 assert!(actions.is_empty()); let actions = responder
744 .endpoints_discovered(test_public_addr_b(), 104.0)
745 .unwrap();
746
747 assert_eq!(responder.state(), HolePunchState::Punching);
748 assert_eq!(actions.len(), 2); let ready_payload = match &actions[0] {
751 HolePunchAction::SendSignal { msgtype, payload, .. } => {
752 assert_eq!(*msgtype, UPGRADE_READY);
753 payload.clone()
754 }
755 _ => panic!("Expected UPGRADE_READY"),
756 };
757 assert!(matches!(&actions[1], HolePunchAction::StartUdpPunch { .. }));
758
759 let actions = initiator
761 .handle_signal(UPGRADE_READY, &ready_payload, None, 105.0)
762 .unwrap();
763
764 assert_eq!(initiator.state(), HolePunchState::Punching);
765 assert_eq!(actions.len(), 1);
766 match &actions[0] {
767 HolePunchAction::StartUdpPunch { peer_public, .. } => {
768 assert_eq!(*peer_public, test_public_addr_b());
769 }
770 _ => panic!("Expected StartUdpPunch"),
771 }
772
773 assert_eq!(initiator.punch_token(), responder.punch_token());
775 }
776
777 #[test]
778 fn test_punch_success() {
779 let link_id = [0x33; 16];
780 let derived_key = test_derived_key();
781 let mut rng = make_rng(0x42);
782
783 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
784 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
785 engine.state = HolePunchState::Punching;
786
787 let actions = engine.punch_succeeded(105.0).unwrap();
788 assert_eq!(engine.state(), HolePunchState::Connected);
789 assert_eq!(actions.len(), 1);
790 assert!(matches!(&actions[0], HolePunchAction::Succeeded { .. }));
791 }
792
793 #[test]
794 fn test_punch_failed() {
795 let link_id = [0x44; 16];
796 let derived_key = test_derived_key();
797 let mut rng = make_rng(0x42);
798
799 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
800 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
801 engine.state = HolePunchState::Punching;
802
803 let actions = engine.punch_failed(120.0).unwrap();
804 assert_eq!(engine.state(), HolePunchState::Failed);
805 assert_eq!(actions.len(), 1);
806 assert!(matches!(&actions[0], HolePunchAction::Failed { .. }));
807 }
808
809 #[test]
810 fn test_reject_when_busy() {
811 let link_id = [0x55; 16];
812 let derived_key = test_derived_key();
813 let mut rng = make_rng(0x42);
814
815 let mut proposer = HolePunchEngine::new(link_id, Some(test_probe_addr()));
817 proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
818 let actions = proposer.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
819 let request_payload = match &actions[0] {
820 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
821 _ => panic!(),
822 };
823
824 let mut responder = HolePunchEngine::new(link_id, Some(test_probe_addr()));
826 responder.state = HolePunchState::Discovering;
827
828 let actions = responder
829 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
830 .unwrap();
831
832 assert_eq!(actions.len(), 1);
834 match &actions[0] {
835 HolePunchAction::SendSignal { msgtype, .. } => {
836 assert_eq!(*msgtype, UPGRADE_REJECT);
837 }
838 _ => panic!("Expected UPGRADE_REJECT"),
839 }
840 }
841
842 #[test]
843 fn test_initiator_receives_reject() {
844 let link_id = [0x66; 16];
845 let derived_key = test_derived_key();
846 let mut rng = make_rng(0x42);
847
848 let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
849 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
850 initiator.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
851 assert_eq!(initiator.state(), HolePunchState::Proposing);
852
853 let session_id = *initiator.session_id();
854 let reject_payload = encode_upgrade_reject(&session_id, REJECT_POLICY);
855
856 let actions = initiator
857 .handle_signal(UPGRADE_REJECT, &reject_payload, None, 102.0)
858 .unwrap();
859
860 assert_eq!(initiator.state(), HolePunchState::Failed);
861 assert_eq!(actions.len(), 1);
862 assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == REJECT_POLICY));
863 }
864
865 #[test]
866 fn test_discover_timeout() {
867 let link_id = [0x77; 16];
868 let derived_key = test_derived_key();
869 let mut rng = make_rng(0x42);
870
871 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
872 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
873 assert_eq!(engine.state(), HolePunchState::Discovering);
874
875 let actions = engine.tick(100.0 + DISCOVER_TIMEOUT - 1.0);
877 assert!(actions.is_empty());
878
879 let actions = engine.tick(100.0 + DISCOVER_TIMEOUT + 1.0);
881 assert_eq!(engine.state(), HolePunchState::Failed);
882 assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_PROBE));
883 }
884
885 #[test]
886 fn test_propose_timeout() {
887 let link_id = [0x88; 16];
888 let derived_key = test_derived_key();
889 let mut rng = make_rng(0x42);
890
891 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
892 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
893 engine.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
894 assert_eq!(engine.state(), HolePunchState::Proposing);
895
896 let actions = engine.tick(101.0 + PROPOSE_TIMEOUT + 1.0);
898 assert_eq!(engine.state(), HolePunchState::Failed);
899 assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT));
900 }
901
902 #[test]
903 fn test_waiting_ready_timeout() {
904 let link_id = [0x99; 16];
905 let derived_key = test_derived_key();
906 let mut rng = make_rng(0x42);
907
908 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
909 engine.propose(&derived_key, 200.0, &mut rng).unwrap();
910 engine.endpoints_discovered(test_public_addr_a(), 201.0).unwrap();
911 engine.state = HolePunchState::WaitingReady;
912 engine.state_entered_at = 202.0;
913
914 let actions = engine.tick(202.0 + READY_TIMEOUT + 1.0);
916 assert_eq!(engine.state(), HolePunchState::Failed);
917 assert!(matches!(&actions[0], HolePunchAction::Failed { reason, .. } if *reason == FAIL_TIMEOUT));
918 }
919
920 #[test]
921 fn test_punch_timeout() {
922 let link_id = [0xAA; 16];
923 let derived_key = test_derived_key();
924 let mut rng = make_rng(0x42);
925
926 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
927 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
928 engine.state = HolePunchState::Punching;
929 engine.state_entered_at = 200.0;
930
931 let actions = engine.tick(200.0 + PUNCH_TIMEOUT - 1.0);
933 assert!(actions.is_empty());
934
935 let actions = engine.tick(200.0 + PUNCH_TIMEOUT + 1.0);
937 assert_eq!(engine.state(), HolePunchState::Failed);
938 }
939
940 #[test]
941 fn test_message_serialization_roundtrip() {
942 let session_id = [0xAB; 16];
943
944 let facilitator = test_probe_addr();
946 let init_pub = test_public_addr_a();
947 let data = encode_upgrade_request(&session_id, &facilitator, &init_pub);
948 let (sid, f, a) = decode_upgrade_request(&data).unwrap();
949 assert_eq!(sid, session_id);
950 assert_eq!(f, facilitator);
951 assert_eq!(a, init_pub);
952
953 let data = encode_upgrade_accept(&session_id);
955 let sid = decode_upgrade_accept(&data).unwrap();
956 assert_eq!(sid, session_id);
957
958 let data = encode_upgrade_reject(&session_id, REJECT_POLICY);
960 let (sid, r) = decode_upgrade_reject(&data).unwrap();
961 assert_eq!(sid, session_id);
962 assert_eq!(r, REJECT_POLICY);
963
964 let resp_pub = test_public_addr_b();
966 let data = encode_upgrade_ready(&session_id, &resp_pub);
967 let (sid, rp) = decode_upgrade_ready(&data).unwrap();
968 assert_eq!(sid, session_id);
969 assert_eq!(rp, resp_pub);
970
971 let data = encode_session_only(&session_id);
973 let sid = decode_session_only(&data).unwrap();
974 assert_eq!(sid, session_id);
975 }
976
977 #[test]
978 fn test_punch_token_derivation_consistency() {
979 let derived_key = vec![0xBB; 32];
980 let session_id = [0xCC; 16];
981
982 let token1 = derive_punch_token(&derived_key, &session_id).unwrap();
983 let token2 = derive_punch_token(&derived_key, &session_id).unwrap();
984 assert_eq!(token1, token2);
985
986 let session_id2 = [0xDD; 16];
988 let token3 = derive_punch_token(&derived_key, &session_id2).unwrap();
989 assert_ne!(token1, token3);
990 }
991
992 #[test]
993 fn test_reset() {
994 let link_id = [0xBB; 16];
995 let derived_key = test_derived_key();
996 let mut rng = make_rng(0x42);
997
998 let mut engine = HolePunchEngine::new(link_id, Some(test_probe_addr()));
999 engine.propose(&derived_key, 100.0, &mut rng).unwrap();
1000 assert_eq!(engine.state(), HolePunchState::Discovering);
1001
1002 engine.reset();
1003 assert_eq!(engine.state(), HolePunchState::Idle);
1004 assert_eq!(engine.session_id(), &[0u8; 16]);
1005 }
1006
1007 #[test]
1008 fn test_build_reject_static() {
1009 let link_id = [0xCC; 16];
1010 let derived_key = test_derived_key();
1011 let mut rng = make_rng(0x42);
1012
1013 let mut proposer = HolePunchEngine::new(link_id, Some(test_probe_addr()));
1014 proposer.propose(&derived_key, 100.0, &mut rng).unwrap();
1015 let actions = proposer.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
1016 let request_payload = match &actions[0] {
1017 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1018 _ => panic!(),
1019 };
1020
1021 let action = HolePunchEngine::build_reject(link_id, &request_payload, REJECT_POLICY).unwrap();
1022 match action {
1023 HolePunchAction::SendSignal { msgtype, .. } => {
1024 assert_eq!(msgtype, UPGRADE_REJECT);
1025 }
1026 _ => panic!("Expected SendSignal(UPGRADE_REJECT)"),
1027 }
1028 }
1029
1030 #[test]
1031 fn test_responder_needs_no_probe_addr() {
1032 let link_id = [0xDD; 16];
1034 let derived_key = test_derived_key();
1035 let mut rng = make_rng(0x42);
1036
1037 let mut initiator = HolePunchEngine::new(link_id, Some(test_probe_addr()));
1039 initiator.propose(&derived_key, 100.0, &mut rng).unwrap();
1040 let actions = initiator.endpoints_discovered(test_public_addr_a(), 101.0).unwrap();
1041 let request_payload = match &actions[0] {
1042 HolePunchAction::SendSignal { payload, .. } => payload.clone(),
1043 _ => panic!(),
1044 };
1045
1046 let mut responder = HolePunchEngine::new(link_id, None);
1048 let actions = responder
1049 .handle_signal(UPGRADE_REQUEST, &request_payload, Some(&derived_key), 102.0)
1050 .unwrap();
1051
1052 assert_eq!(responder.state(), HolePunchState::Discovering);
1054 assert_eq!(actions.len(), 2);
1055 assert!(matches!(&actions[0], HolePunchAction::SendSignal { msgtype, .. } if *msgtype == UPGRADE_ACCEPT));
1056 assert!(matches!(&actions[1], HolePunchAction::DiscoverEndpoints { .. }));
1057 }
1058}