use std::cell::Cell;
use std::io;
use std::os::fd::{AsFd, BorrowedFd};
use std::sync::{mpsc, Arc};
use std::time::Duration;
use nix::sys::eventfd::{EfdFlags, EventFd};
use crate::subscriber::{Handler, HasInterest};
use crate::thin::ThinBoxSubscriber;
use crate::{interest, Event, EventpOps, EventpOpsAdd, Interest, Pinned};
type BoxFn<Ep> = Box<dyn FnOnce(Pinned<Ep>) + Send>;
pub fn remote_endpoint<Ep>() -> io::Result<Pair<Ep>> {
let eventfd = EventFd::from_flags(EfdFlags::EFD_CLOEXEC | EfdFlags::EFD_NONBLOCK)
.map_err(io::Error::from)?;
let eventfd = Arc::new(eventfd);
let (tx, rx) = mpsc::channel();
let subscriber = Subscriber {
eventfd: Arc::clone(&eventfd),
interest: Cell::new(interest().read()),
rx,
};
let endpoint = RemoteEndpoint { eventfd, tx };
Ok(Pair {
subscriber,
endpoint,
})
}
pub struct Pair<Ep> {
pub subscriber: Subscriber<Ep>,
pub endpoint: RemoteEndpoint<Ep>,
}
pub struct Subscriber<Ep> {
eventfd: Arc<EventFd>,
interest: Cell<Interest>,
rx: mpsc::Receiver<BoxFn<Ep>>,
}
pub struct RemoteEndpoint<Ep> {
eventfd: Arc<EventFd>,
tx: mpsc::Sender<BoxFn<Ep>>,
}
impl<Ep: EventpOps> Pair<Ep> {
pub fn register_into<R>(self, eventp: &mut R) -> io::Result<RemoteEndpoint<Ep>>
where
Self: Sized,
R: EventpOpsAdd<Ep>,
{
eventp.add(ThinBoxSubscriber::new(self.subscriber))?;
Ok(self.endpoint)
}
}
impl<Ep> AsFd for Subscriber<Ep> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.eventfd.as_fd()
}
}
impl<Ep> HasInterest for Subscriber<Ep> {
fn interest(&self) -> &Cell<Interest> {
&self.interest
}
}
impl<Ep: EventpOps> Handler<Ep> for Subscriber<Ep> {
fn handle(&mut self, _event: Event, mut eventp: Pinned<'_, Ep>) {
let _ = self.eventfd.read();
while let Ok(f) = self.rx.try_recv() {
(f)(eventp.as_mut())
}
}
}
fn err_subscriber_dropped() -> io::Error {
io::Error::new(
io::ErrorKind::BrokenPipe,
"`remote_endpoint::Subscriber` dropped",
)
}
macro_rules! call_variant {
($self:ident, $f:ident, |$rx:ident| $rx_expr:expr, |$rx_err:ident| $err_map:expr) => {{
let (tx, $rx) = oneshot::channel();
$self
.tx
.send(Box::new(move |ep| {
let _ = tx.send($f(ep));
}))
.map_err(|_| err_subscriber_dropped())?;
$self.eventfd.write(1).map_err(io::Error::from)?;
match $rx_expr {
Ok(v) => v,
Err($rx_err) => return Err($err_map),
}
}};
}
impl<Ep> RemoteEndpoint<Ep> {
pub async fn call_blocking_async<F, T>(&self, f: F) -> io::Result<T>
where
F: 'static + FnOnce(Pinned<'_, Ep>) -> io::Result<T> + Send,
T: 'static + Send,
{
call_variant!(self, f, |rx| rx.await, |_e| err_subscriber_dropped())
}
pub fn call_blocking<F, T>(&self, f: F) -> io::Result<T>
where
F: 'static + FnOnce(Pinned<'_, Ep>) -> io::Result<T> + Send,
T: 'static + Send,
{
call_variant!(self, f, |rx| rx.recv(), |_e| err_subscriber_dropped())
}
pub fn call_blocking_with_timeout<F, T>(&self, f: F, timeout: Duration) -> io::Result<T>
where
F: 'static + FnOnce(Pinned<'_, Ep>) -> io::Result<T> + Send,
T: 'static + Send,
{
call_variant!(self, f, |rx| rx.recv_timeout(timeout), |e| match e {
oneshot::RecvTimeoutError::Timeout => {
io::Error::new(io::ErrorKind::TimedOut, "remote call timed out")
}
oneshot::RecvTimeoutError::Disconnected => err_subscriber_dropped(),
})
}
pub fn call_nonblocking<F>(&self, f: F) -> io::Result<()>
where
F: 'static + FnOnce(Pinned<'_, Ep>) + Send,
{
self.tx
.send(Box::new(f))
.map_err(|_| err_subscriber_dropped())?;
self.eventfd.write(1).map_err(io::Error::from)?;
Ok(())
}
}
impl<Ep> Clone for RemoteEndpoint<Ep> {
fn clone(&self) -> Self {
Self {
eventfd: self.eventfd.clone(),
tx: self.tx.clone(),
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc as StdArc, Barrier};
use std::thread;
use nix::sys::epoll::EpollTimeout;
use super::*;
use crate::Eventp;
#[cfg(feature = "mock")]
use crate::MockEventp;
const fn assert_send<T: Send>() {}
const fn assert_sync<T: Send>() {}
const _: () = {
assert_send::<RemoteEndpoint<Eventp>>();
assert_sync::<RemoteEndpoint<Eventp>>();
#[cfg(feature = "mock")]
assert_send::<RemoteEndpoint<MockEventp>>();
#[cfg(feature = "mock")]
assert_sync::<RemoteEndpoint<MockEventp>>();
};
fn poll_timeout() -> EpollTimeout {
EpollTimeout::from(500u16)
}
fn spawn_reactor() -> (
RemoteEndpoint<Eventp>,
thread::JoinHandle<()>,
StdArc<AtomicU32>,
) {
let stop = StdArc::new(AtomicU32::new(0));
let stop_for_thread = stop.clone();
let (tx, rx) = mpsc::channel();
let handle = thread::spawn(move || {
let mut eventp = Eventp::default();
let endpoint = remote_endpoint()
.unwrap()
.register_into(&mut eventp)
.unwrap();
tx.send(endpoint).expect("main thread receiving endpoint");
while stop_for_thread.load(Ordering::Acquire) == 0 {
eventp.run_once_with_timeout(poll_timeout()).unwrap();
}
});
let endpoint = rx.recv().expect("reactor thread sending endpoint");
(endpoint, handle, stop)
}
fn shutdown(stop: StdArc<AtomicU32>, handle: thread::JoinHandle<()>) {
stop.store(1, Ordering::Release);
handle.join().expect("reactor thread panicked");
}
#[test]
fn call_blocking_runs_closure_on_reactor_thread() {
let (endpoint, handle, stop) = spawn_reactor();
let reactor_tid = handle.thread().id();
let observed_tid = endpoint
.call_blocking(move |_| Ok(thread::current().id()))
.unwrap();
assert_eq!(observed_tid, reactor_tid);
let err = endpoint
.call_blocking(|_| -> io::Result<()> {
Err(io::Error::new(io::ErrorKind::PermissionDenied, "denied"))
})
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
shutdown(stop, handle);
}
#[test]
fn call_blocking_with_timeout_elapses_when_reactor_idle() {
let pair = remote_endpoint::<Eventp>().unwrap();
let endpoint = pair.endpoint.clone();
let _keep = pair.subscriber;
let err = endpoint
.call_blocking_with_timeout(|_| Ok(()), Duration::from_millis(50))
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::TimedOut);
}
#[test]
fn call_nonblocking_executes_and_drains_batch() {
let (endpoint, handle, stop) = spawn_reactor();
let counter = StdArc::new(AtomicU32::new(0));
for _ in 0..16 {
let c = counter.clone();
endpoint
.call_nonblocking(move |_| {
c.fetch_add(1, Ordering::Relaxed);
})
.unwrap();
}
endpoint.call_blocking(|_| Ok(())).unwrap();
assert_eq!(counter.load(Ordering::Relaxed), 16);
shutdown(stop, handle);
}
#[test]
fn endpoint_returns_error_after_subscriber_dropped() {
let pair = remote_endpoint::<Eventp>().unwrap();
let endpoint = pair.endpoint.clone();
drop(pair.subscriber);
let e1 = endpoint.call_nonblocking(|_| {}).unwrap_err();
assert_eq!(e1.kind(), io::ErrorKind::BrokenPipe);
let e2 = endpoint.call_blocking(|_| Ok(())).unwrap_err();
assert_eq!(e2.kind(), io::ErrorKind::BrokenPipe);
let e3 = endpoint
.call_blocking_with_timeout(|_| Ok(()), Duration::from_millis(10))
.unwrap_err();
assert_eq!(e3.kind(), io::ErrorKind::BrokenPipe);
}
#[test]
fn cloned_endpoints_serve_multiple_threads() {
let (endpoint, handle, stop) = spawn_reactor();
let n = 8usize;
let barrier = StdArc::new(Barrier::new(n));
let counter = StdArc::new(AtomicU32::new(0));
let workers: Vec<_> = (0..n)
.map(|_| {
let ep = endpoint.clone();
let b = barrier.clone();
let c = counter.clone();
thread::spawn(move || {
b.wait();
let v = ep
.call_blocking(move |_| Ok(c.fetch_add(1, Ordering::Relaxed) + 1))
.unwrap();
assert!(v >= 1 && v <= n as u32);
})
})
.collect();
for w in workers {
w.join().unwrap();
}
assert_eq!(counter.load(Ordering::Relaxed), n as u32);
shutdown(stop, handle);
}
#[test]
fn closure_can_mutate_reactor_state() {
let (endpoint, handle, stop) = spawn_reactor();
let err = endpoint
.call_blocking(|mut ep| ep.delete(424242))
.unwrap_err();
assert_eq!(err.kind(), io::ErrorKind::NotFound);
shutdown(stop, handle);
}
}