use {
crate::{
BorrowedQueue, Connection, Queue,
utils::{eventfd::Eventfd, executor::TaskId, os_error::OsError, poller},
},
parking_lot::{Condvar, Mutex},
run_on_drop::on_drop,
std::{
convert::Infallible,
future::poll_fn,
io,
os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd},
sync::{
Arc,
atomic::{AtomicBool, Ordering::Relaxed},
},
task::{Poll, Waker},
},
};
#[cfg(test)]
mod tests;
#[derive(Clone)]
pub struct QueueWatcher {
data: Arc<QueueWatcherData>,
}
struct QueueWatcherData {
task_id: TaskId,
connection: Connection,
data: Arc<QueueWatcherShared>,
}
struct QueueWatcherShared {
eventfd: Eventfd,
has_error: AtomicBool,
data: Mutex<QueueWatcherMutable>,
cancellation: Condvar,
}
#[derive(Default)]
struct QueueWatcherMutable {
wait_for_reset: bool,
waker: Option<Waker>,
last_error: Option<OsError>,
cancelled: bool,
}
impl Connection {
pub async fn wait_for_events(&self, queues: &[&BorrowedQueue]) -> io::Result<()> {
self.flush()?;
self.wait_for_events_without_flush(queues).await
}
pub(crate) async fn wait_for_events_without_flush(
&self,
queues: &[&BorrowedQueue],
) -> io::Result<()> {
for queue in queues {
if queue.connection() != self {
wrong_con();
}
}
loop {
let mut lock = None;
if let Some((first, other)) = queues.split_first() {
let Some(l) = self.acquire_read_lock_async(first).await else {
return Ok(());
};
lock = Some(l);
for queue in other {
if self.queue_has_events(queue) {
return Ok(());
}
}
}
self.data.data.ensure_no_error()?;
let poll_data = self.data.poller.data.clone();
self.data
.executor
.execute::<io::Result<()>, _>(async move {
poller::readable(&poll_data).await?;
if let Some(lock) = lock {
lock.read_events().await?;
}
Ok(())
})
.await?;
}
}
pub fn create_watcher(
&self,
owned: &[&Queue],
borrowed: impl IntoIterator<Item = BorrowedQueue>,
) -> io::Result<QueueWatcher> {
self.create_watcher_(owned, borrowed.into_iter().collect())
}
fn create_watcher_(
&self,
owned: &[&Queue],
borrowed: Vec<BorrowedQueue>,
) -> io::Result<QueueWatcher> {
for q in owned {
if q.connection() != self {
wrong_con();
}
}
for q in &borrowed {
if q.connection() != self {
wrong_con();
}
}
let shared = Arc::new(QueueWatcherShared {
eventfd: Eventfd::new()?,
has_error: Default::default(),
data: Default::default(),
cancellation: Default::default(),
});
struct CancelData<F> {
connection: Connection,
shared: Arc<QueueWatcherShared>,
owned: Vec<Queue>,
borrowed: Vec<BorrowedQueue>,
_f: F,
}
let cancel_data = CancelData {
connection: self.clone(),
shared: shared.clone(),
owned: owned.iter().map(|q| (*q).clone()).collect(),
borrowed,
_f: on_drop({
let shared = shared.clone();
move || {
shared.data.lock().cancelled = true;
shared.cancellation.notify_all();
}
}),
};
let task_id = self.data.executor.add(async move {
let cancel_data = cancel_data;
let mut qs = vec![];
for q in &cancel_data.owned {
qs.push(&**q);
}
for q in &cancel_data.borrowed {
qs.push(q);
}
let res: io::Result<Infallible> = async {
loop {
cancel_data
.connection
.wait_for_events_without_flush(&qs)
.await?;
cancel_data.shared.eventfd.bump()?;
poll_fn(|ctx| {
let d = &mut *cancel_data.shared.data.lock();
if d.wait_for_reset {
d.waker = Some(ctx.waker().clone());
Poll::Pending
} else {
d.wait_for_reset = true;
d.waker = None;
Poll::Ready(())
}
})
.await;
}
}
.await;
let e = res.unwrap_err();
cancel_data.shared.data.lock().last_error = Some(e.into());
cancel_data.shared.has_error.store(true, Relaxed);
});
let data = Arc::new(QueueWatcherData {
task_id,
connection: self.clone(),
data: shared,
});
Ok(QueueWatcher { data })
}
}
impl QueueWatcher {
pub fn reset(&self) -> io::Result<()> {
let data = &*self.data.data;
if data.has_error.load(Relaxed) {
if let Some(e) = data.data.lock().last_error {
return Err(e.into());
}
}
data.eventfd.clear()?;
let d = &mut *data.data.lock();
if let Some(e) = d.last_error {
let _ = data.eventfd.bump();
return Err(e.into());
}
d.wait_for_reset = false;
if let Some(waker) = d.waker.take() {
waker.wake()
}
Ok(())
}
}
impl Drop for QueueWatcherData {
fn drop(&mut self) {
self.connection.data.executor.cancel(self.task_id);
let mut lock = self.data.data.lock();
while !lock.cancelled {
self.data.cancellation.wait(&mut lock);
}
}
}
impl AsFd for QueueWatcher {
fn as_fd(&self) -> BorrowedFd<'_> {
self.data.data.eventfd.as_fd()
}
}
impl AsRawFd for QueueWatcher {
fn as_raw_fd(&self) -> RawFd {
self.as_fd().as_raw_fd()
}
}
impl AsRawFd for &'_ QueueWatcher {
fn as_raw_fd(&self) -> RawFd {
self.as_fd().as_raw_fd()
}
}
#[cold]
fn wrong_con() -> ! {
panic!("queue does not belong to this connection");
}