#![cfg(feature = "io-uring")]
mod cqe_processor;
mod eventfd_poller;
mod external_op_tracker;
mod handler_manager;
mod internal_op_tracker;
mod main_loop;
mod multishot_reader;
mod sqe_builder;
use crate::io_uring_backend::buffer_manager::BufferRingManager;
use crate::io_uring_backend::connection_handler::{
HandlerSqeBlueprint, HandlerUpstreamEvent, ProtocolHandlerFactory, WorkerIoConfig
};
use crate::io_uring_backend::ops::UringOpRequest;
use crate::io_uring_backend::send_buffer_pool::SendBufferPool;
use crate::io_uring_backend::signaling_op_sender::SignalingOpSender;
use crate::io_uring_backend::UserData;
use crate::uring::{global_state, UringConfig};
use crate::{Msg, ZmqError};
use std::collections::{HashMap, VecDeque};
use std::fmt;
use std::mem;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::Arc;
use fibre::mpmc::{unbounded, AsyncSender, Sender as SyncSender, Receiver as SyncReceiver};
use fibre::mpsc;
use io_uring::opcode;
use io_uring::IoUring;
use tracing::{debug, error, info, trace, warn};
pub(crate) use eventfd_poller::EventFdPoller;
pub(crate) use external_op_tracker::{ExternalOpContext, ExternalOpTracker};
pub(crate) use handler_manager::HandlerManager;
pub(crate) use internal_op_tracker::{InternalOpPayload, InternalOpTracker, InternalOpType};
pub(crate) use multishot_reader::MultishotReader;
#[derive(Debug, Default)]
struct FdWork {
pub(crate) atomic_batches: VecDeque<Vec<HandlerSqeBlueprint>>,
pub(crate) app_data: VecDeque<Arc<Vec<Msg>>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WorkerState {
Running,
Draining, CleaningUp,
Stopped,
}
pub struct UringWorker {
pub(crate) state: WorkerState,
ring: IoUring,
op_rx: SyncReceiver<UringOpRequest>,
pub(crate) work_map: HashMap<RawFd, FdWork>,
buffer_manager: Option<BufferRingManager>, handler_manager: HandlerManager,
external_op_tracker: ExternalOpTracker,
internal_op_tracker: InternalOpTracker,
worker_io_config: Arc<WorkerIoConfig>,
default_buffer_ring_group_id_val: Option<u16>,
fds_needing_close_initiated_pass: VecDeque<RawFd>,
pub(crate) event_fd_poller: EventFdPoller,
send_buffer_pool: Option<Arc<SendBufferPool>>, pub(crate) fd_to_mpsc_rx: HashMap<RawFd, Arc<mpsc::BoundedReceiver<Arc<Vec<Msg>>>>>,
cfg_send_zerocopy_enabled: bool,
cfg_send_buffer_count: usize, cfg_send_buffer_size: usize,
}
impl fmt::Debug for UringWorker {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UringWorker")
.field("ring_fd", &self.ring.as_raw_fd())
.field("op_rx_len", &self.op_rx.len())
.field("op_rx_is_closed", &self.op_rx.is_closed())
.field("buffer_manager_is_some", &self.buffer_manager.is_some())
.field("external_op_tracker_len", &self.external_op_tracker.in_flight.len())
.field("internal_op_tracker_len", &self.internal_op_tracker.op_to_details.len())
.field("default_buffer_ring_group_id_val", &self.default_buffer_ring_group_id_val)
.field("event_fd_poller", &self.event_fd_poller)
.field("fd_to_mpsc_rx_len", &self.fd_to_mpsc_rx.len())
.finish_non_exhaustive()
}
}
impl UringWorker {
pub fn spawn_with_config(
config: UringConfig,
factories: Vec<Arc<dyn ProtocolHandlerFactory>>,
upstream_event_tx: SyncSender<(RawFd, HandlerUpstreamEvent)>,
) -> Result<(SignalingOpSender, std::thread::JoinHandle<Result<(), ZmqError>>), ZmqError> {
let (op_tx_sync, op_rx) = unbounded::<UringOpRequest>();
let op_tx_async_for_signaler: AsyncSender<UringOpRequest> = op_tx_sync.to_async();
let worker_io_config = Arc::new(WorkerIoConfig { upstream_event_tx });
let event_fd_master_instance =
eventfd::EventFD::new(0, eventfd::EfdFlags::EFD_CLOEXEC | eventfd::EfdFlags::EFD_NONBLOCK).map_err(|e| {
error!("Failed to create master EventFD for UringWorker: {}", e);
ZmqError::Internal(format!("Master EventFD creation failed: {}", e))
})?;
let signaling_op_sender = SignalingOpSender::new(op_tx_async_for_signaler, event_fd_master_instance.clone());
let worker_thread_join_handle = std::thread::Builder::new()
.name("rzmq-io-uring-worker".into())
.spawn(move || { match IoUring::new(config.ring_entries) {
Ok(ring) => {
let mut internal_tracker = InternalOpTracker::new();
let event_fd_poller_instance = EventFdPoller::new_with_fd(
event_fd_master_instance,
&mut internal_tracker,
);
let mut worker_send_buffer_pool: Option<Arc<SendBufferPool>> = None;
let mut effective_send_zerocopy_enabled_for_worker = config.default_send_zerocopy;
if config.default_send_zerocopy {
if config.default_send_buffer_count > 0 && config.default_send_buffer_size > 0 {
match SendBufferPool::new(&ring, config.default_send_buffer_count, config.default_send_buffer_size) {
Ok(pool) => {
info!("UringWorker: SendBufferPool initialized (count: {}, size: {}). Zero-copy send enabled.", config.default_send_buffer_count, config.default_send_buffer_size);
worker_send_buffer_pool = Some(Arc::new(pool));
}
Err(e) => {
error!("UringWorker: Failed to initialize SendBufferPool from config: {}. Disabling ZC send for this worker.", e);
effective_send_zerocopy_enabled_for_worker = false;
}
}
} else {
info!("UringWorker: Zero-copy send requested by config, but pool count/size is zero. Disabling ZC send for this worker.");
effective_send_zerocopy_enabled_for_worker = false;
}
} else {
info!("UringWorker: Zero-copy send not enabled by config.");
}
let mut default_worker_buffer_manager: Option<BufferRingManager> = None;
let mut default_worker_bgid_val: Option<u16> = None;
if config.default_recv_buffer_count > 0 && config.default_recv_buffer_size > 0 {
match BufferRingManager::new(&ring, config.default_recv_buffer_count as u16, 0, config.default_recv_buffer_size) {
Ok(bm) => {
info!("UringWorker: Default BufferRingManager (bgid 0) initialized (count: {}, size: {}). Provided-buffer reads are enabled.", config.default_recv_buffer_count, config.default_recv_buffer_size);
default_worker_buffer_manager = Some(bm);
default_worker_bgid_val = Some(0); }
Err(e) => {
error!("UringWorker: Failed to initialize default BufferRingManager from config: {}. Provided-buffer reads may fail.", e);
}
}
} else {
info!("UringWorker: Default provided-buffer recv ring not configured (count/size is zero).");
}
let mut worker = UringWorker {
state: WorkerState::Running,
ring,
op_rx,
work_map: HashMap::new(),
buffer_manager: default_worker_buffer_manager,
handler_manager: HandlerManager::new(factories, worker_io_config.clone()),
external_op_tracker: ExternalOpTracker::new(),
internal_op_tracker: internal_tracker,
event_fd_poller: event_fd_poller_instance,
worker_io_config,
default_buffer_ring_group_id_val: default_worker_bgid_val,
fds_needing_close_initiated_pass: VecDeque::new(),
send_buffer_pool: worker_send_buffer_pool,
fd_to_mpsc_rx: HashMap::new(),
cfg_send_zerocopy_enabled: effective_send_zerocopy_enabled_for_worker,
cfg_send_buffer_count: config.default_send_buffer_count,
cfg_send_buffer_size: config.default_send_buffer_size,
};
{
let loop_result = main_loop::run_worker_loop(&mut worker);
if loop_result.is_err() {
warn!("UringWorker: Loop exited with error, cleanup might be partial.");
}
}
info!("rzmq-io-uring-worker OS thread (PID: {}) finished.", std::process::id());
Ok(()) }
Err(e) => {
error!("Failed to initialize IoUring (entries: {}): {}. UringWorker thread cannot start.", config.ring_entries, e);
Err(ZmqError::Internal(format!("IoUring init failed: {}", e)))
}
}
})
.map_err(|e| ZmqError::Internal(format!("Failed to spawn UringWorker thread: {:?}", e)))?;
Ok((signaling_op_sender, worker_thread_join_handle))
}
}
impl UringWorker {
fn transition_to_draining(&mut self) {
info!("UringWorker: Transitioning to DRAINING state.");
self.state = WorkerState::Draining;
if let Some(tx) = global_state::get_global_parsed_msg_tx_mutex().lock().take() {
drop(tx);
debug!("UringWorker: Dropped upstream TX channel to signal processor shutdown.");
}
let mut sq_for_shutdown = unsafe { self.ring.submission_shared() };
let internal_ops_to_cancel: Vec<UserData> =
self.internal_op_tracker.op_to_details.keys().copied().collect();
info!(
"UringWorker: Draining state - Cancelling {} in-flight internal operations.",
internal_ops_to_cancel.len()
);
for op_ud in internal_ops_to_cancel {
if sq_for_shutdown.is_full() {
warn!("UringWorker draining transition: SQ full, cannot submit all cancel ops. CQE processing will need to handle the rest.");
break;
}
trace!(
"UringWorker draining transition: Submitting AsyncCancel for internal op_ud {}",
op_ud
);
let cancel_sqe = opcode::AsyncCancel::new(op_ud).build().user_data(0); unsafe {
let _ = sq_for_shutdown.push(&cancel_sqe);
}
}
drop(sq_for_shutdown);
if let Err(e) = self.ring.submitter().submit() {
warn!(
"UringWorker draining transition: Error submitting cancellation SQEs: {}",
e
);
}
}
}
pub(crate) fn socket_addr_to_sockaddr_storage(
addr: &SocketAddr,
storage: &mut libc::sockaddr_storage,
) -> libc::socklen_t {
unsafe {
*(storage as *mut _ as *mut [u8; std::mem::size_of::<libc::sockaddr_storage>()]) =
[0; std::mem::size_of::<libc::sockaddr_storage>()];
match addr {
SocketAddr::V4(v4_addr) => {
let sockaddr_in: &mut libc::sockaddr_in = mem::transmute(storage);
sockaddr_in.sin_family = libc::AF_INET as libc::sa_family_t;
sockaddr_in.sin_port = v4_addr.port().to_be();
sockaddr_in.sin_addr = libc::in_addr {
s_addr: u32::from_ne_bytes(v4_addr.ip().octets()).to_be(),
};
mem::size_of::<libc::sockaddr_in>() as libc::socklen_t
}
SocketAddr::V6(v6_addr) => {
let sockaddr_in6: &mut libc::sockaddr_in6 = mem::transmute(storage);
sockaddr_in6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
sockaddr_in6.sin6_port = v6_addr.port().to_be();
sockaddr_in6.sin6_addr = libc::in6_addr {
s6_addr: v6_addr.ip().octets(),
};
sockaddr_in6.sin6_flowinfo = v6_addr.flowinfo(); sockaddr_in6.sin6_scope_id = v6_addr.scope_id(); mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t
}
}
}
}
#[allow(dead_code)] pub(crate) fn dummy_socket_addr() -> SocketAddr {
SocketAddr::from((Ipv6Addr::UNSPECIFIED, 0))
}
pub(crate) fn get_peer_local_addr(fd: RawFd) -> Result<(SocketAddr, SocketAddr), std::io::Error> {
let mut peer_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
let mut peer_addrlen = mem::size_of_val(&peer_storage) as libc::socklen_t;
if unsafe {
libc::getpeername(
fd,
&mut peer_storage as *mut _ as *mut libc::sockaddr,
&mut peer_addrlen,
)
} != 0
{
return Err(std::io::Error::last_os_error());
}
let peer_saddr = sockaddr_storage_to_socket_addr(&peer_storage, peer_addrlen)?;
let mut local_storage: libc::sockaddr_storage = unsafe { mem::zeroed() };
let mut local_addrlen = mem::size_of_val(&local_storage) as libc::socklen_t;
if unsafe {
libc::getsockname(
fd,
&mut local_storage as *mut _ as *mut libc::sockaddr,
&mut local_addrlen,
)
} != 0
{
return Err(std::io::Error::last_os_error());
}
let local_saddr = sockaddr_storage_to_socket_addr(&local_storage, local_addrlen)?;
Ok((peer_saddr, local_saddr))
}
pub(crate) fn sockaddr_storage_to_socket_addr(
storage: &libc::sockaddr_storage,
len: libc::socklen_t,
) -> std::io::Result<SocketAddr> {
match storage.ss_family as libc::c_int {
libc::AF_INET => {
if len as usize >= mem::size_of::<libc::sockaddr_in>() {
let sa = unsafe { &*(storage as *const _ as *const libc::sockaddr_in) };
let ip = Ipv4Addr::from(u32::from_be(sa.sin_addr.s_addr)); let port = u16::from_be(sa.sin_port); Ok(SocketAddr::V4(SocketAddrV4::new(ip, port)))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"sockaddr_in length too small",
))
}
}
libc::AF_INET6 => {
if len as usize >= mem::size_of::<libc::sockaddr_in6>() {
let sa = unsafe { &*(storage as *const _ as *const libc::sockaddr_in6) };
let ip = Ipv6Addr::from(sa.sin6_addr.s6_addr); let port = u16::from_be(sa.sin6_port); let flowinfo = u32::from_be(sa.sin6_flowinfo); let scope_id = u32::from_be(sa.sin6_scope_id); Ok(SocketAddr::V6(SocketAddrV6::new(ip, port, flowinfo, scope_id)))
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"sockaddr_in6 length too small",
))
}
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"invalid socket address family",
)),
}
}