Skip to main content

chromiumoxide/
listeners.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15/// Unique identifier for a listener.
16pub type ListenerId = u64;
17
18/// Monotonic id generator for listeners.
19static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21/// Handle returned when you register a listener.
22/// Use it to remove a listener immediately.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25    pub method: MethodId,
26    pub id: ListenerId,
27}
28
29/// All the currently active listeners
30#[derive(Debug, Default)]
31pub struct EventListeners {
32    /// Tracks the listeners for each event identified by the key
33    listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37    /// Register a subscription for a method, returning a handle to remove it.
38    pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39        let EventListenerRequest {
40            listener,
41            method,
42            kind,
43        } = req;
44
45        let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47        let subs = self.listeners.entry(method.clone()).or_default();
48        subs.push(EventListener {
49            id,
50            listener,
51            kind,
52            queued_events: Default::default(),
53        });
54
55        EventListenerHandle { method, id }
56    }
57
58    /// Remove a specific listener immediately.
59    /// Returns true if something was removed.
60    pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61        let mut removed = false;
62        let mut became_empty = false;
63
64        if let Some(subs) = self.listeners.get_mut(&handle.method) {
65            let before = subs.len();
66            subs.retain(|s| s.id != handle.id);
67            removed = subs.len() != before;
68            became_empty = subs.is_empty();
69            // `subs` borrow ends here (end of this if block)
70        }
71
72        if became_empty {
73            self.listeners.remove(&handle.method);
74        }
75
76        removed
77    }
78    /// Remove all listeners for a given method.
79    /// Returns how many were removed.
80    pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81        self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82    }
83
84    /// Queue in an event that should be sent to all listeners.
85    pub fn start_send<T: Event>(&mut self, event: T) {
86        if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87            let event: Arc<dyn Event> = Arc::new(event);
88            subscriptions
89                .iter_mut()
90                .for_each(|sub| sub.start_send(Arc::clone(&event)));
91        }
92    }
93
94    /// Try to queue a custom event if a listener is registered and the json conversion succeeds.
95    pub fn try_send_custom(
96        &mut self,
97        method: &str,
98        val: serde_json::Value,
99    ) -> serde_json::Result<()> {
100        if let Some(subscriptions) = self.listeners.get_mut(method) {
101            let mut event = None;
102
103            if let Some(json_to_arc_event) = subscriptions
104                .iter()
105                .filter_map(|sub| match &sub.kind {
106                    EventKind::Custom(conv) => Some(conv),
107                    _ => None,
108                })
109                .next()
110            {
111                event = Some(json_to_arc_event(val)?);
112            }
113
114            if let Some(event) = event {
115                subscriptions
116                    .iter_mut()
117                    .filter(|sub| sub.kind.is_custom())
118                    .for_each(|sub| sub.start_send(Arc::clone(&event)));
119            }
120        }
121        Ok(())
122    }
123
124    /// Drains all queued events and does housekeeping when the receiver is dropped.
125    ///
126    /// Uses `retain_mut` for a single-pass flush + prune instead of the
127    /// swap-remove/push pattern, avoiding per-listener Vec reshuffling.
128    pub fn poll(&mut self, cx: &mut Context<'_>) {
129        let _ = cx;
130        let mut any_disconnected = false;
131
132        for subscriptions in self.listeners.values_mut() {
133            subscriptions.retain_mut(|sub| match sub.flush() {
134                Ok(()) => true,
135                Err(_) => {
136                    any_disconnected = true;
137                    false
138                }
139            });
140        }
141
142        if any_disconnected {
143            self.listeners.retain(|_, v| !v.is_empty());
144        }
145    }
146
147    /// Flush all queued events without requiring a waker `Context`.
148    ///
149    /// Identical to [`poll`](Self::poll) but usable from the
150    /// `Handler::run()` async path where no `Context` is available.
151    pub fn flush(&mut self) {
152        let mut any_disconnected = false;
153
154        for subscriptions in self.listeners.values_mut() {
155            subscriptions.retain_mut(|sub| match sub.flush() {
156                Ok(()) => true,
157                Err(_) => {
158                    any_disconnected = true;
159                    false
160                }
161            });
162        }
163
164        if any_disconnected {
165            self.listeners.retain(|_, v| !v.is_empty());
166        }
167    }
168}
169
170pub struct EventListenerRequest {
171    listener: UnboundedSender<Arc<dyn Event>>,
172    pub method: MethodId,
173    pub kind: EventKind,
174}
175
176impl EventListenerRequest {
177    pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
178        Self {
179            listener,
180            method: T::method_id(),
181            kind: T::event_kind(),
182        }
183    }
184}
185
186impl fmt::Debug for EventListenerRequest {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        f.debug_struct("EventListenerRequest")
189            .field("method", &self.method)
190            .field("kind", &self.kind)
191            .finish()
192    }
193}
194
195/// Represents a single event listener.
196///
197/// Uses an unbounded channel intentionally: `flush()` is called
198/// synchronously from the handler's poll loop and cannot await a
199/// bounded channel's back-pressure. Bounding the channel would
200/// require either dropping events (behaviour change) or making
201/// the flush path async (large refactor). Consumers that register
202/// a listener must poll the resulting `EventStream` to prevent
203/// unbounded growth.
204pub struct EventListener {
205    /// Unique id for this listener (used for immediate removal).
206    pub id: ListenerId,
207    /// the sender half of the event channel
208    listener: UnboundedSender<Arc<dyn Event>>,
209    /// currently queued events
210    queued_events: VecDeque<Arc<dyn Event>>,
211    /// For what kind of event this event is for
212    kind: EventKind,
213}
214
215impl EventListener {
216    /// queue in a new event
217    pub fn start_send(&mut self, event: Arc<dyn Event>) {
218        self.queued_events.push_back(event)
219    }
220
221    /// Drains all queued events, sending them synchronously via the unbounded channel.
222    /// Returns `Err` if the receiver has been dropped.
223    pub fn flush(&mut self) -> std::result::Result<(), mpsc::error::SendError<Arc<dyn Event>>> {
224        while let Some(event) = self.queued_events.pop_front() {
225            self.listener.send(event)?;
226        }
227        Ok(())
228    }
229}
230
231impl fmt::Debug for EventListener {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        f.debug_struct("EventListener")
234            .field("id", &self.id)
235            .finish()
236    }
237}
238
239/// The receiver part of an event subscription
240pub struct EventStream<T: IntoEventKind> {
241    events: UnboundedReceiver<Arc<dyn Event>>,
242    _marker: PhantomData<T>,
243}
244
245impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
246    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247        f.debug_struct("EventStream").finish()
248    }
249}
250
251impl<T: IntoEventKind> EventStream<T> {
252    pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
253        Self {
254            events,
255            _marker: PhantomData,
256        }
257    }
258}
259
260impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
261    type Item = Arc<T>;
262
263    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
264        let pin = self.get_mut();
265        loop {
266            match pin.events.poll_recv(cx) {
267                Poll::Ready(Some(event)) => {
268                    if let Ok(e) = event.into_any_arc().downcast() {
269                        return Poll::Ready(Some(e));
270                    }
271                    // wrong type — try the next message in the channel
272                    continue;
273                }
274                Poll::Ready(None) => return Poll::Ready(None),
275                Poll::Pending => return Poll::Pending,
276            }
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use futures_util::StreamExt;
284
285    use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
286    use chromiumoxide_cdp::cdp::CustomEvent;
287    use chromiumoxide_types::{MethodId, MethodType};
288
289    use super::*;
290
291    #[tokio::test]
292    async fn event_stream() {
293        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
294        let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
295
296        let event = EventAnimationCanceled {
297            id: "id".to_string(),
298        };
299        let msg: Arc<dyn Event> = Arc::new(event.clone());
300        tx.send(msg).unwrap();
301        let next = stream.next().await.unwrap();
302        assert_eq!(&*next, &event);
303    }
304
305    #[tokio::test]
306    async fn custom_event_stream() {
307        use serde::Deserialize;
308
309        #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
310        struct MyCustomEvent {
311            name: String,
312        }
313
314        impl MethodType for MyCustomEvent {
315            fn method_id() -> MethodId {
316                "Custom.Event".into()
317            }
318        }
319        impl CustomEvent for MyCustomEvent {}
320
321        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
322        let mut stream = EventStream::<MyCustomEvent>::new(rx);
323
324        let event = MyCustomEvent {
325            name: "my event".to_string(),
326        };
327        let msg: Arc<dyn Event> = Arc::new(event.clone());
328        tx.send(msg).unwrap();
329        let next = stream.next().await.unwrap();
330        assert_eq!(&*next, &event);
331    }
332
333    #[tokio::test]
334    async fn remove_listener_immediately_stops_delivery() {
335        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
336        let mut listeners = EventListeners::default();
337
338        let handle =
339            listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
340        assert!(listeners.remove_listener(&handle));
341
342        listeners.start_send(EventAnimationCanceled {
343            id: "nope".to_string(),
344        });
345
346        std::future::poll_fn(|cx| {
347            listeners.poll(cx);
348            Poll::Ready(())
349        })
350        .await;
351
352        // The listener was removed, so nothing should have been sent
353        assert!(rx.try_recv().is_err());
354    }
355}