Skip to main content

chromiumoxide/
listeners.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15/// Unique identifier for a listener.
16pub type ListenerId = u64;
17
18/// Monotonic id generator for listeners.
19static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21/// Handle returned when you register a listener.
22/// Use it to remove a listener immediately.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25    pub method: MethodId,
26    pub id: ListenerId,
27}
28
29/// All the currently active listeners
30#[derive(Debug, Default)]
31pub struct EventListeners {
32    /// Tracks the listeners for each event identified by the key
33    listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37    /// Register a subscription for a method, returning a handle to remove it.
38    pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39        let EventListenerRequest {
40            listener,
41            method,
42            kind,
43        } = req;
44
45        let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47        let subs = self.listeners.entry(method.clone()).or_default();
48        subs.push(EventListener {
49            id,
50            listener,
51            kind,
52            queued_events: Default::default(),
53        });
54
55        EventListenerHandle { method, id }
56    }
57
58    /// Remove a specific listener immediately.
59    /// Returns true if something was removed.
60    pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61        let mut removed = false;
62        let mut became_empty = false;
63
64        if let Some(subs) = self.listeners.get_mut(&handle.method) {
65            let before = subs.len();
66            subs.retain(|s| s.id != handle.id);
67            removed = subs.len() != before;
68            became_empty = subs.is_empty();
69            // `subs` borrow ends here (end of this if block)
70        }
71
72        if became_empty {
73            self.listeners.remove(&handle.method);
74        }
75
76        removed
77    }
78    /// Remove all listeners for a given method.
79    /// Returns how many were removed.
80    pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81        self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82    }
83
84    /// Queue in an event that should be sent to all listeners.
85    pub fn start_send<T: Event>(&mut self, event: T) {
86        if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87            let event: Arc<dyn Event> = Arc::new(event);
88            subscriptions
89                .iter_mut()
90                .for_each(|sub| sub.start_send(Arc::clone(&event)));
91        }
92    }
93
94    /// Try to queue a custom event if a listener is registered and the json conversion succeeds.
95    pub fn try_send_custom(
96        &mut self,
97        method: &str,
98        val: serde_json::Value,
99    ) -> serde_json::Result<()> {
100        if let Some(subscriptions) = self.listeners.get_mut(method) {
101            let mut event = None;
102
103            if let Some(json_to_arc_event) = subscriptions
104                .iter()
105                .filter_map(|sub| match &sub.kind {
106                    EventKind::Custom(conv) => Some(conv),
107                    _ => None,
108                })
109                .next()
110            {
111                event = Some(json_to_arc_event(val)?);
112            }
113
114            if let Some(event) = event {
115                subscriptions
116                    .iter_mut()
117                    .filter(|sub| sub.kind.is_custom())
118                    .for_each(|sub| sub.start_send(Arc::clone(&event)));
119            }
120        }
121        Ok(())
122    }
123
124    /// Drains all queued events and does housekeeping when the receiver is dropped.
125    pub fn poll(&mut self, cx: &mut Context<'_>) {
126        let _ = cx;
127        for subscriptions in self.listeners.values_mut() {
128            for n in (0..subscriptions.len()).rev() {
129                let mut sub = subscriptions.swap_remove(n);
130                match sub.flush() {
131                    Ok(()) => subscriptions.push(sub),
132                    Err(_) => {
133                        // disconnected — drop the listener
134                    }
135                }
136            }
137        }
138
139        self.listeners.retain(|_, v| !v.is_empty());
140    }
141}
142
143pub struct EventListenerRequest {
144    listener: UnboundedSender<Arc<dyn Event>>,
145    pub method: MethodId,
146    pub kind: EventKind,
147}
148
149impl EventListenerRequest {
150    pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
151        Self {
152            listener,
153            method: T::method_id(),
154            kind: T::event_kind(),
155        }
156    }
157}
158
159impl fmt::Debug for EventListenerRequest {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        f.debug_struct("EventListenerRequest")
162            .field("method", &self.method)
163            .field("kind", &self.kind)
164            .finish()
165    }
166}
167
168/// Represents a single event listener
169pub struct EventListener {
170    /// Unique id for this listener (used for immediate removal).
171    pub id: ListenerId,
172    /// the sender half of the event channel
173    listener: UnboundedSender<Arc<dyn Event>>,
174    /// currently queued events
175    queued_events: VecDeque<Arc<dyn Event>>,
176    /// For what kind of event this event is for
177    kind: EventKind,
178}
179
180impl EventListener {
181    /// queue in a new event
182    pub fn start_send(&mut self, event: Arc<dyn Event>) {
183        self.queued_events.push_back(event)
184    }
185
186    /// Drains all queued events, sending them synchronously via the unbounded channel.
187    /// Returns `Err` if the receiver has been dropped.
188    pub fn flush(
189        &mut self,
190    ) -> std::result::Result<(), mpsc::error::SendError<Arc<dyn Event>>> {
191        while let Some(event) = self.queued_events.pop_front() {
192            self.listener.send(event)?;
193        }
194        Ok(())
195    }
196}
197
198impl fmt::Debug for EventListener {
199    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200        f.debug_struct("EventListener")
201            .field("id", &self.id)
202            .finish()
203    }
204}
205
206/// The receiver part of an event subscription
207pub struct EventStream<T: IntoEventKind> {
208    events: UnboundedReceiver<Arc<dyn Event>>,
209    _marker: PhantomData<T>,
210}
211
212impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        f.debug_struct("EventStream").finish()
215    }
216}
217
218impl<T: IntoEventKind> EventStream<T> {
219    pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
220        Self {
221            events,
222            _marker: PhantomData,
223        }
224    }
225}
226
227impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
228    type Item = Arc<T>;
229
230    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
231        let pin = self.get_mut();
232        match pin.events.poll_recv(cx) {
233            Poll::Ready(Some(event)) => {
234                if let Ok(e) = event.into_any_arc().downcast() {
235                    Poll::Ready(Some(e))
236                } else {
237                    // wrong type for this stream; keep polling
238                    cx.waker().wake_by_ref();
239                    Poll::Pending
240                }
241            }
242            Poll::Ready(None) => Poll::Ready(None),
243            Poll::Pending => Poll::Pending,
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use futures_util::StreamExt;
251
252    use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
253    use chromiumoxide_cdp::cdp::CustomEvent;
254    use chromiumoxide_types::{MethodId, MethodType};
255
256    use super::*;
257
258    #[tokio::test]
259    async fn event_stream() {
260        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
261        let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
262
263        let event = EventAnimationCanceled {
264            id: "id".to_string(),
265        };
266        let msg: Arc<dyn Event> = Arc::new(event.clone());
267        tx.send(msg).unwrap();
268        let next = stream.next().await.unwrap();
269        assert_eq!(&*next, &event);
270    }
271
272    #[tokio::test]
273    async fn custom_event_stream() {
274        use serde::Deserialize;
275
276        #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
277        struct MyCustomEvent {
278            name: String,
279        }
280
281        impl MethodType for MyCustomEvent {
282            fn method_id() -> MethodId {
283                "Custom.Event".into()
284            }
285        }
286        impl CustomEvent for MyCustomEvent {}
287
288        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
289        let mut stream = EventStream::<MyCustomEvent>::new(rx);
290
291        let event = MyCustomEvent {
292            name: "my event".to_string(),
293        };
294        let msg: Arc<dyn Event> = Arc::new(event.clone());
295        tx.send(msg).unwrap();
296        let next = stream.next().await.unwrap();
297        assert_eq!(&*next, &event);
298    }
299
300    #[tokio::test]
301    async fn remove_listener_immediately_stops_delivery() {
302        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
303        let mut listeners = EventListeners::default();
304
305        let handle =
306            listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
307        assert!(listeners.remove_listener(&handle));
308
309        listeners.start_send(EventAnimationCanceled {
310            id: "nope".to_string(),
311        });
312
313        std::future::poll_fn(|cx| {
314            listeners.poll(cx);
315            Poll::Ready(())
316        })
317        .await;
318
319        // The listener was removed, so nothing should have been sent
320        assert!(rx.try_recv().is_err());
321    }
322}