1#![deny(missing_docs)]
45#![allow(clippy::new_without_default)]
46#![allow(clippy::comparison_chain)]
47use std::io;
48use std::io::prelude::*;
49use std::ops::Deref;
50use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
51use std::os::unix::net::UnixStream;
52use std::time::Duration;
53
54pub use interest::Interest;
55
56pub type Events = libc::c_short;
58
59pub mod interest {
61 pub type Interest = super::Events;
63
64 pub const READ: Interest = POLLIN | POLLPRI;
66 pub const WRITE: Interest = POLLOUT | libc::POLLWRBAND;
68 pub const ALL: Interest = READ | WRITE;
70 pub const NONE: Interest = 0x0;
72
73 const POLLIN: Interest = libc::POLLIN;
78 const POLLPRI: Interest = libc::POLLPRI;
80 const POLLOUT: Interest = libc::POLLOUT;
82}
83
84#[derive(Debug)]
86pub struct Event<K> {
87 pub key: K,
89 pub source: Source,
91}
92
93impl<K> Deref for Event<K> {
94 type Target = Source;
95
96 fn deref(&self) -> &Self::Target {
97 &self.source
98 }
99}
100
101#[derive(Debug, Clone)]
106pub enum Timeout {
107 After(Duration),
109 Never,
111}
112
113impl Timeout {
114 pub fn from_secs(seconds: u32) -> Self {
118 Self::After(Duration::from_secs(seconds as u64))
119 }
120
121 pub fn from_millis(milliseconds: u32) -> Self {
125 Self::After(Duration::from_millis(milliseconds as u64))
126 }
127}
128
129impl From<Duration> for Timeout {
130 fn from(duration: Duration) -> Self {
134 Self::After(duration)
135 }
136}
137
138impl From<Option<Duration>> for Timeout {
139 fn from(duration: Option<Duration>) -> Self {
143 match duration {
144 Some(duration) => Self::from(duration),
145 None => Self::Never,
146 }
147 }
148}
149
150#[repr(C)]
152#[derive(Debug, Copy, Clone, Default)]
153pub struct Source {
154 fd: RawFd,
155 events: Interest,
156 revents: Interest,
157}
158
159impl Source {
160 fn new(fd: impl AsRawFd, events: Interest) -> Self {
161 Self {
162 fd: fd.as_raw_fd(),
163 events,
164 revents: 0,
165 }
166 }
167
168 pub unsafe fn raw<T: FromRawFd>(&self) -> T {
175 T::from_raw_fd(self.fd)
176 }
177
178 pub fn set(&mut self, events: Interest) {
180 self.events |= events;
181 }
182
183 pub fn unset(&mut self, events: Interest) {
185 self.events &= !events;
186 }
187
188 pub fn raw_events(&self) -> Events {
190 self.revents
191 }
192
193 pub fn is_writable(self) -> bool {
195 self.revents & interest::WRITE != 0
196 }
197
198 pub fn is_readable(self) -> bool {
200 self.revents & interest::READ != 0
201 }
202
203 pub fn is_hangup(self) -> bool {
205 self.revents & libc::POLLHUP != 0
206 }
207
208 pub fn is_error(self) -> bool {
213 self.revents & libc::POLLERR != 0
214 }
215
216 pub fn is_invalid(self) -> bool {
218 self.revents & libc::POLLNVAL != 0
219 }
220}
221
222impl AsRawFd for &Source {
223 fn as_raw_fd(&self) -> RawFd {
224 self.fd
225 }
226}
227
228impl AsRawFd for Source {
229 fn as_raw_fd(&self) -> RawFd {
230 self.fd
231 }
232}
233
234#[derive(Debug, Clone)]
236pub struct Sources<K> {
237 index: Vec<K>,
239 list: Vec<Source>,
241}
242
243impl<K> Sources<K> {
244 pub fn new() -> Self {
246 Self {
247 index: vec![],
248 list: vec![],
249 }
250 }
251
252 pub fn with_capacity(cap: usize) -> Self {
255 Self {
256 index: Vec::with_capacity(cap),
257 list: Vec::with_capacity(cap),
258 }
259 }
260
261 pub fn len(&self) -> usize {
263 self.list.len()
264 }
265
266 pub fn is_empty(&self) -> bool {
268 self.list.is_empty()
269 }
270}
271
272impl<S: AsRawFd, K: PartialEq + Eq + Clone> FromIterator<(K, S, Interest)> for Sources<K> {
273 fn from_iter<T: IntoIterator<Item = (K, S, Interest)>>(iter: T) -> Self {
274 let mut sources = Sources::new();
275 for (key, source, interest) in iter {
276 sources.register(key, &source, interest);
277 }
278 sources
279 }
280}
281
282impl<K: Clone + PartialEq> Sources<K> {
283 pub fn register(&mut self, key: K, fd: &impl AsRawFd, events: Interest) {
288 self.insert(key, Source::new(fd.as_raw_fd(), events));
289 }
290
291 pub fn unregister(&mut self, key: &K) {
293 if let Some(ix) = self.find(key) {
294 self.index.swap_remove(ix);
295 self.list.swap_remove(ix);
296 }
297 }
298
299 pub fn set(&mut self, key: &K, events: Interest) -> bool {
301 if let Some(ix) = self.find(key) {
302 self.list[ix].set(events);
303 return true;
304 }
305 false
306 }
307
308 pub fn unset(&mut self, key: &K, events: Interest) -> bool {
310 if let Some(ix) = self.find(key) {
311 self.list[ix].unset(events);
312 return true;
313 }
314 false
315 }
316
317 pub fn get(&mut self, key: &K) -> Option<&Source> {
319 self.find(key).map(move |ix| &self.list[ix])
320 }
321
322 pub fn get_mut(&mut self, key: &K) -> Option<&mut Source> {
324 self.find(key).map(move |ix| &mut self.list[ix])
325 }
326
327 pub fn poll(
336 &mut self,
337 events: &mut impl Extend<Event<K>>,
338 timeout: impl Into<Timeout>,
339 ) -> Result<usize, io::Error> {
340 let timeout = match timeout.into() {
341 Timeout::After(duration) => duration.as_millis() as libc::c_int,
342 Timeout::Never => -1,
343 };
344 if self.list.is_empty() {
346 return Ok(0);
347 }
348
349 loop {
350 let result = unsafe {
352 libc::poll(
353 self.list.as_mut_ptr() as *mut libc::pollfd,
354 self.list.len() as libc::nfds_t,
355 timeout,
356 )
357 };
358
359 events.extend(
360 self.index
361 .iter()
362 .zip(self.list.iter())
363 .filter(|(_, s)| s.revents != 0)
364 .map(|(key, source)| Event {
365 key: key.clone(),
366 source: *source,
367 }),
368 );
369
370 if result == 0 {
371 if self.is_empty() {
372 return Ok(0);
373 } else {
374 return Err(io::ErrorKind::TimedOut.into());
375 }
376 } else if result > 0 {
377 return Ok(result as usize);
378 } else {
379 let err = io::Error::last_os_error();
380 match err.raw_os_error() {
381 Some(libc::EAGAIN) => continue,
384 Some(libc::EINTR) => continue,
387 _ => {
388 return Err(err);
389 }
390 }
391 }
392 }
393 }
394
395 pub fn wait_timeout(
401 &mut self,
402 events: &mut impl Extend<Event<K>>,
403 timeout: Duration,
404 ) -> Result<usize, io::Error> {
405 self.poll(events, timeout)
406 }
407
408 pub fn wait(&mut self, events: &mut impl Extend<Event<K>>) -> Result<usize, io::Error> {
414 self.poll(events, Timeout::Never)
415 }
416
417 fn find(&self, key: &K) -> Option<usize> {
418 self.index.iter().position(|k| k == key)
419 }
420
421 fn insert(&mut self, key: K, source: Source) {
422 self.index.push(key);
423 self.list.push(source);
424 }
425}
426
427pub struct Waker {
429 reader: UnixStream,
430 writer: UnixStream,
431}
432
433impl AsRawFd for &Waker {
434 fn as_raw_fd(&self) -> RawFd {
435 self.reader.as_raw_fd()
436 }
437}
438
439impl AsRawFd for Waker {
440 fn as_raw_fd(&self) -> RawFd {
441 self.reader.as_raw_fd()
442 }
443}
444
445impl Waker {
446 pub fn register<K: Eq + Clone>(sources: &mut Sources<K>, key: K) -> io::Result<Waker> {
495 let waker = Waker::new()?;
496 sources.insert(key, Source::new(&waker, interest::READ));
497
498 Ok(waker)
499 }
500
501 pub fn new() -> io::Result<Waker> {
503 let (writer, reader) = UnixStream::pair()?;
504
505 reader.set_nonblocking(true)?;
506 writer.set_nonblocking(true)?;
507
508 Ok(Waker { reader, writer })
509 }
510
511 pub fn wake(&self) -> io::Result<()> {
514 use io::ErrorKind::*;
515
516 match (&self.writer).write_all(&[0x1]) {
517 Ok(_) => Ok(()),
518 Err(e) if e.kind() == WouldBlock => {
519 Waker::reset(self.reader.as_raw_fd())?;
520 self.wake()
521 }
522 Err(e) if e.kind() == Interrupted => self.wake(),
523 Err(e) => Err(e),
524 }
525 }
526
527 pub fn reset(fd: impl AsRawFd) -> io::Result<()> {
529 let mut buf = [0u8; 4096];
530
531 loop {
532 match unsafe {
535 libc::read(
536 fd.as_raw_fd(),
537 buf.as_mut_ptr() as *mut libc::c_void,
538 buf.len(),
539 )
540 } {
541 -1 => match io::Error::last_os_error() {
542 e if e.kind() == io::ErrorKind::WouldBlock => return Ok(()),
543 e => return Err(e),
544 },
545 0 => return Ok(()),
546 _ => continue,
547 }
548 }
549 }
550}
551
552pub fn set_nonblocking(fd: &dyn AsRawFd, nonblocking: bool) -> io::Result<i32> {
577 let fd = fd.as_raw_fd();
578
579 let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
581 if flags == -1 {
582 return Err(io::Error::last_os_error());
583 }
584
585 let flags = if nonblocking {
586 flags | libc::O_NONBLOCK
587 } else {
588 flags & !libc::O_NONBLOCK
589 };
590
591 match unsafe { libc::fcntl(fd, libc::F_SETFL, flags) } {
593 -1 => Err(io::Error::last_os_error()),
594 result => Ok(result),
595 }
596}
597
598#[cfg(test)]
599#[allow(clippy::unnecessary_first_then_check)]
600mod tests {
601 use super::*;
602
603 use std::io;
604 use std::thread;
605 use std::time::Duration;
606
607 #[test]
608 fn test_readable() -> io::Result<()> {
609 let (writer0, reader0) = UnixStream::pair()?;
610 let (writer1, reader1) = UnixStream::pair()?;
611 let (writer2, reader2) = UnixStream::pair()?;
612
613 let mut events = Vec::new();
614 let mut sources = Sources::new();
615
616 for reader in &[&reader0, &reader1, &reader2] {
617 reader.set_nonblocking(true)?;
618 }
619
620 sources.register("reader0", &reader0, interest::READ);
621 sources.register("reader1", &reader1, interest::READ);
622 sources.register("reader2", &reader2, interest::READ);
623
624 {
625 let err = sources
626 .poll(&mut events, Timeout::from_millis(1))
627 .unwrap_err();
628
629 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
630 assert!(events.is_empty());
631 }
632
633 let tests = &mut [
634 (&writer0, &reader0, "reader0", 0x1u8),
635 (&writer1, &reader1, "reader1", 0x2u8),
636 (&writer2, &reader2, "reader2", 0x3u8),
637 ];
638
639 for (mut writer, mut reader, key, byte) in tests.iter_mut() {
640 let mut buf = [0u8; 1];
641
642 assert!(matches!(
643 reader.read(&mut buf[..]),
644 Err(err) if err.kind() == io::ErrorKind::WouldBlock
645 ));
646
647 writer.write_all(&[*byte])?;
648
649 events.clear();
650 sources.poll(&mut events, Timeout::from_millis(1))?;
651 assert!(!events.is_empty());
652
653 let mut events = events.iter();
654 let event = events.next().unwrap();
655
656 assert_eq!(&event.key, key);
657 assert!(
658 event.is_readable()
659 && !event.is_writable()
660 && !event.is_error()
661 && !event.is_hangup()
662 );
663 assert!(events.next().is_none());
664
665 assert_eq!(reader.read(&mut buf[..])?, 1);
666 assert_eq!(&buf[..], &[*byte]);
667 }
668 Ok(())
669 }
670
671 #[test]
672 fn test_empty() -> io::Result<()> {
673 let mut events: Vec<Event<()>> = Vec::new();
674 let mut sources = Sources::new();
675
676 sources
677 .poll(&mut events, Timeout::from_millis(1))
678 .expect("no error if nothing registered");
679
680 assert!(events.is_empty());
681
682 Ok(())
683 }
684
685 #[test]
686 fn test_timeout() -> io::Result<()> {
687 let mut events = Vec::new();
688 let mut sources = Sources::new();
689
690 sources.register((), &io::stdout(), interest::READ);
691
692 let err = sources
693 .poll(&mut events, Timeout::from_millis(1))
694 .unwrap_err();
695
696 assert_eq!(sources.len(), 1);
697 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
698 assert!(events.is_empty());
699
700 Ok(())
701 }
702
703 #[test]
704 fn test_threaded() -> io::Result<()> {
705 let (writer0, reader0) = UnixStream::pair()?;
706 let (writer1, reader1) = UnixStream::pair()?;
707 let (writer2, reader2) = UnixStream::pair()?;
708
709 let mut events = Vec::new();
710 let mut sources = Sources::new();
711 let readers = &[&reader0, &reader1, &reader2];
712
713 for reader in readers {
714 reader.set_nonblocking(true)?;
715 }
716
717 sources.register("reader0", &reader0, interest::READ);
718 sources.register("reader1", &reader1, interest::READ);
719 sources.register("reader2", &reader2, interest::READ);
720
721 let handle = thread::spawn(move || {
722 thread::sleep(Duration::from_millis(8));
723
724 for writer in &mut [&writer1, &writer2, &writer0] {
725 writer.write_all(&[1]).unwrap();
726 writer.write_all(&[2]).unwrap();
727 }
728 });
729
730 let mut closed = vec![];
731 while closed.len() < readers.len() {
732 sources.poll(&mut events, Timeout::from_millis(64))?;
733
734 for event in events.drain(..) {
735 assert!(event.is_readable());
736 assert!(!event.is_writable());
737 assert!(!event.is_error());
738
739 if event.is_hangup() {
740 closed.push(event.key.to_owned());
741 continue;
742 }
743
744 let mut buf = [0u8; 2];
745 let mut reader = match event.key {
746 "reader0" => &reader0,
747 "reader1" => &reader1,
748 "reader2" => &reader2,
749 _ => unreachable!(),
750 };
751 let n = reader.read(&mut buf[..])?;
752
753 assert_eq!(n, 2);
754 assert_eq!(&buf[..], &[1, 2]);
755 }
756 }
757 handle.join().unwrap();
758
759 Ok(())
760 }
761
762 #[test]
763 fn test_unregister() -> io::Result<()> {
764 use std::collections::HashSet;
765
766 let (mut writer0, reader0) = UnixStream::pair()?;
767 let (mut writer1, reader1) = UnixStream::pair()?;
768 let (writer2, reader2) = UnixStream::pair()?;
769
770 let mut events = Vec::new();
771 let mut sources = Sources::new();
772
773 for reader in &[&reader0, &reader1, &reader2] {
774 reader.set_nonblocking(true)?;
775 }
776
777 sources.register("reader0", &reader0, interest::READ);
778 sources.register("reader1", &reader1, interest::READ);
779 sources.register("reader2", &reader2, interest::READ);
780
781 {
782 let err = sources
783 .poll(&mut events, Timeout::from_millis(1))
784 .unwrap_err();
785
786 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
787 assert!(events.is_empty());
788 }
789
790 {
791 writer1.write_all(&[0x0])?;
792
793 events.clear();
794 sources.poll(&mut events, Timeout::from_millis(1))?;
795 let event = events.first().unwrap();
796
797 assert_eq!(event.key, "reader1");
798 }
799
800 {
802 sources.unregister(&"reader1");
803 writer1.write_all(&[0x0])?;
804
805 events.clear();
806 sources.poll(&mut events, Timeout::from_millis(1)).ok();
807 assert!(events.first().is_none());
808
809 for w in &mut [&writer0, &writer1, &writer2] {
810 w.write_all(&[0])?;
811 }
812
813 sources.poll(&mut events, Timeout::from_millis(1))?;
814 let keys = events.iter().map(|e| e.key).collect::<HashSet<_>>();
815
816 assert!(keys.contains(&"reader0"));
817 assert!(!keys.contains(&"reader1"));
818 assert!(keys.contains(&"reader2"));
819
820 sources.unregister(&"reader0");
821
822 for w in &mut [&writer0, &writer1, &writer2] {
823 w.write_all(&[0])?;
824 }
825
826 events.clear();
827 sources.poll(&mut events, Timeout::from_millis(1))?;
828 let keys = events.iter().map(|e| e.key).collect::<HashSet<_>>();
829
830 assert!(!keys.contains(&"reader0"));
831 assert!(!keys.contains(&"reader1"));
832 assert!(keys.contains(&"reader2"));
833
834 sources.unregister(&"reader2");
835
836 for w in &mut [&writer0, &writer1, &writer2] {
837 w.write_all(&[0])?;
838 }
839
840 events.clear();
841 sources.poll(&mut events, Timeout::from_millis(1)).ok();
842
843 assert!(events.is_empty());
844 }
845
846 {
848 sources.register("reader0", &reader0, interest::READ);
849 writer0.write_all(&[0])?;
850
851 sources.poll(&mut events, Timeout::from_millis(1))?;
852 let event = events.first().unwrap();
853
854 assert_eq!(event.key, "reader0");
855 }
856
857 Ok(())
858 }
859
860 #[test]
861 fn test_set() -> io::Result<()> {
862 let (mut writer0, reader0) = UnixStream::pair()?;
863 let (mut writer1, reader1) = UnixStream::pair()?;
864
865 let mut events = Vec::new();
866 let mut sources = Sources::new();
867
868 for reader in &[&reader0, &reader1] {
869 reader.set_nonblocking(true)?;
870 }
871
872 sources.register("reader0", &reader0, interest::READ);
873 sources.register("reader1", &reader1, interest::NONE);
874
875 {
876 writer0.write_all(&[0])?;
877
878 sources.poll(&mut events, Timeout::from_millis(1))?;
879 let event = events.first().unwrap();
880 assert_eq!(event.key, "reader0");
881
882 sources.unset(&event.key, interest::READ);
883 writer0.write_all(&[0])?;
884 events.clear();
885
886 sources.poll(&mut events, Timeout::from_millis(1)).ok();
887 assert!(events.first().is_none());
888 }
889
890 {
891 writer1.write_all(&[0])?;
892
893 sources.poll(&mut events, Timeout::from_millis(1)).ok();
894 assert!(events.first().is_none());
895
896 sources.set(&"reader1", interest::READ);
897 writer1.write_all(&[0])?;
898
899 sources.poll(&mut events, Timeout::from_millis(1))?;
900 let event = events.first().unwrap();
901 assert_eq!(event.key, "reader1");
902 }
903
904 Ok(())
905 }
906
907 #[test]
908 fn test_waker() -> io::Result<()> {
909 let mut events = Vec::new();
910 let mut sources = Sources::new();
911 let mut waker = Waker::register(&mut sources, "waker")?;
912 let buf = [0; 4096];
913
914 sources.poll(&mut events, Timeout::from_millis(1)).ok();
915 assert!(events.first().is_none());
916
917 loop {
919 match waker.writer.write(&buf) {
920 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
921 break;
922 }
923 Err(e) => return Err(e),
924 _ => continue,
925 }
926 }
927
928 sources.poll(&mut events, Timeout::from_millis(1))?;
929 let event @ Event { key, .. } = events.first().unwrap();
930
931 assert!(event.is_readable());
932 assert!(!event.is_writable() && !event.is_hangup() && !event.is_error());
933 assert_eq!(key, &"waker");
934
935 waker.wake()?;
936
937 events.clear();
938 sources.poll(&mut events, Timeout::from_millis(1))?;
939 let event @ Event { key, .. } = events.first().unwrap();
940
941 assert!(event.is_readable());
942 assert_eq!(key, &"waker");
943
944 waker.wake()?;
946 waker.wake()?;
947 waker.wake()?;
948
949 events.clear();
950 sources.poll(&mut events, Timeout::from_millis(1))?;
951 assert_eq!(events.len(), 1, "multiple wakes count as one");
952
953 let event @ Event { key, .. } = events.first().unwrap();
954 assert_eq!(key, &"waker");
955
956 Waker::reset(event.source).unwrap();
957
958 let result = sources.poll(&mut events, Timeout::from_millis(1));
960 assert!(
961 matches!(
962 result.err().map(|e| e.kind()),
963 Some(io::ErrorKind::TimedOut)
964 ),
965 "the waker should only wake once"
966 );
967
968 Ok(())
969 }
970
971 #[test]
972 fn test_waker_threaded() {
973 let mut events = Vec::new();
974 let mut sources = Sources::new();
975 let waker = Waker::register(&mut sources, "waker").unwrap();
976 let (tx, rx) = std::sync::mpsc::channel();
977 let iterations = 100_000;
978 let handle = std::thread::spawn(move || {
979 for _ in 0..iterations {
980 tx.send(()).unwrap();
981 waker.wake().unwrap();
982 }
983 });
984
985 let mut wakes = 0;
986 let mut received = 0;
987
988 while !handle.is_finished() {
989 events.clear();
990
991 let count = sources.poll(&mut events, Timeout::Never).unwrap();
992 if count > 0 {
993 let event = events.pop().unwrap();
994 assert_eq!(event.key, "waker");
995 assert!(events.is_empty());
996
997 rx.recv().unwrap();
999 received += 1;
1000
1001 while rx.try_recv().is_ok() {
1004 received += 1;
1005 }
1006
1007 if received == iterations {
1008 Waker::reset(event.source).unwrap_err();
1011 break;
1012 }
1013
1014 Waker::reset(event.source).ok(); wakes += 1;
1016 }
1017 }
1018 handle.join().unwrap();
1019
1020 assert_eq!(received, iterations);
1021 assert!(wakes <= received);
1022 }
1023}