use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use iroh::EndpointId;
use iroh::endpoint::{Connection, ConnectionError, VarInt};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::firewall::{self, Direction, SharedFirewall};
use crate::peers::{DeviceUserMap, PeerTable};
use crate::stats::{DropReason, ForwardMetrics};
use crate::tun::{TunReader, TunWriter};
const MAX_PEER_DATAGRAM: usize = 1500;
const TX_POOL_CHUNK: usize = 64 * 1024;
struct SshNat {
active: AtomicBool,
v4: Ipv4Addr,
v6: Ipv6Addr,
listen_port: u16,
}
static SSH_NAT: OnceLock<SshNat> = OnceLock::new();
pub fn init_ssh_nat(v4: Ipv4Addr, v6: Ipv6Addr, listen_port: u16) {
let _ = SSH_NAT.set(SshNat {
active: AtomicBool::new(false),
v4,
v6,
listen_port,
});
}
pub fn set_ssh_nat_active(on: bool) {
if let Some(nat) = SSH_NAT.get() {
nat.active.store(on, Ordering::Relaxed);
}
}
fn ssh_nat() -> Option<&'static SshNat> {
SSH_NAT
.get()
.filter(|n| n.active.load(Ordering::Relaxed))
}
impl SshNat {
fn is_ours(&self, ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v) => v == self.v4,
IpAddr::V6(v) => v == self.v6,
}
}
}
fn csum_replace2(check: u16, old: u16, new: u16) -> u16 {
let mut sum = (!check as u32) + (!old as u32 & 0xffff) + new as u32;
while (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
fn rewrite_ssh_port(pkt: &mut [u8], info: &firewall::PacketInfo, inbound: bool) -> bool {
let Some(nat) = ssh_nat() else { return false };
if info.protocol != 6 {
return false; }
let ihl = match pkt.first().map(|b| b >> 4) {
Some(4) => ((pkt[0] & 0x0f) as usize) * 4,
Some(6) => 40, _ => return false,
};
if pkt.len() < ihl + 18 {
return false;
}
let (port_off, old, new) = if inbound {
if !nat.is_ours(info.dst_ip) || info.dst_port != crate::ssh::SSH_PORT {
return false;
}
(ihl + 2, crate::ssh::SSH_PORT, nat.listen_port)
} else {
if !nat.is_ours(info.src_ip) || info.src_port != nat.listen_port {
return false;
}
(ihl, nat.listen_port, crate::ssh::SSH_PORT)
};
pkt[port_off..port_off + 2].copy_from_slice(&new.to_be_bytes());
let ck_off = ihl + 16;
let old_ck = u16::from_be_bytes([pkt[ck_off], pkt[ck_off + 1]]);
let new_ck = csum_replace2(old_ck, old, new);
pkt[ck_off..ck_off + 2].copy_from_slice(&new_ck.to_be_bytes());
true
}
pub(crate) enum InboundDecision {
Accept,
DropFirewall(firewall::PacketInfo),
DropMalformed,
DropSpoof,
}
pub(crate) fn evaluate_inbound(
packet: &[u8],
firewall: &SharedFirewall,
peer_id: &EndpointId,
peer_ip: Ipv4Addr,
peer_ipv6: std::net::Ipv6Addr,
network: &str,
) -> InboundDecision {
if packet.len() > MAX_PEER_DATAGRAM {
return InboundDecision::DropMalformed;
}
let Some(info) = firewall::parse_packet_info(packet) else {
return InboundDecision::DropMalformed;
};
let src_ok = match info.src_ip {
IpAddr::V4(v4) => v4 == peer_ip,
IpAddr::V6(v6) => v6 == peer_ipv6,
};
if !src_ok {
return InboundDecision::DropSpoof;
}
if firewall
.evaluate_packet(Direction::In, &info, peer_id, Some(network))
.is_deny()
{
return InboundDecision::DropFirewall(info);
}
InboundDecision::Accept
}
pub const LEAVE_CODE: u32 = 0x1ea5e;
pub const ABUSE_CODE: u32 = 0xab05e;
pub struct DisconnectEvent {
pub endpoint_id: EndpointId,
pub ip: Ipv4Addr,
pub ipv6: std::net::Ipv6Addr,
pub network: String,
pub intentional: bool,
}
pub struct ForwardCtx {
pub firewall: SharedFirewall,
pub tun_tx: mpsc::Sender<Bytes>,
pub disconnect_tx: mpsc::Sender<DisconnectEvent>,
pub token: CancellationToken,
pub stats: Arc<ForwardMetrics>,
pub device_user_map: DeviceUserMap,
}
pub(crate) fn is_magic_dns(info: &firewall::PacketInfo) -> bool {
info.dst_port == 53 && info.dst_ip == IpAddr::V4(crate::dns::MAGIC_DNS_V4)
}
#[allow(clippy::too_many_arguments)]
pub async fn run_mesh(
mut tun: TunReader,
peers: PeerTable,
firewall: SharedFirewall,
token: CancellationToken,
stats: Arc<ForwardMetrics>,
resolver: Arc<crate::dns_resolver::Resolver>,
tun_tx: mpsc::Sender<Bytes>,
) -> Result<()> {
let mut pool = BytesMut::with_capacity(TX_POOL_CHUNK);
loop {
if pool.capacity() < MAX_PEER_DATAGRAM {
pool.reserve(TX_POOL_CHUNK);
}
let n = tokio::select! {
_ = token.cancelled() => return Ok(()),
result = tun.read_into(&mut pool) => result?,
};
if n == 0 {
continue;
}
let pkt = pool.split_to(n).freeze();
tracing::debug!(len = n, first_byte = pkt[0], "TUN read");
let Some(info) = firewall::parse_packet_info(&pkt) else {
tracing::debug!(len = n, "not IP, dropping");
continue;
};
if is_magic_dns(&info) {
let resolver = resolver.clone();
let tun_tx = tun_tx.clone();
let pkt = pkt.clone();
tokio::spawn(async move {
resolver.handle_tun_query(&pkt, &info, &tun_tx).await;
});
continue; }
let lookup = match info.dst_ip {
IpAddr::V4(v4) => peers.lookup_v4(&v4),
IpAddr::V6(v6) => peers.lookup_v6(&v6),
};
let Some(route) = lookup else {
tracing::debug!(dst = %info.dst_ip, "no peer for dst");
stats.record_drop(DropReason::NoPeer);
continue;
};
if firewall
.evaluate_packet(
Direction::Out,
&info,
&route.endpoint_id,
Some(&route.network),
)
.is_deny()
{
tracing::debug!(dst = %info.dst_ip, port = info.dst_port, "firewall denied outbound");
stats.record_drop(DropReason::Firewall);
if firewall.reject_enabled()
&& let Some(reply) = crate::reject::build_reject(&pkt, &info)
{
stats.record_reject();
let _ = tun_tx.send(reply).await;
}
continue;
}
tracing::debug!(dst = %info.dst_ip, "routing to peer");
let pkt = if ssh_nat().is_some_and(|n| info.protocol == 6 && info.src_port == n.listen_port)
{
let mut v = pkt.to_vec();
rewrite_ssh_port(&mut v, &info, false);
Bytes::from(v)
} else {
pkt
};
match route.conn.send_datagram(pkt) {
Ok(()) => stats.record_tx(n),
Err(e) => {
tracing::debug!(dst = %info.dst_ip, error = %e, "datagram send failed");
stats.record_drop(DropReason::SendFailure);
}
}
}
}
pub fn spawn_peer_reader(
conn: Connection,
peer_id: EndpointId,
peer_ip: Ipv4Addr,
peer_ipv6: std::net::Ipv6Addr,
network: String,
ctx: ForwardCtx,
) -> tokio::task::JoinHandle<()> {
let ForwardCtx {
firewall,
tun_tx,
disconnect_tx,
token,
stats,
device_user_map,
} = ctx;
use tracing::Instrument as _;
let span = tracing::info_span!("peer", peer = %peer_id.fmt_short(), net = %network);
let reader = async move {
loop {
let datagram = tokio::select! {
_ = token.cancelled() => return,
result = conn.read_datagram() => match result {
Ok(d) => d,
Err(e) => {
let intentional = matches!(
&e,
ConnectionError::ApplicationClosed(ac)
if ac.error_code == VarInt::from_u32(LEAVE_CODE)
);
tracing::warn!(peer = %peer_id.fmt_short(), ip = %peer_ip, error = %e, intentional, "peer connection lost");
let _ = disconnect_tx
.send(DisconnectEvent {
endpoint_id: peer_id,
ip: peer_ip,
ipv6: peer_ipv6,
network: network.clone(),
intentional,
})
.await;
return;
}
},
};
let peer_user = device_user_map.resolve(&peer_id);
match evaluate_inbound(
&datagram, &firewall, &peer_user, peer_ip, peer_ipv6, &network,
) {
InboundDecision::Accept => {
stats.record_rx(datagram.len());
let datagram = match ssh_nat() {
Some(_) => match firewall::parse_packet_info(&datagram) {
Some(info) if info.protocol == 6 && info.dst_port == crate::ssh::SSH_PORT => {
let mut v = datagram.to_vec();
rewrite_ssh_port(&mut v, &info, true);
Bytes::from(v)
}
_ => datagram,
},
None => datagram,
};
if tun_tx.send(datagram).await.is_err() {
return;
}
}
InboundDecision::DropFirewall(info) => {
stats.record_drop(DropReason::Firewall);
if firewall.reject_enabled()
&& let Some(reply) = crate::reject::build_reject(&datagram, &info)
{
stats.record_reject();
let _ = conn.send_datagram(reply);
}
}
InboundDecision::DropMalformed => stats.record_drop(DropReason::Malformed),
InboundDecision::DropSpoof => {
stats.record_drop(DropReason::Spoof);
tracing::debug!(
peer = %peer_id.fmt_short(),
"dropped inbound packet with spoofed source IP"
);
}
}
}
};
tokio::spawn(reader.instrument(span))
}
pub fn spawn_tun_writer(
mut tun: TunWriter,
mut tun_rx: mpsc::Receiver<Bytes>,
active: Arc<std::sync::atomic::AtomicBool>,
) -> tokio::task::JoinHandle<()> {
use std::sync::atomic::Ordering;
tokio::spawn(async move {
while let Some(packet) = tun_rx.recv().await {
if !active.load(Ordering::Relaxed) {
continue;
}
if let Err(e) = tun.write_packet(&packet).await {
tracing::warn!(error = %e, "TUN write failed");
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::firewall::Action;
#[test]
fn test_parse_packet_valid_ipv4() {
let mut packet = vec![0u8; 24];
packet[0] = 0x45;
packet[9] = 6; packet[16] = 100;
packet[17] = 64;
packet[18] = 0;
packet[19] = 3;
let info = firewall::parse_packet_info(&packet).unwrap();
assert_eq!(info.dst_ip, Ipv4Addr::new(100, 64, 0, 3));
assert_eq!(info.protocol, 6);
}
#[test]
fn test_parse_packet_too_short() {
assert!(firewall::parse_packet_info(&[0x45; 10]).is_none());
}
#[test]
fn test_parse_packet_ipv6() {
let mut packet = vec![0u8; 40];
packet[0] = 0x60; packet[6] = 6; packet[24] = 0x02;
packet[25] = 0x01;
let info = firewall::parse_packet_info(&packet).unwrap();
assert!(info.dst_ip.is_ipv6());
}
const TEST_V4: Ipv4Addr = Ipv4Addr::new(100, 64, 0, 5);
const TEST_V6: std::net::Ipv6Addr = std::net::Ipv6Addr::UNSPECIFIED;
fn make_tcp_packet(dst_port: u16) -> Vec<u8> {
let mut p = vec![0u8; 24];
p[0] = 0x45; p[9] = 6; p[12..16].copy_from_slice(&[100, 64, 0, 5]); p[16..20].copy_from_slice(&[100, 64, 0, 3]); p[20] = 0;
p[21] = 80; p[22] = (dst_port >> 8) as u8;
p[23] = dst_port as u8;
p
}
fn inbound_fw(default: Action, rules: Vec<firewall::FirewallRule>) -> SharedFirewall {
SharedFirewall::new(firewall::FirewallConfig {
default_inbound: default,
default_outbound: Action::Allow,
reject: false,
rules,
})
}
#[test]
fn inbound_oversized_datagram_dropped_as_malformed() {
let fw = SharedFirewall::new(firewall::FirewallConfig::default());
let peer = iroh::SecretKey::generate().public();
let huge = vec![0u8; MAX_PEER_DATAGRAM + 1];
assert!(matches!(
evaluate_inbound(&huge, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::DropMalformed
));
}
#[test]
fn inbound_ipv6_evaluated_by_firewall() {
let fw = inbound_fw(Action::Deny, vec![]);
let peer = iroh::SecretKey::generate().public();
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; pkt[6] = 6; assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::DropFirewall(_)
));
}
#[test]
fn inbound_firewall_denied_port() {
let peer = iroh::SecretKey::generate().public();
let fw = inbound_fw(
Action::Allow,
vec![firewall::FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: firewall::Protocol::Tcp,
port: Some(firewall::PortRange { start: 22, end: 22 }),
peer: firewall::PeerFilter::Any,
network: None,
origin: firewall::RuleOrigin::Local,
}],
);
let blocked = make_tcp_packet(22);
let allowed = make_tcp_packet(80);
assert!(matches!(
evaluate_inbound(&blocked, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::DropFirewall(_)
));
assert!(matches!(
evaluate_inbound(&allowed, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::Accept
));
}
#[test]
fn inbound_clean_tcp_denied_by_secure_default() {
let peer = iroh::SecretKey::generate().public();
let fw = SharedFirewall::new(firewall::FirewallConfig::default());
let pkt = make_tcp_packet(443);
assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::DropFirewall(_)
));
}
#[test]
fn inbound_icmp_accepted_by_default() {
let peer = iroh::SecretKey::generate().public();
let fw = SharedFirewall::new(firewall::FirewallConfig::default());
let mut pkt = vec![0u8; 28];
pkt[0] = 0x45; pkt[9] = 1; pkt[12..16].copy_from_slice(&[100, 64, 0, 5]); pkt[16..20].copy_from_slice(&[100, 64, 0, 3]); assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::Accept
));
}
fn tcp_csum_v4(pkt: &[u8]) -> u16 {
let tcp = &pkt[20..];
let mut sum = 0u32;
for off in [12, 14, 16, 18] {
sum += u16::from_be_bytes([pkt[off], pkt[off + 1]]) as u32;
}
sum += 6; sum += tcp.len() as u32;
let mut i = 0;
while i + 1 < tcp.len() {
if i != 16 {
sum += u16::from_be_bytes([tcp[i], tcp[i + 1]]) as u32;
}
i += 2;
}
while (sum >> 16) != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
#[test]
fn ssh_nat_rewrites_port_and_keeps_checksum_valid() {
let v4 = Ipv4Addr::new(100, 88, 0, 1);
init_ssh_nat(v4, Ipv6Addr::LOCALHOST, 41384);
set_ssh_nat_active(true);
let mut pkt = vec![0u8; 40];
pkt[0] = 0x45;
pkt[9] = 6; pkt[12..16].copy_from_slice(&[100, 88, 0, 9]); pkt[16..20].copy_from_slice(&v4.octets()); pkt[20..22].copy_from_slice(&5000u16.to_be_bytes()); pkt[22..24].copy_from_slice(&22u16.to_be_bytes()); pkt[32] = 0x50; let ck = tcp_csum_v4(&pkt);
pkt[36..38].copy_from_slice(&ck.to_be_bytes());
let info = firewall::parse_packet_info(&pkt).unwrap();
assert!(rewrite_ssh_port(&mut pkt, &info, true));
let info2 = firewall::parse_packet_info(&pkt).unwrap();
assert_eq!(info2.dst_port, 41384, "dest port rewritten 22 -> listen");
let field = u16::from_be_bytes([pkt[36], pkt[37]]);
assert_eq!(field, tcp_csum_v4(&pkt), "checksum stays valid after rewrite");
set_ssh_nat_active(false);
let mut pkt2 = pkt.clone();
let info3 = firewall::parse_packet_info(&pkt2).unwrap();
assert!(!rewrite_ssh_port(&mut pkt2, &info3, true));
}
#[test]
fn csum_replace2_round_trips() {
let c = 0x1234u16;
assert_eq!(csum_replace2(csum_replace2(c, 22, 41384), 41384, 22), c);
}
#[test]
fn inbound_spoofed_source_ip_dropped() {
let peer = iroh::SecretKey::generate().public();
let fw = inbound_fw(Action::Allow, vec![]);
let pkt = make_tcp_packet(80); assert!(matches!(
evaluate_inbound(
&pkt,
&fw,
&peer,
Ipv4Addr::new(100, 64, 0, 9),
TEST_V6,
"test-net"
),
InboundDecision::DropSpoof
));
assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, TEST_V4, TEST_V6, "test-net"),
InboundDecision::Accept
));
}
#[test]
fn magic_dns_predicate_matches_only_magic_ip_port_53() {
let mk = |ip: std::net::IpAddr, port: u16| firewall::PacketInfo {
src_ip: "100.64.0.5".parse().unwrap(),
dst_ip: ip,
protocol: 17,
src_port: 50000,
dst_port: port,
tcp_flags: 0,
icmp_type: 0,
icmp_id: 0,
};
assert!(is_magic_dns(&mk(
std::net::IpAddr::V4(crate::dns::MAGIC_DNS_V4),
53
)));
assert!(!is_magic_dns(&mk(
std::net::IpAddr::V4(crate::dns::MAGIC_DNS_V4),
80
)));
assert!(!is_magic_dns(&mk("100.64.0.9".parse().unwrap(), 53)));
}
#[test]
fn inbound_tcp_accepted_when_port_explicitly_opened() {
let peer = iroh::SecretKey::generate().public();
let fw = inbound_fw(
Action::Deny,
vec![firewall::FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: firewall::Protocol::Tcp,
port: Some(firewall::PortRange {
start: 8080,
end: 8080,
}),
peer: firewall::PeerFilter::Any,
network: None,
origin: firewall::RuleOrigin::Local,
}],
);
assert!(matches!(
evaluate_inbound(
&make_tcp_packet(8080),
&fw,
&peer,
TEST_V4,
TEST_V6,
"test-net"
),
InboundDecision::Accept
));
assert!(matches!(
evaluate_inbound(
&make_tcp_packet(9090),
&fw,
&peer,
TEST_V4,
TEST_V6,
"test-net"
),
InboundDecision::DropFirewall(_)
));
}
}