Skip to main content

chromiumoxide/
listeners.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15/// Unique identifier for a listener.
16pub type ListenerId = u64;
17
18/// Monotonic id generator for listeners.
19static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21/// Handle returned when you register a listener.
22/// Use it to remove a listener immediately.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25    pub method: MethodId,
26    pub id: ListenerId,
27}
28
29/// All the currently active listeners
30#[derive(Debug, Default)]
31pub struct EventListeners {
32    /// Tracks the listeners for each event identified by the key
33    listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37    /// Register a subscription for a method, returning a handle to remove it.
38    pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39        let EventListenerRequest {
40            listener,
41            method,
42            kind,
43        } = req;
44
45        let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47        let subs = self.listeners.entry(method.clone()).or_default();
48        subs.push(EventListener {
49            id,
50            listener,
51            kind,
52            queued_events: Default::default(),
53        });
54
55        EventListenerHandle { method, id }
56    }
57
58    /// Remove a specific listener immediately.
59    /// Returns true if something was removed.
60    pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61        let mut removed = false;
62        let mut became_empty = false;
63
64        if let Some(subs) = self.listeners.get_mut(&handle.method) {
65            let before = subs.len();
66            subs.retain(|s| s.id != handle.id);
67            removed = subs.len() != before;
68            became_empty = subs.is_empty();
69            // `subs` borrow ends here (end of this if block)
70        }
71
72        if became_empty {
73            self.listeners.remove(&handle.method);
74        }
75
76        removed
77    }
78    /// Remove all listeners for a given method.
79    /// Returns how many were removed.
80    pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81        self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82    }
83
84    /// Queue in an event that should be sent to all listeners.
85    pub fn start_send<T: Event>(&mut self, event: T) {
86        if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87            let event: Arc<dyn Event> = Arc::new(event);
88            subscriptions
89                .iter_mut()
90                .for_each(|sub| sub.start_send(Arc::clone(&event)));
91        }
92    }
93
94    /// Try to queue a custom event if a listener is registered and the json conversion succeeds.
95    pub fn try_send_custom(
96        &mut self,
97        method: &str,
98        val: serde_json::Value,
99    ) -> serde_json::Result<()> {
100        if let Some(subscriptions) = self.listeners.get_mut(method) {
101            let mut event = None;
102
103            if let Some(json_to_arc_event) = subscriptions
104                .iter()
105                .filter_map(|sub| match &sub.kind {
106                    EventKind::Custom(conv) => Some(conv),
107                    _ => None,
108                })
109                .next()
110            {
111                event = Some(json_to_arc_event(val)?);
112            }
113
114            if let Some(event) = event {
115                subscriptions
116                    .iter_mut()
117                    .filter(|sub| sub.kind.is_custom())
118                    .for_each(|sub| sub.start_send(Arc::clone(&event)));
119            }
120        }
121        Ok(())
122    }
123
124    /// Drains all queued events and does housekeeping when the receiver is dropped.
125    pub fn poll(&mut self, cx: &mut Context<'_>) {
126        let _ = cx;
127        for subscriptions in self.listeners.values_mut() {
128            for n in (0..subscriptions.len()).rev() {
129                let mut sub = subscriptions.swap_remove(n);
130                match sub.flush() {
131                    Ok(()) => subscriptions.push(sub),
132                    Err(_) => {
133                        // disconnected — drop the listener
134                    }
135                }
136            }
137        }
138
139        self.listeners.retain(|_, v| !v.is_empty());
140    }
141}
142
143pub struct EventListenerRequest {
144    listener: UnboundedSender<Arc<dyn Event>>,
145    pub method: MethodId,
146    pub kind: EventKind,
147}
148
149impl EventListenerRequest {
150    pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
151        Self {
152            listener,
153            method: T::method_id(),
154            kind: T::event_kind(),
155        }
156    }
157}
158
159impl fmt::Debug for EventListenerRequest {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        f.debug_struct("EventListenerRequest")
162            .field("method", &self.method)
163            .field("kind", &self.kind)
164            .finish()
165    }
166}
167
168/// Represents a single event listener
169pub struct EventListener {
170    /// Unique id for this listener (used for immediate removal).
171    pub id: ListenerId,
172    /// the sender half of the event channel
173    listener: UnboundedSender<Arc<dyn Event>>,
174    /// currently queued events
175    queued_events: VecDeque<Arc<dyn Event>>,
176    /// For what kind of event this event is for
177    kind: EventKind,
178}
179
180impl EventListener {
181    /// queue in a new event
182    pub fn start_send(&mut self, event: Arc<dyn Event>) {
183        self.queued_events.push_back(event)
184    }
185
186    /// Drains all queued events, sending them synchronously via the unbounded channel.
187    /// Returns `Err` if the receiver has been dropped.
188    pub fn flush(&mut self) -> std::result::Result<(), mpsc::error::SendError<Arc<dyn Event>>> {
189        while let Some(event) = self.queued_events.pop_front() {
190            self.listener.send(event)?;
191        }
192        Ok(())
193    }
194}
195
196impl fmt::Debug for EventListener {
197    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198        f.debug_struct("EventListener")
199            .field("id", &self.id)
200            .finish()
201    }
202}
203
204/// The receiver part of an event subscription
205pub struct EventStream<T: IntoEventKind> {
206    events: UnboundedReceiver<Arc<dyn Event>>,
207    _marker: PhantomData<T>,
208}
209
210impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
211    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
212        f.debug_struct("EventStream").finish()
213    }
214}
215
216impl<T: IntoEventKind> EventStream<T> {
217    pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
218        Self {
219            events,
220            _marker: PhantomData,
221        }
222    }
223}
224
225impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
226    type Item = Arc<T>;
227
228    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
229        let pin = self.get_mut();
230        match pin.events.poll_recv(cx) {
231            Poll::Ready(Some(event)) => {
232                if let Ok(e) = event.into_any_arc().downcast() {
233                    Poll::Ready(Some(e))
234                } else {
235                    // wrong type for this stream; keep polling
236                    cx.waker().wake_by_ref();
237                    Poll::Pending
238                }
239            }
240            Poll::Ready(None) => Poll::Ready(None),
241            Poll::Pending => Poll::Pending,
242        }
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use futures_util::StreamExt;
249
250    use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
251    use chromiumoxide_cdp::cdp::CustomEvent;
252    use chromiumoxide_types::{MethodId, MethodType};
253
254    use super::*;
255
256    #[tokio::test]
257    async fn event_stream() {
258        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
259        let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
260
261        let event = EventAnimationCanceled {
262            id: "id".to_string(),
263        };
264        let msg: Arc<dyn Event> = Arc::new(event.clone());
265        tx.send(msg).unwrap();
266        let next = stream.next().await.unwrap();
267        assert_eq!(&*next, &event);
268    }
269
270    #[tokio::test]
271    async fn custom_event_stream() {
272        use serde::Deserialize;
273
274        #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
275        struct MyCustomEvent {
276            name: String,
277        }
278
279        impl MethodType for MyCustomEvent {
280            fn method_id() -> MethodId {
281                "Custom.Event".into()
282            }
283        }
284        impl CustomEvent for MyCustomEvent {}
285
286        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
287        let mut stream = EventStream::<MyCustomEvent>::new(rx);
288
289        let event = MyCustomEvent {
290            name: "my event".to_string(),
291        };
292        let msg: Arc<dyn Event> = Arc::new(event.clone());
293        tx.send(msg).unwrap();
294        let next = stream.next().await.unwrap();
295        assert_eq!(&*next, &event);
296    }
297
298    #[tokio::test]
299    async fn remove_listener_immediately_stops_delivery() {
300        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
301        let mut listeners = EventListeners::default();
302
303        let handle =
304            listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
305        assert!(listeners.remove_listener(&handle));
306
307        listeners.start_send(EventAnimationCanceled {
308            id: "nope".to_string(),
309        });
310
311        std::future::poll_fn(|cx| {
312            listeners.poll(cx);
313            Poll::Ready(())
314        })
315        .await;
316
317        // The listener was removed, so nothing should have been sent
318        assert!(rx.try_recv().is_err());
319    }
320}