#![allow(unsafe_code)]
use super::util::IntHasher;
use curl::multi::Socket;
use polling::{AsSource, Event, Events, Poller};
use std::{
collections::{HashMap, HashSet},
hash::BuildHasherDefault,
io,
sync::Arc,
task::Waker,
time::Duration,
};
pub(crate) struct Selector {
poller: Arc<Poller>,
sockets: HashMap<Socket, Registration, BuildHasherDefault<IntHasher>>,
bad_sockets: HashSet<Socket, BuildHasherDefault<IntHasher>>,
events: Events,
tick: usize,
}
struct Registration {
readable: bool,
writable: bool,
tick: usize,
}
impl Selector {
pub(crate) fn new() -> io::Result<Self> {
Ok(Self {
poller: Arc::new(Poller::new()?),
sockets: HashMap::with_hasher(Default::default()),
bad_sockets: HashSet::with_hasher(Default::default()),
events: Events::new(),
tick: 0,
})
}
pub(crate) fn waker(&self) -> Waker {
waker_fn::waker_fn({
let poller_ref = self.poller.clone();
move || {
let _ = poller_ref.notify();
}
})
}
pub(crate) fn register(
&mut self,
socket: Socket,
readable: bool,
writable: bool,
) -> io::Result<()> {
let previous = self.sockets.insert(
socket,
Registration {
readable,
writable,
tick: self.tick,
},
);
let result = if previous.is_some() {
poller_modify(&self.poller, socket, readable, writable)
} else {
poller_add(&self.poller, socket, readable, writable)
};
match result {
Err(error) if is_bad_socket_error(&error) => {
tracing::debug!(
socket,
?error,
"bad socket registered, will try again later"
);
self.bad_sockets.insert(socket);
Ok(())
}
result => result,
}
}
pub(crate) fn deregister(&mut self, socket: Socket) -> io::Result<()> {
if self.sockets.remove(&socket).is_some() {
self.bad_sockets.remove(&socket);
if let Err(e) = self.poller.delete(unsafe { as_source(socket) }) {
if !is_bad_socket_error(&e) && e.kind() != io::ErrorKind::PermissionDenied {
return Err(e);
}
}
}
Ok(())
}
pub(crate) fn poll(&mut self, timeout: Duration) -> io::Result<bool> {
for event in self.events.iter() {
let socket = event.key as Socket;
if let Some(registration) = self.sockets.get_mut(&socket) {
if registration.tick != self.tick {
poller_modify(
&self.poller,
socket,
registration.readable,
registration.writable,
)?;
registration.tick = self.tick;
}
}
}
self.events.clear();
self.bad_sockets.retain({
let sockets = &mut self.sockets;
let poller = &self.poller;
let tick = self.tick;
move |&socket| {
if let Some(registration) = sockets.get_mut(&socket) {
if registration.tick != tick {
registration.tick = tick;
poller_add(poller, socket, registration.readable, registration.writable)
.is_err()
} else {
true
}
} else {
false
}
}
});
self.tick = self.tick.wrapping_add(1);
match self.poller.wait(&mut self.events, Some(timeout)) {
Ok(0) => Ok(false),
Ok(_) => Ok(true),
Err(e) if e.kind() == io::ErrorKind::Interrupted => Ok(false),
Err(e) => Err(e),
}
}
pub(crate) fn events(&self) -> impl Iterator<Item = (Socket, bool, bool)> + '_ {
self.events
.iter()
.map(|event| (event.key as Socket, event.readable, event.writable))
}
}
fn poller_add(poller: &Poller, socket: Socket, readable: bool, writable: bool) -> io::Result<()> {
let interest = Event::new(socket as usize, readable, writable);
if let Err(error) = unsafe { poller.add(socket, interest) } {
tracing::debug!(
socket,
?error,
"failed to add interest for socket, retrying as a modify",
);
poller.modify(unsafe { as_source(socket) }, interest)?;
}
Ok(())
}
fn poller_modify(
poller: &Poller,
socket: Socket,
readable: bool,
writable: bool,
) -> io::Result<()> {
let interest = Event::new(socket as usize, readable, writable);
if let Err(error) = poller.modify(unsafe { as_source(socket) }, interest) {
tracing::debug!(
socket,
?error,
"failed to modify interest for socket, retrying as an add",
);
unsafe {
poller.add(socket, interest)?;
}
}
Ok(())
}
fn is_bad_socket_error(error: &io::Error) -> bool {
const EBADF: i32 = 9;
const ERROR_INVALID_HANDLE: i32 = 6;
const ERROR_NOT_FOUND: i32 = 1168;
match error.kind() {
io::ErrorKind::NotFound | io::ErrorKind::InvalidInput => true,
_ => match error.raw_os_error() {
Some(EBADF) if cfg!(unix) => true,
Some(ERROR_INVALID_HANDLE) | Some(ERROR_NOT_FOUND) if cfg!(windows) => true,
_ => false,
},
}
}
unsafe fn as_source(socket: Socket) -> impl AsSource {
#[cfg(unix)]
unsafe {
std::os::fd::BorrowedFd::borrow_raw(socket)
}
#[cfg(windows)]
unsafe {
std::os::windows::io::BorrowedSocket::borrow_raw(socket)
}
}