use std::collections::HashMap;
use std::net::{SocketAddr, UdpSocket};
use std::thread;
use rns_core::holepunch::types::is_holepunch_msgtype;
use rns_core::holepunch::{
Endpoint, HolePunchAction, HolePunchEngine, HolePunchState, ProbeProtocol, REJECT_POLICY,
UPGRADE_REQUEST,
};
use crate::event::{Event, EventSender};
use crate::time;
use super::probe;
use super::puncher::{self, PunchHandle};
use super::udp_direct;
pub use crate::common::event::HolePunchPolicy;
struct HolePunchSession {
engine: HolePunchEngine,
socket: Option<UdpSocket>,
punch_handle: Option<PunchHandle>,
last_proposal: f64,
}
pub struct HolePunchManager {
sessions: HashMap<[u8; 16], HolePunchSession>, link_to_session: HashMap<[u8; 16], [u8; 16]>, policy: HolePunchPolicy,
probe_addrs: Vec<SocketAddr>,
probe_protocol: ProbeProtocol,
device: Option<String>,
next_interface_id: u64,
}
pub enum HolePunchManagerAction {
SendChannelMessage {
link_id: [u8; 16],
msgtype: u16,
payload: Vec<u8>,
},
DirectConnectEstablished {
link_id: [u8; 16],
session_id: [u8; 16],
interface_id: rns_core::transport::types::InterfaceId,
rtt: f64,
mtu: u32,
},
DirectConnectFailed {
link_id: [u8; 16],
session_id: [u8; 16],
reason: u8,
},
}
impl HolePunchManager {
pub fn new(
probe_addrs: Vec<SocketAddr>,
probe_protocol: ProbeProtocol,
device: Option<String>,
) -> Self {
HolePunchManager {
sessions: HashMap::new(),
link_to_session: HashMap::new(),
policy: HolePunchPolicy::default(),
probe_addrs,
probe_protocol,
device,
next_interface_id: 50000, }
}
pub fn set_policy(&mut self, policy: HolePunchPolicy) {
self.policy = policy;
}
pub fn policy(&self) -> HolePunchPolicy {
self.policy
}
pub fn propose(
&mut self,
link_id: [u8; 16],
derived_key: &[u8],
rng: &mut dyn rns_crypto::Rng,
tx: &EventSender,
) -> Vec<HolePunchManagerAction> {
let now = time::now();
if let Some(session_id) = self.link_to_session.get(&link_id) {
if let Some(session) = self.sessions.get(session_id) {
let elapsed = now - session.last_proposal;
if elapsed < rns_core::holepunch::types::PROPOSAL_COOLDOWN {
log::debug!(
"Hole punch proposal rate limited for link {:02x?}",
&link_id[..4]
);
return Vec::new();
}
}
}
if let Some(old_session_id) = self.link_to_session.remove(&link_id) {
self.sessions.remove(&old_session_id);
}
let probe_endpoint = self.probe_addrs.first().map(|addr| Endpoint {
addr: match addr {
SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
},
port: addr.port(),
});
let mut engine = HolePunchEngine::new(link_id, probe_endpoint, self.probe_protocol);
let actions = match engine.propose(derived_key, now, rng) {
Ok(a) => a,
Err(e) => {
log::warn!("Failed to propose hole punch: {}", e);
return Vec::new();
}
};
let session_id = *engine.session_id();
self.link_to_session.insert(link_id, session_id);
self.sessions.insert(
session_id,
HolePunchSession {
engine,
socket: None,
punch_handle: None,
last_proposal: now,
},
);
let mgr_actions = convert_engine_actions(link_id, &actions);
self.start_endpoint_discovery_from_actions(link_id, &actions, tx);
mgr_actions
}
pub fn handle_signal(
&mut self,
link_id: [u8; 16],
msgtype: u16,
payload: Vec<u8>,
derived_key: Option<&[u8]>,
tx: &EventSender,
) -> (bool, Vec<HolePunchManagerAction>) {
if !is_holepunch_msgtype(msgtype) {
return (false, Vec::new());
}
if msgtype == UPGRADE_REQUEST {
match self.policy {
HolePunchPolicy::Reject => {
log::debug!("Rejecting hole punch proposal (policy=Reject)");
match HolePunchEngine::build_reject(link_id, &payload, REJECT_POLICY) {
Ok(action) => {
let mgr_actions = convert_engine_actions(link_id, &[action]);
return (true, mgr_actions);
}
Err(e) => {
log::warn!("Failed to build reject for proposal: {}", e);
return (true, Vec::new());
}
}
}
HolePunchPolicy::AcceptAll => {
}
HolePunchPolicy::AskApp => {
}
}
let mut engine = HolePunchEngine::new(link_id, None, ProbeProtocol::Rnsp);
let now = time::now();
let actions = match engine.handle_signal(msgtype, &payload, derived_key, now) {
Ok(a) => a,
Err(e) => {
log::warn!("Error handling UPGRADE_REQUEST: {}", e);
return (true, Vec::new());
}
};
let session_id = *engine.session_id();
self.link_to_session.insert(link_id, session_id);
self.sessions.insert(
session_id,
HolePunchSession {
engine,
socket: None,
punch_handle: None,
last_proposal: now,
},
);
let mgr_actions = convert_engine_actions(link_id, &actions);
self.start_endpoint_discovery_from_actions(link_id, &actions, tx);
return (true, mgr_actions);
}
let session_id = match self.link_to_session.get(&link_id) {
Some(s) => *s,
None => {
log::debug!("No hole punch session for link {:02x?}", &link_id[..4]);
return (true, Vec::new());
}
};
let session = match self.sessions.get_mut(&session_id) {
Some(s) => s,
None => return (true, Vec::new()),
};
let now = time::now();
let actions = match session
.engine
.handle_signal(msgtype, &payload, derived_key, now)
{
Ok(a) => a,
Err(e) => {
log::warn!("Error handling signal 0x{:04x}: {}", msgtype, e);
return (true, Vec::new());
}
};
let mgr_actions = convert_engine_actions(link_id, &actions);
self.start_punch_from_actions(&session_id, &actions);
(true, mgr_actions)
}
pub fn tick(&mut self, tx: &EventSender) -> Vec<HolePunchManagerAction> {
let now = time::now();
let mut all_actions = Vec::new();
let session_link_pairs: Vec<([u8; 16], Option<[u8; 16]>)> = self
.sessions
.keys()
.map(|sid| {
let link_id = self
.link_to_session
.iter()
.find(|(_, v)| *v == sid)
.map(|(k, _)| *k);
(*sid, link_id)
})
.collect();
struct PunchCompletion {
session_id: [u8; 16],
link_id: [u8; 16],
succeeded: bool,
socket: Option<std::net::UdpSocket>,
peer_addr: Option<SocketAddr>,
rtt: Option<f64>,
}
let mut completions: Vec<PunchCompletion> = Vec::new();
for (session_id, link_id) in &session_link_pairs {
if let Some(session) = self.sessions.get_mut(session_id) {
let punch_done = session
.punch_handle
.as_ref()
.map(|h| !h.is_running())
.unwrap_or(false);
if punch_done {
let succeeded = session
.punch_handle
.as_ref()
.map(|h| h.succeeded())
.unwrap_or(false);
if let Some(link_id) = link_id {
if succeeded {
if let Some(handle) = session.punch_handle.take() {
if let Some(result) = handle.join() {
let rtt_secs = result.rtt.as_secs_f64();
completions.push(PunchCompletion {
session_id: *session_id,
link_id: *link_id,
succeeded: true,
socket: Some(result.socket),
peer_addr: Some(result.peer_addr),
rtt: Some(rtt_secs),
});
}
}
} else {
session.punch_handle.take();
completions.push(PunchCompletion {
session_id: *session_id,
link_id: *link_id,
succeeded: false,
socket: None,
peer_addr: None,
rtt: None,
});
}
}
}
}
}
for completion in completions {
if completion.succeeded {
if let (Some(socket), Some(peer_addr)) = (completion.socket, completion.peer_addr) {
let interface_id =
rns_core::transport::types::InterfaceId(self.next_interface_id);
self.next_interface_id += 1;
let mut iface_ok = false;
if let Some(session) = self.sessions.get(&completion.session_id) {
let session_id = *session.engine.session_id();
let punch_token = *session.engine.punch_token();
match udp_direct::start_direct_interface(
socket,
peer_addr,
interface_id,
session_id,
punch_token,
tx.clone(),
) {
Ok((writer, info)) => {
log::info!("Direct UDP interface registered: {}", info.name);
let _ = tx.send(Event::InterfaceUp(
interface_id,
Some(writer),
Some(info),
));
iface_ok = true;
}
Err(e) => {
log::warn!("Failed to start direct interface: {}", e);
}
}
}
if let Some(session) = self.sessions.get_mut(&completion.session_id) {
let engine_actions = match session.engine.punch_succeeded(now) {
Ok(a) => a,
Err(_) => Vec::new(),
};
for action in engine_actions {
match action {
HolePunchAction::Succeeded { session_id } if iface_ok => {
all_actions.push(
HolePunchManagerAction::DirectConnectEstablished {
link_id: completion.link_id,
session_id,
interface_id,
rtt: completion.rtt.unwrap_or(0.0),
mtu: 1400,
},
);
}
HolePunchAction::Succeeded { session_id } => {
all_actions.push(HolePunchManagerAction::DirectConnectFailed {
link_id: completion.link_id,
session_id,
reason: rns_core::holepunch::types::FAIL_TIMEOUT,
});
}
_ => {
let mgr = convert_engine_actions(completion.link_id, &[action]);
all_actions.extend(mgr);
}
}
}
}
}
} else {
if let Some(session) = self.sessions.get_mut(&completion.session_id) {
let engine_actions = match session.engine.punch_failed(now) {
Ok(a) => a,
Err(_) => Vec::new(),
};
let mgr = convert_engine_actions(completion.link_id, engine_actions.as_slice());
all_actions.extend(mgr);
}
}
}
for (session_id, link_id) in &session_link_pairs {
if let Some(link_id) = link_id {
if let Some(session) = self.sessions.get_mut(session_id) {
let timeout_actions = session.engine.tick(now);
if !timeout_actions.is_empty() {
let mgr = convert_engine_actions(*link_id, timeout_actions.as_slice());
all_actions.extend(mgr);
}
}
}
}
let failed_sessions: Vec<[u8; 16]> = self
.sessions
.iter()
.filter(|(_, s)| s.engine.state() == HolePunchState::Failed)
.map(|(id, _)| *id)
.collect();
for session_id in failed_sessions {
self.sessions.remove(&session_id);
self.link_to_session.retain(|_, v| *v != session_id);
}
all_actions
}
pub fn link_closed(&mut self, link_id: &[u8; 16]) {
if let Some(session_id) = self.link_to_session.remove(link_id) {
if let Some(mut session) = self.sessions.remove(&session_id) {
if let Some(handle) = session.punch_handle.take() {
handle.cancel();
}
}
}
}
pub fn abort_all_sessions(&mut self) {
for (_, mut session) in self.sessions.drain() {
if let Some(handle) = session.punch_handle.take() {
handle.cancel();
}
}
self.link_to_session.clear();
}
pub fn is_holepunch_message(msgtype: u16) -> bool {
is_holepunch_msgtype(msgtype)
}
fn start_endpoint_discovery_from_actions(
&self,
link_id: [u8; 16],
actions: &[HolePunchAction],
tx: &EventSender,
) {
for action in actions {
if let HolePunchAction::DiscoverEndpoints {
probe_addr,
protocol,
} = action
{
let session_id = match self.link_to_session.get(&link_id) {
Some(s) => *s,
None => continue,
};
let session = match self.sessions.get(&session_id) {
Some(s) => s,
None => continue,
};
let servers: Vec<SocketAddr> = if session.engine.is_initiator() {
self.probe_addrs.clone()
} else {
match endpoint_to_socket_addr(probe_addr) {
Some(a) => vec![a],
None => {
log::warn!("Invalid probe endpoint: {:?}", probe_addr);
continue;
}
}
};
if servers.is_empty() {
log::warn!(
"No probe servers available for session {:02x?}",
&session_id[..4]
);
continue;
}
let tx_clone = tx.clone();
let session_id_copy = session_id;
let device_clone = self.device.clone();
let protocol_copy = *protocol;
if let Err(e) =
thread::Builder::new()
.name("probe-worker".into())
.spawn(move || {
run_probe_worker(
servers,
protocol_copy,
session_id_copy,
link_id,
tx_clone,
device_clone,
);
})
{
log::warn!("Failed to spawn probe worker: {}", e);
}
}
}
}
fn start_punch_from_actions(&mut self, session_id: &[u8; 16], actions: &[HolePunchAction]) {
for action in actions {
if let HolePunchAction::StartUdpPunch {
peer_public,
punch_token,
session_id: sid,
} = action
{
self.start_punch_for_session(session_id, peer_public, punch_token, sid);
}
}
}
fn start_punch_for_session(
&mut self,
session_id: &[u8; 16],
peer_public: &Endpoint,
punch_token: &[u8; 32],
_engine_session_id: &[u8; 16],
) {
let session = match self.sessions.get_mut(session_id) {
Some(s) => s,
None => return,
};
if session.punch_handle.is_some() {
return; }
let socket = match session.socket.take() {
Some(s) => s,
None => {
log::warn!("No socket available for punching");
return;
}
};
let peer_addr = match endpoint_to_socket_addr(peer_public) {
Some(a) => a,
None => {
log::warn!("Invalid peer endpoint for punch: {:?}", peer_public);
session.socket = Some(socket);
return;
}
};
let punch_token_copy = *punch_token;
let session_id_copy = *session.engine.session_id();
log::info!(
"Starting UDP hole punch for session {:02x?} to {}",
&session_id_copy[..4],
peer_addr
);
match puncher::start_udp_punch(
socket,
vec![peer_addr],
vec![],
session_id_copy,
punch_token_copy,
) {
Ok(handle) => {
session.punch_handle = Some(handle);
}
Err(e) => {
log::warn!("Failed to start UDP punch: {}", e);
}
}
}
pub fn handle_probe_result(
&mut self,
link_id: [u8; 16],
session_id: [u8; 16],
observed_addr: SocketAddr,
socket: UdpSocket,
probe_server: SocketAddr,
) -> Vec<HolePunchManagerAction> {
let session = match self.sessions.get_mut(&session_id) {
Some(s) => s,
None => return Vec::new(),
};
session.socket = Some(socket);
if session.engine.is_initiator() {
let first = self.probe_addrs.first().copied();
if first.map_or(true, |f| f != probe_server) {
let facilitator_ep = Endpoint {
addr: match probe_server {
SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
},
port: probe_server.port(),
};
session.engine.set_facilitator(facilitator_ep);
}
}
if let Some(idx) = self.probe_addrs.iter().position(|a| *a == probe_server) {
if idx > 0 {
self.probe_addrs[..=idx].rotate_right(1);
}
}
let public_endpoint = Endpoint {
addr: match observed_addr {
SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
},
port: observed_addr.port(),
};
let now = time::now();
let actions = match session.engine.endpoints_discovered(public_endpoint, now) {
Ok(a) => a,
Err(e) => {
log::warn!("Error in endpoints_discovered: {}", e);
return Vec::new();
}
};
let mgr_actions = convert_engine_actions(link_id, &actions);
let should_punch = actions
.iter()
.any(|a| matches!(a, HolePunchAction::StartUdpPunch { .. }));
if should_punch {
for action in &actions {
if let HolePunchAction::StartUdpPunch {
peer_public,
punch_token,
session_id: sid,
} = action
{
self.start_punch_for_session(&session_id, peer_public, punch_token, sid);
}
}
}
mgr_actions
}
pub fn handle_probe_failed(
&mut self,
link_id: [u8; 16],
session_id: [u8; 16],
) -> Vec<HolePunchManagerAction> {
if let Some(session) = self.sessions.get_mut(&session_id) {
session.engine.reset();
vec![HolePunchManagerAction::DirectConnectFailed {
link_id,
session_id,
reason: rns_core::holepunch::types::FAIL_PROBE,
}]
} else {
Vec::new()
}
}
}
fn run_probe_worker(
servers: Vec<SocketAddr>,
protocol: ProbeProtocol,
session_id: [u8; 16],
link_id: [u8; 16],
tx: EventSender,
device: Option<String>,
) {
match probe::probe_endpoint_failover(
&servers,
protocol,
std::time::Duration::from_secs(3),
device.as_deref(),
) {
Ok((observed, socket, probe_server)) => {
log::info!(
"Probe discovered endpoint: {} via server {} for session {:02x?}",
observed,
probe_server,
&session_id[..4]
);
let _ = tx.send(Event::HolePunchProbeResult {
link_id,
session_id,
observed_addr: observed,
socket,
probe_server,
});
}
Err(e) => {
log::warn!(
"Probe failed for session {:02x?} (tried {} servers): {}",
&session_id[..4],
servers.len(),
e
);
let _ = tx.send(Event::HolePunchProbeFailed {
link_id,
session_id,
});
}
}
}
fn convert_engine_actions(
link_id: [u8; 16],
actions: &[HolePunchAction],
) -> Vec<HolePunchManagerAction> {
let mut mgr_actions = Vec::new();
for action in actions {
match action {
HolePunchAction::SendSignal {
link_id,
msgtype,
payload,
} => {
mgr_actions.push(HolePunchManagerAction::SendChannelMessage {
link_id: *link_id,
msgtype: *msgtype,
payload: payload.clone(),
});
}
HolePunchAction::DiscoverEndpoints { .. } => {
}
HolePunchAction::StartUdpPunch { .. } => {
}
HolePunchAction::Succeeded { .. } => {
}
HolePunchAction::Failed { session_id, reason } => {
mgr_actions.push(HolePunchManagerAction::DirectConnectFailed {
link_id,
session_id: *session_id,
reason: *reason,
});
}
}
}
mgr_actions
}
pub fn endpoint_to_socket_addr(ep: &Endpoint) -> Option<SocketAddr> {
match ep.addr.len() {
4 => {
let ip = std::net::Ipv4Addr::new(ep.addr[0], ep.addr[1], ep.addr[2], ep.addr[3]);
Some(SocketAddr::new(ip.into(), ep.port))
}
16 => {
let mut octets = [0u8; 16];
octets.copy_from_slice(&ep.addr);
let ip = std::net::Ipv6Addr::from(octets);
Some(SocketAddr::new(ip.into(), ep.port))
}
_ => None,
}
}
impl HolePunchManager {
pub fn session_count(&self) -> usize {
self.sessions.len()
}
}
#[cfg(test)]
impl HolePunchManager {
fn has_session_for_link(&self, link_id: &[u8; 16]) -> bool {
self.link_to_session.contains_key(link_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rns_core::holepunch::{UPGRADE_ACCEPT, UPGRADE_READY, UPGRADE_REJECT};
use rns_crypto::FixedRng;
fn make_rng(seed: u8) -> FixedRng {
FixedRng::new(&[seed; 128])
}
fn test_derived_key() -> Vec<u8> {
vec![0xAA; 32]
}
fn make_probe_addr() -> SocketAddr {
"127.0.0.1:4343".parse().unwrap()
}
fn make_tx() -> (EventSender, crate::event::EventReceiver) {
crate::event::channel()
}
#[test]
fn test_propose_creates_session_in_discovering() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let link_id = [0x11; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
let actions = mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
assert_eq!(mgr.session_count(), 1);
assert!(mgr.has_session_for_link(&link_id));
assert!(actions.is_empty());
}
#[test]
fn test_propose_rate_limited() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let link_id = [0x22; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
let actions = mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
assert!(actions.is_empty());
}
#[test]
fn test_reject_policy_sends_reject() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
mgr.set_policy(HolePunchPolicy::Reject);
let link_id = [0x33; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
let mut proposer = HolePunchEngine::new(
link_id,
Some(Endpoint {
addr: vec![127, 0, 0, 1],
port: 4343,
}),
ProbeProtocol::Rnsp,
);
proposer
.propose(&test_derived_key(), 100.0, &mut rng)
.unwrap();
let discover_actions = proposer
.endpoints_discovered(
Endpoint {
addr: vec![1, 2, 3, 4],
port: 41000,
},
101.0,
)
.unwrap();
let request_payload = match &discover_actions[0] {
HolePunchAction::SendSignal { payload, .. } => payload.clone(),
_ => panic!("Expected SendSignal"),
};
let (handled, actions) = mgr.handle_signal(
link_id,
UPGRADE_REQUEST,
request_payload,
Some(&test_derived_key()),
&tx,
);
assert!(handled);
assert_eq!(actions.len(), 1);
assert!(
matches!(&actions[0], HolePunchManagerAction::SendChannelMessage { msgtype, .. } if *msgtype == UPGRADE_REJECT)
);
assert_eq!(mgr.session_count(), 0);
}
#[test]
fn test_accept_policy_creates_session() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
mgr.set_policy(HolePunchPolicy::AcceptAll);
let link_id = [0x44; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
let mut proposer = HolePunchEngine::new(
link_id,
Some(Endpoint {
addr: vec![127, 0, 0, 1],
port: 4343,
}),
ProbeProtocol::Rnsp,
);
proposer
.propose(&test_derived_key(), 100.0, &mut rng)
.unwrap();
let discover_actions = proposer
.endpoints_discovered(
Endpoint {
addr: vec![1, 2, 3, 4],
port: 41000,
},
101.0,
)
.unwrap();
let request_payload = match &discover_actions[0] {
HolePunchAction::SendSignal { payload, .. } => payload.clone(),
_ => panic!("Expected SendSignal"),
};
let (handled, actions) = mgr.handle_signal(
link_id,
UPGRADE_REQUEST,
request_payload,
Some(&test_derived_key()),
&tx,
);
assert!(handled);
assert!(actions.iter().any(|a| matches!(a, HolePunchManagerAction::SendChannelMessage { msgtype, .. } if *msgtype == UPGRADE_ACCEPT)));
assert_eq!(mgr.session_count(), 1);
assert!(mgr.has_session_for_link(&link_id));
}
#[test]
fn test_non_holepunch_message_not_handled() {
let mut mgr = HolePunchManager::new(vec![], ProbeProtocol::Rnsp, None);
let (tx, _rx) = make_tx();
let (handled, actions) = mgr.handle_signal([0x55; 16], 0x0001, vec![1, 2, 3], None, &tx);
assert!(!handled);
assert!(actions.is_empty());
}
#[test]
fn test_link_closed_cleans_up() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let link_id = [0x66; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
assert_eq!(mgr.session_count(), 1);
mgr.link_closed(&link_id);
assert_eq!(mgr.session_count(), 0);
assert!(!mgr.has_session_for_link(&link_id));
}
#[test]
fn test_abort_all_sessions_cleans_up_all_links() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let mut rng = make_rng(0x55);
let (tx, _rx) = make_tx();
let link_a = [0x11; 16];
let link_b = [0x22; 16];
mgr.propose(link_a, &test_derived_key(), &mut rng, &tx);
mgr.propose(link_b, &[0xBB; 32], &mut rng, &tx);
assert!(mgr.session_count() >= 1);
mgr.abort_all_sessions();
assert_eq!(mgr.session_count(), 0);
assert!(!mgr.has_session_for_link(&link_a));
assert!(!mgr.has_session_for_link(&link_b));
}
#[test]
fn test_handle_probe_failed_with_session() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let link_id = [0x77; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
let session_id = *mgr.link_to_session.get(&link_id).unwrap();
let actions = mgr.handle_probe_failed(link_id, session_id);
assert_eq!(actions.len(), 1);
assert!(
matches!(&actions[0], HolePunchManagerAction::DirectConnectFailed { reason, .. }
if *reason == rns_core::holepunch::types::FAIL_PROBE)
);
}
#[test]
fn test_handle_probe_failed_without_session() {
let mut mgr = HolePunchManager::new(vec![], ProbeProtocol::Rnsp, None);
let actions = mgr.handle_probe_failed([0x88; 16], [0x99; 16]);
assert!(actions.is_empty());
}
#[test]
fn test_handle_probe_result_initiator_sends_request() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let link_id = [0xAA; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
let session_id = *mgr.link_to_session.get(&link_id).unwrap();
let probe_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let observed: SocketAddr = "1.2.3.4:41000".parse().unwrap();
let actions = mgr.handle_probe_result(
link_id,
session_id,
observed,
probe_socket,
make_probe_addr(),
);
assert!(actions.iter().any(|a| matches!(a,
HolePunchManagerAction::SendChannelMessage { msgtype, .. }
if *msgtype == UPGRADE_REQUEST
)));
}
#[test]
fn test_handle_probe_result_responder_sends_ready() {
let mut mgr = HolePunchManager::new(vec![make_probe_addr()], ProbeProtocol::Rnsp, None);
let link_id = [0xBB; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
let mut proposer = HolePunchEngine::new(
link_id,
Some(Endpoint {
addr: vec![127, 0, 0, 1],
port: 4343,
}),
ProbeProtocol::Rnsp,
);
proposer
.propose(&test_derived_key(), 100.0, &mut rng)
.unwrap();
let discover_actions = proposer
.endpoints_discovered(
Endpoint {
addr: vec![1, 2, 3, 4],
port: 41000,
},
101.0,
)
.unwrap();
let request_payload = match &discover_actions[0] {
HolePunchAction::SendSignal { payload, .. } => payload.clone(),
_ => panic!("Expected SendSignal"),
};
mgr.handle_signal(
link_id,
UPGRADE_REQUEST,
request_payload,
Some(&test_derived_key()),
&tx,
);
let session_id = *mgr.link_to_session.get(&link_id).unwrap();
let probe_socket = UdpSocket::bind("127.0.0.1:0").unwrap();
let observed: SocketAddr = "5.6.7.8:52000".parse().unwrap();
let actions = mgr.handle_probe_result(
link_id,
session_id,
observed,
probe_socket,
make_probe_addr(),
);
assert!(actions.iter().any(|a| matches!(a,
HolePunchManagerAction::SendChannelMessage { msgtype, .. }
if *msgtype == UPGRADE_READY
)));
}
#[test]
fn test_endpoint_to_socket_addr_ipv4() {
let ep = Endpoint {
addr: vec![10, 0, 0, 1],
port: 8080,
};
let addr = endpoint_to_socket_addr(&ep).unwrap();
assert_eq!(addr, "10.0.0.1:8080".parse::<SocketAddr>().unwrap());
}
#[test]
fn test_endpoint_to_socket_addr_ipv6() {
let ep = Endpoint {
addr: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
port: 9090,
};
let addr = endpoint_to_socket_addr(&ep).unwrap();
assert_eq!(addr, "[::1]:9090".parse::<SocketAddr>().unwrap());
}
#[test]
fn test_endpoint_to_socket_addr_invalid() {
let ep = Endpoint {
addr: vec![1, 2, 3],
port: 80,
};
assert!(endpoint_to_socket_addr(&ep).is_none());
}
#[test]
fn test_policy_default_is_accept_all() {
let mgr = HolePunchManager::new(vec![], ProbeProtocol::Rnsp, None);
assert_eq!(mgr.policy(), HolePunchPolicy::AcceptAll);
}
#[test]
fn test_set_policy() {
let mut mgr = HolePunchManager::new(vec![], ProbeProtocol::Rnsp, None);
mgr.set_policy(HolePunchPolicy::Reject);
assert_eq!(mgr.policy(), HolePunchPolicy::Reject);
}
#[test]
fn test_is_holepunch_message() {
assert!(HolePunchManager::is_holepunch_message(0xFE00));
assert!(HolePunchManager::is_holepunch_message(0xFE04));
assert!(!HolePunchManager::is_holepunch_message(0x0000));
assert!(!HolePunchManager::is_holepunch_message(0xFE05));
}
#[test]
fn test_tick_empty_is_noop() {
let mut mgr = HolePunchManager::new(vec![], ProbeProtocol::Rnsp, None);
let (tx, _rx) = make_tx();
let actions = mgr.tick(&tx);
assert!(actions.is_empty());
}
#[test]
fn test_propose_without_probe_addr() {
let mut mgr = HolePunchManager::new(vec![], ProbeProtocol::Rnsp, None); let link_id = [0xCC; 16];
let mut rng = make_rng(0x42);
let (tx, _rx) = make_tx();
let actions = mgr.propose(link_id, &test_derived_key(), &mut rng, &tx);
assert!(actions.is_empty());
assert_eq!(mgr.session_count(), 0);
}
}