chromiumoxide/
listeners.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures::channel::mpsc::{SendError, UnboundedReceiver, UnboundedSender};
10use futures::{Sink, Stream};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15/// Unique identifier for a listener.
16pub type ListenerId = u64;
17
18/// Monotonic id generator for listeners.
19static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21/// Handle returned when you register a listener.
22/// Use it to remove a listener immediately.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25    pub method: MethodId,
26    pub id: ListenerId,
27}
28
29/// All the currently active listeners
30#[derive(Debug, Default)]
31pub struct EventListeners {
32    /// Tracks the listeners for each event identified by the key
33    listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37    /// Register a subscription for a method, returning a handle to remove it.
38    pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39        let EventListenerRequest {
40            listener,
41            method,
42            kind,
43        } = req;
44
45        let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47        let subs = self.listeners.entry(method.clone()).or_default();
48        subs.push(EventListener {
49            id,
50            listener,
51            kind,
52            queued_events: Default::default(),
53        });
54
55        EventListenerHandle { method, id }
56    }
57
58    /// Remove a specific listener immediately.
59    /// Returns true if something was removed.
60    pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61        let mut removed = false;
62        let mut became_empty = false;
63
64        if let Some(subs) = self.listeners.get_mut(&handle.method) {
65            let before = subs.len();
66            subs.retain(|s| s.id != handle.id);
67            removed = subs.len() != before;
68            became_empty = subs.is_empty();
69            // `subs` borrow ends here (end of this if block)
70        }
71
72        if became_empty {
73            self.listeners.remove(&handle.method);
74        }
75
76        removed
77    }
78    /// Remove all listeners for a given method.
79    /// Returns how many were removed.
80    pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81        self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82    }
83
84    /// Queue in an event that should be sent to all listeners.
85    pub fn start_send<T: Event>(&mut self, event: T) {
86        if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87            let event: Arc<dyn Event> = Arc::new(event);
88            subscriptions
89                .iter_mut()
90                .for_each(|sub| sub.start_send(Arc::clone(&event)));
91        }
92    }
93
94    /// Try to queue a custom event if a listener is registered and the json conversion succeeds.
95    pub fn try_send_custom(
96        &mut self,
97        method: &str,
98        val: serde_json::Value,
99    ) -> serde_json::Result<()> {
100        if let Some(subscriptions) = self.listeners.get_mut(method) {
101            let mut event = None;
102
103            if let Some(json_to_arc_event) = subscriptions
104                .iter()
105                .filter_map(|sub| match &sub.kind {
106                    EventKind::Custom(conv) => Some(conv),
107                    _ => None,
108                })
109                .next()
110            {
111                event = Some(json_to_arc_event(val)?);
112            }
113
114            if let Some(event) = event {
115                subscriptions
116                    .iter_mut()
117                    .filter(|sub| sub.kind.is_custom())
118                    .for_each(|sub| sub.start_send(Arc::clone(&event)));
119            }
120        }
121        Ok(())
122    }
123
124    /// Drains all queued events and does housekeeping when the receiver is dropped.
125    pub fn poll(&mut self, cx: &mut Context<'_>) {
126        for subscriptions in self.listeners.values_mut() {
127            for n in (0..subscriptions.len()).rev() {
128                let mut sub = subscriptions.swap_remove(n);
129                match sub.poll(cx) {
130                    Poll::Ready(Err(err)) => {
131                        // disconnected
132                        if !err.is_disconnected() {
133                            subscriptions.push(sub);
134                        }
135                    }
136                    _ => subscriptions.push(sub),
137                }
138            }
139        }
140
141        self.listeners.retain(|_, v| !v.is_empty());
142    }
143}
144
145pub struct EventListenerRequest {
146    listener: UnboundedSender<Arc<dyn Event>>,
147    pub method: MethodId,
148    pub kind: EventKind,
149}
150
151impl EventListenerRequest {
152    pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
153        Self {
154            listener,
155            method: T::method_id(),
156            kind: T::event_kind(),
157        }
158    }
159}
160
161impl fmt::Debug for EventListenerRequest {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        f.debug_struct("EventListenerRequest")
164            .field("method", &self.method)
165            .field("kind", &self.kind)
166            .finish()
167    }
168}
169
170/// Represents a single event listener
171pub struct EventListener {
172    /// Unique id for this listener (used for immediate removal).
173    pub id: ListenerId,
174    /// the sender half of the event channel
175    listener: UnboundedSender<Arc<dyn Event>>,
176    /// currently queued events
177    queued_events: VecDeque<Arc<dyn Event>>,
178    /// For what kind of event this event is for
179    kind: EventKind,
180}
181
182impl EventListener {
183    /// queue in a new event
184    pub fn start_send(&mut self, event: Arc<dyn Event>) {
185        self.queued_events.push_back(event)
186    }
187
188    /// Drains all queued events and begins sending them to the sink.
189    pub fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), SendError>> {
190        loop {
191            match Sink::poll_ready(Pin::new(&mut self.listener), cx) {
192                Poll::Ready(Ok(_)) => {}
193                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
194                Poll::Pending => return Poll::Pending,
195            }
196
197            if let Some(event) = self.queued_events.pop_front() {
198                if let Err(err) = Sink::start_send(Pin::new(&mut self.listener), event) {
199                    return Poll::Ready(Err(err));
200                }
201            } else {
202                return Poll::Ready(Ok(()));
203            }
204        }
205    }
206}
207
208impl fmt::Debug for EventListener {
209    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
210        f.debug_struct("EventListener")
211            .field("id", &self.id)
212            .finish()
213    }
214}
215
216/// The receiver part of an event subscription
217pub struct EventStream<T: IntoEventKind> {
218    events: UnboundedReceiver<Arc<dyn Event>>,
219    _marker: PhantomData<T>,
220}
221
222impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        f.debug_struct("EventStream").finish()
225    }
226}
227
228impl<T: IntoEventKind> EventStream<T> {
229    pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
230        Self {
231            events,
232            _marker: PhantomData,
233        }
234    }
235}
236
237impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
238    type Item = Arc<T>;
239
240    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
241        let pin = self.get_mut();
242        match Stream::poll_next(Pin::new(&mut pin.events), cx) {
243            Poll::Ready(Some(event)) => {
244                if let Ok(e) = event.into_any_arc().downcast() {
245                    Poll::Ready(Some(e))
246                } else {
247                    // wrong type for this stream; keep polling
248                    Poll::Pending
249                }
250            }
251            Poll::Ready(None) => Poll::Ready(None),
252            Poll::Pending => Poll::Pending,
253        }
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use futures::{SinkExt, StreamExt};
260
261    use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
262    use chromiumoxide_cdp::cdp::CustomEvent;
263    use chromiumoxide_types::{MethodId, MethodType};
264
265    use super::*;
266
267    #[tokio::test]
268    async fn event_stream() {
269        let (mut tx, rx) = futures::channel::mpsc::unbounded();
270        let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
271
272        let event = EventAnimationCanceled {
273            id: "id".to_string(),
274        };
275        let msg: Arc<dyn Event> = Arc::new(event.clone());
276        tx.send(msg).await.unwrap();
277        let next = stream.next().await.unwrap();
278        assert_eq!(&*next, &event);
279    }
280
281    #[tokio::test]
282    async fn custom_event_stream() {
283        use serde::Deserialize;
284
285        #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
286        struct MyCustomEvent {
287            name: String,
288        }
289
290        impl MethodType for MyCustomEvent {
291            fn method_id() -> MethodId {
292                "Custom.Event".into()
293            }
294        }
295        impl CustomEvent for MyCustomEvent {}
296
297        let (mut tx, rx) = futures::channel::mpsc::unbounded();
298        let mut stream = EventStream::<MyCustomEvent>::new(rx);
299
300        let event = MyCustomEvent {
301            name: "my event".to_string(),
302        };
303        let msg: Arc<dyn Event> = Arc::new(event.clone());
304        tx.send(msg).await.unwrap();
305        let next = stream.next().await.unwrap();
306        assert_eq!(&*next, &event);
307    }
308
309    #[tokio::test]
310    async fn remove_listener_immediately_stops_delivery() {
311        let (tx, mut rx) = futures::channel::mpsc::unbounded();
312        let mut listeners = EventListeners::default();
313
314        let handle =
315            listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
316        assert!(listeners.remove_listener(&handle));
317
318        listeners.start_send(EventAnimationCanceled {
319            id: "nope".to_string(),
320        });
321
322        futures::future::poll_fn(|cx| {
323            listeners.poll(cx);
324            Poll::Ready(())
325        })
326        .await;
327
328        assert!(rx.try_next().is_err() || rx.try_next().unwrap().is_none());
329    }
330}