use std::collections::HashMap;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, TcpListener};
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::seccomp::notif::{read_child_mem, write_child_mem, NotifAction};
use crate::seccomp::state::NetworkState;
use crate::sys::structs::{SeccompNotif, AF_INET, AF_INET6};
#[derive(Default)]
pub struct PortMap {
pub virtual_to_real: HashMap<u16, u16>,
pub real_to_virtual: HashMap<u16, u16>,
pub bound_ports: std::collections::HashSet<u16>,
#[allow(clippy::type_complexity)]
pub on_bind: Option<Box<dyn Fn(&HashMap<u16, u16>) + Send + Sync>>,
}
impl PortMap {
pub fn new() -> Self {
Self::default()
}
pub fn record_bind(&mut self, virtual_port: u16, real_port: u16) {
self.bound_ports.insert(real_port);
if virtual_port != real_port {
self.virtual_to_real.insert(virtual_port, real_port);
self.real_to_virtual.insert(real_port, virtual_port);
}
if let Some(ref cb) = self.on_bind {
let mut all: HashMap<u16, u16> = self.bound_ports.iter()
.map(|&p| (self.real_to_virtual.get(&p).copied().unwrap_or(p), p))
.collect();
all.extend(self.virtual_to_real.iter().map(|(&v, &r)| (v, r)));
cb(&all);
}
}
pub fn get_real(&self, virtual_port: u16) -> Option<u16> {
self.virtual_to_real.get(&virtual_port).copied()
}
pub fn get_virtual(&self, real_port: u16) -> Option<u16> {
self.real_to_virtual.get(&real_port).copied()
}
pub fn allocate_or_reserve(&mut self, virtual_port: u16, family: u32) -> Option<u16> {
if let Some(real) = self.virtual_to_real.get(&virtual_port) {
return Some(*real);
}
if let Some(port) = try_reserve_port(virtual_port, family) {
self.record_bind(virtual_port, port);
return Some(port);
}
let real = allocate_real_port(family)?;
self.record_bind(virtual_port, real);
Some(real)
}
}
fn extract_port(bytes: &[u8]) -> Option<u16> {
if bytes.len() < 4 {
return None;
}
let family = u16::from_ne_bytes([bytes[0], bytes[1]]) as u32;
match family {
f if f == AF_INET || f == AF_INET6 => {
Some(u16::from_be_bytes([bytes[2], bytes[3]]))
}
_ => None,
}
}
fn set_port_in_sockaddr(bytes: &mut [u8], port: u16) {
if bytes.len() >= 4 {
let port_bytes = port.to_be_bytes();
bytes[2] = port_bytes[0];
bytes[3] = port_bytes[1];
}
}
fn try_reserve_port(port: u16, family: u32) -> Option<u16> {
if port == 0 {
return None;
}
let addr: SocketAddr = if family == AF_INET6 {
SocketAddr::new(Ipv6Addr::LOCALHOST.into(), port)
} else {
SocketAddr::new(Ipv4Addr::LOCALHOST.into(), port)
};
TcpListener::bind(addr).ok().map(|_| port)
}
fn allocate_real_port(family: u32) -> Option<u16> {
let addr: SocketAddr = if family == AF_INET6 {
SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 0)
} else {
SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 0)
};
let listener = TcpListener::bind(addr).ok()?;
Some(listener.local_addr().ok()?.port())
}
pub(crate) async fn handle_bind(
notif: &SeccompNotif,
network: &Arc<Mutex<NetworkState>>,
notif_fd: RawFd,
) -> NotifAction {
let sockfd = notif.data.args[0] as i32;
let addr_ptr = notif.data.args[1];
let addr_len = notif.data.args[2] as usize;
if addr_ptr == 0 || addr_len < 4 {
return NotifAction::Continue;
}
let read_len = addr_len.min(128);
let mut bytes = match read_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, read_len) {
Ok(b) => b,
Err(_) => return NotifAction::Errno(libc::EIO),
};
let family = u16::from_ne_bytes([bytes[0], bytes[1]]) as u32;
if let Some(virtual_port) = extract_port(&bytes) {
if virtual_port == 0 {
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(_) => return NotifAction::Errno(libc::ENOSYS),
};
let ret = unsafe {
libc::bind(dup_fd.as_raw_fd(), bytes.as_ptr() as *const libc::sockaddr, addr_len as libc::socklen_t)
};
return if ret == 0 {
NotifAction::ReturnValue(0)
} else {
NotifAction::Errno(unsafe { *libc::__errno_location() })
};
}
let mut ns = network.lock().await;
if let Some(real_port) = ns.port_map.allocate_or_reserve(virtual_port, family) {
if real_port != virtual_port {
set_port_in_sockaddr(&mut bytes, real_port);
}
}
drop(ns);
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(_) => return NotifAction::Errno(libc::ENOSYS),
};
let ret = unsafe {
libc::bind(dup_fd.as_raw_fd(), bytes.as_ptr() as *const libc::sockaddr, addr_len as libc::socklen_t)
};
if ret == 0 {
NotifAction::ReturnValue(0)
} else {
NotifAction::Errno(unsafe { *libc::__errno_location() })
}
} else {
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(_) => return NotifAction::Errno(libc::ENOSYS),
};
let ret = unsafe {
libc::bind(dup_fd.as_raw_fd(), bytes.as_ptr() as *const libc::sockaddr, addr_len as libc::socklen_t)
};
if ret == 0 {
NotifAction::ReturnValue(0)
} else {
NotifAction::Errno(unsafe { *libc::__errno_location() })
}
}
}
pub(crate) async fn handle_getsockname(
notif: &SeccompNotif,
network: &Arc<Mutex<NetworkState>>,
notif_fd: RawFd,
) -> NotifAction {
let addr_ptr = notif.data.args[1];
let addrlen_ptr = notif.data.args[2];
if addr_ptr == 0 || addrlen_ptr == 0 {
return NotifAction::Continue;
}
let addrlen_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, addrlen_ptr, 4) {
Ok(b) if b.len() >= 4 => b,
_ => return NotifAction::Continue,
};
let addr_len = u32::from_ne_bytes(addrlen_bytes[..4].try_into().unwrap()) as usize;
if addr_len < 4 {
return NotifAction::Continue;
}
let read_len = addr_len.min(128);
let mut bytes = match read_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, read_len) {
Ok(b) => b,
Err(_) => return NotifAction::Continue,
};
if let Some(real_port) = extract_port(&bytes) {
let ns = network.lock().await;
if let Some(virtual_port) = ns.port_map.get_virtual(real_port) {
set_port_in_sockaddr(&mut bytes, virtual_port);
drop(ns);
let _ = write_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, &bytes);
}
}
NotifAction::Continue
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_port_map_identity() {
let mut pm = PortMap::new();
pm.record_bind(8080, 8080);
assert!(pm.bound_ports.contains(&8080));
assert_eq!(pm.get_real(8080), None); assert_eq!(pm.get_virtual(8080), None);
}
#[test]
fn test_port_map_remap() {
let mut pm = PortMap::new();
pm.record_bind(8080, 9090);
assert!(pm.bound_ports.contains(&9090));
assert_eq!(pm.get_real(8080), Some(9090));
assert_eq!(pm.get_virtual(9090), Some(8080));
}
#[test]
fn test_extract_port_ipv4() {
let mut buf = vec![0u8; 16];
let family = (AF_INET as u16).to_ne_bytes();
buf[0] = family[0];
buf[1] = family[1];
buf[2] = 0x1F; buf[3] = 0x90; assert_eq!(extract_port(&buf), Some(8080));
}
#[test]
fn test_extract_port_ipv6() {
let mut buf = vec![0u8; 28];
let family = (AF_INET6 as u16).to_ne_bytes();
buf[0] = family[0];
buf[1] = family[1];
buf[2] = 0x00;
buf[3] = 0x50; assert_eq!(extract_port(&buf), Some(80));
}
#[test]
fn test_extract_port_unix() {
let buf = vec![1, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(extract_port(&buf), None);
}
#[test]
fn test_extract_port_too_short() {
let buf = vec![2, 0];
assert_eq!(extract_port(&buf), None);
}
#[test]
fn test_set_port_in_sockaddr() {
let mut buf = vec![0u8; 16];
set_port_in_sockaddr(&mut buf, 443);
assert_eq!(buf[2], 0x01);
assert_eq!(buf[3], 0xBB);
}
#[test]
fn test_try_reserve_port_zero() {
assert!(try_reserve_port(0, AF_INET).is_none());
}
#[test]
fn test_allocate_real_port() {
let port = allocate_real_port(AF_INET);
assert!(port.is_some());
assert!(port.unwrap() > 0);
}
#[test]
fn test_port_map_allocate_or_reserve() {
let mut pm = PortMap::new();
let real = pm.allocate_or_reserve(18080, AF_INET); assert!(real.is_some());
let real = real.unwrap();
assert!(pm.bound_ports.contains(&real));
}
#[test]
fn test_port_map_allocate_or_reserve_cached() {
let mut pm = PortMap::new();
let first = pm.allocate_or_reserve(18081, AF_INET).unwrap();
let second = pm.allocate_or_reserve(18081, AF_INET).unwrap();
assert_eq!(first, second); }
}