Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{FuturesOrdered, SplitSink};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::future::Future;
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12use tokio_tungstenite::MaybeTlsStream;
13use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
14
15use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
16use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
17
18use crate::error::CdpError;
19use crate::error::Result;
20
21type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
22
23/// Exchanges the messages with the websocket
24#[must_use = "streams do nothing unless polled"]
25#[derive(Debug)]
26pub struct Connection<T: EventMessage> {
27    /// Queue of commands to send.
28    pending_commands: VecDeque<MethodCall>,
29    /// The websocket of the chromium instance
30    ws: WebSocketStream<ConnectStream>,
31    /// The identifier for a specific command
32    next_id: usize,
33    /// Whether the write buffer has unsent data that needs flushing.
34    needs_flush: bool,
35    /// The phantom marker.
36    _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40    /// Nagle's algorithm disabled?
41    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42        Ok(disable_nagle) => disable_nagle == "true",
43        _ => true
44    };
45    /// Websocket config defaults
46    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47        Ok(d) => d == "true",
48        _ => false
49    };
50}
51
52/// Default number of WebSocket connection retry attempts.
53pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
54
55/// Initial backoff delay between connection retries (in milliseconds).
56const INITIAL_BACKOFF_MS: u64 = 50;
57
58/// Maximum backoff delay between connection retries (in milliseconds).
59const MAX_BACKOFF_MS: u64 = 2_000;
60
61impl<T: EventMessage + Unpin> Connection<T> {
62    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
63        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
64    }
65
66    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
67        let mut config = WebSocketConfig::default();
68
69        // Cap the internal write buffer so a slow receiver cannot cause
70        // unbounded memory growth (default is usize::MAX).
71        config.max_write_buffer_size = 4 * 1024 * 1024;
72
73        if !*WEBSOCKET_DEFAULTS {
74            config.max_message_size = None;
75            config.max_frame_size = None;
76        }
77
78        let url = debug_ws_url.as_ref();
79        let use_uring = crate::uring_fs::is_enabled();
80        let mut last_err = None;
81
82        for attempt in 0..=retries {
83            let result = if use_uring {
84                Self::connect_uring(url, config).await
85            } else {
86                Self::connect_default(url, config).await
87            };
88
89            match result {
90                Ok(ws) => {
91                    return Ok(Self {
92                        pending_commands: Default::default(),
93                        ws,
94                        next_id: 0,
95                        needs_flush: false,
96                        _marker: Default::default(),
97                    });
98                }
99                Err(e) => {
100                    // Detect non-retriable errors early to avoid wasting time
101                    // on connections that will never succeed.
102                    let should_retry = match &e {
103                        // Connection refused — nothing is listening on this port.
104                        CdpError::Io(io_err)
105                            if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
106                        {
107                            false
108                        }
109                        // HTTP response to a WebSocket upgrade (e.g. wrong path
110                        // returns 404 / redirect) — retrying the same URL won't help.
111                        CdpError::Ws(tungstenite_err) => !matches!(
112                            tungstenite_err,
113                            tokio_tungstenite::tungstenite::Error::Http(_)
114                                | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
115                        ),
116                        _ => true,
117                    };
118
119                    last_err = Some(e);
120
121                    if !should_retry {
122                        break;
123                    }
124
125                    if attempt < retries {
126                        let backoff_ms =
127                            (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
128                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
129                    }
130                }
131            }
132        }
133
134        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
135    }
136
137    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
138    async fn connect_default(
139        url: &str,
140        config: WebSocketConfig,
141    ) -> Result<WebSocketStream<ConnectStream>> {
142        let (ws, _) =
143            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
144        Ok(ws)
145    }
146
147    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
148    /// handshake over the pre-connected stream.
149    async fn connect_uring(
150        url: &str,
151        config: WebSocketConfig,
152    ) -> Result<WebSocketStream<ConnectStream>> {
153        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
154
155        let request = url.into_client_request()?;
156        let host = request
157            .uri()
158            .host()
159            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
160        let port = request.uri().port_u16().unwrap_or(9222);
161
162        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
163        let addr_str = format!("{}:{}", host, port);
164        let addr: std::net::SocketAddr = match addr_str.parse() {
165            Ok(a) => a,
166            Err(_) => {
167                // Hostname needs DNS — fall back to default path.
168                return Self::connect_default(url, config).await;
169            }
170        };
171
172        // TCP connect via io_uring.
173        let std_stream = crate::uring_fs::tcp_connect(addr)
174            .await
175            .map_err(CdpError::Io)?;
176
177        // Set non-blocking + Nagle.
178        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
179        if *DISABLE_NAGLE {
180            let _ = std_stream.set_nodelay(true);
181        }
182
183        // Wrap in tokio TcpStream.
184        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
185
186        // WebSocket handshake over the pre-connected stream.
187        let (ws, _) = tokio_tungstenite::client_async_with_config(
188            request,
189            MaybeTlsStream::Plain(tokio_stream),
190            Some(config),
191        )
192        .await?;
193
194        Ok(ws)
195    }
196}
197
198impl<T: EventMessage> Connection<T> {
199    fn next_call_id(&mut self) -> CallId {
200        let id = CallId::new(self.next_id);
201        self.next_id = self.next_id.wrapping_add(1);
202        id
203    }
204
205    /// Queue in the command to send over the socket and return the id for this
206    /// command
207    pub fn submit_command(
208        &mut self,
209        method: MethodId,
210        session_id: Option<SessionId>,
211        params: serde_json::Value,
212    ) -> serde_json::Result<CallId> {
213        let id = self.next_call_id();
214        let call = MethodCall {
215            id,
216            method,
217            session_id: session_id.map(Into::into),
218            params,
219        };
220        self.pending_commands.push_back(call);
221        Ok(id)
222    }
223
224    /// Buffer all queued commands into the WebSocket sink, then flush once.
225    ///
226    /// This batches multiple CDP commands into a single TCP write instead of
227    /// flushing after every individual message.
228    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
229        // Complete any pending flush from a previous poll first.
230        if self.needs_flush {
231            match self.ws.poll_flush_unpin(cx) {
232                Poll::Ready(Ok(())) => self.needs_flush = false,
233                Poll::Ready(Err(e)) => return Err(e.into()),
234                Poll::Pending => return Ok(()),
235            }
236        }
237
238        // Buffer as many queued commands as the sink will accept.
239        let mut sent_any = false;
240        while !self.pending_commands.is_empty() {
241            match self.ws.poll_ready_unpin(cx) {
242                Poll::Ready(Ok(())) => {
243                    let Some(cmd) = self.pending_commands.pop_front() else {
244                        break;
245                    };
246                    tracing::trace!("Sending {:?}", cmd);
247                    let msg = serde_json::to_string(&cmd)?;
248                    self.ws.start_send_unpin(msg.into())?;
249                    sent_any = true;
250                }
251                _ => break,
252            }
253        }
254
255        // Flush the entire batch in one write.
256        if sent_any {
257            match self.ws.poll_flush_unpin(cx) {
258                Poll::Ready(Ok(())) => {}
259                Poll::Ready(Err(e)) => return Err(e.into()),
260                Poll::Pending => self.needs_flush = true,
261            }
262        }
263
264        Ok(())
265    }
266}
267
268/// Capacity of the bounded channel feeding the background WS writer task.
269/// Large enough that bursts of CDP commands never block the handler, small
270/// enough to apply back-pressure before memory grows without bound.
271const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
272
273/// Capacity of the bounded channel from the background WS reader task to
274/// the Handler. Keeps decoded CDP messages buffered so the reader task
275/// can keep reading the socket while the Handler processes a backlog;
276/// applies TCP-level back-pressure on Chrome when the Handler is slow
277/// (the reader awaits channel capacity, stops draining the socket).
278const WS_READ_CHANNEL_CAPACITY: usize = 1024;
279
280/// Maximum number of in-flight decodes the reader pipeline holds at
281/// once. While any of these is still running on the blocking pool,
282/// the reader can keep draining the socket and starting new decodes,
283/// up to this cap. Applies per-connection; the resulting decoded
284/// messages are emitted to the Handler in strict WS arrival order
285/// via a `FuturesOrdered` queue — no behavior change versus the
286/// serial loop, just concurrent execution of independent decodes.
287const MAX_IN_FLIGHT_DECODES: usize = 32;
288
289/// Payload size at/above which `decode_message` runs via
290/// `tokio::task::spawn_blocking` instead of inline on the reader task.
291///
292/// `serde_json::from_slice` is CPU-bound with no `.await` points, so
293/// a multi-MB payload can occupy one tokio worker thread for tens of
294/// milliseconds. Offloading to the blocking thread pool above a
295/// threshold keeps the reader task cooperatively yielding — critical
296/// on single-threaded runtimes where the reader shares its worker
297/// with the Handler, user tasks, and timers.
298///
299/// The threshold is chosen so that typical CDP traffic (events,
300/// responses, small evaluates) stays on the inline fast path and
301/// doesn't pay the ~10-30 µs `spawn_blocking` hand-off cost, while
302/// screenshot payloads, wide network events, and huge console
303/// payloads take the offloaded path.
304const LARGE_FRAME_THRESHOLD: usize = 256 * 1024; // 256 KiB
305
306/// Split parts returned by [`Connection::into_async`].
307#[derive(Debug)]
308pub struct AsyncConnection<T: EventMessage> {
309    /// Receive half for decoded CDP messages. Backed by a bounded mpsc
310    /// fed by a dedicated background reader task — decode runs on that
311    /// task, never on the Handler task, so large CDP responses (multi-MB
312    /// screenshots, huge event payloads) cannot stall the Handler's
313    /// event loop.
314    pub reader: WsReader<T>,
315    /// Sender half for submitting outgoing CDP commands.
316    pub cmd_tx: mpsc::Sender<MethodCall>,
317    /// Handle to the background writer task.
318    pub writer_handle: tokio::task::JoinHandle<Result<()>>,
319    /// Handle to the background reader task (reads + decodes WS frames).
320    pub reader_handle: tokio::task::JoinHandle<()>,
321    /// Next command-call-id counter (continue numbering from where Connection left off).
322    pub next_id: usize,
323}
324
325impl<T: EventMessage + Unpin + Send + 'static> Connection<T> {
326    /// Consume the connection and split into a background reader + writer
327    /// pair, exposing the Handler-facing ends via `AsyncConnection`.
328    ///
329    /// Two `tokio::spawn`'d tasks are created:
330    ///
331    /// * `ws_write_loop` — batches outgoing commands and flushes them in
332    ///   one write per wakeup.
333    /// * `ws_read_loop`  — reads WS frames, decodes them to typed
334    ///   `Message<T>`, and forwards them via a bounded mpsc to the
335    ///   Handler. Ping/pong/malformed frames are skipped on this task
336    ///   and never reach the Handler. Large-message decode (SerDe CPU
337    ///   work) runs here, **not** on the Handler task, so the Handler's
338    ///   poll loop never stalls for tens of milliseconds on a 10 MB
339    ///   screenshot response.
340    ///
341    /// The design uses only `tokio::spawn` (cooperative async) — no
342    /// `spawn_blocking` or blocking thread-pool — so it scales with the
343    /// tokio runtime's worker threads on multi-threaded runtimes, and
344    /// interleaves cleanly with the Handler task on single-threaded
345    /// runtimes.
346    pub fn into_async(self) -> AsyncConnection<T> {
347        let (ws_sink, ws_stream) = self.ws.split();
348        let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
349        let (msg_tx, msg_rx) = mpsc::channel::<Result<Box<Message<T>>>>(WS_READ_CHANNEL_CAPACITY);
350
351        let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
352        let reader_handle = tokio::spawn(ws_read_loop::<T, _>(ws_stream, msg_tx));
353
354        let reader = WsReader {
355            rx: msg_rx,
356            _marker: PhantomData,
357        };
358
359        AsyncConnection {
360            reader,
361            cmd_tx,
362            writer_handle,
363            reader_handle,
364            next_id: self.next_id,
365        }
366    }
367}
368
369/// An entry in the reader's decode pipeline.
370///
371/// Small frames have been decoded inline on the reader task and sit
372/// in `Ready(Some(result))` waiting their turn to emit — zero
373/// allocation beyond the `Option`. Large frames were offloaded to
374/// `tokio::task::spawn_blocking`, so their entry is the
375/// corresponding `JoinHandle`.
376///
377/// A single concrete enum means `FuturesOrdered<InFlightDecode<T>>`
378/// can hold either kind without `Box<dyn Future>`, keeping the
379/// pipeline cost-proportional to the workload.
380enum InFlightDecode<T: EventMessage + Send + 'static> {
381    /// Small-frame fast path: already decoded inline. `take()`'d
382    /// exactly once when `FuturesOrdered` first polls it to Ready.
383    Ready(Option<Result<Box<Message<T>>>>),
384    /// Large-frame path: decoding on the blocking thread pool.
385    Blocking(tokio::task::JoinHandle<Result<Box<Message<T>>>>),
386}
387
388impl<T: EventMessage + Send + 'static> Future for InFlightDecode<T> {
389    type Output = Result<Box<Message<T>>>;
390
391    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
392        // Safety: both variants are structurally pin-agnostic —
393        // `Option<Result<..>>` is `Unpin`, and `tokio::task::JoinHandle`
394        // is documented as `Unpin`. So we can project out a `&mut`
395        // without unsafe.
396        match self.get_mut() {
397            InFlightDecode::Ready(slot) => Poll::Ready(
398                slot.take()
399                    .expect("InFlightDecode::Ready polled after completion"),
400            ),
401            InFlightDecode::Blocking(handle) => match Pin::new(handle).poll(cx) {
402                Poll::Ready(Ok(res)) => Poll::Ready(res),
403                Poll::Ready(Err(join_err)) => Poll::Ready(Err(CdpError::msg(format!(
404                    "WS decode blocking task join error: {join_err}"
405                )))),
406                Poll::Pending => Poll::Pending,
407            },
408        }
409    }
410}
411
412/// Emit a single decoded-frame result to the Handler, logging parse
413/// errors. Returns `true` if the channel is still open, `false` if
414/// the Handler has dropped the receiver (caller should exit).
415async fn emit_decoded<T>(
416    tx: &mpsc::Sender<Result<Box<Message<T>>>>,
417    res: Result<Box<Message<T>>>,
418) -> bool
419where
420    T: EventMessage + Send + 'static,
421{
422    match res {
423        Ok(msg) => tx.send(Ok(msg)).await.is_ok(),
424        Err(err) => {
425            tracing::debug!(
426                target: "chromiumoxide::conn::raw_ws::parse_errors",
427                "Dropping malformed WS frame: {err}",
428            );
429            true
430        }
431    }
432}
433
434/// Background task that reads frames from the WebSocket, decodes them to
435/// typed CDP `Message<T>`, and forwards them to the Handler over a
436/// bounded mpsc.
437///
438/// Runs on a `tokio::spawn`'d task. Small-to-medium frames are
439/// decoded inline (fast path); payloads at or above
440/// [`LARGE_FRAME_THRESHOLD`] are offloaded to `spawn_blocking` so
441/// multi-MB deserialization doesn't monopolise a tokio worker
442/// thread — especially important on single-threaded runtimes where
443/// the reader, Handler, and user tasks share the same worker.
444///
445/// Flow per frame:
446///
447/// * `Text` / `Binary` → [`decode_ws_frame`]; decoded `Ok(msg)` is
448///   sent to the Handler. Decode errors are logged and the frame is
449///   dropped (same behavior as the legacy inline decode path).
450/// * `Close` → loop exits cleanly, dropping `tx`. The Handler's
451///   `next_message().await` returns `None` on the next call.
452/// * `Ping` / `Pong` / unexpected frame types → skipped silently; they
453///   never cross the channel to the Handler.
454/// * Transport error → forwarded as `Err(CdpError::Ws(..))`, then the
455///   loop exits (the WS half is considered dead after an error).
456///
457/// Back-pressure: the outbound `tx` is bounded. If the Handler is busy
458/// and the channel fills, `tx.send(..).await` parks this task, which
459/// stops draining the WS socket. TCP flow control then applies
460/// back-pressure to Chrome instead of letting memory grow without bound.
461async fn ws_read_loop<T, S>(mut stream: S, tx: mpsc::Sender<Result<Box<Message<T>>>>)
462where
463    T: EventMessage + Send + 'static,
464    S: Stream<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
465        + Unpin,
466{
467    // Pipeline of decodes in strict arrival order. Small-frame decodes
468    // are produced inline (zero allocation, borrowing the frame body);
469    // large-frame decodes are offloaded to `spawn_blocking`. Both
470    // variants share a single concrete `InFlightDecode<T>` so the
471    // queue avoids `Box<dyn Future>` overhead.
472    let mut in_flight: FuturesOrdered<InFlightDecode<T>> = FuturesOrdered::new();
473
474    // Shutdown state. When the stream signals `Close`, transport
475    // error, or end-of-stream, we stop reading new frames but keep
476    // running the select loop so the emit arm can flush any still
477    // in-flight decodes *interleaved with* whatever else the runtime
478    // is doing. A pending transport error is surfaced to the Handler
479    // only after the in-order flush completes.
480    let mut stream_terminated = false;
481    let mut pending_err: Option<CdpError> = None;
482
483    loop {
484        tokio::select! {
485            // Bias: emit already-ready decodes before reading more
486            // frames. Keeps the pipeline small in the steady state
487            // while still allowing concurrency under burst, and —
488            // critically during shutdown — drains the pipeline one
489            // ready item at a time inside the select loop instead
490            // of blocking in a dedicated drain helper.
491            biased;
492
493            // Emit the head of the pipeline as soon as it is ready.
494            // `FuturesOrdered::next` preserves submit order, so
495            // downstream delivery is byte-identical to the serial
496            // loop's ordering guarantee.
497            Some(res) = in_flight.next(), if !in_flight.is_empty() => {
498                if !emit_decoded(&tx, res).await {
499                    return;
500                }
501            }
502
503            // Read the next frame if the pipeline has capacity and
504            // the stream hasn't terminated. Disabled once the stream
505            // signals end (Close / None / Err) so subsequent loop
506            // iterations only do emit work.
507            maybe_frame = stream.next(),
508                if !stream_terminated && in_flight.len() < MAX_IN_FLIGHT_DECODES =>
509            {
510                match maybe_frame {
511                    Some(Ok(WsMessage::Text(text))) => {
512                        // Zero-copy enqueue. The small-frame fast
513                        // path decodes inline *now* (borrowing
514                        // `text`, keeping the `raw_text_for_logging`
515                        // preview); the large-frame path moves the
516                        // `Utf8Bytes` (`Send + 'static`) directly
517                        // into `spawn_blocking` without an
518                        // intermediate allocation.
519                        if text.len() >= LARGE_FRAME_THRESHOLD {
520                            in_flight.push_back(InFlightDecode::Blocking(
521                                tokio::task::spawn_blocking(move || {
522                                    decode_message::<T>(text.as_bytes(), None)
523                                }),
524                            ));
525                        } else {
526                            let res = decode_message::<T>(text.as_bytes(), Some(&text));
527                            in_flight.push_back(InFlightDecode::Ready(Some(res)));
528                        }
529                    }
530                    Some(Ok(WsMessage::Binary(buf))) => {
531                        // Same shape as Text: move `Bytes`
532                        // (`Send + 'static`) into `spawn_blocking`
533                        // for large payloads, decode inline for
534                        // small ones.
535                        if buf.len() >= LARGE_FRAME_THRESHOLD {
536                            in_flight.push_back(InFlightDecode::Blocking(
537                                tokio::task::spawn_blocking(move || {
538                                    decode_message::<T>(&buf, None)
539                                }),
540                            ));
541                        } else {
542                            let res = decode_message::<T>(&buf, None);
543                            in_flight.push_back(InFlightDecode::Ready(Some(res)));
544                        }
545                    }
546                    Some(Ok(WsMessage::Close(_))) => {
547                        stream_terminated = true;
548                    }
549                    Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {}
550                    Some(Ok(msg)) => {
551                        tracing::debug!(
552                            target: "chromiumoxide::conn::raw_ws::parse_errors",
553                            "Unexpected WS message type: {:?}",
554                            msg
555                        );
556                    }
557                    Some(Err(err)) => {
558                        // Defer the error until after the already
559                        // in-flight decodes have emitted — preserves
560                        // the ordering contract that callers see
561                        // frames up to the failure point before the
562                        // error itself.
563                        stream_terminated = true;
564                        pending_err = Some(CdpError::Ws(err));
565                    }
566                    None => {
567                        // Stream ended (connection closed without a
568                        // `Close` frame). No more input, but
569                        // in_flight may still hold pending decodes.
570                        stream_terminated = true;
571                    }
572                }
573            }
574
575            // Both arms disabled: `in_flight` is empty AND
576            // `stream_terminated`. We have nothing more to do.
577            else => {
578                break;
579            }
580        }
581    }
582
583    if let Some(err) = pending_err {
584        let _ = tx.send(Err(err)).await;
585    }
586}
587
588/// Background task that batches and flushes outgoing CDP commands.
589async fn ws_write_loop(
590    mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
591    mut rx: mpsc::Receiver<MethodCall>,
592) -> Result<()> {
593    while let Some(call) = rx.recv().await {
594        let msg = crate::serde_json::to_string(&call)?;
595        sink.feed(WsMessage::Text(msg.into()))
596            .await
597            .map_err(CdpError::Ws)?;
598
599        // Batch: drain all buffered commands without waiting.
600        while let Ok(call) = rx.try_recv() {
601            let msg = crate::serde_json::to_string(&call)?;
602            sink.feed(WsMessage::Text(msg.into()))
603                .await
604                .map_err(CdpError::Ws)?;
605        }
606
607        // Flush the entire batch in one write.
608        sink.flush().await.map_err(CdpError::Ws)?;
609    }
610    Ok(())
611}
612
613/// Handler-facing read half of the split WebSocket connection.
614///
615/// Decoded CDP messages are produced by a dedicated background task
616/// (see [`ws_read_loop`]) and forwarded over a bounded mpsc. `WsReader`
617/// itself is a thin `Receiver` wrapper — calling `next_message()` does
618/// a single `rx.recv().await` with no per-message decoding work on the
619/// caller's task. This keeps the Handler's poll loop free of CPU-bound
620/// deserialize time, which matters for large (multi-MB) CDP responses
621/// such as screenshots and wide-header network events.
622#[derive(Debug)]
623pub struct WsReader<T: EventMessage> {
624    rx: mpsc::Receiver<Result<Box<Message<T>>>>,
625    _marker: PhantomData<T>,
626}
627
628impl<T: EventMessage + Unpin> WsReader<T> {
629    /// Read the next CDP message from the WebSocket.
630    ///
631    /// Returns `None` when the background reader task has exited
632    /// (connection closed or sender dropped). This call does only a
633    /// channel `recv` — the actual WS read + JSON decode happens on
634    /// the background `ws_read_loop` task.
635    pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
636        self.rx.recv().await
637    }
638}
639
640impl<T: EventMessage + Unpin> Stream for Connection<T> {
641    type Item = Result<Box<Message<T>>>;
642
643    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
644        let pin = self.get_mut();
645
646        // Send and flush outgoing messages
647        if let Err(err) = pin.start_send_next(cx) {
648            return Poll::Ready(Some(Err(err)));
649        }
650
651        // Read from the websocket, skipping non-data frames (pings,
652        // pongs, malformed messages) without yielding back to the
653        // executor.  This avoids a full round-trip per skipped frame.
654        //
655        // Cap consecutive skips so a flood of non-data frames (many
656        // pings, malformed/unexpected types) cannot starve the
657        // runtime — yield Pending after `MAX_SKIPS_PER_POLL` and
658        // self-wake so we resume on the next tick.
659        const MAX_SKIPS_PER_POLL: u32 = 16;
660        let mut skips: u32 = 0;
661        loop {
662            match ready!(pin.ws.poll_next_unpin(cx)) {
663                Some(Ok(WsMessage::Text(text))) => {
664                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
665                        Ok(msg) => return Poll::Ready(Some(Ok(msg))),
666                        Err(err) => {
667                            tracing::debug!(
668                                target: "chromiumoxide::conn::raw_ws::parse_errors",
669                                "Dropping malformed text WS frame: {err}",
670                            );
671                            skips += 1;
672                        }
673                    }
674                }
675                Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
676                    Ok(msg) => return Poll::Ready(Some(Ok(msg))),
677                    Err(err) => {
678                        tracing::debug!(
679                            target: "chromiumoxide::conn::raw_ws::parse_errors",
680                            "Dropping malformed binary WS frame: {err}",
681                        );
682                        skips += 1;
683                    }
684                },
685                Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
686                Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
687                    skips += 1;
688                }
689                Some(Ok(msg)) => {
690                    tracing::debug!(
691                        target: "chromiumoxide::conn::raw_ws::parse_errors",
692                        "Unexpected WS message type: {:?}",
693                        msg
694                    );
695                    skips += 1;
696                }
697                Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
698                None => return Poll::Ready(None),
699            }
700
701            if skips >= MAX_SKIPS_PER_POLL {
702                cx.waker().wake_by_ref();
703                return Poll::Pending;
704            }
705        }
706    }
707}
708
709/// Shared decode path for both text and binary WS frames.
710/// `raw_text_for_logging` is only provided for textual frames so we can log the original
711/// payload on parse failure if desired.
712#[cfg(not(feature = "serde_stacker"))]
713fn decode_message<T: EventMessage>(
714    bytes: &[u8],
715    raw_text_for_logging: Option<&str>,
716) -> Result<Box<Message<T>>> {
717    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
718        Ok(msg) => {
719            tracing::trace!("Received {:?}", msg);
720            Ok(msg)
721        }
722        Err(err) => {
723            if let Some(txt) = raw_text_for_logging {
724                let preview = &txt[..txt.len().min(512)];
725                tracing::debug!(
726                    target: "chromiumoxide::conn::raw_ws::parse_errors",
727                    msg_len = txt.len(),
728                    "Skipping unrecognized WS message {err} preview={preview}",
729                );
730            } else {
731                tracing::debug!(
732                    target: "chromiumoxide::conn::raw_ws::parse_errors",
733                    "Skipping unrecognized binary WS message {err}",
734                );
735            }
736            Err(err.into())
737        }
738    }
739}
740
741/// Shared decode path for both text and binary WS frames.
742/// `raw_text_for_logging` is only provided for textual frames so we can log the original
743/// payload on parse failure if desired.
744#[cfg(feature = "serde_stacker")]
745fn decode_message<T: EventMessage>(
746    bytes: &[u8],
747    raw_text_for_logging: Option<&str>,
748) -> Result<Box<Message<T>>> {
749    use serde::Deserialize;
750    let mut de = serde_json::Deserializer::from_slice(bytes);
751
752    de.disable_recursion_limit();
753
754    let de = serde_stacker::Deserializer::new(&mut de);
755
756    match Box::<Message<T>>::deserialize(de) {
757        Ok(msg) => {
758            tracing::trace!("Received {:?}", msg);
759            Ok(msg)
760        }
761        Err(err) => {
762            if let Some(txt) = raw_text_for_logging {
763                let preview = &txt[..txt.len().min(512)];
764                tracing::debug!(
765                    target: "chromiumoxide::conn::raw_ws::parse_errors",
766                    msg_len = txt.len(),
767                    "Skipping unrecognized WS message {err} preview={preview}",
768                );
769            } else {
770                tracing::debug!(
771                    target: "chromiumoxide::conn::raw_ws::parse_errors",
772                    "Skipping unrecognized binary WS message {err}",
773                );
774            }
775            Err(err.into())
776        }
777    }
778}
779
780#[cfg(test)]
781mod ws_read_loop_tests {
782    //! Unit tests for the `ws_read_loop` background reader task.
783    //!
784    //! These tests feed a synthetic `Stream<Item = Result<WsMessage, _>>`
785    //! into `ws_read_loop` — no real WebSocket, no Chrome — and observe
786    //! what comes out the other side of the mpsc channel.
787    //!
788    //! The properties under test are the ones that make the reader-task
789    //! decoupling safe: FIFO ordering, no-deadlock on a bounded channel
790    //! under back-pressure, silent drop of non-data frames, graceful
791    //! transport-error propagation, and clean exit on `Close`.
792    //!
793    //! The typed events are `chromiumoxide_cdp::cdp::CdpEventMessage` —
794    //! the same instantiation the real Handler uses — so these tests
795    //! exercise the actual decode path (`serde_json::from_slice`), not
796    //! a simplified fake.
797    use super::*;
798    use chromiumoxide_cdp::cdp::CdpEventMessage;
799    use chromiumoxide_types::CallId;
800    use futures_util::stream;
801    use tokio::sync::mpsc;
802    use tokio_tungstenite::tungstenite::Message as WsMessage;
803
804    /// Build a CDP `Response` WS frame as text — the smallest valid CDP
805    /// message. `id` tags the frame for ordering assertions.
806    fn response_frame(id: u64) -> WsMessage {
807        WsMessage::Text(
808            format!(r#"{{"id":{id},"result":{{"ok":true}}}}"#)
809                .to_string()
810                .into(),
811        )
812    }
813
814    /// Build a frame far larger than a typical socket chunk, to exercise
815    /// the "large message" path that motivated this refactor. The blob
816    /// field pushes serde_json through a big allocation even though the
817    /// envelope is tiny.
818    fn large_response_frame(id: u64, blob_bytes: usize) -> WsMessage {
819        let blob = "x".repeat(blob_bytes);
820        WsMessage::Text(
821            format!(r#"{{"id":{id},"result":{{"blob":"{blob}"}}}}"#)
822                .to_string()
823                .into(),
824        )
825    }
826
827    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
828    async fn forwards_messages_in_stream_order() {
829        let frames = vec![
830            Ok(response_frame(1)),
831            Ok(response_frame(2)),
832            Ok(response_frame(3)),
833        ];
834        let stream = stream::iter(frames);
835        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
836        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
837
838        for expected in [1u64, 2, 3] {
839            let msg = rx.recv().await.expect("msg").expect("decode ok");
840            if let Message::Response(resp) = *msg {
841                assert_eq!(resp.id, CallId::new(expected as usize));
842            } else {
843                panic!("expected Response");
844            }
845        }
846        assert!(rx.recv().await.is_none(), "channel must close on EOF");
847        task.await.expect("reader task join");
848    }
849
850    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
851    async fn pings_and_pongs_never_reach_the_handler() {
852        let frames = vec![
853            Ok(WsMessage::Ping(vec![1, 2, 3].into())),
854            Ok(response_frame(7)),
855            Ok(WsMessage::Pong(vec![].into())),
856            Ok(response_frame(8)),
857        ];
858        let stream = stream::iter(frames);
859        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
860        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
861
862        for expected in [7u64, 8] {
863            let msg = rx.recv().await.expect("msg").expect("decode ok");
864            if let Message::Response(resp) = *msg {
865                assert_eq!(resp.id, CallId::new(expected as usize));
866            }
867        }
868        assert!(rx.recv().await.is_none());
869        task.await.expect("reader task join");
870    }
871
872    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
873    async fn malformed_frames_do_not_block_subsequent_valid_frames() {
874        let frames = vec![
875            Ok(WsMessage::Text("{not valid json".to_string().into())),
876            Ok(response_frame(42)),
877        ];
878        let stream = stream::iter(frames);
879        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
880        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
881
882        let msg = rx.recv().await.expect("msg").expect("decode ok");
883        if let Message::Response(resp) = *msg {
884            assert_eq!(resp.id, CallId::new(42));
885        }
886        assert!(rx.recv().await.is_none());
887        task.await.expect("reader task join");
888    }
889
890    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
891    async fn close_frame_terminates_the_reader() {
892        let frames = vec![
893            Ok(response_frame(1)),
894            Ok(WsMessage::Close(None)),
895            Ok(response_frame(2)), // unreachable after Close
896        ];
897        let stream = stream::iter(frames);
898        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
899        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
900
901        let msg = rx.recv().await.expect("msg").expect("decode ok");
902        if let Message::Response(resp) = *msg {
903            assert_eq!(resp.id, CallId::new(1));
904        }
905        assert!(
906            rx.recv().await.is_none(),
907            "reader must exit on Close; frames after Close must not appear"
908        );
909        task.await.expect("reader task join");
910    }
911
912    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
913    async fn transport_error_is_forwarded_once_then_reader_exits() {
914        let frames = vec![
915            Ok(response_frame(1)),
916            Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed),
917            Ok(response_frame(2)),
918        ];
919        let stream = stream::iter(frames);
920        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
921        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
922
923        let msg = rx.recv().await.expect("msg").expect("ok");
924        assert!(matches!(*msg, Message::Response(_)));
925        match rx.recv().await {
926            Some(Err(CdpError::Ws(_))) => {}
927            other => panic!("expected forwarded Ws error, got {other:?}"),
928        }
929        assert!(rx.recv().await.is_none());
930        task.await.expect("reader task join");
931    }
932
933    /// Back-pressure property: with the smallest possible channel and
934    /// many frames, the reader task awaits capacity after each send and
935    /// never deadlocks. This is the core "no deadlock" proof for the
936    /// new design — if the reader held anything across its `.await` that
937    /// the consumer needed, the consumer's `recv().await` would block
938    /// forever. Completion under a 5s watchdog proves it doesn't.
939    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
940    async fn bounded_channel_does_not_deadlock_under_backpressure() {
941        const N: u64 = 512;
942        let frames: Vec<_> = (1..=N).map(|id| Ok(response_frame(id))).collect();
943        let stream = stream::iter(frames);
944
945        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(1);
946        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
947
948        let deadline = std::time::Duration::from_secs(5);
949        let collected = tokio::time::timeout(deadline, async {
950            let mut seen = 0u64;
951            while let Some(frame) = rx.recv().await {
952                let msg = frame.expect("decode ok");
953                if let Message::Response(resp) = *msg {
954                    seen += 1;
955                    assert_eq!(
956                        resp.id,
957                        CallId::new(seen as usize),
958                        "back-pressure must preserve FIFO order"
959                    );
960                }
961            }
962            seen
963        })
964        .await
965        .expect("reader must make forward progress despite cap-1 back-pressure");
966
967        assert_eq!(collected, N, "all frames must arrive");
968        task.await.expect("reader task join");
969    }
970
971    /// Large message (>1 MB) is decoded correctly on the background
972    /// task. This is the specific scenario the reader-task refactor
973    /// was built for — we don't measure time here (benches cover that),
974    /// we just prove the end-to-end path works without corruption or
975    /// deadlock.
976    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
977    async fn large_message_decodes_without_corruption() {
978        let big = 2 * 1024 * 1024; // 2 MB payload
979        let frames = vec![
980            Ok(large_response_frame(100, big)),
981            Ok(response_frame(101)),
982        ];
983        let stream = stream::iter(frames);
984        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(4);
985        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
986
987        let first = rx.recv().await.expect("msg").expect("ok");
988        if let Message::Response(resp) = *first {
989            assert_eq!(resp.id, CallId::new(100));
990        }
991        let second = rx.recv().await.expect("msg").expect("ok");
992        if let Message::Response(resp) = *second {
993            assert_eq!(resp.id, CallId::new(101));
994        }
995        assert!(rx.recv().await.is_none());
996        task.await.expect("reader task join");
997    }
998
999    /// FIFO ordering under the pipelined reader when large-frame
1000    /// decodes run in parallel via `spawn_blocking`.
1001    ///
1002    /// This test submits an interleaved sequence of large and small
1003    /// frames. Large frames take the `spawn_blocking` path (decode
1004    /// on the blocking pool, variable completion order); small
1005    /// frames take the inline path (decode immediately). The
1006    /// pipeline's `FuturesOrdered` queue must emit them to the
1007    /// Handler in strict arrival order regardless of which
1008    /// blocking-pool thread finishes first.
1009    ///
1010    /// If the ordering guarantee were ever broken — e.g. by
1011    /// accidentally swapping `FuturesOrdered` for `FuturesUnordered`
1012    /// — id sequence checks here would catch it immediately.
1013    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1014    async fn pipelined_large_and_small_frames_keep_fifo_order() {
1015        let big = 2 * 1024 * 1024; // 2 MB payload — forces spawn_blocking
1016        let frames = vec![
1017            Ok(large_response_frame(1, big)),
1018            Ok(response_frame(2)),
1019            Ok(response_frame(3)),
1020            Ok(large_response_frame(4, big)),
1021            Ok(response_frame(5)),
1022            Ok(large_response_frame(6, big)),
1023            Ok(response_frame(7)),
1024            Ok(response_frame(8)),
1025        ];
1026        let expected: Vec<usize> = (1..=8).collect();
1027
1028        let stream = stream::iter(frames);
1029        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(16);
1030        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
1031
1032        let deadline = std::time::Duration::from_secs(10);
1033        let observed = tokio::time::timeout(deadline, async {
1034            let mut ids = Vec::with_capacity(expected.len());
1035            while let Some(frame) = rx.recv().await {
1036                let msg = frame.expect("decode ok");
1037                if let Message::Response(resp) = *msg {
1038                    ids.push(CallId::new(ids.len() + 1));
1039                    assert_eq!(
1040                        resp.id,
1041                        *ids.last().unwrap(),
1042                        "pipelined reader must emit frames in strict arrival order \
1043                         regardless of per-frame decode latency"
1044                    );
1045                }
1046            }
1047            ids
1048        })
1049        .await
1050        .expect("pipelined reader should make forward progress within 10s");
1051
1052        assert_eq!(
1053            observed.len(),
1054            expected.len(),
1055            "all {} frames must reach the Handler",
1056            expected.len()
1057        );
1058        task.await.expect("reader task join");
1059    }
1060}