chromiumoxide/
listeners.rs1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15pub type ListenerId = u64;
17
18static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25 pub method: MethodId,
26 pub id: ListenerId,
27}
28
29#[derive(Debug, Default)]
31pub struct EventListeners {
32 listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37 pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39 let EventListenerRequest {
40 listener,
41 method,
42 kind,
43 } = req;
44
45 let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47 let subs = self.listeners.entry(method.clone()).or_default();
48 subs.push(EventListener {
49 id,
50 listener,
51 kind,
52 queued_events: Default::default(),
53 });
54
55 EventListenerHandle { method, id }
56 }
57
58 pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61 let mut removed = false;
62 let mut became_empty = false;
63
64 if let Some(subs) = self.listeners.get_mut(&handle.method) {
65 let before = subs.len();
66 subs.retain(|s| s.id != handle.id);
67 removed = subs.len() != before;
68 became_empty = subs.is_empty();
69 }
71
72 if became_empty {
73 self.listeners.remove(&handle.method);
74 }
75
76 removed
77 }
78 pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81 self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82 }
83
84 pub fn start_send<T: Event>(&mut self, event: T) {
86 if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87 let event: Arc<dyn Event> = Arc::new(event);
88 subscriptions
89 .iter_mut()
90 .for_each(|sub| sub.start_send(Arc::clone(&event)));
91 }
92 }
93
94 pub fn try_send_custom(
96 &mut self,
97 method: &str,
98 val: serde_json::Value,
99 ) -> serde_json::Result<()> {
100 if let Some(subscriptions) = self.listeners.get_mut(method) {
101 let mut event = None;
102
103 if let Some(json_to_arc_event) = subscriptions
104 .iter()
105 .filter_map(|sub| match &sub.kind {
106 EventKind::Custom(conv) => Some(conv),
107 _ => None,
108 })
109 .next()
110 {
111 event = Some(json_to_arc_event(val)?);
112 }
113
114 if let Some(event) = event {
115 subscriptions
116 .iter_mut()
117 .filter(|sub| sub.kind.is_custom())
118 .for_each(|sub| sub.start_send(Arc::clone(&event)));
119 }
120 }
121 Ok(())
122 }
123
124 pub fn poll(&mut self, cx: &mut Context<'_>) {
126 let _ = cx;
127 for subscriptions in self.listeners.values_mut() {
128 for n in (0..subscriptions.len()).rev() {
129 let mut sub = subscriptions.swap_remove(n);
130 match sub.flush() {
131 Ok(()) => subscriptions.push(sub),
132 Err(_) => {
133 }
135 }
136 }
137 }
138
139 self.listeners.retain(|_, v| !v.is_empty());
140 }
141}
142
143pub struct EventListenerRequest {
144 listener: UnboundedSender<Arc<dyn Event>>,
145 pub method: MethodId,
146 pub kind: EventKind,
147}
148
149impl EventListenerRequest {
150 pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
151 Self {
152 listener,
153 method: T::method_id(),
154 kind: T::event_kind(),
155 }
156 }
157}
158
159impl fmt::Debug for EventListenerRequest {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 f.debug_struct("EventListenerRequest")
162 .field("method", &self.method)
163 .field("kind", &self.kind)
164 .finish()
165 }
166}
167
168pub struct EventListener {
170 pub id: ListenerId,
172 listener: UnboundedSender<Arc<dyn Event>>,
174 queued_events: VecDeque<Arc<dyn Event>>,
176 kind: EventKind,
178}
179
180impl EventListener {
181 pub fn start_send(&mut self, event: Arc<dyn Event>) {
183 self.queued_events.push_back(event)
184 }
185
186 pub fn flush(
189 &mut self,
190 ) -> std::result::Result<(), mpsc::error::SendError<Arc<dyn Event>>> {
191 while let Some(event) = self.queued_events.pop_front() {
192 self.listener.send(event)?;
193 }
194 Ok(())
195 }
196}
197
198impl fmt::Debug for EventListener {
199 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
200 f.debug_struct("EventListener")
201 .field("id", &self.id)
202 .finish()
203 }
204}
205
206pub struct EventStream<T: IntoEventKind> {
208 events: UnboundedReceiver<Arc<dyn Event>>,
209 _marker: PhantomData<T>,
210}
211
212impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 f.debug_struct("EventStream").finish()
215 }
216}
217
218impl<T: IntoEventKind> EventStream<T> {
219 pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
220 Self {
221 events,
222 _marker: PhantomData,
223 }
224 }
225}
226
227impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
228 type Item = Arc<T>;
229
230 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
231 let pin = self.get_mut();
232 match pin.events.poll_recv(cx) {
233 Poll::Ready(Some(event)) => {
234 if let Ok(e) = event.into_any_arc().downcast() {
235 Poll::Ready(Some(e))
236 } else {
237 cx.waker().wake_by_ref();
239 Poll::Pending
240 }
241 }
242 Poll::Ready(None) => Poll::Ready(None),
243 Poll::Pending => Poll::Pending,
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use futures_util::StreamExt;
251
252 use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
253 use chromiumoxide_cdp::cdp::CustomEvent;
254 use chromiumoxide_types::{MethodId, MethodType};
255
256 use super::*;
257
258 #[tokio::test]
259 async fn event_stream() {
260 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
261 let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
262
263 let event = EventAnimationCanceled {
264 id: "id".to_string(),
265 };
266 let msg: Arc<dyn Event> = Arc::new(event.clone());
267 tx.send(msg).unwrap();
268 let next = stream.next().await.unwrap();
269 assert_eq!(&*next, &event);
270 }
271
272 #[tokio::test]
273 async fn custom_event_stream() {
274 use serde::Deserialize;
275
276 #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
277 struct MyCustomEvent {
278 name: String,
279 }
280
281 impl MethodType for MyCustomEvent {
282 fn method_id() -> MethodId {
283 "Custom.Event".into()
284 }
285 }
286 impl CustomEvent for MyCustomEvent {}
287
288 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
289 let mut stream = EventStream::<MyCustomEvent>::new(rx);
290
291 let event = MyCustomEvent {
292 name: "my event".to_string(),
293 };
294 let msg: Arc<dyn Event> = Arc::new(event.clone());
295 tx.send(msg).unwrap();
296 let next = stream.next().await.unwrap();
297 assert_eq!(&*next, &event);
298 }
299
300 #[tokio::test]
301 async fn remove_listener_immediately_stops_delivery() {
302 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
303 let mut listeners = EventListeners::default();
304
305 let handle =
306 listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
307 assert!(listeners.remove_listener(&handle));
308
309 listeners.start_send(EventAnimationCanceled {
310 id: "nope".to_string(),
311 });
312
313 std::future::poll_fn(|cx| {
314 listeners.poll(cx);
315 Poll::Ready(())
316 })
317 .await;
318
319 assert!(rx.try_recv().is_err());
321 }
322}