1use std::collections::{HashMap, VecDeque};
2use std::fmt;
3use std::marker::PhantomData;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9use futures_util::Stream;
10use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
11
12use chromiumoxide_cdp::cdp::{Event, EventKind, IntoEventKind};
13use chromiumoxide_types::MethodId;
14
15pub type ListenerId = u64;
17
18static NEXT_LISTENER_ID: AtomicU64 = AtomicU64::new(1);
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct EventListenerHandle {
25 pub method: MethodId,
26 pub id: ListenerId,
27}
28
29#[derive(Debug, Default)]
31pub struct EventListeners {
32 listeners: HashMap<MethodId, Vec<EventListener>>,
34}
35
36impl EventListeners {
37 pub fn add_listener(&mut self, req: EventListenerRequest) -> EventListenerHandle {
39 let EventListenerRequest {
40 listener,
41 method,
42 kind,
43 } = req;
44
45 let id = NEXT_LISTENER_ID.fetch_add(1, Ordering::Relaxed);
46
47 let subs = self.listeners.entry(method.clone()).or_default();
48 subs.push(EventListener {
49 id,
50 listener,
51 kind,
52 queued_events: Default::default(),
53 });
54
55 EventListenerHandle { method, id }
56 }
57
58 pub fn remove_listener(&mut self, handle: &EventListenerHandle) -> bool {
61 let mut removed = false;
62 let mut became_empty = false;
63
64 if let Some(subs) = self.listeners.get_mut(&handle.method) {
65 let before = subs.len();
66 subs.retain(|s| s.id != handle.id);
67 removed = subs.len() != before;
68 became_empty = subs.is_empty();
69 }
71
72 if became_empty {
73 self.listeners.remove(&handle.method);
74 }
75
76 removed
77 }
78 pub fn remove_all_for_method(&mut self, method: &MethodId) -> usize {
81 self.listeners.remove(method).map(|v| v.len()).unwrap_or(0)
82 }
83
84 pub fn start_send<T: Event>(&mut self, event: T) {
86 if let Some(subscriptions) = self.listeners.get_mut(&T::method_id()) {
87 let event: Arc<dyn Event> = Arc::new(event);
88 subscriptions
89 .iter_mut()
90 .for_each(|sub| sub.start_send(Arc::clone(&event)));
91 }
92 }
93
94 pub fn try_send_custom(
96 &mut self,
97 method: &str,
98 val: serde_json::Value,
99 ) -> serde_json::Result<()> {
100 if let Some(subscriptions) = self.listeners.get_mut(method) {
101 let mut event = None;
102
103 if let Some(json_to_arc_event) = subscriptions
104 .iter()
105 .filter_map(|sub| match &sub.kind {
106 EventKind::Custom(conv) => Some(conv),
107 _ => None,
108 })
109 .next()
110 {
111 event = Some(json_to_arc_event(val)?);
112 }
113
114 if let Some(event) = event {
115 subscriptions
116 .iter_mut()
117 .filter(|sub| sub.kind.is_custom())
118 .for_each(|sub| sub.start_send(Arc::clone(&event)));
119 }
120 }
121 Ok(())
122 }
123
124 pub fn poll(&mut self, cx: &mut Context<'_>) {
129 let _ = cx;
130 let mut any_disconnected = false;
131
132 for subscriptions in self.listeners.values_mut() {
133 subscriptions.retain_mut(|sub| match sub.flush() {
134 Ok(()) => true,
135 Err(_) => {
136 any_disconnected = true;
137 false
138 }
139 });
140 }
141
142 if any_disconnected {
143 self.listeners.retain(|_, v| !v.is_empty());
144 }
145 }
146
147 pub fn flush(&mut self) {
152 let mut any_disconnected = false;
153
154 for subscriptions in self.listeners.values_mut() {
155 subscriptions.retain_mut(|sub| match sub.flush() {
156 Ok(()) => true,
157 Err(_) => {
158 any_disconnected = true;
159 false
160 }
161 });
162 }
163
164 if any_disconnected {
165 self.listeners.retain(|_, v| !v.is_empty());
166 }
167 }
168}
169
170pub struct EventListenerRequest {
171 listener: UnboundedSender<Arc<dyn Event>>,
172 pub method: MethodId,
173 pub kind: EventKind,
174}
175
176impl EventListenerRequest {
177 pub fn new<T: IntoEventKind>(listener: UnboundedSender<Arc<dyn Event>>) -> Self {
178 Self {
179 listener,
180 method: T::method_id(),
181 kind: T::event_kind(),
182 }
183 }
184}
185
186impl fmt::Debug for EventListenerRequest {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 f.debug_struct("EventListenerRequest")
189 .field("method", &self.method)
190 .field("kind", &self.kind)
191 .finish()
192 }
193}
194
195pub struct EventListener {
205 pub id: ListenerId,
207 listener: UnboundedSender<Arc<dyn Event>>,
209 queued_events: VecDeque<Arc<dyn Event>>,
211 kind: EventKind,
213}
214
215impl EventListener {
216 pub fn start_send(&mut self, event: Arc<dyn Event>) {
218 self.queued_events.push_back(event)
219 }
220
221 pub fn flush(&mut self) -> std::result::Result<(), mpsc::error::SendError<Arc<dyn Event>>> {
224 while let Some(event) = self.queued_events.pop_front() {
225 self.listener.send(event)?;
226 }
227 Ok(())
228 }
229}
230
231impl fmt::Debug for EventListener {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 f.debug_struct("EventListener")
234 .field("id", &self.id)
235 .finish()
236 }
237}
238
239pub struct EventStream<T: IntoEventKind> {
241 events: UnboundedReceiver<Arc<dyn Event>>,
242 _marker: PhantomData<T>,
243}
244
245impl<T: IntoEventKind> fmt::Debug for EventStream<T> {
246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
247 f.debug_struct("EventStream").finish()
248 }
249}
250
251impl<T: IntoEventKind> EventStream<T> {
252 pub fn new(events: UnboundedReceiver<Arc<dyn Event>>) -> Self {
253 Self {
254 events,
255 _marker: PhantomData,
256 }
257 }
258}
259
260const MAX_WRONG_TYPE_PER_POLL: usize = 32;
267
268impl<T: IntoEventKind + Unpin> Stream for EventStream<T> {
269 type Item = Arc<T>;
270
271 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272 let pin = self.get_mut();
273 for _ in 0..MAX_WRONG_TYPE_PER_POLL {
274 match pin.events.poll_recv(cx) {
275 Poll::Ready(Some(event)) => {
276 if let Ok(e) = event.into_any_arc().downcast() {
277 return Poll::Ready(Some(e));
278 }
279 continue;
281 }
282 Poll::Ready(None) => return Poll::Ready(None),
283 Poll::Pending => return Poll::Pending,
284 }
285 }
286 cx.waker().wake_by_ref();
289 Poll::Pending
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use futures_util::StreamExt;
296
297 use chromiumoxide_cdp::cdp::browser_protocol::animation::EventAnimationCanceled;
298 use chromiumoxide_cdp::cdp::CustomEvent;
299 use chromiumoxide_types::{MethodId, MethodType};
300
301 use super::*;
302
303 #[tokio::test]
304 async fn event_stream() {
305 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
306 let mut stream = EventStream::<EventAnimationCanceled>::new(rx);
307
308 let event = EventAnimationCanceled {
309 id: "id".to_string(),
310 };
311 let msg: Arc<dyn Event> = Arc::new(event.clone());
312 tx.send(msg).unwrap();
313 let next = stream.next().await.unwrap();
314 assert_eq!(&*next, &event);
315 }
316
317 #[tokio::test]
318 async fn custom_event_stream() {
319 use serde::Deserialize;
320
321 #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
322 struct MyCustomEvent {
323 name: String,
324 }
325
326 impl MethodType for MyCustomEvent {
327 fn method_id() -> MethodId {
328 "Custom.Event".into()
329 }
330 }
331 impl CustomEvent for MyCustomEvent {}
332
333 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
334 let mut stream = EventStream::<MyCustomEvent>::new(rx);
335
336 let event = MyCustomEvent {
337 name: "my event".to_string(),
338 };
339 let msg: Arc<dyn Event> = Arc::new(event.clone());
340 tx.send(msg).unwrap();
341 let next = stream.next().await.unwrap();
342 assert_eq!(&*next, &event);
343 }
344
345 #[tokio::test]
346 async fn remove_listener_immediately_stops_delivery() {
347 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
348 let mut listeners = EventListeners::default();
349
350 let handle =
351 listeners.add_listener(EventListenerRequest::new::<EventAnimationCanceled>(tx));
352 assert!(listeners.remove_listener(&handle));
353
354 listeners.start_send(EventAnimationCanceled {
355 id: "nope".to_string(),
356 });
357
358 std::future::poll_fn(|cx| {
359 listeners.poll(cx);
360 Poll::Ready(())
361 })
362 .await;
363
364 assert!(rx.try_recv().is_err());
366 }
367
368 use serde::Deserialize;
380
381 #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
382 struct WrongA {
383 a: i32,
384 }
385 impl MethodType for WrongA {
386 fn method_id() -> MethodId {
387 "Custom.PollBudget".into()
388 }
389 }
390 impl CustomEvent for WrongA {}
391
392 #[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
393 struct RightB {
394 b: i32,
395 }
396 impl MethodType for RightB {
397 fn method_id() -> MethodId {
398 "Custom.PollBudget".into()
399 }
400 }
401 impl CustomEvent for RightB {}
402
403 #[tokio::test]
407 async fn poll_next_drains_wrong_type_flood() {
408 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
409 let mut stream = EventStream::<RightB>::new(rx);
410
411 let flood = MAX_WRONG_TYPE_PER_POLL * 10;
414 for i in 0..flood {
415 let msg: Arc<dyn Event> = Arc::new(WrongA { a: i as i32 });
416 tx.send(msg).unwrap();
417 }
418 let target = RightB { b: 7 };
419 let target_msg: Arc<dyn Event> = Arc::new(target.clone());
420 tx.send(target_msg).unwrap();
421
422 let got = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
423 .await
424 .expect("stream must not hang under wrong-type flood")
425 .expect("stream should yield the right-type event");
426 assert_eq!(&*got, &target);
427 }
428
429 #[tokio::test]
434 async fn poll_next_returns_pending_after_budget() {
435 use std::pin::Pin;
436 use std::task::Poll;
437
438 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
439 let mut stream = EventStream::<RightB>::new(rx);
440
441 let queued = MAX_WRONG_TYPE_PER_POLL + 5;
445 for i in 0..queued {
446 let msg: Arc<dyn Event> = Arc::new(WrongA { a: i as i32 });
447 tx.send(msg).unwrap();
448 }
449
450 let waker = futures_util::task::noop_waker();
451 let mut cx = Context::from_waker(&waker);
452 let res = Pin::new(&mut stream).poll_next(&mut cx);
453 assert!(
454 matches!(res, Poll::Pending),
455 "first poll must yield once the per-poll budget is consumed"
456 );
457
458 let mut remaining = 0usize;
462 while stream.events.try_recv().is_ok() {
463 remaining += 1;
464 }
465 assert!(
466 remaining >= queued - MAX_WRONG_TYPE_PER_POLL,
467 "expected at least {} events left after budget poll, found {}",
468 queued - MAX_WRONG_TYPE_PER_POLL,
469 remaining
470 );
471 }
472
473 #[tokio::test]
477 async fn poll_next_resumes_after_budget_yield() {
478 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
479 let mut stream = EventStream::<RightB>::new(rx);
480
481 for i in 0..(MAX_WRONG_TYPE_PER_POLL + 5) {
483 let msg: Arc<dyn Event> = Arc::new(WrongA { a: i as i32 });
484 tx.send(msg).unwrap();
485 }
486 let target = RightB { b: 99 };
487 let target_msg: Arc<dyn Event> = Arc::new(target.clone());
488 tx.send(target_msg).unwrap();
489
490 let got = tokio::time::timeout(std::time::Duration::from_secs(5), stream.next())
493 .await
494 .expect("re-arm must wake the stream after budget yield")
495 .expect("right-type event should be delivered");
496 assert_eq!(&*got, &target);
497 }
498}