#![deny(missing_docs)]
#![allow(clippy::new_without_default)]
#![allow(clippy::comparison_chain)]
use std::io;
use std::io::prelude::*;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::time;
pub use interest::Interest;
pub mod interest {
pub type Interest = libc::c_short;
pub const READ: Interest = POLLIN | POLLPRI;
pub const WRITE: Interest = POLLOUT | libc::POLLWRBAND;
pub const ALL: Interest = READ | WRITE;
pub const NONE: Interest = 0x0;
const POLLIN: Interest = libc::POLLIN;
const POLLPRI: Interest = libc::POLLPRI;
const POLLOUT: Interest = libc::POLLOUT;
}
#[derive(Debug)]
pub struct Event<'a> {
pub writable: bool,
pub readable: bool,
pub hangup: bool,
pub errored: bool,
pub invalid: bool,
pub source: &'a Source,
}
impl<'a> Event<'a> {
pub fn source<T: FromRawFd>(&self) -> T {
unsafe { T::from_raw_fd(self.source.fd) }
}
pub fn is_err(&self) -> bool {
self.errored || self.invalid
}
}
impl<'a> From<&'a Source> for Event<'a> {
fn from(source: &'a Source) -> Self {
let revents = source.revents;
Self {
readable: revents & interest::READ != 0,
writable: revents & interest::WRITE != 0,
hangup: revents & libc::POLLHUP != 0,
errored: revents & libc::POLLERR != 0,
invalid: revents & libc::POLLNVAL != 0,
source,
}
}
}
#[derive(Debug)]
pub struct Events<K> {
count: usize,
sources: Sources<K>,
}
impl<K: Eq + Clone> Events<K> {
pub fn new() -> Self {
Self {
count: 0,
sources: Sources::new(),
}
}
pub fn with_capacity(cap: usize) -> Self {
Self {
count: 0,
sources: Sources::with_capacity(cap),
}
}
pub fn iter<'a>(&'a self) -> impl Iterator<Item = (&'a K, Event<'a>)> + 'a {
self.sources
.index
.iter()
.zip(self.sources.list.iter())
.filter(|(_, d)| d.revents != 0)
.map(|(key, source)| (key, Event::from(source)))
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn len(&self) -> usize {
self.count
}
fn initialize(&mut self, sources: Sources<K>) {
self.count = 0;
self.sources = sources;
}
}
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct Source {
fd: RawFd,
events: Interest,
revents: Interest,
}
impl Source {
fn new(fd: RawFd, events: Interest) -> Self {
Self {
fd,
events,
revents: 0,
}
}
pub fn set(&mut self, events: Interest) {
self.events |= events;
}
pub fn unset(&mut self, events: Interest) {
self.events &= !events;
}
}
#[derive(Debug, Clone)]
pub struct Sources<K> {
index: Vec<K>,
list: Vec<Source>,
}
impl<K: Eq + Clone> Sources<K> {
pub fn new() -> Self {
Self {
index: vec![],
list: vec![],
}
}
pub fn with_capacity(cap: usize) -> Self {
Self {
index: Vec::with_capacity(cap),
list: Vec::with_capacity(cap),
}
}
pub fn len(&self) -> usize {
self.list.len()
}
pub fn is_empty(&self) -> bool {
self.list.is_empty()
}
pub fn register(&mut self, key: K, fd: &impl AsRawFd, events: Interest) {
self.insert(key, Source::new(fd.as_raw_fd(), events));
}
pub fn unregister(&mut self, key: &K) {
if let Some(ix) = self.find(key) {
self.index.swap_remove(ix);
self.list.swap_remove(ix);
}
}
pub fn set(&mut self, key: &K, events: Interest) -> bool {
if let Some(ix) = self.find(key) {
self.list[ix].set(events);
return true;
}
false
}
pub fn unset(&mut self, key: &K, events: Interest) -> bool {
if let Some(ix) = self.find(key) {
self.list[ix].unset(events);
return true;
}
false
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut Source> {
self.find(key).map(move |ix| &mut self.list[ix])
}
pub fn wait_timeout(
&mut self,
events: &mut Events<K>,
timeout: time::Duration,
) -> Result<(), io::Error> {
let timeout = timeout.as_millis() as libc::c_int;
events.initialize(self.clone());
let result = self.poll(events, timeout);
if result == 0 {
if self.is_empty() {
Ok(())
} else {
Err(io::ErrorKind::TimedOut.into())
}
} else if result > 0 {
events.count = result as usize;
Ok(())
} else {
Err(io::Error::last_os_error())
}
}
pub fn wait(&mut self, events: &mut Events<K>) -> Result<(), io::Error> {
events.initialize(self.clone());
let result = self.poll(events, -1);
if result < 0 {
Err(io::Error::last_os_error())
} else {
events.count = result as usize;
Ok(())
}
}
fn poll(&mut self, events: &mut Events<K>, timeout: i32) -> i32 {
unsafe {
libc::poll(
events.sources.list.as_mut_ptr() as *mut libc::pollfd,
events.sources.list.len() as libc::nfds_t,
timeout,
)
}
}
fn find(&self, key: &K) -> Option<usize> {
self.index.iter().position(|k| k == key)
}
fn insert(&mut self, key: K, source: Source) {
self.index.push(key);
self.list.push(source);
}
}
pub struct Waker {
reader: UnixStream,
writer: UnixStream,
}
impl Waker {
pub fn new<K: Eq + Clone>(sources: &mut Sources<K>, key: K) -> io::Result<Waker> {
let (writer, reader) = UnixStream::pair()?;
let fd = reader.as_raw_fd();
reader.set_nonblocking(true)?;
writer.set_nonblocking(true)?;
sources.insert(key, Source::new(fd, interest::READ));
Ok(Waker { reader, writer })
}
pub fn wake(&self) -> io::Result<()> {
use io::ErrorKind::*;
match (&self.writer).write_all(&[0x1]) {
Ok(_) => Ok(()),
Err(e) if e.kind() == WouldBlock => {
self.unblock()?;
self.wake()
}
Err(e) if e.kind() == Interrupted => self.wake(),
Err(e) => Err(e),
}
}
fn unblock(&self) -> io::Result<()> {
let mut buf = [0; 4096];
loop {
match (&self.reader).read(&mut buf) {
Ok(0) => return Ok(()),
Ok(_) => continue,
Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Ok(()),
Err(e) => return Err(e),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
use std::thread;
use std::time::Duration;
#[test]
fn test_readable() -> io::Result<()> {
let (writer0, reader0) = UnixStream::pair()?;
let (writer1, reader1) = UnixStream::pair()?;
let (writer2, reader2) = UnixStream::pair()?;
let mut events = Events::new();
let mut sources = Sources::new();
for reader in &[&reader0, &reader1, &reader2] {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::READ);
sources.register("reader2", &reader2, interest::READ);
{
let err = sources
.wait_timeout(&mut events, Duration::from_millis(1))
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(events.is_empty());
}
let tests = &mut [
(&writer0, &reader0, "reader0", 0x1 as u8),
(&writer1, &reader1, "reader1", 0x2 as u8),
(&writer2, &reader2, "reader2", 0x3 as u8),
];
for (mut writer, mut reader, key, byte) in tests.iter_mut() {
let mut buf = [0u8; 1];
assert!(matches!(
reader.read(&mut buf[..]),
Err(err) if err.kind() == io::ErrorKind::WouldBlock
));
writer.write(&[*byte])?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
assert!(!events.is_empty());
let mut events = events.iter();
let (k, event) = events.next().unwrap();
assert_eq!(&k, &key);
assert!(event.readable && !event.writable && !event.errored && !event.hangup);
assert!(events.next().is_none());
assert_eq!(reader.read(&mut buf[..])?, 1);
assert_eq!(&buf[..], &[*byte]);
}
Ok(())
}
#[test]
fn test_empty() -> io::Result<()> {
let mut events: Events<()> = Events::new();
let mut sources = Sources::new();
sources
.wait_timeout(&mut events, time::Duration::from_millis(1))
.expect("no error if nothing registered");
assert!(events.is_empty());
Ok(())
}
#[test]
fn test_timeout() -> io::Result<()> {
let mut events = Events::new();
let mut sources = Sources::new();
sources.register((), &io::stdin(), interest::READ);
let err = sources
.wait_timeout(&mut events, Duration::from_millis(1))
.unwrap_err();
assert_eq!(sources.len(), 1);
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(events.is_empty());
Ok(())
}
#[test]
fn test_threaded() -> io::Result<()> {
let (writer0, reader0) = UnixStream::pair()?;
let (writer1, reader1) = UnixStream::pair()?;
let (writer2, reader2) = UnixStream::pair()?;
let mut events = Events::new();
let mut sources = Sources::new();
let readers = &[&reader0, &reader1, &reader2];
for reader in readers {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::READ);
sources.register("reader2", &reader2, interest::READ);
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(8));
for writer in &mut [&writer1, &writer2, &writer0] {
writer.write(&[1]).unwrap();
writer.write(&[2]).unwrap();
}
});
let mut closed = vec![];
while closed.len() < readers.len() {
sources.wait_timeout(&mut events, Duration::from_millis(64))?;
for (key, event) in events.iter() {
assert!(event.readable);
assert!(!event.writable);
assert!(!event.errored);
if event.hangup {
closed.push(key.clone());
continue;
}
let mut buf = [0u8; 2];
let mut reader = match key {
&"reader0" => &reader0,
&"reader1" => &reader1,
&"reader2" => &reader2,
_ => unreachable!(),
};
let n = reader.read(&mut buf[..])?;
assert_eq!(n, 2);
assert_eq!(&buf[..], &[1, 2]);
}
}
handle.join().unwrap();
Ok(())
}
#[test]
fn test_unregister() -> io::Result<()> {
use std::collections::HashSet;
let (mut writer0, reader0) = UnixStream::pair()?;
let (mut writer1, reader1) = UnixStream::pair()?;
let (writer2, reader2) = UnixStream::pair()?;
let mut events = Events::new();
let mut sources = Sources::new();
for reader in &[&reader0, &reader1, &reader2] {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::READ);
sources.register("reader2", &reader2, interest::READ);
{
let err = sources
.wait_timeout(&mut events, Duration::from_millis(1))
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
assert!(events.is_empty());
}
{
writer1.write(&[0x0])?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let (key, _) = events.iter().next().unwrap();
assert_eq!(key, &"reader1");
}
{
sources.unregister(&"reader1");
writer1.write(&[0x0])?;
sources
.wait_timeout(&mut events, Duration::from_millis(1))
.ok();
assert!(events.iter().next().is_none());
for w in &mut [&writer0, &writer1, &writer2] {
w.write(&[0])?;
}
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let keys = events.iter().map(|(k, _)| k).collect::<HashSet<_>>();
assert!(keys.contains(&"reader0"));
assert!(!keys.contains(&"reader1"));
assert!(keys.contains(&"reader2"));
sources.unregister(&"reader0");
for w in &mut [&writer0, &writer1, &writer2] {
w.write(&[0])?;
}
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let keys = events.iter().map(|(k, _)| k).collect::<HashSet<_>>();
assert!(!keys.contains(&"reader0"));
assert!(!keys.contains(&"reader1"));
assert!(keys.contains(&"reader2"));
sources.unregister(&"reader2");
for w in &mut [&writer0, &writer1, &writer2] {
w.write(&[0])?;
}
sources
.wait_timeout(&mut events, Duration::from_millis(1))
.ok();
assert!(events.is_empty());
}
{
sources.register("reader0", &reader0, interest::READ);
writer0.write(&[0])?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let (key, _) = events.iter().next().unwrap();
assert_eq!(key, &"reader0");
}
Ok(())
}
#[test]
fn test_set() -> io::Result<()> {
let (mut writer0, reader0) = UnixStream::pair()?;
let (mut writer1, reader1) = UnixStream::pair()?;
let mut events = Events::new();
let mut sources = Sources::new();
for reader in &[&reader0, &reader1] {
reader.set_nonblocking(true)?;
}
sources.register("reader0", &reader0, interest::READ);
sources.register("reader1", &reader1, interest::NONE);
{
writer0.write(&[0])?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let (key, _) = events.iter().next().unwrap();
assert_eq!(key, &"reader0");
sources.unset(key, interest::READ);
writer0.write(&[0])?;
sources
.wait_timeout(&mut events, Duration::from_millis(1))
.ok();
assert!(events.iter().next().is_none());
}
{
writer1.write(&[0])?;
sources
.wait_timeout(&mut events, Duration::from_millis(1))
.ok();
assert!(events.iter().next().is_none());
sources.set(&"reader1", interest::READ);
writer1.write(&[0])?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let (key, _) = events.iter().next().unwrap();
assert_eq!(key, &"reader1");
}
Ok(())
}
#[test]
fn test_waker() -> io::Result<()> {
let mut events = Events::new();
let mut sources = Sources::new();
let mut waker = Waker::new(&mut sources, "waker")?;
let buf = [0; 4096];
sources
.wait_timeout(&mut events, Duration::from_millis(1))
.ok();
assert!(events.iter().next().is_none());
loop {
match waker.writer.write(&buf) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(e) => return Err(e),
_ => continue,
}
}
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let (key, event) = events.iter().next().unwrap();
assert!(event.readable);
assert!(!event.writable && !event.hangup && !event.errored);
assert_eq!(key, &"waker");
waker.wake()?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
let (key, event) = events.iter().next().unwrap();
assert!(event.readable);
assert_eq!(key, &"waker");
waker.wake()?;
waker.wake()?;
waker.wake()?;
sources.wait_timeout(&mut events, Duration::from_millis(1))?;
assert_eq!(events.iter().count(), 1, "multiple wakes count as one");
Ok(())
}
}