1#![cfg(target_os = "linux")]
8#![allow(warnings)]
9
10use std::cell::UnsafeCell;
11use std::os::fd::{AsRawFd, RawFd};
12use std::sync::Arc;
13use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
14use std::time::Duration;
15
16use crate::driver::{CompletionEntry, Driver, ERROR_TRANSPORT, Interest, SubmitEntry};
17
18const MIN_EPOLL_SIZE: u32 = 32;
20
21struct EpollState {
24 submit_head: AtomicUsize,
26 submit_tail: AtomicUsize,
28 completion_head: AtomicUsize,
30 completion_tail: AtomicU32,
32}
33
34struct CompletionQueue {
37 entries: Box<[Option<CompletionEntry>]>,
39}
40
41unsafe impl Send for CompletionQueue {}
44unsafe impl Sync for CompletionQueue {}
45
46impl CompletionQueue {
47 fn new(capacity: usize) -> Self {
50 Self {
51 entries: vec![None; capacity].into_boxed_slice(),
52 }
53 }
54
55 fn get(&self, index: usize) -> Option<&CompletionEntry> {
58 self.entries[index].as_ref()
59 }
60
61 unsafe fn set(&self, index: usize, entry: Option<CompletionEntry>) {
69 let ptr = self.entries.as_ptr() as *mut Option<CompletionEntry>;
72 *ptr.add(index) = entry;
73 }
74}
75
76pub struct EpollDriver {
82 epoll_fd: RawFd,
84 submit_queue: UnsafeCell<Vec<SubmitEntry>>,
86 completion_queue: CompletionQueue,
88 capacity: usize,
90 capacity_mask: usize,
92 state: Arc<EpollState>,
94 event_buffer: UnsafeCell<Vec<libc::epoll_event>>,
96}
97
98unsafe impl Send for EpollDriver {}
101
102unsafe impl Sync for EpollDriver {}
105
106impl EpollDriver {
107 pub fn new() -> std::io::Result<Self> {
115 Self::with_config(crate::driver::DriverConfig::default())
116 }
117
118 pub fn with_config(config: crate::driver::DriverConfig) -> std::io::Result<Self> {
130 let size = config.entries.max(MIN_EPOLL_SIZE);
133 let epoll_fd = unsafe {
134 libc::epoll_create(size as i32)
137 };
138
139 if epoll_fd < 0 {
140 return Err(std::io::Error::last_os_error());
141 }
142
143 unsafe {
146 let flags = libc::fcntl(epoll_fd, libc::F_GETFD);
147 if flags >= 0 {
148 libc::fcntl(epoll_fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
149 }
150 }
151
152 if let Some(_core) = config.cpu_affinity {
155 if let Err(e) = Self::set_cpu_affinity(_core) {
156 eprintln!("Warning: Failed to set CPU affinity: {}", e);
159 }
160 }
161
162 let capacity = size as usize;
163 let capacity_mask = capacity - 1;
164
165 Ok(Self {
166 epoll_fd,
167 submit_queue: UnsafeCell::new(vec![SubmitEntry::new(-1, 0, 0); capacity]),
168 completion_queue: CompletionQueue::new(capacity),
169 capacity,
170 capacity_mask,
171 state: Arc::new(EpollState {
172 submit_head: AtomicUsize::new(0),
173 submit_tail: AtomicUsize::new(0),
174 completion_head: AtomicUsize::new(0),
175 completion_tail: AtomicU32::new(0),
176 }),
177 event_buffer: UnsafeCell::new(vec![libc::epoll_event { events: 0, u64: 0 }; capacity]),
178 })
179 }
180
181 fn set_cpu_affinity(core: usize) -> std::io::Result<()> {
184 #[cfg(target_os = "linux")]
185 unsafe {
186 let mut cpu_set: libc::cpu_set_t = std::mem::zeroed();
187 libc::CPU_ZERO(&mut cpu_set);
188 libc::CPU_SET(core % libc::CPU_SETSIZE as usize, &mut cpu_set);
189
190 let result =
191 libc::sched_setaffinity(0, size_of::<libc::cpu_set_t>(), &cpu_set);
192
193 if result < 0 {
194 return Err(std::io::Error::last_os_error());
195 }
196 }
197
198 Ok(())
199 }
200
201 #[inline]
204 fn submit_pos(&self, index: usize) -> usize {
205 index & self.capacity_mask
206 }
207
208 #[inline]
211 fn completion_pos(&self, index: usize) -> usize {
212 index & self.capacity_mask
213 }
214}
215
216impl Drop for EpollDriver {
217 fn drop(&mut self) {
218 if self.epoll_fd >= 0 {
219 unsafe {
220 libc::close(self.epoll_fd);
221 }
222 }
223 }
224}
225
226impl AsRawFd for EpollDriver {
227 fn as_raw_fd(&self) -> RawFd {
228 self.epoll_fd
229 }
230}
231
232impl Driver for EpollDriver {
233 fn submit(&self) -> std::io::Result<usize> {
234 let mut submitted = 0;
235 let head = self.state.submit_head.load(Ordering::Acquire);
236 let tail = self.state.submit_tail.load(Ordering::Acquire);
237
238 let mut idx = head;
241 while idx != tail {
242 let pos = self.submit_pos(idx);
243 let submit_queue = unsafe { &*self.submit_queue.get() };
244 let entry = &submit_queue[pos];
245
246 if entry.fd >= 0 {
247 let mut event = libc::epoll_event {
250 events: (libc::EPOLLONESHOT | libc::EPOLLRDHUP) as u32,
251 u64: entry.user_data,
252 };
253
254 match entry.opcode {
257 crate::driver::opcode::READ => event.events |= libc::EPOLLIN as u32,
258 crate::driver::opcode::WRITE => event.events |= libc::EPOLLOUT as u32,
259 _ => {},
260 }
261
262 let op = libc::EPOLL_CTL_MOD;
263 let result = unsafe { libc::epoll_ctl(self.epoll_fd, op, entry.fd, &mut event) };
264
265 if result < 0 {
266 let err = std::io::Error::last_os_error();
267 if err.kind() == std::io::ErrorKind::NotFound {
270 let add_result = unsafe {
271 libc::epoll_ctl(
272 self.epoll_fd,
273 libc::EPOLL_CTL_ADD,
274 entry.fd,
275 &mut event,
276 )
277 };
278 if add_result < 0 {
279 return Err(err);
280 }
281 } else {
282 return Err(err);
283 }
284 }
285
286 submitted += 1;
287 }
288
289 idx += 1;
290 }
291
292 self.state.submit_head.store(tail, Ordering::Release);
295
296 Ok(submitted)
297 }
298
299 fn wait(&self) -> std::io::Result<usize> {
300 self.wait_internal(None)
301 }
302
303 fn wait_timeout(&self, duration: Duration) -> std::io::Result<(usize, bool)> {
304 let timeout_ms = duration.as_millis().min(i32::MAX as u128) as i32;
305 let result = self.wait_internal(Some(timeout_ms))?;
306
307 let head = self.state.completion_head.load(Ordering::Acquire) as u32;
310 let tail = self.state.completion_tail.load(Ordering::Acquire);
311
312 Ok((result, head == tail))
313 }
314
315 fn get_submission(&self) -> Option<&mut SubmitEntry> {
316 let tail = self.state.submit_tail.load(Ordering::Acquire);
317 let next_tail = tail + 1;
318 let head = self.state.submit_head.load(Ordering::Acquire);
319
320 if next_tail - head > self.capacity {
323 return None;
324 }
325
326 let pos = self.submit_pos(tail);
327 unsafe {
330 let submit_queue = &mut *self.submit_queue.get();
331 Some(&mut submit_queue[pos])
332 }
333 }
334
335 fn get_completion(&self) -> Option<&CompletionEntry> {
336 let head = self.state.completion_head.load(Ordering::Acquire);
337 let tail = self.state.completion_tail.load(Ordering::Acquire) as usize;
338
339 if head == tail {
340 return None;
341 }
342
343 let pos = self.completion_pos(head);
344 self.completion_queue.get(pos)
345 }
346
347 fn advance_completion(&self) {
348 let head = self.state.completion_head.load(Ordering::Acquire);
349 let tail = self.state.completion_tail.load(Ordering::Acquire) as usize;
350
351 if head != tail {
352 let pos = self.completion_pos(head);
353 unsafe {
356 self.completion_queue.set(pos, None);
357 }
358
359 let new_head = head + 1;
360 self.state
361 .completion_head
362 .store(new_head, Ordering::Release);
363 }
364 }
365
366 fn register(&self, fd: RawFd, interest: Interest) -> std::io::Result<()> {
367 let mut event = libc::epoll_event {
368 events: interest.to_epoll_flags(),
369 u64: 0,
370 };
371
372 let result = unsafe { libc::epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_ADD, fd, &mut event) };
373
374 if result < 0 {
375 Err(std::io::Error::last_os_error())
376 } else {
377 Ok(())
378 }
379 }
380
381 fn deregister(&self, fd: RawFd) -> std::io::Result<()> {
382 let mut event = libc::epoll_event { events: 0, u64: 0 };
383
384 let result = unsafe { libc::epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_DEL, fd, &mut event) };
385
386 if result < 0 {
387 Err(std::io::Error::last_os_error())
388 } else {
389 Ok(())
390 }
391 }
392
393 fn modify(&self, fd: RawFd, interest: Interest) -> std::io::Result<()> {
394 let mut event = libc::epoll_event {
395 events: interest.to_epoll_flags(),
396 u64: 0,
397 };
398
399 let result = unsafe { libc::epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_MOD, fd, &mut event) };
400
401 if result < 0 {
402 Err(std::io::Error::last_os_error())
403 } else {
404 Ok(())
405 }
406 }
407
408 fn submission_capacity(&self) -> usize {
409 self.capacity
410 }
411
412 fn completion_capacity(&self) -> usize {
413 self.capacity
414 }
415
416 fn supports_operation(&self, opcode: u8) -> bool {
417 matches!(
418 opcode,
419 crate::driver::opcode::READ
420 | crate::driver::opcode::WRITE
421 | crate::driver::opcode::CLOSE
422 )
423 }
424}
425
426impl EpollDriver {
427 fn wait_internal(&self, timeout_ms: Option<i32>) -> std::io::Result<usize> {
430 let event_buffer = unsafe { &mut *self.event_buffer.get() };
431 let ptr = event_buffer.as_mut_ptr();
432 let len = event_buffer.len() as i32;
433
434 let result = unsafe { libc::epoll_wait(self.epoll_fd, ptr, len, timeout_ms.unwrap_or(-1)) };
435
436 if result < 0 {
437 return Err(std::io::Error::last_os_error());
438 }
439
440 let count = result as usize;
441
442 for i in 0..count {
445 let event = &event_buffer[i];
446 let tail = self.state.completion_tail.load(Ordering::Acquire) as usize;
447 let pos = self.completion_pos(tail);
448
449 let result = if event.events & (libc::EPOLLERR | libc::EPOLLHUP) as u32 != 0 {
452 ERROR_TRANSPORT
453 } else if event.events & libc::EPOLLIN as u32 != 0 {
454 1 } else if event.events & libc::EPOLLOUT as u32 != 0 {
456 1 } else {
458 0
459 };
460
461 unsafe {
462 self.completion_queue.set(
463 pos,
464 Some(CompletionEntry {
465 user_data: event.u64,
466 result,
467 flags: event.events,
468 }),
469 );
470 }
471
472 self.state
473 .completion_tail
474 .store((tail + 1) as u32, Ordering::Release);
475 }
476
477 Ok(count)
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_epoll_driver_creation() {
487 let driver = EpollDriver::new();
488 assert!(driver.is_ok());
489
490 let driver = driver.unwrap();
491 assert!(driver.epoll_fd >= 0);
492 assert_eq!(driver.capacity, 256);
493 }
494
495 #[test]
496 fn test_epoll_driver_with_config() {
497 let config = crate::driver::DriverConfigBuilder::new()
498 .entries(128)
499 .build();
500
501 let driver = EpollDriver::with_config(config);
502 assert!(driver.is_ok());
503
504 let driver = driver.unwrap();
505 assert_eq!(driver.capacity, 128);
508 }
509
510 #[test]
511 fn test_ring_buffer_positions() {
512 let driver = EpollDriver::new().unwrap();
513
514 assert_eq!(driver.submit_pos(0), 0);
517 assert_eq!(driver.submit_pos(255), 255);
518 assert_eq!(driver.submit_pos(256), 0);
519 assert_eq!(driver.submit_pos(257), 1);
520 }
521}