Skip to main content

chromiumoxide/
listeners.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15/// Unique identifier for a listener.
16pub type ListenerId = u64;
17
18/// Monotonic id generator for listeners.
19static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21/// Handle returned when you register a listener.
22/// Use it to remove a listener immediately.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25    pub method: MethodId,
26    pub id: ListenerId,
27}
28
29/// All the currently active listeners
30#[derive(Debug, Default)]
31pub struct EventListeners {
32    /// Tracks the listeners for each event identified by the key
33    listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37    /// Register a subscription for a method, returning a handle to remove it.
38    pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39        let EventListenerRequest {
40            listener,
41            method,
42            kind,
43        } = req;
44
45        let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47        let subs = self.listeners.entry(method.clone()).or_default();
48        subs.push(EventListener {
49            id,
50            listener,
51            kind,
52            queued_events: Default::default(),
53        });
54
55        EventListenerHandle { method, id }
56    }
57
58    /// Remove a specific listener immediately.
59    /// Returns true if something was removed.
60    pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61        let mut removed = false;
62        let mut became_empty = false;
63
64        if let Some(subs) = self.listeners.get_mut(&handle.method) {
65            let before = subs.len();
66            subs.retain(|s| s.id != handle.id);
67            removed = subs.len() != before;
68            became_empty = subs.is_empty();
69            // `subs` borrow ends here (end of this if block)
70        }
71
72        if became_empty {
73            self.listeners.remove(&handle.method);
74        }
75
76        removed
77    }
78    /// Remove all listeners for a given method.
79    /// Returns how many were removed.
80    pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81        self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82    }
83
84    /// Queue in an event that should be sent to all listeners.
85    pub fn start_send<T: Event>(&mut self, event: T) {
86        if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87            let event: Arc<dyn Event> = Arc::new(event);
88            subscriptions
89                .iter_mut()
90                .for_each(|sub| sub.start_send(Arc::clone(&event)));
91        }
92    }
93
94    /// Try to queue a custom event if a listener is registered and the json conversion succeeds.
95    pub fn try_send_custom(
96        &mut self,
97        method: &str,
98        val: serde_json::Value,
99    ) -> serde_json::Result<()> {
100        if let Some(subscriptions) = self.listeners.get_mut(method) {
101            let mut event = None;
102
103            if let Some(json_to_arc_event) = subscriptions
104                .iter()
105                .filter_map(|sub| match &sub.kind {
106                    EventKind::Custom(conv) => Some(conv),
107                    _ => None,
108                })
109                .next()
110            {
111                event = Some(json_to_arc_event(val)?);
112            }
113
114            if let Some(event) = event {
115                subscriptions
116                    .iter_mut()
117                    .filter(|sub| sub.kind.is_custom())
118                    .for_each(|sub| sub.start_send(Arc::clone(&event)));
119            }
120        }
121        Ok(())
122    }
123
124    /// Drains all queued events and does housekeeping when the receiver is dropped.
125    ///
126    /// Uses `retain_mut` for a single-pass flush + prune instead of the
127    /// swap-remove/push pattern, avoiding per-listener Vec reshuffling.
128    pub fn poll(&mut self, cx: &mut Context<'_>) {
129        let _ = cx;
130        let mut any_disconnected = false;
131
132        for subscriptions in self.listeners.values_mut() {
133            subscriptions.retain_mut(|sub| match sub.flush() {
134                Ok(()) => true,
135                Err(_) => {
136                    any_disconnected = true;
137                    false
138                }
139            });
140        }
141
142        if any_disconnected {
143            self.listeners.retain(|_, v| !v.is_empty());
144        }
145    }
146
147    /// Flush all queued events without requiring a waker `Context`.
148    ///
149    /// Identical to [`poll`](Self::poll) but usable from the
150    /// `Handler::run()` async path where no `Context` is available.
151    pub fn flush(&mut self) {
152        let mut any_disconnected = false;
153
154        for subscriptions in self.listeners.values_mut() {
155            subscriptions.retain_mut(|sub| match sub.flush() {
156                Ok(()) => true,
157                Err(_) => {
158                    any_disconnected = true;
159                    false
160                }
161            });
162        }
163
164        if any_disconnected {
165            self.listeners.retain(|_, v| !v.is_empty());
166        }
167    }
168}
169
170pub struct EventListenerRequest {
171    listener: UnboundedSender<Arc<dyn Event>>,
172    pub method: MethodId,
173    pub kind: EventKind,
174}
175
176impl EventListenerRequest {
177    pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
178        Self {
179            listener,
180            method: T::method_id(),
181            kind: T::event_kind(),
182        }
183    }
184}
185
186impl fmt::Debug for EventListenerRequest {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        f.debug_struct("EventListenerRequest")
189            .field("method", &self.method)
190            .field("kind", &self.kind)
191            .finish()
192    }
193}
194
195/// Represents a single event listener.
196///
197/// Uses an unbounded channel intentionally: `flush()` is called
198/// synchronously from the handler's poll loop and cannot await a
199/// bounded channel's back-pressure. Bounding the channel would
200/// require either dropping events (behaviour change) or making
201/// the flush path async (large refactor). Consumers that register
202/// a listener must poll the resulting `EventStream` to prevent
203/// unbounded growth.
204pub struct EventListener {
205    /// Unique id for this listener (used for immediate removal).
206    pub id: ListenerId,
207    /// the sender half of the event channel
208    listener: UnboundedSender<Arc<dyn Event>>,
209    /// currently queued events
210    queued_events: VecDeque<Arc<dyn Event>>,
211    /// For what kind of event this event is for
212    kind: EventKind,
213}
214
215impl EventListener {
216    /// queue in a new event
217    pub fn start_send(&mut self, event: Arc<dyn Event>) {
218        self.queued_events.push_back(event)
219    }
220
221    /// Drains all queued events, sending them synchronously via the unbounded channel.
222    /// Returns `Err` if the receiver has been dropped.
223    pub fn flush(&mut self) -> std::result::Result<(), mpsc::error::SendError<Arc<dyn Event>>> {
224        while let Some(event) = self.queued_events.pop_front() {
225            self.listener.send(event)?;
226        }
227        Ok(())
228    }
229}
230
231impl fmt::Debug for EventListener {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        f.debug_struct("EventListener")
234            .field("id", &self.id)
235            .finish()
236    }
237}
238
239/// The receiver part of an event subscription
240pub struct EventStream<T: IntoEventKind> {
241    events: UnboundedReceiver<Arc<dyn Event>>,
242    _marker: PhantomData<T>,
243}
244
245impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        f.debug_struct("EventStream").finish()
248    }
249}
250
251impl<T: IntoEventKind> EventStream<T> {
252    pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
253        Self {
254            events,
255            _marker: PhantomData,
256        }
257    }
258}
259
260/// Per-poll budget: how many wrong-type events may be drained inside one
261/// `poll_next` call before we cooperatively yield.  Wrong-type events only
262/// occur when two custom-event listeners share a `method_id` but expect
263/// different Rust types (see `EventListeners::try_send_custom`); without a
264/// cap, a steady producer of those mismatches could keep one tokio worker
265/// inside `poll_next` for an unbounded number of synchronous iterations.
266const MAX_WRONG_TYPE_PER_POLL: usize = 32;
267
268impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
269    type Item = Arc<T>;
270
271    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272        let pin = self.get_mut();
273        for _ in 0..MAX_WRONG_TYPE_PER_POLL {
274            match pin.events.poll_recv(cx) {
275                Poll::Ready(Some(event)) => {
276                    if let Ok(e) = event.into_any_arc().downcast() {
277                        return Poll::Ready(Some(e));
278                    }
279                    // wrong type — drop and try the next message
280                    continue;
281                }
282                Poll::Ready(None) => return Poll::Ready(None),
283                Poll::Pending => return Poll::Pending,
284            }
285        }
286        // Hit the per-poll cap.  Re-arm ourselves so the runtime
287        // re-polls us, then yield — other tasks get a chance to run.
288        cx.waker().wake_by_ref();
289        Poll::Pending
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use futures_util::StreamExt;
296
297    use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
298    use chromiumoxide_cdp::cdp::CustomEvent;
299    use chromiumoxide_types::{MethodId, MethodType};
300
301    use super::*;
302
303    #[tokio::test]
304    async fn event_stream() {
305        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
306        let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
307
308        let event = EventAnimationCanceled {
309            id: "id".to_string(),
310        };
311        let msg: Arc<dyn Event> = Arc::new(event.clone());
312        tx.send(msg).unwrap();
313        let next = stream.next().await.unwrap();
314        assert_eq!(&*next, &event);
315    }
316
317    #[tokio::test]
318    async fn custom_event_stream() {
319        use serde::Deserialize;
320
321        #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
322        struct MyCustomEvent {
323            name: String,
324        }
325
326        impl MethodType for MyCustomEvent {
327            fn method_id() -> MethodId {
328                "Custom.Event".into()
329            }
330        }
331        impl CustomEvent for MyCustomEvent {}
332
333        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
334        let mut stream = EventStream::<MyCustomEvent>::new(rx);
335
336        let event = MyCustomEvent {
337            name: "my event".to_string(),
338        };
339        let msg: Arc<dyn Event> = Arc::new(event.clone());
340        tx.send(msg).unwrap();
341        let next = stream.next().await.unwrap();
342        assert_eq!(&*next, &event);
343    }
344
345    #[tokio::test]
346    async fn remove_listener_immediately_stops_delivery() {
347        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
348        let mut listeners = EventListeners::default();
349
350        let handle =
351            listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
352        assert!(listeners.remove_listener(&handle));
353
354        listeners.start_send(EventAnimationCanceled {
355            id: "nope".to_string(),
356        });
357
358        std::future::poll_fn(|cx| {
359            listeners.poll(cx);
360            Poll::Ready(())
361        })
362        .await;
363
364        // The listener was removed, so nothing should have been sent
365        assert!(rx.try_recv().is_err());
366    }
367
368    // ---------------------------------------------------------------
369    // Per-poll budget regression tests for `EventStream::poll_next`.
370    //
371    // Wrong-type messages reach a stream when two custom-event listeners
372    // share a `method_id` but have different Rust types (the dispatcher
373    // sends one converted `Arc<dyn Event>` to all of them; the second
374    // listener's `downcast` then fails). Without a per-poll cap, a steady
375    // stream of those would loop synchronously inside one `poll_next`
376    // call, blocking the worker.
377    // ---------------------------------------------------------------
378
379    use serde::Deserialize;
380
381    #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
382    struct WrongA {
383        a: i32,
384    }
385    impl MethodType for WrongA {
386        fn method_id() -> MethodId {
387            "Custom.PollBudget".into()
388        }
389    }
390    impl CustomEvent for WrongA {}
391
392    #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
393    struct RightB {
394        b: i32,
395    }
396    impl MethodType for RightB {
397        fn method_id() -> MethodId {
398            "Custom.PollBudget".into()
399        }
400    }
401    impl CustomEvent for RightB {}
402
403    /// A flood of wrong-type events larger than `MAX_WRONG_TYPE_PER_POLL`
404    /// is fully drained and the trailing right-type event is still
405    /// delivered — proving the per-poll cap doesn't lose events.
406    #[tokio::test]
407    async fn poll_next_drains_wrong_type_flood() {
408        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
409        let mut stream = EventStream::<RightB>::new(rx);
410
411        // Use ~10x the budget so the stream must yield-and-resume several
412        // times across separate poll calls.
413        let flood = MAX_WRONG_TYPE_PER_POLL * 10;
414        for i in 0..flood {
415            let msg: Arc<dyn Event> = Arc::new(WrongA { a: i as i32 });
416            tx.send(msg).unwrap();
417        }
418        let target = RightB { b: 7 };
419        let target_msg: Arc<dyn Event> = Arc::new(target.clone());
420        tx.send(target_msg).unwrap();
421
422        let got = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
423            .await
424            .expect("stream must not hang under wrong-type flood")
425            .expect("stream should yield the right-type event");
426        assert_eq!(&*got, &target);
427    }
428
429    /// One `poll_next` call must consume at most `MAX_WRONG_TYPE_PER_POLL`
430    /// wrong-type messages and then return `Pending` (re-arming itself via
431    /// the waker). With strictly more wrong-type events queued than the
432    /// budget, the first poll must NOT keep going to completion.
433    #[tokio::test]
434    async fn poll_next_returns_pending_after_budget() {
435        use std::pin::Pin;
436        use std::task::Poll;
437
438        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
439        let mut stream = EventStream::<RightB>::new(rx);
440
441        // Strictly more wrong-type events than the budget, with no
442        // right-type event queued yet. Without the cap, the loop would
443        // synchronously drain everything and then block on `Pending`.
444        let queued = MAX_WRONG_TYPE_PER_POLL + 5;
445        for i in 0..queued {
446            let msg: Arc<dyn Event> = Arc::new(WrongA { a: i as i32 });
447            tx.send(msg).unwrap();
448        }
449
450        let waker = futures_util::task::noop_waker();
451        let mut cx = Context::from_waker(&waker);
452        let res = Pin::new(&mut stream).poll_next(&mut cx);
453        assert!(
454            matches!(res, Poll::Pending),
455            "first poll must yield once the per-poll budget is consumed"
456        );
457
458        // The exact remaining count isn't part of the contract, but at
459        // least `queued - MAX_WRONG_TYPE_PER_POLL` events must still be
460        // sitting in the channel — the cap really did stop early.
461        let mut remaining = 0usize;
462        while stream.events.try_recv().is_ok() {
463            remaining += 1;
464        }
465        assert!(
466            remaining >= queued - MAX_WRONG_TYPE_PER_POLL,
467            "expected at least {} events left after budget poll, found {}",
468            queued - MAX_WRONG_TYPE_PER_POLL,
469            remaining
470        );
471    }
472
473    /// After yielding under the budget cap, a follow-up poll must resume
474    /// draining and ultimately deliver a trailing right-type event — the
475    /// re-arm via `wake_by_ref` is what keeps the stream live.
476    #[tokio::test]
477    async fn poll_next_resumes_after_budget_yield() {
478        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
479        let mut stream = EventStream::<RightB>::new(rx);
480
481        // > 1 budget worth of wrong types, then the right-type tail.
482        for i in 0..(MAX_WRONG_TYPE_PER_POLL + 5) {
483            let msg: Arc<dyn Event> = Arc::new(WrongA { a: i as i32 });
484            tx.send(msg).unwrap();
485        }
486        let target = RightB { b: 99 };
487        let target_msg: Arc<dyn Event> = Arc::new(target.clone());
488        tx.send(target_msg).unwrap();
489
490        // Awaiting `next()` exercises the wake-and-resume path — if the
491        // re-arm were missing, this would hang.
492        let got = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
493            .await
494            .expect("re-arm must wake the stream after budget yield")
495            .expect("right-type event should be delivered");
496        assert_eq!(&*got, &target);
497    }
498}