use ahash::{HashMap, HashMapExt};
use core::mem::MaybeUninit;
use lever::sync::prelude::*;
use pin_utils::unsafe_pinned;
use std::future::Future;
use std::io;
use std::os::unix::io::AsRawFd;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
macro_rules! syscall {
($fn:ident $args:tt) => {{
let res = unsafe { libc::$fn $args };
if res == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(res)
}
}};
}
use crate::config::NucleiConfig;
use crossbeam_channel::{unbounded, Receiver, Sender};
use rustix_uring::cqueue::{more, sock_nonempty};
use rustix_uring::{
cqueue::Entry as CQEntry, squeue::Entry as SQEntry, CompletionQueue, IoUring, SubmissionQueue,
Submitter,
};
use socket2::SockAddr;
use std::mem;
use std::os::unix::net::SocketAddr as UnixSocketAddr;
fn max_len() -> usize {
if cfg!(target_os = "macos") {
<libc::c_int>::max_value() as usize - 1
} else {
<libc::ssize_t>::max_value() as usize
}
}
pub(crate) fn shim_recv_from<A: AsRawFd>(
fd: A,
buf: &mut [u8],
flags: libc::c_int,
) -> io::Result<(usize, SockAddr)> {
let mut storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
let mut addrlen = mem::size_of_val(&storage) as libc::socklen_t;
let n = syscall!(recvfrom(
fd.as_raw_fd() as _,
buf.as_mut_ptr() as *mut libc::c_void,
std::cmp::min(buf.len(), max_len()),
flags,
&mut storage as *mut _ as *mut _,
&mut addrlen,
))?;
let addr = unsafe { SockAddr::from_raw_parts(&storage as *const _ as *const _, addrlen) };
Ok((n as usize, addr))
}
struct FakeUnixSocketAddr {
addr: libc::sockaddr_un,
len: libc::socklen_t,
}
pub(crate) fn shim_to_af_unix(sockaddr: &SockAddr) -> io::Result<UnixSocketAddr> {
let addr = unsafe { &*(sockaddr.as_ptr() as *const libc::sockaddr_un) };
if addr.sun_family != libc::AF_UNIX as libc::sa_family_t {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"socket is not AF_UNIX type",
));
}
let mut len = sockaddr.len();
let abst_sock_ident: libc::c_char = unsafe {
std::slice::from_raw_parts(
&addr.sun_path as *const _ as *const u8,
mem::size_of::<libc::c_char>(),
)
}[1] as libc::c_char;
match (len, abst_sock_ident) {
(sa, 0) if sa != 0 && sa > mem::size_of::<libc::sa_family_t>() as libc::socklen_t => {
len = mem::size_of::<libc::sa_family_t>() as libc::socklen_t;
}
(0, _) => {
let base = &addr as *const _ as usize;
let path = &addr.sun_path as *const _ as usize;
let sun_path_offset = path - base;
len = sun_path_offset as libc::socklen_t;
}
(_, _) => {}
}
let addr: UnixSocketAddr = unsafe {
let mut init = MaybeUninit::<libc::sockaddr_un>::zeroed();
std::ptr::copy_nonoverlapping(
sockaddr.as_ptr(),
&mut init as *mut _ as *mut _,
len as usize,
);
std::mem::transmute(FakeUnixSocketAddr {
addr: init.assume_init(),
len: len as _,
})
};
Ok(addr)
}
pub struct SysProactor {
pub(crate) sq: TTas<SubmissionQueue<'static>>,
pub(crate) cq: TTas<CompletionQueue<'static>>,
sbmt: Submitter<'static>,
submitters: TTas<HashMap<u64, Sender<i32>>>,
submitter_id: AtomicU64,
aggressive_poll: bool,
}
pub type RingTypes = (
SubmissionQueue<'static>,
CompletionQueue<'static>,
Submitter<'static>,
);
pub(crate) static mut IO_URING: Option<IoUring> = None;
impl SysProactor {
pub(crate) fn new(config: NucleiConfig) -> io::Result<SysProactor> {
unsafe {
let mut rb = IoUring::builder();
config
.iouring
.sqpoll_wake_interval
.map(|e| rb.setup_sqpoll(e));
if config.iouring.iopoll_enabled {
rb.setup_iopoll();
}
let ring = rb
.build(config.iouring.queue_len)
.expect("nuclei: uring can't be initialized");
IO_URING = Some(ring);
let (sbmt, sq, cq) = IO_URING.as_mut().unwrap().split();
match (
config.iouring.per_numa_bounded_worker_count,
config.iouring.per_numa_unbounded_worker_count,
) {
(Some(bw), Some(ubw)) => sbmt.register_iowq_max_workers(&mut [bw, ubw])?,
(None, Some(ubw)) => sbmt.register_iowq_max_workers(&mut [0, ubw])?,
(Some(bw), None) => sbmt.register_iowq_max_workers(&mut [bw, 0])?,
(None, None) => sbmt.register_iowq_max_workers(&mut [0, 0])?,
}
Ok(SysProactor {
sq: TTas::new(sq),
cq: TTas::new(cq),
sbmt,
submitters: TTas::new(HashMap::with_capacity(config.iouring.queue_len as usize)),
submitter_id: AtomicU64::default(),
aggressive_poll: config.iouring.aggressive_poll,
})
}
}
pub(crate) fn register_files_sparse(&self, n: u32) -> io::Result<()> {
Ok(self.sbmt.register_files_sparse(n)?)
}
pub(crate) fn register_io(&self, mut sqe: SQEntry) -> io::Result<CompletionChan> {
let id = self.submitter_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = unbounded::<i32>();
sqe = sqe.user_data(id);
let mut subguard = self.submitters.lock();
subguard.insert(id, tx);
drop(subguard);
let mut sq = self.sq.lock();
unsafe {
sq.push(&sqe).expect("nuclei: submission queue is full");
}
sq.sync();
drop(sq);
self.sbmt.submit()?;
Ok(CompletionChan { rx })
}
pub(crate) fn wake(&self) -> io::Result<()> {
Ok(())
}
pub(crate) fn wait(
&self,
max_event_size: usize,
duration: Option<Duration>,
) -> io::Result<usize> {
let mut cq = self.cq.lock();
let mut acc: usize = 0;
'sock: loop {
if !self.aggressive_poll {
self.sbmt.submit_and_wait(1)?;
}
cq.sync();
for cqe in cq.by_ref() {
if more(cqe.flags()) {
self.cqe_completion_multi(&cqe)?;
} else {
self.cqe_completion_single(&cqe)?;
}
acc += 1;
if !sock_nonempty(cqe.flags()) || !more(cqe.flags()) {
break 'sock;
}
}
}
Ok(acc)
}
fn cqe_completion_multi(&self, cqe: &CQEntry) -> io::Result<()> {
let udata = cqe.user_data();
let res: i32 = cqe.result();
let sbmts = self.submitters.lock();
sbmts.get(&udata).map(|s| s.send(res));
Ok(())
}
fn cqe_completion_single(&self, cqe: &CQEntry) -> io::Result<()> {
let udata = cqe.user_data();
let res: i32 = cqe.result();
let mut sbmts = self.submitters.lock();
let x = sbmts.remove(&udata);
x.map(|s| s.send(res));
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct CompletionChan {
rx: Receiver<i32>,
}
impl CompletionChan {
pub fn get_rx(&self) -> Receiver<i32> {
self.rx.clone()
}
}
impl CompletionChan {
unsafe_pinned!(rx: Receiver<i32>);
}
impl Future for CompletionChan {
type Output = io::Result<i32>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.rx();
Poll::Ready(
this.recv()
.map_err(|_| io::Error::new(io::ErrorKind::TimedOut, "sender has been cancelled")),
)
}
}