#![forbid(unsafe_code)]
use std::os::fd::RawFd;
use libc::c_int;
use libseccomp::ScmpNotifResp;
use nix::{errno::Errno, sys::socket::SockFlag};
use crate::{
cache::UnixVal,
compat::{AddressFamily, SockType, AF_MAX, SOCK_TYPE_MASK},
confine::is_valid_ptr,
cookie::{safe_socket, safe_socketpair},
ip::SocketCall,
kernel::net::sandbox_addr_unnamed,
req::UNotifyEventRequest,
sandbox::{Flags, NetlinkFamily, Options, SandboxGuard},
};
pub(crate) fn handle_socket(
request: &UNotifyEventRequest,
args: &[u64; 6],
flags: Flags,
options: Options,
netlink_families: NetlinkFamily,
) -> Result<ScmpNotifResp, Errno> {
let allow_unsafe_socket = options.allow_unsafe_socket();
let allow_unsupp_socket = options.allow_unsupp_socket();
let allow_unsafe_kcapi = options.allow_unsafe_kcapi();
let force_cloexec = flags.force_cloexec();
let force_rand_fd = flags.force_rand_fd();
#[expect(clippy::cast_possible_truncation)]
let stype = args[1] as c_int;
let sflag = SockFlag::from_bits(stype & !SOCK_TYPE_MASK).ok_or(Errno::EINVAL)?;
#[expect(clippy::cast_possible_truncation)]
let domain = AddressFamily::from_raw(args[0] as c_int);
if !(0..AF_MAX).contains(&domain.as_raw()) {
return Err(Errno::EAFNOSUPPORT);
}
let stype = match SockType::try_from(stype) {
Err(Errno::EINVAL) => return Err(Errno::EINVAL),
_ if domain == AddressFamily::Unspec => return Err(Errno::EAFNOSUPPORT),
Err(errno) => return Err(errno),
Ok(stype) => stype,
};
let stype = if domain == AddressFamily::Unix && stype == SockType::Raw {
SockType::Datagram
} else {
stype
};
#[expect(clippy::cast_possible_truncation)]
let proto = args[2] as c_int;
if !allow_unsupp_socket {
match domain {
AddressFamily::Unix | AddressFamily::Inet | AddressFamily::Inet6 => {}
AddressFamily::Alg if allow_unsafe_kcapi => {}
AddressFamily::Netlink => {
#[expect(clippy::cast_possible_truncation)]
let nlfam = args[2] as i32;
if !(0..=NetlinkFamily::max()).contains(&nlfam) {
return Err(Errno::EPROTONOSUPPORT);
}
let nlfam = NetlinkFamily::from_bits(1 << nlfam).ok_or(Errno::EPROTONOSUPPORT)?;
if !netlink_families.contains(nlfam) {
return Err(Errno::EPROTONOSUPPORT);
}
}
AddressFamily::Packet if !allow_unsafe_socket => return Err(Errno::EACCES),
AddressFamily::Packet => {}
_ => return Err(Errno::EAFNOSUPPORT),
}
} else if !allow_unsafe_kcapi && domain == AddressFamily::Alg {
return Err(Errno::EAFNOSUPPORT);
} else if !allow_unsafe_socket
&& (domain == AddressFamily::Packet
|| (domain != AddressFamily::Netlink && stype.is_unsafe()))
{
return Err(Errno::EACCES);
} else {
}
let cloexec = force_cloexec || sflag.contains(SockFlag::SOCK_CLOEXEC);
let sflag = sflag | SockFlag::SOCK_CLOEXEC;
let req = request.scmpreq;
request.cache.add_sys_block(req, false)?;
let result = safe_socket(domain, stype, sflag, proto);
request.cache.del_sys_block(req.id)?;
let fd = result?;
request.send_fd(fd, cloexec, force_rand_fd)
}
pub(crate) fn handle_socketpair(
request: &UNotifyEventRequest,
sandbox: SandboxGuard,
args: &[u64; 6],
call: SocketCall,
) -> Result<ScmpNotifResp, Errno> {
let flags = *sandbox.flags;
let options = *sandbox.options;
let force_cloexec = flags.force_cloexec();
let force_rand_fd = flags.force_rand_fd();
let allow_unsupp_socket = options.allow_unsupp_socket();
#[expect(clippy::cast_possible_truncation)]
let stype = args[1] as c_int;
let sflag = SockFlag::from_bits(stype & !SOCK_TYPE_MASK).ok_or(Errno::EINVAL)?;
#[expect(clippy::cast_possible_truncation)]
let domain = AddressFamily::from_raw(args[0] as c_int);
if !(0..AF_MAX).contains(&domain.as_raw()) {
return Err(Errno::EAFNOSUPPORT);
}
let stype = match SockType::try_from(stype) {
Err(Errno::EINVAL) => return Err(Errno::EINVAL),
_ if domain == AddressFamily::Unspec => return Err(Errno::EAFNOSUPPORT),
Err(errno) => return Err(errno),
Ok(stype) => stype,
};
#[expect(clippy::cast_possible_truncation)]
let proto = args[2] as c_int;
let stype = if domain == AddressFamily::Unix && stype == SockType::Raw {
SockType::Datagram
} else {
stype
};
let check_access = match domain {
AddressFamily::Unix if !matches!(proto, 0 | libc::AF_UNIX) => {
return Err(Errno::EPROTONOSUPPORT)
}
AddressFamily::Unix => true,
AddressFamily::Tipc if !allow_unsupp_socket => return Err(Errno::EOPNOTSUPP),
_ => false,
};
if check_access {
sandbox_addr_unnamed(request, &sandbox, call)?;
}
drop(sandbox);
let fdptr = args[3];
if !is_valid_ptr(fdptr, request.scmpreq.data.arch) {
return Err(Errno::EFAULT);
}
let cloexec = force_cloexec || sflag.contains(SockFlag::SOCK_CLOEXEC);
let sflag = sflag | SockFlag::SOCK_CLOEXEC;
let req = request.scmpreq;
request.cache.add_sys_block(req, false)?;
let result = safe_socketpair(domain, stype, proto, sflag);
request.cache.del_sys_block(req.id)?;
let (fd0, fd1) = result?;
let out = [0u8; 2 * size_of::<RawFd>()];
request.write_mem_all(&out, fdptr)?;
if domain == AddressFamily::Unix {
let _ = request.add_unix(&fd0, request.scmpreq.pid(), UnixVal::default());
let _ = request.add_unix(&fd1, request.scmpreq.pid(), UnixVal::default());
}
let newfd0 = request.add_fd(fd0, cloexec, force_rand_fd)?;
let newfd1 = request.add_fd(fd1, cloexec, force_rand_fd)?;
let a = newfd0.to_ne_bytes();
let b = newfd1.to_ne_bytes();
let out = [a[0], a[1], a[2], a[3], b[0], b[1], b[2], b[3]];
request.write_mem_all(&out, fdptr)?;
Ok(request.return_syscall(0))
}