#![cfg(target_os = "linux")]
#![allow(warnings)]
use std::cell::UnsafeCell;
use std::os::fd::{AsRawFd, RawFd};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
use std::time::Duration;
use crate::driver::{CompletionEntry, Driver, ERROR_TRANSPORT, Interest, SubmitEntry};
const MIN_EPOLL_SIZE: u32 = 32;
struct EpollState {
submit_head: AtomicUsize,
submit_tail: AtomicUsize,
completion_head: AtomicUsize,
completion_tail: AtomicU32,
}
struct CompletionQueue {
entries: Box<[Option<CompletionEntry>]>,
}
unsafe impl Send for CompletionQueue {}
unsafe impl Sync for CompletionQueue {}
impl CompletionQueue {
fn new(capacity: usize) -> Self {
Self {
entries: vec![None; capacity].into_boxed_slice(),
}
}
fn get(&self, index: usize) -> Option<&CompletionEntry> {
self.entries[index].as_ref()
}
unsafe fn set(&self, index: usize, entry: Option<CompletionEntry>) {
let ptr = self.entries.as_ptr() as *mut Option<CompletionEntry>;
*ptr.add(index) = entry;
}
}
pub struct EpollDriver {
epoll_fd: RawFd,
submit_queue: UnsafeCell<Vec<SubmitEntry>>,
completion_queue: CompletionQueue,
capacity: usize,
capacity_mask: usize,
state: Arc<EpollState>,
event_buffer: UnsafeCell<Vec<libc::epoll_event>>,
}
unsafe impl Send for EpollDriver {}
unsafe impl Sync for EpollDriver {}
impl EpollDriver {
pub fn new() -> std::io::Result<Self> {
Self::with_config(crate::driver::DriverConfig::default())
}
pub fn with_config(config: crate::driver::DriverConfig) -> std::io::Result<Self> {
let size = config.entries.max(MIN_EPOLL_SIZE);
let epoll_fd = unsafe {
libc::epoll_create(size as i32)
};
if epoll_fd < 0 {
return Err(std::io::Error::last_os_error());
}
unsafe {
let flags = libc::fcntl(epoll_fd, libc::F_GETFD);
if flags >= 0 {
libc::fcntl(epoll_fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
}
}
if let Some(_core) = config.cpu_affinity {
if let Err(e) = Self::set_cpu_affinity(_core) {
eprintln!("Warning: Failed to set CPU affinity: {}", e);
}
}
let capacity = size as usize;
let capacity_mask = capacity - 1;
Ok(Self {
epoll_fd,
submit_queue: UnsafeCell::new(vec![SubmitEntry::new(-1, 0, 0); capacity]),
completion_queue: CompletionQueue::new(capacity),
capacity,
capacity_mask,
state: Arc::new(EpollState {
submit_head: AtomicUsize::new(0),
submit_tail: AtomicUsize::new(0),
completion_head: AtomicUsize::new(0),
completion_tail: AtomicU32::new(0),
}),
event_buffer: UnsafeCell::new(vec![libc::epoll_event { events: 0, u64: 0 }; capacity]),
})
}
fn set_cpu_affinity(core: usize) -> std::io::Result<()> {
#[cfg(target_os = "linux")]
unsafe {
let mut cpu_set: libc::cpu_set_t = std::mem::zeroed();
libc::CPU_ZERO(&mut cpu_set);
libc::CPU_SET(core % libc::CPU_SETSIZE as usize, &mut cpu_set);
let result =
libc::sched_setaffinity(0, size_of::<libc::cpu_set_t>(), &cpu_set);
if result < 0 {
return Err(std::io::Error::last_os_error());
}
}
Ok(())
}
#[inline]
fn submit_pos(&self, index: usize) -> usize {
index & self.capacity_mask
}
#[inline]
fn completion_pos(&self, index: usize) -> usize {
index & self.capacity_mask
}
}
impl Drop for EpollDriver {
fn drop(&mut self) {
if self.epoll_fd >= 0 {
unsafe {
libc::close(self.epoll_fd);
}
}
}
}
impl AsRawFd for EpollDriver {
fn as_raw_fd(&self) -> RawFd {
self.epoll_fd
}
}
impl Driver for EpollDriver {
fn submit(&self) -> std::io::Result<usize> {
let mut submitted = 0;
let head = self.state.submit_head.load(Ordering::Acquire);
let tail = self.state.submit_tail.load(Ordering::Acquire);
let mut idx = head;
while idx != tail {
let pos = self.submit_pos(idx);
let submit_queue = unsafe { &*self.submit_queue.get() };
let entry = &submit_queue[pos];
if entry.fd >= 0 {
let mut event = libc::epoll_event {
events: (libc::EPOLLONESHOT | libc::EPOLLRDHUP) as u32,
u64: entry.user_data,
};
match entry.opcode {
crate::driver::opcode::READ => event.events |= libc::EPOLLIN as u32,
crate::driver::opcode::WRITE => event.events |= libc::EPOLLOUT as u32,
_ => {},
}
let op = libc::EPOLL_CTL_MOD;
let result = unsafe { libc::epoll_ctl(self.epoll_fd, op, entry.fd, &mut event) };
if result < 0 {
let err = std::io::Error::last_os_error();
if err.kind() == std::io::ErrorKind::NotFound {
let add_result = unsafe {
libc::epoll_ctl(
self.epoll_fd,
libc::EPOLL_CTL_ADD,
entry.fd,
&mut event,
)
};
if add_result < 0 {
return Err(err);
}
} else {
return Err(err);
}
}
submitted += 1;
}
idx += 1;
}
self.state.submit_head.store(tail, Ordering::Release);
Ok(submitted)
}
fn wait(&self) -> std::io::Result<usize> {
self.wait_internal(None)
}
fn wait_timeout(&self, duration: Duration) -> std::io::Result<(usize, bool)> {
let timeout_ms = duration.as_millis().min(i32::MAX as u128) as i32;
let result = self.wait_internal(Some(timeout_ms))?;
let head = self.state.completion_head.load(Ordering::Acquire) as u32;
let tail = self.state.completion_tail.load(Ordering::Acquire);
Ok((result, head == tail))
}
fn get_submission(&self) -> Option<&mut SubmitEntry> {
let tail = self.state.submit_tail.load(Ordering::Acquire);
let next_tail = tail + 1;
let head = self.state.submit_head.load(Ordering::Acquire);
if next_tail - head > self.capacity {
return None;
}
let pos = self.submit_pos(tail);
unsafe {
let submit_queue = &mut *self.submit_queue.get();
Some(&mut submit_queue[pos])
}
}
fn get_completion(&self) -> Option<&CompletionEntry> {
let head = self.state.completion_head.load(Ordering::Acquire);
let tail = self.state.completion_tail.load(Ordering::Acquire) as usize;
if head == tail {
return None;
}
let pos = self.completion_pos(head);
self.completion_queue.get(pos)
}
fn advance_completion(&self) {
let head = self.state.completion_head.load(Ordering::Acquire);
let tail = self.state.completion_tail.load(Ordering::Acquire) as usize;
if head != tail {
let pos = self.completion_pos(head);
unsafe {
self.completion_queue.set(pos, None);
}
let new_head = head + 1;
self.state
.completion_head
.store(new_head, Ordering::Release);
}
}
fn register(&self, fd: RawFd, interest: Interest) -> std::io::Result<()> {
let mut event = libc::epoll_event {
events: interest.to_epoll_flags(),
u64: 0,
};
let result = unsafe { libc::epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_ADD, fd, &mut event) };
if result < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
fn deregister(&self, fd: RawFd) -> std::io::Result<()> {
let mut event = libc::epoll_event { events: 0, u64: 0 };
let result = unsafe { libc::epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_DEL, fd, &mut event) };
if result < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
fn modify(&self, fd: RawFd, interest: Interest) -> std::io::Result<()> {
let mut event = libc::epoll_event {
events: interest.to_epoll_flags(),
u64: 0,
};
let result = unsafe { libc::epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_MOD, fd, &mut event) };
if result < 0 {
Err(std::io::Error::last_os_error())
} else {
Ok(())
}
}
fn submission_capacity(&self) -> usize {
self.capacity
}
fn completion_capacity(&self) -> usize {
self.capacity
}
fn supports_operation(&self, opcode: u8) -> bool {
matches!(
opcode,
crate::driver::opcode::READ
| crate::driver::opcode::WRITE
| crate::driver::opcode::CLOSE
)
}
}
impl EpollDriver {
fn wait_internal(&self, timeout_ms: Option<i32>) -> std::io::Result<usize> {
let event_buffer = unsafe { &mut *self.event_buffer.get() };
let ptr = event_buffer.as_mut_ptr();
let len = event_buffer.len() as i32;
let result = unsafe { libc::epoll_wait(self.epoll_fd, ptr, len, timeout_ms.unwrap_or(-1)) };
if result < 0 {
return Err(std::io::Error::last_os_error());
}
let count = result as usize;
for i in 0..count {
let event = &event_buffer[i];
let tail = self.state.completion_tail.load(Ordering::Acquire) as usize;
let pos = self.completion_pos(tail);
let result = if event.events & (libc::EPOLLERR | libc::EPOLLHUP) as u32 != 0 {
ERROR_TRANSPORT
} else if event.events & libc::EPOLLIN as u32 != 0 {
1 } else if event.events & libc::EPOLLOUT as u32 != 0 {
1 } else {
0
};
unsafe {
self.completion_queue.set(
pos,
Some(CompletionEntry {
user_data: event.u64,
result,
flags: event.events,
}),
);
}
self.state
.completion_tail
.store((tail + 1) as u32, Ordering::Release);
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_epoll_driver_creation() {
let driver = EpollDriver::new();
assert!(driver.is_ok());
let driver = driver.unwrap();
assert!(driver.epoll_fd >= 0);
assert_eq!(driver.capacity, 256);
}
#[test]
fn test_epoll_driver_with_config() {
let config = crate::driver::DriverConfigBuilder::new()
.entries(128)
.build();
let driver = EpollDriver::with_config(config);
assert!(driver.is_ok());
let driver = driver.unwrap();
assert_eq!(driver.capacity, 128);
}
#[test]
fn test_ring_buffer_positions() {
let driver = EpollDriver::new().unwrap();
assert_eq!(driver.submit_pos(0), 0);
assert_eq!(driver.submit_pos(255), 255);
assert_eq!(driver.submit_pos(256), 0);
assert_eq!(driver.submit_pos(257), 1);
}
}