use std::collections::{HashMap, HashSet};
use std::io;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::error::SandboxError;
use crate::seccomp::ctx::SupervisorCtx;
use crate::seccomp::notif::{read_child_mem, write_child_mem, NotifAction};
use crate::sys::structs::{SeccompNotif, AF_INET, AF_INET6, ECONNREFUSED};
const MAX_SEND_BUF: usize = 64 << 20;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Protocol {
Tcp,
Udp,
Icmp,
}
impl Protocol {
fn parse(s: &str) -> Option<Self> {
match s {
"tcp" => Some(Protocol::Tcp),
"udp" => Some(Protocol::Udp),
"icmp" => Some(Protocol::Icmp),
_ => None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub struct NetAllow {
#[serde(default = "default_protocol_tcp")]
pub protocol: Protocol,
pub host: Option<String>,
pub ports: Vec<u16>,
#[serde(default)]
pub all_ports: bool,
}
fn default_protocol_tcp() -> Protocol {
Protocol::Tcp
}
impl NetAllow {
pub fn parse(s: &str) -> Result<Self, SandboxError> {
let (protocol, rest) = match s.split_once("://") {
Some((scheme, body)) => {
let proto = Protocol::parse(scheme).ok_or_else(|| {
SandboxError::Invalid(format!(
"--net-allow: unknown scheme `{}://` in `{}` (expected tcp, udp, icmp)",
scheme, s
))
})?;
(proto, body)
}
None => (Protocol::Tcp, s),
};
if protocol == Protocol::Icmp {
return Self::parse_icmp(rest, s);
}
let (host_part, port_part) = rest.rsplit_once(':').ok_or_else(|| {
SandboxError::Invalid(format!(
"--net-allow: expected `host:port` or `:port`, got `{}`",
s
))
})?;
let host = match host_part {
"" | "*" => None,
h => Some(h.to_string()),
};
let mut ports = Vec::new();
let mut saw_wildcard = false;
for p in port_part.split(',') {
let p = p.trim();
if p == "*" {
saw_wildcard = true;
continue;
}
let n: u16 = p.parse().map_err(|_| {
SandboxError::Invalid(format!("--net-allow: invalid port `{}` in `{}`", p, s))
})?;
if n == 0 {
return Err(SandboxError::Invalid(format!(
"--net-allow: port 0 is not valid in `{}`",
s
)));
}
ports.push(n);
}
if saw_wildcard && !ports.is_empty() {
return Err(SandboxError::Invalid(format!(
"--net-allow: cannot mix `*` with concrete ports in `{}`",
s
)));
}
if !saw_wildcard && ports.is_empty() {
return Err(SandboxError::Invalid(format!(
"--net-allow: at least one port required in `{}`",
s
)));
}
Ok(NetAllow {
protocol,
host,
ports,
all_ports: saw_wildcard,
})
}
fn parse_icmp(body: &str, full: &str) -> Result<Self, SandboxError> {
if body.contains(':') {
return Err(SandboxError::Invalid(format!(
"--net-allow: icmp rules take no port, got `{}`",
full
)));
}
if body.is_empty() {
return Err(SandboxError::Invalid(format!(
"--net-allow: icmp rule needs a host or `*`, got `{}`",
full
)));
}
let host = match body {
"*" => None,
h => Some(h.to_string()),
};
Ok(NetAllow {
protocol: Protocol::Icmp,
host,
ports: Vec::new(),
all_ports: false,
})
}
}
fn parse_ip_from_sockaddr(bytes: &[u8]) -> Option<IpAddr> {
if bytes.len() < 2 {
return None;
}
let family = u16::from_ne_bytes([bytes[0], bytes[1]]) as u32;
match family {
f if f == AF_INET => {
if bytes.len() < 8 {
return None;
}
Some(IpAddr::V4(Ipv4Addr::new(
bytes[4], bytes[5], bytes[6], bytes[7],
)))
}
f if f == AF_INET6 => {
if bytes.len() < 24 {
return None;
}
let mut addr_bytes = [0u8; 16];
addr_bytes.copy_from_slice(&bytes[8..24]);
Some(IpAddr::V6(Ipv6Addr::from(addr_bytes)))
}
_ => None,
}
}
fn parse_port_from_sockaddr(bytes: &[u8]) -> Option<u16> {
if bytes.len() < 4 {
return None;
}
let family = u16::from_ne_bytes([bytes[0], bytes[1]]) as u32;
match family {
f if f == AF_INET || f == AF_INET6 => {
Some(u16::from_be_bytes([bytes[2], bytes[3]]))
}
_ => None,
}
}
fn set_port_in_sockaddr(bytes: &mut [u8], port: u16) {
if bytes.len() >= 4 {
let port_bytes = port.to_be_bytes();
bytes[2] = port_bytes[0];
bytes[3] = port_bytes[1];
}
}
fn query_socket_protocol(fd: RawFd) -> Option<Protocol> {
let mut proto: libc::c_int = 0;
let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let rc = unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_PROTOCOL,
&mut proto as *mut _ as *mut libc::c_void,
&mut len,
)
};
if rc != 0 {
return None;
}
match proto {
libc::IPPROTO_TCP => Some(Protocol::Tcp),
libc::IPPROTO_UDP => Some(Protocol::Udp),
libc::IPPROTO_ICMP | libc::IPPROTO_ICMPV6 => Some(Protocol::Icmp),
_ => None,
}
}
async fn connect_on_behalf(
notif: &SeccompNotif,
ctx: &Arc<SupervisorCtx>,
notif_fd: RawFd,
) -> NotifAction {
let args = ¬if.data.args;
let sockfd = args[0] as i32;
let addr_ptr = args[1];
let addr_len = args[2] as u32;
let addr_bytes =
match read_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, addr_len as usize) {
Ok(b) => b,
Err(_) => return NotifAction::Errno(libc::EIO),
};
if let Some(ip) = parse_ip_from_sockaddr(&addr_bytes) {
let dest_port = parse_port_from_sockaddr(&addr_bytes);
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(e) => return NotifAction::Errno(e.raw_os_error().unwrap_or(libc::EBADF)),
};
let protocol = match query_socket_protocol(dup_fd.as_raw_fd()) {
Some(p) => p,
None => return NotifAction::Errno(ECONNREFUSED),
};
let ns = ctx.network.lock().await;
let live_policy = {
let pfs = ctx.policy_fn.lock().await;
pfs.live_policy.clone()
};
let effective = ns.effective_network_policy(notif.pid, protocol, live_policy.as_ref());
match (effective, dest_port) {
(crate::seccomp::notif::NetworkPolicy::Unrestricted, _) => {
}
(policy, Some(p)) => {
if !policy.allows(ip, p) {
return NotifAction::Errno(ECONNREFUSED);
}
}
(_, None) => {
return NotifAction::Errno(ECONNREFUSED);
}
}
let http_acl_addr = ns.http_acl_addr;
let http_acl_intercept = dest_port.map_or(false, |p| ns.http_acl_ports.contains(&p));
let http_acl_orig_dest = ns.http_acl_orig_dest.clone();
let remapped_loopback_port = if ctx.policy.port_remap && ip.is_loopback() {
dest_port.and_then(|p| ns.port_map.get_real(p))
} else {
None
};
drop(ns);
let mut redirected = false;
let is_ipv6 = parse_ip_from_sockaddr(&addr_bytes)
.map_or(false, |ip| ip.is_ipv6());
let (mut connect_addr, connect_len) = if let Some(proxy_addr) = http_acl_addr {
if http_acl_intercept {
redirected = true;
if is_ipv6 {
let mut sa6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
sa6.sin6_family = libc::AF_INET6 as u16;
sa6.sin6_port = proxy_addr.port().to_be();
let mapped = std::net::Ipv6Addr::from(
match proxy_addr {
std::net::SocketAddr::V4(v4) => v4.ip().to_ipv6_mapped(),
std::net::SocketAddr::V6(v6) => *v6.ip(),
}
);
sa6.sin6_addr.s6_addr = mapped.octets();
let bytes = unsafe {
std::slice::from_raw_parts(
&sa6 as *const _ as *const u8,
std::mem::size_of::<libc::sockaddr_in6>(),
)
}
.to_vec();
(bytes, std::mem::size_of::<libc::sockaddr_in6>() as u32)
} else {
let mut sa: libc::sockaddr_in = unsafe { std::mem::zeroed() };
sa.sin_family = libc::AF_INET as u16;
sa.sin_port = proxy_addr.port().to_be();
match proxy_addr {
std::net::SocketAddr::V4(v4) => {
sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
}
std::net::SocketAddr::V6(_) => {
return NotifAction::Errno(libc::EAFNOSUPPORT);
}
}
let bytes = unsafe {
std::slice::from_raw_parts(
&sa as *const _ as *const u8,
std::mem::size_of::<libc::sockaddr_in>(),
)
}
.to_vec();
(bytes, std::mem::size_of::<libc::sockaddr_in>() as u32)
}
} else {
(addr_bytes.clone(), addr_len)
}
} else {
(addr_bytes.clone(), addr_len)
};
if !redirected {
if let Some(real_port) = remapped_loopback_port {
set_port_in_sockaddr(&mut connect_addr, real_port);
}
}
if redirected {
if let Some(ref orig_dest_map) = http_acl_orig_dest {
if let Some(orig_ip) = parse_ip_from_sockaddr(&addr_bytes) {
if is_ipv6 {
let mut bind_sa6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
bind_sa6.sin6_family = libc::AF_INET6 as u16;
unsafe {
libc::bind(
dup_fd.as_raw_fd(),
&bind_sa6 as *const _ as *const libc::sockaddr,
std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
);
}
let mut local_sa6: libc::sockaddr_in6 = unsafe { std::mem::zeroed() };
let mut local_len: libc::socklen_t =
std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t;
let gs_ret = unsafe {
libc::getsockname(
dup_fd.as_raw_fd(),
&mut local_sa6 as *mut _ as *mut libc::sockaddr,
&mut local_len,
)
};
if gs_ret == 0 {
let local_port = u16::from_be(local_sa6.sin6_port);
let local_ip = Ipv6Addr::from(local_sa6.sin6_addr.s6_addr);
let local_addr = std::net::SocketAddr::V6(
std::net::SocketAddrV6::new(local_ip, local_port, 0, 0),
);
if let Ok(mut map) = orig_dest_map.write() {
map.insert(local_addr, orig_ip);
}
}
} else {
let mut bind_sa: libc::sockaddr_in = unsafe { std::mem::zeroed() };
bind_sa.sin_family = libc::AF_INET as u16;
unsafe {
libc::bind(
dup_fd.as_raw_fd(),
&bind_sa as *const _ as *const libc::sockaddr,
std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
);
}
let mut local_sa: libc::sockaddr_in = unsafe { std::mem::zeroed() };
let mut local_len: libc::socklen_t =
std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t;
let gs_ret = unsafe {
libc::getsockname(
dup_fd.as_raw_fd(),
&mut local_sa as *mut _ as *mut libc::sockaddr,
&mut local_len,
)
};
if gs_ret == 0 {
let local_port = u16::from_be(local_sa.sin_port);
let local_ip = Ipv4Addr::from(u32::from_be(local_sa.sin_addr.s_addr));
let local_addr = std::net::SocketAddr::V4(
std::net::SocketAddrV4::new(local_ip, local_port),
);
if let Ok(mut map) = orig_dest_map.write() {
map.insert(local_addr, orig_ip);
}
}
}
}
}
}
let ret = unsafe {
libc::connect(
dup_fd.as_raw_fd(),
connect_addr.as_ptr() as *const libc::sockaddr,
connect_len as libc::socklen_t,
)
};
if ret == 0 {
NotifAction::ReturnValue(0)
} else {
let errno = unsafe { *libc::__errno_location() };
NotifAction::Errno(errno)
}
} else {
NotifAction::Continue
}
}
async fn sendto_on_behalf(
notif: &SeccompNotif,
ctx: &Arc<SupervisorCtx>,
notif_fd: RawFd,
) -> NotifAction {
let args = ¬if.data.args;
let sockfd = args[0] as i32;
let buf_ptr = args[1];
let buf_len = args[2] as usize;
if buf_len > MAX_SEND_BUF {
return NotifAction::Errno(libc::EMSGSIZE);
}
let flags = args[3] as i32;
let addr_ptr = args[4];
let addr_len = args[5] as u32;
if addr_ptr == 0 {
return NotifAction::Continue; }
let addr_bytes =
match read_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, addr_len as usize) {
Ok(b) => b,
Err(_) => return NotifAction::Errno(libc::EIO),
};
if let Some(ip) = parse_ip_from_sockaddr(&addr_bytes) {
let dest_port = parse_port_from_sockaddr(&addr_bytes);
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(e) => return NotifAction::Errno(e.raw_os_error().unwrap_or(libc::EBADF)),
};
let protocol = match query_socket_protocol(dup_fd.as_raw_fd()) {
Some(p) => p,
None => return NotifAction::Errno(ECONNREFUSED),
};
let ns = ctx.network.lock().await;
let live_policy = {
let pfs = ctx.policy_fn.lock().await;
pfs.live_policy.clone()
};
let effective = ns.effective_network_policy(notif.pid, protocol, live_policy.as_ref());
if !matches!(effective, crate::seccomp::notif::NetworkPolicy::Unrestricted) {
match dest_port {
Some(p) if !effective.allows(ip, p) => {
return NotifAction::Errno(ECONNREFUSED);
}
None => return NotifAction::Errno(ECONNREFUSED),
Some(_) => {}
}
}
drop(ns);
let data = match read_child_mem(notif_fd, notif.id, notif.pid, buf_ptr, buf_len) {
Ok(b) => b,
Err(_) => return NotifAction::Errno(libc::EIO),
};
let ret = unsafe {
libc::sendto(
dup_fd.as_raw_fd(),
data.as_ptr() as *const libc::c_void,
data.len(),
flags,
addr_bytes.as_ptr() as *const libc::sockaddr,
addr_len as libc::socklen_t,
)
};
if ret >= 0 {
NotifAction::ReturnValue(ret as i64)
} else {
let errno = unsafe { *libc::__errno_location() };
NotifAction::Errno(errno)
}
} else {
NotifAction::Continue
}
}
async fn sendmsg_on_behalf(
notif: &SeccompNotif,
ctx: &Arc<SupervisorCtx>,
notif_fd: RawFd,
) -> NotifAction {
let args = ¬if.data.args;
let sockfd = args[0] as i32;
let msghdr_ptr = args[1];
let flags = args[2] as i32;
match prescan_msghdr(notif, notif_fd, msghdr_ptr) {
PrescanResult::ContinueWholeCall => return NotifAction::Continue,
PrescanResult::Errno(e) => return NotifAction::Errno(e),
PrescanResult::OnBehalf => {}
}
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(e) => return NotifAction::Errno(e.raw_os_error().unwrap_or(libc::EBADF)),
};
let protocol = match query_socket_protocol(dup_fd.as_raw_fd()) {
Some(p) => p,
None => return NotifAction::Errno(ECONNREFUSED),
};
match send_msghdr_on_behalf(notif, ctx, notif_fd, &dup_fd, protocol, msghdr_ptr, flags).await {
Ok(n) => NotifAction::ReturnValue(n as i64),
Err(errno) => NotifAction::Errno(errno),
}
}
#[derive(Clone, Copy)]
enum PrescanResult {
OnBehalf,
ContinueWholeCall,
Errno(i32),
}
fn prescan_msghdr(
notif: &SeccompNotif,
notif_fd: RawFd,
msghdr_ptr: u64,
) -> PrescanResult {
let msghdr_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, msghdr_ptr, 56) {
Ok(b) if b.len() >= 56 => b,
_ => return PrescanResult::Errno(libc::EFAULT),
};
let msg_name_ptr = u64::from_ne_bytes(msghdr_bytes[0..8].try_into().unwrap());
if msg_name_ptr == 0 {
return PrescanResult::ContinueWholeCall;
}
let msg_namelen = u32::from_ne_bytes(msghdr_bytes[8..12].try_into().unwrap());
let addr_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, msg_name_ptr, msg_namelen as usize) {
Ok(b) => b,
Err(_) => return PrescanResult::Errno(libc::EIO),
};
if parse_ip_from_sockaddr(&addr_bytes).is_none() {
return PrescanResult::ContinueWholeCall;
}
PrescanResult::OnBehalf
}
async fn send_msghdr_on_behalf(
notif: &SeccompNotif,
ctx: &Arc<SupervisorCtx>,
notif_fd: RawFd,
dup_fd: &std::os::unix::io::OwnedFd,
protocol: Protocol,
msghdr_ptr: u64,
flags: i32,
) -> Result<isize, i32> {
let msghdr_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, msghdr_ptr, 56) {
Ok(b) if b.len() >= 56 => b,
_ => return Err(libc::EFAULT),
};
let msg_name_ptr = u64::from_ne_bytes(msghdr_bytes[0..8].try_into().unwrap());
let msg_namelen = u32::from_ne_bytes(msghdr_bytes[8..12].try_into().unwrap());
let msg_iov_ptr = u64::from_ne_bytes(msghdr_bytes[16..24].try_into().unwrap());
let msg_iovlen = u64::from_ne_bytes(msghdr_bytes[24..32].try_into().unwrap());
let msg_control_ptr = u64::from_ne_bytes(msghdr_bytes[32..40].try_into().unwrap());
let msg_controllen = u64::from_ne_bytes(msghdr_bytes[40..48].try_into().unwrap());
let addr_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, msg_name_ptr, msg_namelen as usize) {
Ok(b) => b,
Err(_) => return Err(libc::EIO),
};
let ip = match parse_ip_from_sockaddr(&addr_bytes) {
Some(ip) => ip,
None => return Err(libc::EAFNOSUPPORT),
};
let dest_port = parse_port_from_sockaddr(&addr_bytes);
let ns = ctx.network.lock().await;
let live_policy = {
let pfs = ctx.policy_fn.lock().await;
pfs.live_policy.clone()
};
let effective = ns.effective_network_policy(notif.pid, protocol, live_policy.as_ref());
if !matches!(effective, crate::seccomp::notif::NetworkPolicy::Unrestricted) {
match dest_port {
Some(p) if !effective.allows(ip, p) => return Err(ECONNREFUSED),
None => return Err(ECONNREFUSED),
Some(_) => {}
}
}
drop(ns);
let iovlen = (msg_iovlen as usize).min(1024);
let iov_size = iovlen * 16;
let iov_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, msg_iov_ptr, iov_size) {
Ok(b) => b,
Err(_) => return Err(libc::EIO),
};
let mut data_bufs: Vec<Vec<u8>> = Vec::with_capacity(iovlen);
let mut local_iovs: Vec<libc::iovec> = Vec::with_capacity(iovlen);
for i in 0..iovlen {
let off = i * 16;
if off + 16 > iov_bytes.len() { break; }
let iov_base = u64::from_ne_bytes(iov_bytes[off..off + 8].try_into().unwrap());
let iov_len = u64::from_ne_bytes(iov_bytes[off + 8..off + 16].try_into().unwrap()) as usize;
if iov_len > MAX_SEND_BUF {
return Err(libc::EMSGSIZE);
}
if iov_base == 0 || iov_len == 0 {
data_bufs.push(Vec::new());
continue;
}
let buf = match read_child_mem(notif_fd, notif.id, notif.pid, iov_base, iov_len) {
Ok(b) => b,
Err(_) => return Err(libc::EIO),
};
data_bufs.push(buf);
}
for buf in &data_bufs {
local_iovs.push(libc::iovec {
iov_base: buf.as_ptr() as *mut libc::c_void,
iov_len: buf.len(),
});
}
let control_buf = if msg_control_ptr != 0 && msg_controllen > 0 {
let len = (msg_controllen as usize).min(4096);
read_child_mem(notif_fd, notif.id, notif.pid, msg_control_ptr, len).ok()
} else {
None
};
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_name = addr_bytes.as_ptr() as *mut libc::c_void;
msg.msg_namelen = addr_bytes.len() as u32;
msg.msg_iov = local_iovs.as_mut_ptr();
msg.msg_iovlen = local_iovs.len();
if let Some(ref ctrl) = control_buf {
msg.msg_control = ctrl.as_ptr() as *mut libc::c_void;
msg.msg_controllen = ctrl.len();
}
let ret = unsafe { libc::sendmsg(dup_fd.as_raw_fd(), &msg, flags) };
if ret >= 0 {
Ok(ret)
} else {
Err(unsafe { *libc::__errno_location() })
}
}
const MMSGHDR_SIZE: usize = 64;
const MSG_LEN_OFFSET: usize = 56;
const MAX_MMSGHDR_ENTRIES: usize = 256;
async fn sendmmsg_on_behalf(
notif: &SeccompNotif,
ctx: &Arc<SupervisorCtx>,
notif_fd: RawFd,
) -> NotifAction {
let args = ¬if.data.args;
let sockfd = args[0] as i32;
let msgvec_ptr = args[1];
let vlen = (args[2] as u32 as usize).min(MAX_MMSGHDR_ENTRIES);
let flags = args[3] as i32;
if vlen == 0 {
return NotifAction::ReturnValue(0);
}
for i in 0..vlen {
let entry_ptr = msgvec_ptr + (i * MMSGHDR_SIZE) as u64;
match prescan_msghdr(notif, notif_fd, entry_ptr) {
PrescanResult::OnBehalf => continue,
PrescanResult::ContinueWholeCall => return NotifAction::Continue,
PrescanResult::Errno(e) => return NotifAction::Errno(e),
}
}
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(e) => return NotifAction::Errno(e.raw_os_error().unwrap_or(libc::EBADF)),
};
let protocol = match query_socket_protocol(dup_fd.as_raw_fd()) {
Some(p) => p,
None => return NotifAction::Errno(ECONNREFUSED),
};
let mut sent: usize = 0;
let mut first_errno: Option<i32> = None;
for i in 0..vlen {
let entry_ptr = msgvec_ptr + (i * MMSGHDR_SIZE) as u64;
match send_msghdr_on_behalf(notif, ctx, notif_fd, &dup_fd, protocol, entry_ptr, flags).await {
Ok(n) => {
let bytes = (n as u32).to_ne_bytes();
let _ = write_child_mem(
notif_fd, notif.id, notif.pid,
entry_ptr + MSG_LEN_OFFSET as u64,
&bytes,
);
sent += 1;
}
Err(errno) => {
first_errno = Some(errno);
break;
}
}
}
if sent > 0 {
NotifAction::ReturnValue(sent as i64)
} else {
NotifAction::Errno(first_errno.unwrap_or(ECONNREFUSED))
}
}
pub(crate) async fn handle_net(
notif: &SeccompNotif,
ctx: &Arc<SupervisorCtx>,
notif_fd: RawFd,
) -> NotifAction {
let nr = notif.data.nr as i64;
if nr == libc::SYS_connect {
connect_on_behalf(notif, ctx, notif_fd).await
} else if nr == libc::SYS_sendto {
sendto_on_behalf(notif, ctx, notif_fd).await
} else if nr == libc::SYS_sendmsg {
sendmsg_on_behalf(notif, ctx, notif_fd).await
} else if nr == libc::SYS_sendmmsg {
sendmmsg_on_behalf(notif, ctx, notif_fd).await
} else {
NotifAction::Continue
}
}
pub struct ResolvedNetAllow {
pub per_ip: HashMap<IpAddr, HashSet<u16>>,
pub per_ip_all_ports: HashSet<IpAddr>,
pub any_ip_ports: HashSet<u16>,
pub any_ip_all_ports: bool,
}
pub struct ResolvedNetAllowSet {
pub tcp: ResolvedNetAllow,
pub udp: ResolvedNetAllow,
pub icmp: ResolvedNetAllow,
pub etc_hosts: Option<String>,
}
pub async fn resolve_net_allow(
rules: &[NetAllow],
) -> io::Result<ResolvedNetAllowSet> {
let mut etc_hosts = String::from("127.0.0.1 localhost\n::1 localhost\n");
let mut has_concrete_host = false;
let per_proto = |target: Protocol| async move {
let mut per_ip: HashMap<IpAddr, HashSet<u16>> = HashMap::new();
let mut per_ip_all_ports: HashSet<IpAddr> = HashSet::new();
let mut any_ip_ports: HashSet<u16> = HashSet::new();
let mut any_ip_all_ports = false;
let mut local_etc_hosts = String::new();
let mut local_has_concrete = false;
for rule in rules.iter().filter(|r| r.protocol == target) {
match &rule.host {
None => {
if rule.all_ports || target == Protocol::Icmp {
any_ip_all_ports = true;
} else {
for &p in &rule.ports {
any_ip_ports.insert(p);
}
}
}
Some(host) => {
local_has_concrete = true;
let addr = format!("{}:0", host);
let resolved = tokio::net::lookup_host(addr.as_str()).await.map_err(|e| {
io::Error::new(
e.kind(),
format!("failed to resolve host '{}': {}", host, e),
)
})?;
for socket_addr in resolved {
let ip = socket_addr.ip();
if rule.all_ports || target == Protocol::Icmp {
per_ip_all_ports.insert(ip);
per_ip.entry(ip).or_default();
} else {
let entry = per_ip.entry(ip).or_default();
for &p in &rule.ports {
entry.insert(p);
}
}
local_etc_hosts.push_str(&format!("{} {}\n", ip, host));
}
}
}
}
Ok::<_, io::Error>((
ResolvedNetAllow {
per_ip,
per_ip_all_ports,
any_ip_ports,
any_ip_all_ports,
},
local_etc_hosts,
local_has_concrete,
))
};
let (tcp, tcp_eh, tcp_concrete) = per_proto(Protocol::Tcp).await?;
let (udp, udp_eh, udp_concrete) = per_proto(Protocol::Udp).await?;
let (icmp, icmp_eh, icmp_concrete) = per_proto(Protocol::Icmp).await?;
for chunk in [tcp_eh, udp_eh, icmp_eh] {
etc_hosts.push_str(&chunk);
}
has_concrete_host |= tcp_concrete || udp_concrete || icmp_concrete;
Ok(ResolvedNetAllowSet {
tcp,
udp,
icmp,
etc_hosts: if has_concrete_host { Some(etc_hosts) } else { None },
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn netallow_parse_concrete_host_port() {
let r = NetAllow::parse("example.com:443").unwrap();
assert_eq!(r.host.as_deref(), Some("example.com"));
assert_eq!(r.ports, vec![443]);
assert!(!r.all_ports);
}
#[test]
fn netallow_parse_any_host_port() {
let r = NetAllow::parse(":8080").unwrap();
assert_eq!(r.host, None);
assert_eq!(r.ports, vec![8080]);
assert!(!r.all_ports);
let r = NetAllow::parse("*:8080").unwrap();
assert_eq!(r.host, None);
assert_eq!(r.ports, vec![8080]);
assert!(!r.all_ports);
}
#[test]
fn netallow_parse_multiple_ports() {
let r = NetAllow::parse("github.com:22,80,443").unwrap();
assert_eq!(r.host.as_deref(), Some("github.com"));
assert_eq!(r.ports, vec![22, 80, 443]);
assert!(!r.all_ports);
}
#[test]
fn netallow_parse_wildcard_any_host_any_port_colon() {
let r = NetAllow::parse(":*").unwrap();
assert_eq!(r.host, None);
assert!(r.ports.is_empty());
assert!(r.all_ports);
}
#[test]
fn netallow_parse_wildcard_any_host_any_port_star() {
let r = NetAllow::parse("*:*").unwrap();
assert_eq!(r.host, None);
assert!(r.ports.is_empty());
assert!(r.all_ports);
}
#[test]
fn netallow_parse_wildcard_concrete_host_any_port() {
let r = NetAllow::parse("example.com:*").unwrap();
assert_eq!(r.host.as_deref(), Some("example.com"));
assert!(r.ports.is_empty());
assert!(r.all_ports);
}
#[test]
fn netallow_parse_rejects_mixed_wildcard_and_concrete() {
let err = NetAllow::parse("example.com:80,*").unwrap_err();
assert!(format!("{}", err).contains("cannot mix"));
let err = NetAllow::parse("example.com:*,80").unwrap_err();
assert!(format!("{}", err).contains("cannot mix"));
}
#[test]
fn netallow_parse_rejects_port_zero() {
let err = NetAllow::parse("example.com:0").unwrap_err();
assert!(format!("{}", err).contains("port 0"));
}
#[test]
fn netallow_parse_rejects_empty_port() {
let err = NetAllow::parse("example.com:").unwrap_err();
assert!(format!("{}", err).contains("invalid port"));
}
#[test]
fn netallow_parse_rejects_no_colon() {
let err = NetAllow::parse("example.com").unwrap_err();
assert!(format!("{}", err).contains("expected"));
}
#[test]
fn netallow_parse_repeated_wildcard_is_idempotent() {
let r = NetAllow::parse(":*,*").unwrap();
assert!(r.all_ports);
assert!(r.ports.is_empty());
}
#[test]
fn netallow_bare_form_defaults_to_tcp() {
let r = NetAllow::parse("example.com:443").unwrap();
assert_eq!(r.protocol, Protocol::Tcp);
}
#[test]
fn netallow_explicit_tcp_scheme() {
let r = NetAllow::parse("tcp://example.com:443").unwrap();
assert_eq!(r.protocol, Protocol::Tcp);
assert_eq!(r.host.as_deref(), Some("example.com"));
assert_eq!(r.ports, vec![443]);
}
#[test]
fn netallow_udp_scheme_with_host_port() {
let r = NetAllow::parse("udp://1.1.1.1:53").unwrap();
assert_eq!(r.protocol, Protocol::Udp);
assert_eq!(r.host.as_deref(), Some("1.1.1.1"));
assert_eq!(r.ports, vec![53]);
}
#[test]
fn netallow_udp_wildcard_any_anywhere() {
let r = NetAllow::parse("udp://*:*").unwrap();
assert_eq!(r.protocol, Protocol::Udp);
assert_eq!(r.host, None);
assert!(r.all_ports);
}
#[test]
fn netallow_icmp_scheme_with_host() {
let r = NetAllow::parse("icmp://github.com").unwrap();
assert_eq!(r.protocol, Protocol::Icmp);
assert_eq!(r.host.as_deref(), Some("github.com"));
assert!(r.ports.is_empty());
assert!(!r.all_ports);
}
#[test]
fn netallow_icmp_wildcard() {
let r = NetAllow::parse("icmp://*").unwrap();
assert_eq!(r.protocol, Protocol::Icmp);
assert_eq!(r.host, None);
}
#[test]
fn netallow_icmp_rejects_port() {
let err = NetAllow::parse("icmp://github.com:80").unwrap_err();
assert!(format!("{}", err).contains("icmp rules take no port"));
}
#[test]
fn netallow_icmp_rejects_empty_body() {
let err = NetAllow::parse("icmp://").unwrap_err();
assert!(format!("{}", err).contains("needs a host or `*`"));
}
#[test]
fn netallow_unknown_scheme_rejected() {
for spec in ["sctp://host:1234", "icmp-raw://*"] {
let err = NetAllow::parse(spec).unwrap_err();
assert!(format!("{}", err).contains("unknown scheme"), "spec: {}", spec);
}
}
#[tokio::test]
async fn test_resolve_net_allow_empty() {
let resolved = resolve_net_allow(&[]).await.unwrap();
assert!(resolved.tcp.per_ip.is_empty());
assert!(resolved.tcp.any_ip_ports.is_empty());
assert!(resolved.udp.per_ip.is_empty());
assert!(resolved.icmp.per_ip.is_empty());
assert!(resolved.etc_hosts.is_none());
}
#[tokio::test]
async fn test_resolve_net_allow_concrete_host() {
let rules = vec![NetAllow {
protocol: Protocol::Tcp,
host: Some("localhost".to_string()),
ports: vec![80, 443],
all_ports: false,
}];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(!resolved.tcp.per_ip.is_empty());
for ports in resolved.tcp.per_ip.values() {
assert!(ports.contains(&80));
assert!(ports.contains(&443));
}
assert!(resolved.udp.per_ip.is_empty());
assert!(resolved.icmp.per_ip.is_empty());
assert!(resolved.etc_hosts.as_deref().unwrap_or("").contains("localhost"));
}
#[tokio::test]
async fn test_resolve_net_allow_any_ip() {
let rules = vec![NetAllow {
protocol: Protocol::Tcp,
host: None,
ports: vec![8080],
all_ports: false,
}];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(resolved.tcp.per_ip.is_empty());
assert!(resolved.tcp.any_ip_ports.contains(&8080));
assert!(!resolved.tcp.any_ip_all_ports);
assert!(resolved.etc_hosts.is_none());
}
#[tokio::test]
async fn test_resolve_net_allow_any_ip_all_ports() {
let rules = vec![NetAllow {
protocol: Protocol::Tcp,
host: None,
ports: vec![],
all_ports: true,
}];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(resolved.tcp.any_ip_all_ports);
assert!(resolved.tcp.per_ip.is_empty());
assert!(resolved.tcp.per_ip_all_ports.is_empty());
assert!(resolved.tcp.any_ip_ports.is_empty());
assert!(!resolved.udp.any_ip_all_ports);
assert!(!resolved.icmp.any_ip_all_ports);
}
#[tokio::test]
async fn test_resolve_net_allow_concrete_host_all_ports() {
let rules = vec![NetAllow {
protocol: Protocol::Tcp,
host: Some("localhost".to_string()),
ports: vec![],
all_ports: true,
}];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(!resolved.tcp.any_ip_all_ports);
assert!(
!resolved.tcp.per_ip_all_ports.is_empty(),
"localhost should resolve to at least one IP marked as any-port"
);
for ip in resolved.tcp.per_ip_all_ports.iter() {
assert!(resolved.tcp.per_ip.contains_key(ip));
}
assert!(resolved.etc_hosts.is_some());
}
#[tokio::test]
async fn test_resolve_net_allow_mixed_wildcard_and_concrete() {
let rules = vec![
NetAllow {
protocol: Protocol::Tcp,
host: None,
ports: vec![],
all_ports: true,
},
NetAllow {
protocol: Protocol::Tcp,
host: Some("localhost".to_string()),
ports: vec![22],
all_ports: false,
},
];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(resolved.tcp.any_ip_all_ports);
assert!(!resolved.tcp.per_ip.is_empty());
}
#[tokio::test]
async fn test_resolve_per_protocol_isolation() {
let rules = vec![
NetAllow {
protocol: Protocol::Tcp,
host: Some("localhost".to_string()),
ports: vec![443],
all_ports: false,
},
NetAllow {
protocol: Protocol::Udp,
host: None,
ports: vec![53],
all_ports: false,
},
];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(
!resolved.tcp.per_ip.is_empty(),
"TCP rule should populate tcp set"
);
assert!(
resolved.udp.any_ip_ports.contains(&53),
"UDP rule should populate udp set"
);
for ports in resolved.tcp.per_ip.values() {
assert!(!ports.contains(&53), "UDP port leaked into TCP set");
}
assert!(!resolved.udp.any_ip_ports.contains(&443), "TCP port leaked into UDP set");
}
#[tokio::test]
async fn test_resolve_icmp_no_ports() {
let rules = vec![NetAllow {
protocol: Protocol::Icmp,
host: Some("localhost".to_string()),
ports: vec![],
all_ports: false,
}];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(
!resolved.icmp.per_ip.is_empty(),
"icmp host should populate per_ip"
);
assert!(
!resolved.icmp.per_ip_all_ports.is_empty(),
"icmp host should mark per_ip_all_ports (no port check)"
);
assert!(resolved.icmp.any_ip_ports.is_empty());
assert!(resolved.tcp.per_ip.is_empty());
assert!(resolved.udp.per_ip.is_empty());
}
#[tokio::test]
async fn test_resolve_icmp_wildcard() {
let rules = vec![NetAllow {
protocol: Protocol::Icmp,
host: None,
ports: vec![],
all_ports: false,
}];
let resolved = resolve_net_allow(&rules).await.unwrap();
assert!(resolved.icmp.any_ip_all_ports);
assert!(!resolved.tcp.any_ip_all_ports);
}
}