#![deny(missing_docs)]
#![allow(clippy::new_without_default)]
#![allow(clippy::comparison_chain)]
use std::io;
use std::io::prelude::*;
use std::ops::Deref;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::time::Duration;
pub use interest::Interest;
pub type Events = libc::c_short;
pub mod interest {
pub type Interest = super::Events;
pub const READ: Interest = POLLIN | POLLPRI;
pub const WRITE: Interest = POLLOUT | libc::POLLWRBAND;
pub const ALL: Interest = READ | WRITE;
pub const NONE: Interest = 0x0;
const POLLIN: Interest = libc::POLLIN;
const POLLPRI: Interest = libc::POLLPRI;
const POLLOUT: Interest = libc::POLLOUT;
}
#[derive(Debug)]
pub struct Event<K> {
pub key: K,
pub source: Source,
}
impl<K> Deref for Event<K> {
type Target = Source;
fn deref(&self) -> &Self::Target {
&self.source
}
}
#[derive(Debug, Clone)]
pub enum Timeout {
After(Duration),
Never,
}
impl Timeout {
pub fn from_secs(seconds: u32) -> Self {
Self::After(Duration::from_secs(seconds as u64))
}
pub fn from_millis(milliseconds: u32) -> Self {
Self::After(Duration::from_millis(milliseconds as u64))
}
}
impl From<Duration> for Timeout {
fn from(duration: Duration) -> Self {
Self::After(duration)
}
}
impl From<Option<Duration>> for Timeout {
fn from(duration: Option<Duration>) -> Self {
match duration {
Some(duration) => Self::from(duration),
None => Self::Never,
}
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone, Default)]
pub struct Source {
fd: RawFd,
events: Interest,
revents: Interest,
}
impl Source {
fn new(fd: RawFd, events: Interest) -> Self {
Self {
fd,
events,
revents: 0,
}
}
pub unsafe fn raw<T: FromRawFd>(&self) -> T {
T::from_raw_fd(self.fd)
}
pub fn set(&mut self, events: Interest) {
self.events |= events;
}
pub fn unset(&mut self, events: Interest) {
self.events &= !events;
}
pub fn raw_events(&self) -> Events {
self.revents
}
pub fn is_writable(self) -> bool {
self.revents & interest::WRITE != 0
}
pub fn is_readable(self) -> bool {
self.revents & interest::READ != 0
}
pub fn is_hangup(self) -> bool {
self.revents & libc::POLLHUP != 0
}
pub fn is_error(self) -> bool {
self.revents & libc::POLLERR != 0
}
pub fn is_invalid(self) -> bool {
self.revents & libc::POLLNVAL != 0
}
}
impl AsRawFd for &Source {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl AsRawFd for Source {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
#[derive(Debug, Clone)]
pub struct Sources<K> {
index: Vec<K>,
list: Vec<Source>,
}
impl<K> Sources<K> {
pub fn new() -> Self {
Self {
index: vec![],
list: vec![],
}
}
pub fn with_capacity(cap: usize) -> Self {
Self {
index: Vec::with_capacity(cap),
list: Vec::with_capacity(cap),
}
}
pub fn len(&self) -> usize {
self.list.len()
}
pub fn is_empty(&self) -> bool {
self.list.is_empty()
}
}
impl<K: Clone + PartialEq> Sources<K> {
pub fn register(&mut self, key: K, fd: &impl AsRawFd, events: Interest) {
self.insert(key, Source::new(fd.as_raw_fd(), events));
}
pub fn unregister(&mut self, key: &K) {
if let Some(ix) = self.find(key) {
self.index.swap_remove(ix);
self.list.swap_remove(ix);
}
}
pub fn set(&mut self, key: &K, events: Interest) -> bool {
if let Some(ix) = self.find(key) {
self.list[ix].set(events);
return true;
}
false
}
pub fn unset(&mut self, key: &K, events: Interest) -> bool {
if let Some(ix) = self.find(key) {
self.list[ix].unset(events);
return true;
}
false
}
pub fn get(&mut self, key: &K) -> Option<&Source> {
self.find(key).map(move |ix| &self.list[ix])
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut Source> {
self.find(key).map(move |ix| &mut self.list[ix])
}
pub fn poll(
&mut self,
events: &mut impl Extend<Event<K>>,
timeout: impl Into<Timeout>,
) -> Result<usize, io::Error> {
let timeout = match timeout.into() {
Timeout::After(duration) => duration.as_millis() as libc::c_int,
Timeout::Never => -1,
};
let result = unsafe {
libc::poll(
self.list.as_mut_ptr() as *mut libc::pollfd,
self.list.len() as libc::nfds_t,
timeout,
)
};
events.extend(
self.index
.iter()
.zip(self.list.iter())
.filter(|(_, s)| s.revents != 0)
.map(|(key, source)| Event {
key: key.clone(),
source: *source,
}),
);
if result == 0 {
if self.is_empty() {
Ok(0)
} else {
Err(io::ErrorKind::TimedOut.into())
}
} else if result > 0 {
Ok(result as usize)
} else {
Err(io::Error::last_os_error())
}
}
pub fn wait_timeout(
&mut self,
events: &mut impl Extend<Event<K>>,
timeout: Duration,
) -> Result<usize, io::Error> {
self.poll(events, timeout)
}
pub fn wait(&mut self, events: &mut impl Extend<Event<K>>) -> Result<usize, io::Error> {
self.poll(events, Timeout::Never)
}
fn find(&self, key: &K) -> Option<usize> {
self.index.iter().position(|k| k == key)
}
fn insert(&mut self, key: K, source: Source) {
self.index.push(key);
self.list.push(source);
}
}
pub struct Waker {
reader: UnixStream,
writer: UnixStream,
}
impl Waker {
pub fn new<K: Eq + Clone>(sources: &mut Sources<K>, key: K) -> io::Result<Waker> {
let (writer, reader) = UnixStream::pair()?;
let fd = reader.as_raw_fd();
reader.set_nonblocking(true)?;
writer.set_nonblocking(true)?;
sources.insert(key, Source::new(fd, interest::READ));
Ok(Waker { reader, writer })
}
pub fn wake(&self) -> io::Result<()> {
use io::ErrorKind::*;
match (&self.writer).write_all(&[0x1]) {
Ok(_) => Ok(()),
Err(e) if e.kind() == WouldBlock => {
Waker::reset(self.reader.as_raw_fd())?;
self.wake()
}
Err(e) if e.kind() == Interrupted => self.wake(),
Err(e) => Err(e),
}
}
pub fn reset(fd: impl AsRawFd) -> io::Result<()> {
let mut buf = [0u8; 4096];
loop {
match unsafe {
libc::read(
fd.as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
)
} {
-1 => match io::Error::last_os_error() {
e if e.kind() == io::ErrorKind::WouldBlock => return Ok(()),
e => return Err(e),
},
0 => return Ok(()),
_ => continue,
}
}
}
}
pub fn set_nonblocking(fd: &dyn AsRawFd, nonblocking: bool) -> io::Result<i32> {
let fd = fd.as_raw_fd();
let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
if flags == -1 {
return Err(io::Error::last_os_error());
}
let flags = if nonblocking {
flags | libc::O_NONBLOCK
} else {
flags & !libc::O_NONBLOCK
};
match unsafe { libc::fcntl(fd, libc::F_SETFL, flags) } {
-1 => Err(io::Error::last_os_error()),
result => Ok(result),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
use std::thread;
use std::time::Duration;
#[test]
fn test_readable() -> io::Result<()> {
let (writer0, reader0) = UnixStream::pair()?;
let (writer1, reader1) = UnixStream::pair()?;
let (writer2, reader2) = UnixStream::pair()?;
let mut events = Vec::new();
let mut sources = Sources::new();
for reader in &[&reader0, &reader1, &reader2] {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::READ);
sources.register("reader2", &reader2, interest::READ);
{
let err = sources
.poll(&mut events, Timeout::from_millis(1))
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(events.is_empty());
}
let tests = &mut [
(&writer0, &reader0, "reader0", 0x1u8),
(&writer1, &reader1, "reader1", 0x2u8),
(&writer2, &reader2, "reader2", 0x3u8),
];
for (mut writer, mut reader, key, byte) in tests.iter_mut() {
let mut buf = [0u8; 1];
assert!(matches!(
reader.read(&mut buf[..]),
Err(err) if err.kind() == io::ErrorKind::WouldBlock
));
writer.write_all(&[*byte])?;
events.clear();
sources.poll(&mut events, Timeout::from_millis(1))?;
assert!(!events.is_empty());
let mut events = events.iter();
let event = events.next().unwrap();
assert_eq!(&event.key, key);
assert!(
event.is_readable()
&& !event.is_writable()
&& !event.is_error()
&& !event.is_hangup()
);
assert!(events.next().is_none());
assert_eq!(reader.read(&mut buf[..])?, 1);
assert_eq!(&buf[..], &[*byte]);
}
Ok(())
}
#[test]
fn test_empty() -> io::Result<()> {
let mut events: Vec<Event<()>> = Vec::new();
let mut sources = Sources::new();
sources
.poll(&mut events, Timeout::from_millis(1))
.expect("no error if nothing registered");
assert!(events.is_empty());
Ok(())
}
#[test]
fn test_timeout() -> io::Result<()> {
let mut events = Vec::new();
let mut sources = Sources::new();
sources.register((), &io::stdout(), interest::READ);
let err = sources
.poll(&mut events, Timeout::from_millis(1))
.unwrap_err();
assert_eq!(sources.len(), 1);
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(events.is_empty());
Ok(())
}
#[test]
fn test_threaded() -> io::Result<()> {
let (writer0, reader0) = UnixStream::pair()?;
let (writer1, reader1) = UnixStream::pair()?;
let (writer2, reader2) = UnixStream::pair()?;
let mut events = Vec::new();
let mut sources = Sources::new();
let readers = &[&reader0, &reader1, &reader2];
for reader in readers {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::READ);
sources.register("reader2", &reader2, interest::READ);
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(8));
for writer in &mut [&writer1, &writer2, &writer0] {
writer.write_all(&[1]).unwrap();
writer.write_all(&[2]).unwrap();
}
});
let mut closed = vec![];
while closed.len() < readers.len() {
sources.poll(&mut events, Timeout::from_millis(64))?;
for event in events.drain(..) {
assert!(event.is_readable());
assert!(!event.is_writable());
assert!(!event.is_error());
if event.is_hangup() {
closed.push(event.key.to_owned());
continue;
}
let mut buf = [0u8; 2];
let mut reader = match event.key {
"reader0" => &reader0,
"reader1" => &reader1,
"reader2" => &reader2,
_ => unreachable!(),
};
let n = reader.read(&mut buf[..])?;
assert_eq!(n, 2);
assert_eq!(&buf[..], &[1, 2]);
}
}
handle.join().unwrap();
Ok(())
}
#[test]
fn test_unregister() -> io::Result<()> {
use std::collections::HashSet;
let (mut writer0, reader0) = UnixStream::pair()?;
let (mut writer1, reader1) = UnixStream::pair()?;
let (writer2, reader2) = UnixStream::pair()?;
let mut events = Vec::new();
let mut sources = Sources::new();
for reader in &[&reader0, &reader1, &reader2] {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::READ);
sources.register("reader2", &reader2, interest::READ);
{
let err = sources
.poll(&mut events, Timeout::from_millis(1))
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(events.is_empty());
}
{
writer1.write_all(&[0x0])?;
events.clear();
sources.poll(&mut events, Timeout::from_millis(1))?;
let event = events.first().unwrap();
assert_eq!(event.key, "reader1");
}
{
sources.unregister(&"reader1");
writer1.write_all(&[0x0])?;
events.clear();
sources.poll(&mut events, Timeout::from_millis(1)).ok();
assert!(events.first().is_none());
for w in &mut [&writer0, &writer1, &writer2] {
w.write_all(&[0])?;
}
sources.poll(&mut events, Timeout::from_millis(1))?;
let keys = events.iter().map(|e| e.key).collect::<HashSet<_>>();
assert!(keys.contains(&"reader0"));
assert!(!keys.contains(&"reader1"));
assert!(keys.contains(&"reader2"));
sources.unregister(&"reader0");
for w in &mut [&writer0, &writer1, &writer2] {
w.write_all(&[0])?;
}
events.clear();
sources.poll(&mut events, Timeout::from_millis(1))?;
let keys = events.iter().map(|e| e.key).collect::<HashSet<_>>();
assert!(!keys.contains(&"reader0"));
assert!(!keys.contains(&"reader1"));
assert!(keys.contains(&"reader2"));
sources.unregister(&"reader2");
for w in &mut [&writer0, &writer1, &writer2] {
w.write_all(&[0])?;
}
events.clear();
sources.poll(&mut events, Timeout::from_millis(1)).ok();
assert!(events.is_empty());
}
{
sources.register("reader0", &reader0, interest::READ);
writer0.write_all(&[0])?;
sources.poll(&mut events, Timeout::from_millis(1))?;
let event = events.first().unwrap();
assert_eq!(event.key, "reader0");
}
Ok(())
}
#[test]
fn test_set() -> io::Result<()> {
let (mut writer0, reader0) = UnixStream::pair()?;
let (mut writer1, reader1) = UnixStream::pair()?;
let mut events = Vec::new();
let mut sources = Sources::new();
for reader in &[&reader0, &reader1] {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::NONE);
{
writer0.write_all(&[0])?;
sources.poll(&mut events, Timeout::from_millis(1))?;
let event = events.first().unwrap();
assert_eq!(event.key, "reader0");
sources.unset(&event.key, interest::READ);
writer0.write_all(&[0])?;
events.clear();
sources.poll(&mut events, Timeout::from_millis(1)).ok();
assert!(events.first().is_none());
}
{
writer1.write_all(&[0])?;
sources.poll(&mut events, Timeout::from_millis(1)).ok();
assert!(events.first().is_none());
sources.set(&"reader1", interest::READ);
writer1.write_all(&[0])?;
sources.poll(&mut events, Timeout::from_millis(1))?;
let event = events.first().unwrap();
assert_eq!(event.key, "reader1");
}
Ok(())
}
#[test]
fn test_waker() -> io::Result<()> {
let mut events = Vec::new();
let mut sources = Sources::new();
let mut waker = Waker::new(&mut sources, "waker")?;
let buf = [0; 4096];
sources.poll(&mut events, Timeout::from_millis(1)).ok();
assert!(events.first().is_none());
loop {
match waker.writer.write(&buf) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(e) => return Err(e),
_ => continue,
}
}
sources.poll(&mut events, Timeout::from_millis(1))?;
let event @ Event { key, .. } = events.first().unwrap();
assert!(event.is_readable());
assert!(!event.is_writable() && !event.is_hangup() && !event.is_error());
assert_eq!(key, &"waker");
waker.wake()?;
events.clear();
sources.poll(&mut events, Timeout::from_millis(1))?;
let event @ Event { key, .. } = events.first().unwrap();
assert!(event.is_readable());
assert_eq!(key, &"waker");
waker.wake()?;
waker.wake()?;
waker.wake()?;
events.clear();
sources.poll(&mut events, Timeout::from_millis(1))?;
assert_eq!(events.len(), 1, "multiple wakes count as one");
let event @ Event { key, .. } = events.first().unwrap();
assert_eq!(key, &"waker");
Waker::reset(&event.source).unwrap();
let result = sources.poll(&mut events, Timeout::from_millis(1));
assert!(
matches!(
result.err().map(|e| e.kind()),
Some(io::ErrorKind::TimedOut)
),
"the waker should only wake once"
);
Ok(())
}
#[test]
fn test_waker_threaded() {
let mut events = Vec::new();
let mut sources = Sources::new();
let waker = Waker::new(&mut sources, "waker").unwrap();
let (tx, rx) = std::sync::mpsc::channel();
let iterations = 100_000;
let handle = std::thread::spawn(move || {
for _ in 0..iterations {
tx.send(()).unwrap();
waker.wake().unwrap();
}
});
let mut wakes = 0;
let mut received = 0;
while !handle.is_finished() {
events.clear();
let count = sources.poll(&mut events, Timeout::Never).unwrap();
if count > 0 {
let event = events.pop().unwrap();
assert_eq!(event.key, "waker");
assert!(events.is_empty());
rx.recv().unwrap();
received += 1;
while rx.try_recv().is_ok() {
received += 1;
}
if received == iterations {
Waker::reset(event.source).unwrap_err();
break;
}
Waker::reset(event.source).ok(); wakes += 1;
}
}
handle.join().unwrap();
assert_eq!(received, iterations);
assert!(wakes <= received);
}
}