use std::time::Duration;
use crate::error::{Error, Result};
use crate::event::Event;
pub struct Subscription {
rx: EventReceiver,
_cancel: CancelHandle,
}
impl Subscription {
pub fn new(rx: EventReceiver, cancel: CancelHandle) -> Self {
Self {
rx,
_cancel: cancel,
}
}
pub fn try_recv(&self) -> Option<Event> {
self.rx.try_recv()
}
pub fn recv(&self, timeout: Duration) -> Result<Event> {
self.rx
.recv_timeout(timeout)
.ok_or(Error::Timeout { elapsed: timeout })
}
pub fn recv_status(&self, timeout: Duration) -> RecvStatus {
self.rx.recv_timeout_status(timeout)
}
pub fn wait_for(&self, predicate: impl Fn(&Event) -> bool, timeout: Duration) -> Result<Event> {
let start = std::time::Instant::now();
loop {
let remaining = timeout.saturating_sub(start.elapsed());
if remaining.is_zero() {
return Err(Error::Timeout {
elapsed: start.elapsed(),
});
}
let poll = remaining.min(Duration::from_millis(10));
match self.rx.recv_timeout_status(poll) {
RecvStatus::Event(event) => {
if predicate(&event) {
return Ok(*event);
}
}
RecvStatus::Timeout => continue,
RecvStatus::Disconnected => {
return Err(Error::Timeout {
elapsed: start.elapsed(),
});
}
}
}
}
pub fn iter(&self) -> SubscriptionIter<'_> {
SubscriptionIter { sub: self }
}
}
pub enum RecvStatus {
Event(Box<Event>),
Timeout,
Disconnected,
}
pub struct SubscriptionIter<'a> {
sub: &'a Subscription,
}
impl<'a> Iterator for SubscriptionIter<'a> {
type Item = Event;
fn next(&mut self) -> Option<Event> {
loop {
match self.sub.rx.recv_timeout_status(Duration::from_millis(100)) {
RecvStatus::Event(event) => return Some(*event),
RecvStatus::Timeout => continue,
RecvStatus::Disconnected => return None,
}
}
}
}
pub struct EventReceiver {
rx: std::sync::mpsc::Receiver<Event>,
}
impl EventReceiver {
pub fn new(rx: std::sync::mpsc::Receiver<Event>) -> Self {
Self { rx }
}
pub fn try_recv(&self) -> Option<Event> {
self.rx.try_recv().ok()
}
pub fn recv_timeout(&self, timeout: Duration) -> Option<Event> {
self.rx.recv_timeout(timeout).ok()
}
pub fn recv_timeout_status(&self, timeout: Duration) -> RecvStatus {
use std::sync::mpsc::RecvTimeoutError;
match self.rx.recv_timeout(timeout) {
Ok(event) => RecvStatus::Event(Box::new(event)),
Err(RecvTimeoutError::Timeout) => RecvStatus::Timeout,
Err(RecvTimeoutError::Disconnected) => RecvStatus::Disconnected,
}
}
}
pub struct CancelHandle {
cancel_fn: Option<Box<dyn FnOnce() + Send>>,
}
impl CancelHandle {
pub fn new(cancel_fn: impl FnOnce() + Send + 'static) -> Self {
Self {
cancel_fn: Some(Box::new(cancel_fn)),
}
}
pub fn noop() -> Self {
Self { cancel_fn: None }
}
}
impl Drop for CancelHandle {
fn drop(&mut self) {
if let Some(cancel) = self.cancel_fn.take() {
cancel();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::event::{Event, EventKind};
use std::sync::mpsc;
fn make_event() -> Event {
Event {
kind: EventKind::FocusChanged,
target: None,
app_name: "test".into(),
app_pid: 0,
timestamp: std::time::Instant::now(),
}
}
#[test]
fn recv_timeout_status_distinguishes_timeout_and_disconnect() {
let (tx, rx) = mpsc::channel::<Event>();
let receiver = EventReceiver::new(rx);
match receiver.recv_timeout_status(Duration::from_millis(10)) {
RecvStatus::Timeout => {}
RecvStatus::Event(_) => panic!("unexpected event"),
RecvStatus::Disconnected => panic!("should not be disconnected yet"),
}
tx.send(make_event()).unwrap();
match receiver.recv_timeout_status(Duration::from_millis(10)) {
RecvStatus::Event(_) => {}
RecvStatus::Timeout => panic!("expected Event, got Timeout"),
RecvStatus::Disconnected => panic!("expected Event, got Disconnected"),
}
drop(tx);
match receiver.recv_timeout_status(Duration::from_millis(10)) {
RecvStatus::Disconnected => {}
RecvStatus::Timeout => panic!("expected Disconnected, got Timeout"),
RecvStatus::Event(_) => panic!("unexpected event"),
}
}
#[test]
fn wait_for_short_circuits_on_disconnect() {
let (tx, rx) = mpsc::channel::<Event>();
drop(tx);
let sub = Subscription::new(EventReceiver::new(rx), CancelHandle::noop());
let start = std::time::Instant::now();
let err = sub
.wait_for(|_| true, Duration::from_secs(5))
.expect_err("wait_for must not return an event on a disconnected stream");
let elapsed = start.elapsed();
match err {
Error::Timeout { .. } => {}
other => panic!("expected Timeout error, got {other:?}"),
}
assert!(
elapsed < Duration::from_secs(1),
"wait_for blocked for {elapsed:?} — should short-circuit on disconnect"
);
}
#[test]
fn subscription_iter_terminates_on_disconnect() {
let (tx, rx) = mpsc::channel::<Event>();
tx.send(make_event()).unwrap();
drop(tx);
let sub = Subscription::new(EventReceiver::new(rx), CancelHandle::noop());
let (done_tx, done_rx) = mpsc::channel::<Vec<Event>>();
std::thread::spawn(move || {
let collected: Vec<Event> = sub.iter().collect();
let _ = done_tx.send(collected);
});
let collected = done_rx
.recv_timeout(Duration::from_secs(2))
.expect("SubscriptionIter did not terminate after sender was dropped");
assert_eq!(collected.len(), 1, "expected the one buffered event");
}
}