1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
use crate::{Events, SysError};
use libc::{close, epoll_create1, epoll_ctl, epoll_wait};
use std::collections::HashMap;

impl From<u32> for Events {
    fn from(val: u32) -> Self {
        let mut events = Events::new();
        if (val & libc::EPOLLIN as u32) == libc::EPOLLIN as u32 {
            events = events.with_read();
        }
        if (val & libc::EPOLLOUT as u32) == libc::EPOLLOUT as u32 {
            events = events.with_write();
        }
        if (val & libc::EPOLLERR as u32) == libc::EPOLLERR as u32 {
            events = events.with_error();
        }
        events
    }
}

impl Into<u32> for Events {
    fn into(self) -> u32 {
        let mut events = 0u32;
        if self.has_read() {
            events |= libc::EPOLLIN as u32;
        }
        if self.has_write() {
            events |= libc::EPOLLOUT as u32;
        }
        if self.has_error() {
            events |= libc::EPOLLERR as u32;
        }
        events
    }
}

/// 定义文件 I/O 事件通知器。
#[derive(Debug)]
pub struct Poller {
    epoll_fd: i32,
    watches: HashMap<i32, Events>,
}

impl Default for Poller {
    fn default() -> Self {
        Self {
            epoll_fd: -1,
            watches: HashMap::new(),
        }
    }
}

impl Drop for Poller {
    fn drop(&mut self) {
        if self.epoll_fd > 0 {
            unsafe {
                close(self.epoll_fd);
            };
            self.epoll_fd = -1;
        }
    }
}

impl Poller {
    /// 创建一个新的 I/O 事件通知器。
    pub fn new() -> Self {
        let epoll_fd = unsafe { epoll_create1(0) };
        assert!(epoll_fd > 0, "epoll_create()");
        Self {
            epoll_fd,
            watches: HashMap::new(),
        }
    }

    /// 添加一个文件描述符到监视列表中。
    pub fn add(&mut self, fd: i32, events: Events) -> Result<(), SysError> {
        unsafe {
            let mut ev = libc::epoll_event {
                events: events.into(),
                u64: fd as u64,
            };
            let err = epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_ADD, fd, &mut ev);
            if err < 0 {
                return Err(SysError::last());
            }
            self.watches.insert(fd, events);
            Ok(())
        }
    }

    /// 将一个文件描述符从监视列表中移除。
    pub fn remove(&mut self, fd: i32) -> Result<(), SysError> {
        if !self.watches.contains_key(&fd) {
            return Err(SysError::from(libc::ENOENT));
        }
        let err =
            unsafe { epoll_ctl(self.epoll_fd, libc::EPOLL_CTL_DEL, fd, std::ptr::null_mut()) };
        if err < 0 {
            Err(SysError::last())
        } else {
            self.watches.remove(&fd).unwrap();
            Ok(())
        }
    }

    /// 拉取所有被监测到的 I/O 事件。
    ///
    /// # Examples
    ///
    /// ```
    /// let mut poller = Poller::new();
    /// poller.add(0, Events::new().with_read());
    /// for (fd, events) in poller.pull_events(1000).unwrap().iter() {
    ///     println!("Fd={}, Events={}", fd, events);
    /// }
    /// ```
    pub fn pull_events(&self, timeout_ms: i32) -> Result<Vec<(i32, Events)>, SysError> {
        unsafe {
            let mut ev: Vec<libc::epoll_event> = Vec::with_capacity(self.watches.len());
            let nfds = epoll_wait(
                self.epoll_fd,
                ev.as_mut_ptr(),
                self.watches.len() as i32,
                timeout_ms,
            );
            if nfds < 0 {
                return Err(SysError::last());
            }
            ev.set_len(nfds as usize);
            Ok(ev
                .into_iter()
                .map(|x| (x.u64 as i32, Events::from(x.events)))
                .collect())
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_poller() {
        unsafe {
            let cstr = std::ffi::CString::new("/proc/uptime").unwrap();
            let fd = libc::open(cstr.as_ptr(), libc::O_RDONLY);
            let mut poller = Poller::new();
            assert_eq!(poller.add(fd, Events::new().with_read()).is_ok(), true);
            for _ in 0..1000 {
                assert_eq!(poller.pull_events(1000).unwrap().len(), 1);
            }
            assert_eq!(poller.remove(fd).is_ok(), true);
            for _ in 0..1000 {
                assert_eq!(poller.add(fd, Events::new().with_read()).is_ok(), true);
                assert_eq!(poller.remove(fd).is_ok(), true);
            }
            libc::close(fd);
        }
    }
}