use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use anyhow::Result;
use bytes::{Bytes, BytesMut};
use iroh::EndpointId;
use iroh::endpoint::{Connection, ConnectionError, VarInt};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::firewall::{self, Direction, SharedFirewall};
use crate::peers::{DeviceUserMap, PeerTable};
use crate::stats::{DropReason, ForwardMetrics};
use crate::tun::{TunReader, TunWriter};
const MAX_PEER_DATAGRAM: usize = 1500;
const TX_POOL_CHUNK: usize = 64 * 1024;
pub(crate) enum InboundDecision {
Accept,
DropFirewall,
DropMalformed,
}
pub(crate) fn evaluate_inbound(
packet: &[u8],
firewall: &SharedFirewall,
peer_id: &EndpointId,
network: &str,
) -> InboundDecision {
if packet.len() > MAX_PEER_DATAGRAM {
return InboundDecision::DropMalformed;
}
let Some(info) = firewall::parse_packet_info(packet) else {
return InboundDecision::DropMalformed;
};
if firewall
.evaluate_packet(Direction::In, &info, peer_id, Some(network))
.is_deny()
{
return InboundDecision::DropFirewall;
}
InboundDecision::Accept
}
pub const LEAVE_CODE: u32 = 0x1ea5e;
pub struct DisconnectEvent {
pub endpoint_id: EndpointId,
pub ip: Ipv4Addr,
pub ipv6: std::net::Ipv6Addr,
pub network: String,
pub intentional: bool,
}
pub(crate) fn is_magic_dns(info: &firewall::PacketInfo) -> bool {
info.dst_port == 53 && info.dst_ip == IpAddr::V4(crate::dns::MAGIC_DNS_V4)
}
#[allow(clippy::too_many_arguments)]
pub async fn run_mesh(
mut tun: TunReader,
peers: PeerTable,
firewall: SharedFirewall,
token: CancellationToken,
stats: Arc<ForwardMetrics>,
resolver: Arc<crate::dns_resolver::Resolver>,
tun_tx: mpsc::Sender<Bytes>,
) -> Result<()> {
let mut pool = BytesMut::with_capacity(TX_POOL_CHUNK);
loop {
if pool.capacity() < MAX_PEER_DATAGRAM {
pool.reserve(TX_POOL_CHUNK);
}
let n = tokio::select! {
_ = token.cancelled() => return Ok(()),
result = tun.read_into(&mut pool) => result?,
};
if n == 0 {
continue;
}
let pkt = pool.split_to(n).freeze();
tracing::debug!(len = n, first_byte = pkt[0], "TUN read");
let Some(info) = firewall::parse_packet_info(&pkt) else {
tracing::debug!(len = n, "not IP, dropping");
continue;
};
if is_magic_dns(&info) {
let resolver = resolver.clone();
let tun_tx = tun_tx.clone();
let pkt = pkt.clone();
tokio::spawn(async move {
resolver.handle_tun_query(&pkt, &info, &tun_tx).await;
});
continue; }
let lookup = match info.dst_ip {
IpAddr::V4(v4) => peers.lookup_v4(&v4),
IpAddr::V6(v6) => peers.lookup_v6(&v6),
};
let Some(route) = lookup else {
tracing::debug!(dst = %info.dst_ip, "no peer for dst");
stats.record_drop(DropReason::NoPeer);
continue;
};
if firewall
.evaluate_packet(
Direction::Out,
&info,
&route.endpoint_id,
Some(&route.network),
)
.is_deny()
{
tracing::debug!(dst = %info.dst_ip, port = info.dst_port, "firewall denied outbound");
stats.record_drop(DropReason::Firewall);
continue;
}
tracing::debug!(dst = %info.dst_ip, "routing to peer");
match route.conn.send_datagram(pkt) {
Ok(()) => stats.record_tx(n),
Err(e) => {
tracing::debug!(dst = %info.dst_ip, error = %e, "datagram send failed");
stats.record_drop(DropReason::SendFailure);
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub fn spawn_peer_reader(
conn: Connection,
peer_id: EndpointId,
peer_ip: Ipv4Addr,
peer_ipv6: std::net::Ipv6Addr,
network: String,
firewall: SharedFirewall,
tun_tx: mpsc::Sender<Bytes>,
disconnect_tx: mpsc::Sender<DisconnectEvent>,
token: CancellationToken,
stats: Arc<ForwardMetrics>,
device_user_map: DeviceUserMap,
) -> tokio::task::JoinHandle<()> {
use tracing::Instrument as _;
let span = tracing::info_span!("peer", peer = %peer_id.fmt_short(), net = %network);
let reader = async move {
loop {
let datagram = tokio::select! {
_ = token.cancelled() => return,
result = conn.read_datagram() => match result {
Ok(d) => d,
Err(e) => {
let intentional = matches!(
&e,
ConnectionError::ApplicationClosed(ac)
if ac.error_code == VarInt::from_u32(LEAVE_CODE)
);
tracing::warn!(peer = %peer_id.fmt_short(), ip = %peer_ip, error = %e, intentional, "peer connection lost");
let _ = disconnect_tx
.send(DisconnectEvent {
endpoint_id: peer_id,
ip: peer_ip,
ipv6: peer_ipv6,
network: network.clone(),
intentional,
})
.await;
return;
}
},
};
let peer_user = device_user_map.resolve(&peer_id);
match evaluate_inbound(&datagram, &firewall, &peer_user, &network) {
InboundDecision::Accept => {
stats.record_rx(datagram.len());
if tun_tx.send(datagram).await.is_err() {
return;
}
}
InboundDecision::DropFirewall => stats.record_drop(DropReason::Firewall),
InboundDecision::DropMalformed => stats.record_drop(DropReason::Malformed),
}
}
};
tokio::spawn(reader.instrument(span))
}
pub fn spawn_tun_writer(
mut tun: TunWriter,
mut tun_rx: mpsc::Receiver<Bytes>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
while let Some(packet) = tun_rx.recv().await {
if let Err(e) = tun.write_packet(&packet).await {
tracing::warn!(error = %e, "TUN write failed");
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::firewall::Action;
#[test]
fn test_parse_packet_valid_ipv4() {
let mut packet = vec![0u8; 24];
packet[0] = 0x45;
packet[9] = 6; packet[16] = 100;
packet[17] = 64;
packet[18] = 0;
packet[19] = 3;
let info = firewall::parse_packet_info(&packet).unwrap();
assert_eq!(info.dst_ip, Ipv4Addr::new(100, 64, 0, 3));
assert_eq!(info.protocol, 6);
}
#[test]
fn test_parse_packet_too_short() {
assert!(firewall::parse_packet_info(&[0x45; 10]).is_none());
}
#[test]
fn test_parse_packet_ipv6() {
let mut packet = vec![0u8; 40];
packet[0] = 0x60; packet[6] = 6; packet[24] = 0x02;
packet[25] = 0x01;
let info = firewall::parse_packet_info(&packet).unwrap();
assert!(info.dst_ip.is_ipv6());
}
fn make_tcp_packet(dst_port: u16) -> Vec<u8> {
let mut p = vec![0u8; 24];
p[0] = 0x45; p[9] = 6; p[16..20].copy_from_slice(&[100, 64, 0, 3]); p[20] = 0;
p[21] = 80; p[22] = (dst_port >> 8) as u8;
p[23] = dst_port as u8;
p
}
fn inbound_fw(default: Action, rules: Vec<firewall::FirewallRule>) -> SharedFirewall {
SharedFirewall::new(firewall::FirewallConfig {
default_inbound: default,
default_outbound: Action::Allow,
rules,
})
}
#[test]
fn inbound_oversized_datagram_dropped_as_malformed() {
let fw = SharedFirewall::new(firewall::FirewallConfig::default());
let peer = iroh::SecretKey::generate().public();
let huge = vec![0u8; MAX_PEER_DATAGRAM + 1];
assert!(matches!(
evaluate_inbound(&huge, &fw, &peer, "test-net"),
InboundDecision::DropMalformed
));
}
#[test]
fn inbound_ipv6_evaluated_by_firewall() {
let fw = inbound_fw(Action::Deny, vec![]);
let peer = iroh::SecretKey::generate().public();
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; pkt[6] = 6; assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, "test-net"),
InboundDecision::DropFirewall
));
}
#[test]
fn inbound_firewall_denied_port() {
let peer = iroh::SecretKey::generate().public();
let fw = inbound_fw(
Action::Allow,
vec![firewall::FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: firewall::Protocol::Tcp,
port: Some(firewall::PortRange { start: 22, end: 22 }),
peer: firewall::PeerFilter::Any,
network: None,
origin: firewall::RuleOrigin::Local,
}],
);
let blocked = make_tcp_packet(22);
let allowed = make_tcp_packet(80);
assert!(matches!(
evaluate_inbound(&blocked, &fw, &peer, "test-net"),
InboundDecision::DropFirewall
));
assert!(matches!(
evaluate_inbound(&allowed, &fw, &peer, "test-net"),
InboundDecision::Accept
));
}
#[test]
fn inbound_clean_tcp_denied_by_secure_default() {
let peer = iroh::SecretKey::generate().public();
let fw = SharedFirewall::new(firewall::FirewallConfig::default());
let pkt = make_tcp_packet(443);
assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, "test-net"),
InboundDecision::DropFirewall
));
}
#[test]
fn inbound_icmp_accepted_by_default() {
let peer = iroh::SecretKey::generate().public();
let fw = SharedFirewall::new(firewall::FirewallConfig::default());
let mut pkt = vec![0u8; 28];
pkt[0] = 0x45; pkt[9] = 1; pkt[16..20].copy_from_slice(&[100, 64, 0, 3]); assert!(matches!(
evaluate_inbound(&pkt, &fw, &peer, "test-net"),
InboundDecision::Accept
));
}
#[test]
fn magic_dns_predicate_matches_only_magic_ip_port_53() {
let mk = |ip: std::net::IpAddr, port: u16| firewall::PacketInfo {
src_ip: "100.64.0.5".parse().unwrap(),
dst_ip: ip,
protocol: 17,
src_port: 50000,
dst_port: port,
tcp_flags: 0,
icmp_type: 0,
icmp_id: 0,
};
assert!(is_magic_dns(&mk(
std::net::IpAddr::V4(crate::dns::MAGIC_DNS_V4),
53
)));
assert!(!is_magic_dns(&mk(
std::net::IpAddr::V4(crate::dns::MAGIC_DNS_V4),
80
)));
assert!(!is_magic_dns(&mk("100.64.0.9".parse().unwrap(), 53)));
}
#[test]
fn inbound_tcp_accepted_when_port_explicitly_opened() {
let peer = iroh::SecretKey::generate().public();
let fw = inbound_fw(
Action::Deny,
vec![firewall::FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: firewall::Protocol::Tcp,
port: Some(firewall::PortRange {
start: 8080,
end: 8080,
}),
peer: firewall::PeerFilter::Any,
network: None,
origin: firewall::RuleOrigin::Local,
}],
);
assert!(matches!(
evaluate_inbound(&make_tcp_packet(8080), &fw, &peer, "test-net"),
InboundDecision::Accept
));
assert!(matches!(
evaluate_inbound(&make_tcp_packet(9090), &fw, &peer, "test-net"),
InboundDecision::DropFirewall
));
}
}