use std::collections::HashMap;
use std::os::unix::io::{AsRawFd, OwnedFd, RawFd};
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::seccomp::notif::{read_child_mem, write_child_mem, NotifAction};
use crate::seccomp::state::NetworkState;
use crate::sys::structs::{SeccompNotif, AF_INET, AF_INET6};
#[derive(Default)]
pub struct PortMap {
pub virtual_to_real: HashMap<u16, u16>,
pub real_to_virtual: HashMap<u16, u16>,
pub bound_ports: std::collections::HashSet<u16>,
#[allow(clippy::type_complexity)]
pub on_bind: Option<Box<dyn Fn(&HashMap<u16, u16>) + Send + Sync>>,
}
impl PortMap {
pub fn new() -> Self {
Self::default()
}
pub fn record_bind(&mut self, virtual_port: u16, real_port: u16) {
self.bound_ports.insert(real_port);
if virtual_port != real_port {
self.virtual_to_real.insert(virtual_port, real_port);
self.real_to_virtual.insert(real_port, virtual_port);
}
if let Some(ref cb) = self.on_bind {
let mut all: HashMap<u16, u16> = self.bound_ports.iter()
.map(|&p| (self.real_to_virtual.get(&p).copied().unwrap_or(p), p))
.collect();
all.extend(self.virtual_to_real.iter().map(|(&v, &r)| (v, r)));
cb(&all);
}
}
pub fn get_real(&self, virtual_port: u16) -> Option<u16> {
self.virtual_to_real.get(&virtual_port).copied()
}
pub fn get_virtual(&self, real_port: u16) -> Option<u16> {
self.real_to_virtual.get(&real_port).copied()
}
}
fn extract_port(bytes: &[u8]) -> Option<u16> {
if bytes.len() < 4 {
return None;
}
let family = u16::from_ne_bytes([bytes[0], bytes[1]]) as u32;
match family {
f if f == AF_INET || f == AF_INET6 => {
Some(u16::from_be_bytes([bytes[2], bytes[3]]))
}
_ => None,
}
}
fn set_port_in_sockaddr(bytes: &mut [u8], port: u16) {
if bytes.len() >= 4 {
let port_bytes = port.to_be_bytes();
bytes[2] = port_bytes[0];
bytes[3] = port_bytes[1];
}
}
pub(crate) async fn handle_bind(
notif: &SeccompNotif,
network: &Arc<Mutex<NetworkState>>,
notif_fd: RawFd,
) -> NotifAction {
let sockfd = notif.data.args[0] as i32;
let addr_ptr = notif.data.args[1];
let addr_len = notif.data.args[2] as usize;
if addr_ptr == 0 || addr_len < 4 {
return NotifAction::Continue;
}
let read_len = addr_len.min(128);
let mut bytes = match read_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, read_len) {
Ok(b) => b,
Err(_) => return NotifAction::Errno(libc::EIO),
};
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(e) => return NotifAction::Errno(e.raw_os_error().unwrap_or(libc::EBADF)),
};
let virtual_port = match extract_port(&bytes) {
Some(p) if p != 0 => p,
_ => return bind_verbatim(&dup_fd, &bytes, addr_len),
};
let denied = {
let ns = network.lock().await;
!ns.bind_deny_ports.is_empty() && ns.bind_deny_ports.contains(&virtual_port)
};
if denied
&& crate::network::query_socket_protocol(dup_fd.as_raw_fd())
== Some(crate::network::Protocol::Tcp)
{
return NotifAction::Errno(libc::EACCES);
}
let cached_real = {
let ns = network.lock().await;
ns.port_map.get_real(virtual_port)
};
let attempt_port = cached_real.unwrap_or(virtual_port);
set_port_in_sockaddr(&mut bytes, attempt_port);
let ret = unsafe {
libc::bind(
dup_fd.as_raw_fd(),
bytes.as_ptr() as *const libc::sockaddr,
addr_len as libc::socklen_t,
)
};
if ret == 0 {
if cached_real.is_none() {
network.lock().await.port_map.record_bind(virtual_port, virtual_port);
}
return NotifAction::ReturnValue(0);
}
let err = unsafe { *libc::__errno_location() };
if err != libc::EADDRINUSE {
return NotifAction::Errno(err);
}
set_port_in_sockaddr(&mut bytes, 0);
let ret = unsafe {
libc::bind(
dup_fd.as_raw_fd(),
bytes.as_ptr() as *const libc::sockaddr,
addr_len as libc::socklen_t,
)
};
if ret != 0 {
return NotifAction::Errno(unsafe { *libc::__errno_location() });
}
let real_port = match query_local_port(&dup_fd) {
Some(p) => p,
None => return NotifAction::Errno(libc::EIO),
};
network.lock().await.port_map.record_bind(virtual_port, real_port);
NotifAction::ReturnValue(0)
}
fn bind_verbatim(fd: &OwnedFd, addr: &[u8], len: usize) -> NotifAction {
let ret = unsafe {
libc::bind(
fd.as_raw_fd(),
addr.as_ptr() as *const libc::sockaddr,
len as libc::socklen_t,
)
};
if ret == 0 {
NotifAction::ReturnValue(0)
} else {
NotifAction::Errno(unsafe { *libc::__errno_location() })
}
}
fn query_local_port(fd: &OwnedFd) -> Option<u16> {
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let mut len = std::mem::size_of::<libc::sockaddr_storage>() as libc::socklen_t;
let ret = unsafe {
libc::getsockname(
fd.as_raw_fd(),
&mut storage as *mut _ as *mut libc::sockaddr,
&mut len,
)
};
if ret != 0 {
return None;
}
let bytes = unsafe {
std::slice::from_raw_parts(&storage as *const _ as *const u8, len as usize)
};
extract_port(bytes)
}
pub(crate) async fn handle_getsockname(
notif: &SeccompNotif,
network: &Arc<Mutex<NetworkState>>,
notif_fd: RawFd,
) -> NotifAction {
let sockfd = notif.data.args[0] as i32;
let addr_ptr = notif.data.args[1];
let addrlen_ptr = notif.data.args[2];
if addr_ptr == 0 || addrlen_ptr == 0 {
return NotifAction::Errno(libc::EFAULT);
}
let addrlen_bytes = match read_child_mem(notif_fd, notif.id, notif.pid, addrlen_ptr, 4) {
Ok(b) if b.len() >= 4 => b,
_ => return NotifAction::Errno(libc::EFAULT),
};
let addr_len = u32::from_ne_bytes(addrlen_bytes[..4].try_into().unwrap()) as usize;
let dup_fd = match crate::seccomp::notif::dup_fd_from_pid(notif.pid, sockfd) {
Ok(fd) => fd,
Err(e) => return NotifAction::Errno(e.raw_os_error().unwrap_or(libc::EBADF)),
};
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
let storage_len = std::mem::size_of::<libc::sockaddr_storage>();
let mut actual_len = addr_len.min(storage_len) as libc::socklen_t;
let ret = unsafe {
libc::getsockname(
dup_fd.as_raw_fd(),
&mut storage as *mut _ as *mut libc::sockaddr,
&mut actual_len,
)
};
if ret != 0 {
return NotifAction::Errno(unsafe { *libc::__errno_location() });
}
let actual_len_usize = actual_len as usize;
let to_write = addr_len.min(actual_len_usize).min(storage_len);
let mut bytes = if to_write == 0 {
Vec::new()
} else {
let storage_bytes = unsafe {
std::slice::from_raw_parts(
&storage as *const _ as *const u8,
storage_len,
)
};
storage_bytes[..to_write].to_vec()
};
if let Some(real_port) = extract_port(&bytes) {
let ns = network.lock().await;
if let Some(virtual_port) = ns.port_map.get_virtual(real_port) {
set_port_in_sockaddr(&mut bytes, virtual_port);
}
}
if !bytes.is_empty()
&& write_child_mem(notif_fd, notif.id, notif.pid, addr_ptr, &bytes).is_err()
{
return NotifAction::Errno(libc::EFAULT);
}
let actual = (actual_len as u32).to_ne_bytes();
if write_child_mem(notif_fd, notif.id, notif.pid, addrlen_ptr, &actual).is_err() {
return NotifAction::Errno(libc::EFAULT);
}
NotifAction::ReturnValue(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_port_map_identity() {
let mut pm = PortMap::new();
pm.record_bind(8080, 8080);
assert!(pm.bound_ports.contains(&8080));
assert_eq!(pm.get_real(8080), None); assert_eq!(pm.get_virtual(8080), None);
}
#[test]
fn test_port_map_remap() {
let mut pm = PortMap::new();
pm.record_bind(8080, 9090);
assert!(pm.bound_ports.contains(&9090));
assert_eq!(pm.get_real(8080), Some(9090));
assert_eq!(pm.get_virtual(9090), Some(8080));
}
#[test]
fn test_extract_port_ipv4() {
let mut buf = vec![0u8; 16];
let family = (AF_INET as u16).to_ne_bytes();
buf[0] = family[0];
buf[1] = family[1];
buf[2] = 0x1F; buf[3] = 0x90; assert_eq!(extract_port(&buf), Some(8080));
}
#[test]
fn test_extract_port_ipv6() {
let mut buf = vec![0u8; 28];
let family = (AF_INET6 as u16).to_ne_bytes();
buf[0] = family[0];
buf[1] = family[1];
buf[2] = 0x00;
buf[3] = 0x50; assert_eq!(extract_port(&buf), Some(80));
}
#[test]
fn test_extract_port_unix() {
let buf = vec![1, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(extract_port(&buf), None);
}
#[test]
fn test_extract_port_too_short() {
let buf = vec![2, 0];
assert_eq!(extract_port(&buf), None);
}
#[test]
fn test_set_port_in_sockaddr() {
let mut buf = vec![0u8; 16];
set_port_in_sockaddr(&mut buf, 443);
assert_eq!(buf[2], 0x01);
assert_eq!(buf[3], 0xBB);
}
}