chromiumoxide/conn.rs
1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{FuturesOrdered, SplitSink};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::future::Future;
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12use tokio_tungstenite::MaybeTlsStream;
13use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
14
15use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
16use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
17
18use crate::error::CdpError;
19use crate::error::Result;
20
21type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
22
23/// Exchanges the messages with the websocket
24#[must_use = "streams do nothing unless polled"]
25#[derive(Debug)]
26pub struct Connection<T: EventMessage> {
27 /// Queue of commands to send.
28 pending_commands: VecDeque<MethodCall>,
29 /// The websocket of the chromium instance
30 ws: WebSocketStream<ConnectStream>,
31 /// The identifier for a specific command
32 next_id: usize,
33 /// Whether the write buffer has unsent data that needs flushing.
34 needs_flush: bool,
35 /// The phantom marker.
36 _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40 /// Nagle's algorithm disabled?
41 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42 Ok(disable_nagle) => disable_nagle == "true",
43 _ => true
44 };
45 /// Websocket config defaults
46 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47 Ok(d) => d == "true",
48 _ => false
49 };
50}
51
52/// Default number of WebSocket connection retry attempts.
53pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
54
55/// Initial backoff delay between connection retries (in milliseconds).
56const INITIAL_BACKOFF_MS: u64 = 50;
57
58/// Maximum backoff delay between connection retries (in milliseconds).
59pub(crate) const MAX_BACKOFF_MS: u64 = 2_000;
60
61impl<T: EventMessage + Unpin> Connection<T> {
62 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
63 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
64 }
65
66 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
67 let mut config = WebSocketConfig::default();
68
69 // Cap the internal write buffer so a slow receiver cannot cause
70 // unbounded memory growth (default is usize::MAX).
71 config.max_write_buffer_size = 4 * 1024 * 1024;
72
73 if !*WEBSOCKET_DEFAULTS {
74 config.max_message_size = None;
75 config.max_frame_size = None;
76 }
77
78 let url = debug_ws_url.as_ref();
79 let use_uring = crate::uring_fs::is_enabled();
80 let mut last_err = None;
81
82 for attempt in 0..=retries {
83 let result = if use_uring {
84 Self::connect_uring(url, config).await
85 } else {
86 Self::connect_default(url, config).await
87 };
88
89 match result {
90 Ok(ws) => {
91 return Ok(Self {
92 pending_commands: Default::default(),
93 ws,
94 next_id: 0,
95 needs_flush: false,
96 _marker: Default::default(),
97 });
98 }
99 Err(e) => {
100 // Detect non-retriable errors early to avoid wasting time
101 // on connections that will never succeed.
102 let should_retry = match &e {
103 // Connection refused — nothing is listening on this port.
104 CdpError::Io(io_err)
105 if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
106 {
107 false
108 }
109 // HTTP response to a WebSocket upgrade (e.g. wrong path
110 // returns 404 / redirect) — retrying the same URL won't help.
111 CdpError::Ws(tungstenite_err) => !matches!(
112 tungstenite_err,
113 tokio_tungstenite::tungstenite::Error::Http(_)
114 | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
115 ),
116 _ => true,
117 };
118
119 last_err = Some(e);
120
121 if !should_retry {
122 break;
123 }
124
125 if attempt < retries {
126 let backoff_ms =
127 (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
128 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
129 }
130 }
131 }
132 }
133
134 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
135 }
136
137 /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
138 async fn connect_default(
139 url: &str,
140 config: WebSocketConfig,
141 ) -> Result<WebSocketStream<ConnectStream>> {
142 let (ws, _) =
143 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
144 Ok(ws)
145 }
146
147 /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
148 /// handshake over the pre-connected stream.
149 async fn connect_uring(
150 url: &str,
151 config: WebSocketConfig,
152 ) -> Result<WebSocketStream<ConnectStream>> {
153 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
154
155 let request = url.into_client_request()?;
156 let host = request
157 .uri()
158 .host()
159 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
160 let port = request.uri().port_u16().unwrap_or(9222);
161
162 // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
163 let addr_str = format!("{}:{}", host, port);
164 let addr: std::net::SocketAddr = match addr_str.parse() {
165 Ok(a) => a,
166 Err(_) => {
167 // Hostname needs DNS — fall back to default path.
168 return Self::connect_default(url, config).await;
169 }
170 };
171
172 // TCP connect via io_uring.
173 let std_stream = crate::uring_fs::tcp_connect(addr)
174 .await
175 .map_err(CdpError::Io)?;
176
177 // Set non-blocking + Nagle.
178 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
179 if *DISABLE_NAGLE {
180 let _ = std_stream.set_nodelay(true);
181 }
182
183 // Wrap in tokio TcpStream.
184 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
185
186 // WebSocket handshake over the pre-connected stream.
187 let (ws, _) = tokio_tungstenite::client_async_with_config(
188 request,
189 MaybeTlsStream::Plain(tokio_stream),
190 Some(config),
191 )
192 .await?;
193
194 Ok(ws)
195 }
196}
197
198impl<T: EventMessage> Connection<T> {
199 fn next_call_id(&mut self) -> CallId {
200 let id = CallId::new(self.next_id);
201 self.next_id = self.next_id.wrapping_add(1);
202 id
203 }
204
205 /// Queue in the command to send over the socket and return the id for this
206 /// command
207 pub fn submit_command(
208 &mut self,
209 method: MethodId,
210 session_id: Option<SessionId>,
211 params: serde_json::Value,
212 ) -> serde_json::Result<CallId> {
213 let id = self.next_call_id();
214 let call = MethodCall {
215 id,
216 method,
217 session_id: session_id.map(Into::into),
218 params,
219 };
220 self.pending_commands.push_back(call);
221 Ok(id)
222 }
223
224 /// Buffer all queued commands into the WebSocket sink, then flush once.
225 ///
226 /// This batches multiple CDP commands into a single TCP write instead of
227 /// flushing after every individual message.
228 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
229 // Complete any pending flush from a previous poll first.
230 if self.needs_flush {
231 match self.ws.poll_flush_unpin(cx) {
232 Poll::Ready(Ok(())) => self.needs_flush = false,
233 Poll::Ready(Err(e)) => return Err(e.into()),
234 Poll::Pending => return Ok(()),
235 }
236 }
237
238 // Buffer as many queued commands as the sink will accept.
239 let mut sent_any = false;
240 while !self.pending_commands.is_empty() {
241 match self.ws.poll_ready_unpin(cx) {
242 Poll::Ready(Ok(())) => {
243 let Some(cmd) = self.pending_commands.pop_front() else {
244 break;
245 };
246 tracing::trace!("Sending {:?}", cmd);
247 let msg = serde_json::to_string(&cmd)?;
248 self.ws.start_send_unpin(msg.into())?;
249 sent_any = true;
250 }
251 _ => break,
252 }
253 }
254
255 // Flush the entire batch in one write.
256 if sent_any {
257 match self.ws.poll_flush_unpin(cx) {
258 Poll::Ready(Ok(())) => {}
259 Poll::Ready(Err(e)) => return Err(e.into()),
260 Poll::Pending => self.needs_flush = true,
261 }
262 }
263
264 Ok(())
265 }
266}
267
268/// Capacity of the bounded channel feeding the background WS writer task.
269/// Large enough that bursts of CDP commands never block the handler, small
270/// enough to apply back-pressure before memory grows without bound.
271const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
272
273/// Capacity of the bounded channel from the background WS reader task to
274/// the Handler. Keeps decoded CDP messages buffered so the reader task
275/// can keep reading the socket while the Handler processes a backlog;
276/// applies TCP-level back-pressure on Chrome when the Handler is slow
277/// (the reader awaits channel capacity, stops draining the socket).
278const WS_READ_CHANNEL_CAPACITY: usize = 1024;
279
280/// Maximum number of in-flight decodes the reader pipeline holds at
281/// once. While any of these is still running on the blocking pool,
282/// the reader can keep draining the socket and starting new decodes,
283/// up to this cap. Applies per-connection; the resulting decoded
284/// messages are emitted to the Handler in strict WS arrival order
285/// via a `FuturesOrdered` queue — no behavior change versus the
286/// serial loop, just concurrent execution of independent decodes.
287const MAX_IN_FLIGHT_DECODES: usize = 32;
288
289/// Payload size at/above which `decode_message` runs via
290/// `tokio::task::spawn_blocking` instead of inline on the reader task.
291///
292/// `serde_json::from_slice` is CPU-bound with no `.await` points, so
293/// a multi-MB payload can occupy one tokio worker thread for tens of
294/// milliseconds. Offloading to the blocking thread pool above a
295/// threshold keeps the reader task cooperatively yielding — critical
296/// on single-threaded runtimes where the reader shares its worker
297/// with the Handler, user tasks, and timers.
298///
299/// The threshold is chosen so that typical CDP traffic (events,
300/// responses, small evaluates) stays on the inline fast path and
301/// doesn't pay the ~10-30 µs `spawn_blocking` hand-off cost, while
302/// screenshot payloads, wide network events, and huge console
303/// payloads take the offloaded path.
304const LARGE_FRAME_THRESHOLD: usize = 256 * 1024; // 256 KiB
305
306/// Split parts returned by [`Connection::into_async`].
307#[derive(Debug)]
308pub struct AsyncConnection<T: EventMessage> {
309 /// Receive half for decoded CDP messages. Backed by a bounded mpsc
310 /// fed by a dedicated background reader task — decode runs on that
311 /// task, never on the Handler task, so large CDP responses (multi-MB
312 /// screenshots, huge event payloads) cannot stall the Handler's
313 /// event loop.
314 pub reader: WsReader<T>,
315 /// Sender half for submitting outgoing CDP commands.
316 pub cmd_tx: mpsc::Sender<MethodCall>,
317 /// Handle to the background writer task.
318 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
319 /// Handle to the background reader task (reads + decodes WS frames).
320 pub reader_handle: tokio::task::JoinHandle<()>,
321 /// Next command-call-id counter (continue numbering from where Connection left off).
322 pub next_id: usize,
323}
324
325impl<T: EventMessage + Unpin + Send + 'static> Connection<T> {
326 /// Consume the connection and split into a background reader + writer
327 /// pair, exposing the Handler-facing ends via `AsyncConnection`.
328 ///
329 /// Two `tokio::spawn`'d tasks are created:
330 ///
331 /// * `ws_write_loop` — batches outgoing commands and flushes them in
332 /// one write per wakeup.
333 /// * `ws_read_loop` — reads WS frames, decodes them to typed
334 /// `Message<T>`, and forwards them via a bounded mpsc to the
335 /// Handler. Ping/pong/malformed frames are skipped on this task
336 /// and never reach the Handler. Large-message decode (SerDe CPU
337 /// work) runs here, **not** on the Handler task, so the Handler's
338 /// poll loop never stalls for tens of milliseconds on a 10 MB
339 /// screenshot response.
340 ///
341 /// The design uses only `tokio::spawn` (cooperative async) — no
342 /// `spawn_blocking` or blocking thread-pool — so it scales with the
343 /// tokio runtime's worker threads on multi-threaded runtimes, and
344 /// interleaves cleanly with the Handler task on single-threaded
345 /// runtimes.
346 pub fn into_async(self) -> AsyncConnection<T> {
347 let (ws_sink, ws_stream) = self.ws.split();
348 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
349 let (msg_tx, msg_rx) = mpsc::channel::<Result<Box<Message<T>>>>(WS_READ_CHANNEL_CAPACITY);
350
351 // Replay any commands queued via `submit_command` before the
352 // split — most notably the boot `Target.setDiscoverTargets`
353 // pushed by `Handler::new`. Without this, real Chrome never
354 // emits `Target.targetCreated` and `new_page` hangs forever.
355 // Capacity is `WS_CMD_CHANNEL_CAPACITY`, so the boot batch fits
356 // easily — `try_send` would only fail in a pathological case
357 // and we'd lose those commands either way.
358 for call in self.pending_commands {
359 let _ = cmd_tx.try_send(call);
360 }
361
362 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
363 let reader_handle = tokio::spawn(ws_read_loop::<T, _>(ws_stream, msg_tx));
364
365 let reader = WsReader {
366 rx: msg_rx,
367 _marker: PhantomData,
368 };
369
370 AsyncConnection {
371 reader,
372 cmd_tx,
373 writer_handle,
374 reader_handle,
375 next_id: self.next_id,
376 }
377 }
378}
379
380/// An entry in the reader's decode pipeline.
381///
382/// Small frames have been decoded inline on the reader task and sit
383/// in `Ready(Some(result))` waiting their turn to emit — zero
384/// allocation beyond the `Option`. Large frames were offloaded to
385/// `tokio::task::spawn_blocking`, so their entry is the
386/// corresponding `JoinHandle`.
387///
388/// A single concrete enum means `FuturesOrdered<InFlightDecode<T>>`
389/// can hold either kind without `Box<dyn Future>`, keeping the
390/// pipeline cost-proportional to the workload.
391enum InFlightDecode<T: EventMessage + Send + 'static> {
392 /// Small-frame fast path: already decoded inline. `take()`'d
393 /// exactly once when `FuturesOrdered` first polls it to Ready.
394 Ready(Option<Result<Box<Message<T>>>>),
395 /// Large-frame path: decoding on the blocking thread pool.
396 Blocking(tokio::task::JoinHandle<Result<Box<Message<T>>>>),
397}
398
399impl<T: EventMessage + Send + 'static> Future for InFlightDecode<T> {
400 type Output = Result<Box<Message<T>>>;
401
402 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
403 // Safety: both variants are structurally pin-agnostic —
404 // `Option<Result<..>>` is `Unpin`, and `tokio::task::JoinHandle`
405 // is documented as `Unpin`. So we can project out a `&mut`
406 // without unsafe.
407 match self.get_mut() {
408 InFlightDecode::Ready(slot) => Poll::Ready(
409 slot.take()
410 .expect("InFlightDecode::Ready polled after completion"),
411 ),
412 InFlightDecode::Blocking(handle) => match Pin::new(handle).poll(cx) {
413 Poll::Ready(Ok(res)) => Poll::Ready(res),
414 Poll::Ready(Err(join_err)) => Poll::Ready(Err(CdpError::msg(format!(
415 "WS decode blocking task join error: {join_err}"
416 )))),
417 Poll::Pending => Poll::Pending,
418 },
419 }
420 }
421}
422
423/// Emit a single decoded-frame result to the Handler, logging parse
424/// errors. Returns `true` if the channel is still open, `false` if
425/// the Handler has dropped the receiver (caller should exit).
426async fn emit_decoded<T>(
427 tx: &mpsc::Sender<Result<Box<Message<T>>>>,
428 res: Result<Box<Message<T>>>,
429) -> bool
430where
431 T: EventMessage + Send + 'static,
432{
433 match res {
434 Ok(msg) => tx.send(Ok(msg)).await.is_ok(),
435 Err(err) => {
436 tracing::debug!(
437 target: "chromiumoxide::conn::raw_ws::parse_errors",
438 "Dropping malformed WS frame: {err}",
439 );
440 true
441 }
442 }
443}
444
445/// Background task that reads frames from the WebSocket, decodes them to
446/// typed CDP `Message<T>`, and forwards them to the Handler over a
447/// bounded mpsc.
448///
449/// Runs on a `tokio::spawn`'d task. Small-to-medium frames are
450/// decoded inline (fast path); payloads at or above
451/// [`LARGE_FRAME_THRESHOLD`] are offloaded to `spawn_blocking` so
452/// multi-MB deserialization doesn't monopolise a tokio worker
453/// thread — especially important on single-threaded runtimes where
454/// the reader, Handler, and user tasks share the same worker.
455///
456/// Flow per frame:
457///
458/// * `Text` / `Binary` → [`decode_ws_frame`]; decoded `Ok(msg)` is
459/// sent to the Handler. Decode errors are logged and the frame is
460/// dropped (same behavior as the legacy inline decode path).
461/// * `Close` → loop exits cleanly, dropping `tx`. The Handler's
462/// `next_message().await` returns `None` on the next call.
463/// * `Ping` / `Pong` / unexpected frame types → skipped silently; they
464/// never cross the channel to the Handler.
465/// * Transport error → forwarded as `Err(CdpError::Ws(..))`, then the
466/// loop exits (the WS half is considered dead after an error).
467///
468/// Back-pressure: the outbound `tx` is bounded. If the Handler is busy
469/// and the channel fills, `tx.send(..).await` parks this task, which
470/// stops draining the WS socket. TCP flow control then applies
471/// back-pressure to Chrome instead of letting memory grow without bound.
472async fn ws_read_loop<T, S>(mut stream: S, tx: mpsc::Sender<Result<Box<Message<T>>>>)
473where
474 T: EventMessage + Send + 'static,
475 S: Stream<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>> + Unpin,
476{
477 // Pipeline of decodes in strict arrival order. Small-frame decodes
478 // are produced inline (zero allocation, borrowing the frame body);
479 // large-frame decodes are offloaded to `spawn_blocking`. Both
480 // variants share a single concrete `InFlightDecode<T>` so the
481 // queue avoids `Box<dyn Future>` overhead.
482 let mut in_flight: FuturesOrdered<InFlightDecode<T>> = FuturesOrdered::new();
483
484 // Shutdown state. When the stream signals `Close`, transport
485 // error, or end-of-stream, we stop reading new frames but keep
486 // running the select loop so the emit arm can flush any still
487 // in-flight decodes *interleaved with* whatever else the runtime
488 // is doing. A pending transport error is surfaced to the Handler
489 // only after the in-order flush completes.
490 let mut stream_terminated = false;
491 let mut pending_err: Option<CdpError> = None;
492
493 loop {
494 tokio::select! {
495 // Bias: emit already-ready decodes before reading more
496 // frames. Keeps the pipeline small in the steady state
497 // while still allowing concurrency under burst, and —
498 // critically during shutdown — drains the pipeline one
499 // ready item at a time inside the select loop instead
500 // of blocking in a dedicated drain helper.
501 biased;
502
503 // Emit the head of the pipeline as soon as it is ready.
504 // `FuturesOrdered::next` preserves submit order, so
505 // downstream delivery is byte-identical to the serial
506 // loop's ordering guarantee.
507 Some(res) = in_flight.next(), if !in_flight.is_empty() => {
508 if !emit_decoded(&tx, res).await {
509 return;
510 }
511 }
512
513 // Read the next frame if the pipeline has capacity and
514 // the stream hasn't terminated. Disabled once the stream
515 // signals end (Close / None / Err) so subsequent loop
516 // iterations only do emit work.
517 maybe_frame = stream.next(),
518 if !stream_terminated && in_flight.len() < MAX_IN_FLIGHT_DECODES =>
519 {
520 match maybe_frame {
521 Some(Ok(WsMessage::Text(text))) => {
522 // Zero-copy enqueue. The small-frame fast
523 // path decodes inline *now* (borrowing
524 // `text`, keeping the `raw_text_for_logging`
525 // preview); the large-frame path moves the
526 // `Utf8Bytes` (`Send + 'static`) directly
527 // into `spawn_blocking` without an
528 // intermediate allocation.
529 if text.len() >= LARGE_FRAME_THRESHOLD {
530 in_flight.push_back(InFlightDecode::Blocking(
531 tokio::task::spawn_blocking(move || {
532 decode_message::<T>(text.as_bytes(), None)
533 }),
534 ));
535 } else {
536 let res = decode_message::<T>(text.as_bytes(), Some(&text));
537 in_flight.push_back(InFlightDecode::Ready(Some(res)));
538 }
539 }
540 Some(Ok(WsMessage::Binary(buf))) => {
541 // Same shape as Text: move `Bytes`
542 // (`Send + 'static`) into `spawn_blocking`
543 // for large payloads, decode inline for
544 // small ones.
545 if buf.len() >= LARGE_FRAME_THRESHOLD {
546 in_flight.push_back(InFlightDecode::Blocking(
547 tokio::task::spawn_blocking(move || {
548 decode_message::<T>(&buf, None)
549 }),
550 ));
551 } else {
552 let res = decode_message::<T>(&buf, None);
553 in_flight.push_back(InFlightDecode::Ready(Some(res)));
554 }
555 }
556 Some(Ok(WsMessage::Close(_))) => {
557 stream_terminated = true;
558 }
559 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {}
560 Some(Ok(msg)) => {
561 tracing::debug!(
562 target: "chromiumoxide::conn::raw_ws::parse_errors",
563 "Unexpected WS message type: {:?}",
564 msg
565 );
566 }
567 Some(Err(err)) => {
568 // Defer the error until after the already
569 // in-flight decodes have emitted — preserves
570 // the ordering contract that callers see
571 // frames up to the failure point before the
572 // error itself.
573 stream_terminated = true;
574 pending_err = Some(CdpError::Ws(err));
575 }
576 None => {
577 // Stream ended (connection closed without a
578 // `Close` frame). No more input, but
579 // in_flight may still hold pending decodes.
580 stream_terminated = true;
581 }
582 }
583 }
584
585 // Both arms disabled: `in_flight` is empty AND
586 // `stream_terminated`. We have nothing more to do.
587 else => {
588 break;
589 }
590 }
591 }
592
593 if let Some(err) = pending_err {
594 let _ = tx.send(Err(err)).await;
595 }
596}
597
598/// Background task that batches and flushes outgoing CDP commands.
599async fn ws_write_loop(
600 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
601 mut rx: mpsc::Receiver<MethodCall>,
602) -> Result<()> {
603 while let Some(call) = rx.recv().await {
604 let msg = crate::serde_json::to_string(&call)?;
605 sink.feed(WsMessage::Text(msg.into()))
606 .await
607 .map_err(CdpError::Ws)?;
608
609 // Batch: drain all buffered commands without waiting.
610 while let Ok(call) = rx.try_recv() {
611 let msg = crate::serde_json::to_string(&call)?;
612 sink.feed(WsMessage::Text(msg.into()))
613 .await
614 .map_err(CdpError::Ws)?;
615 }
616
617 // Flush the entire batch in one write.
618 sink.flush().await.map_err(CdpError::Ws)?;
619 }
620
621 // Cmd channel closed → the Handler is shutting down. Send a graceful
622 // WebSocket Close frame so the remote endpoint (esp. for
623 // `Browser::connect()` to a remote DevTools URL, where there is no
624 // child process whose exit closes the socket) tears the connection
625 // down promptly instead of waiting for an idle timeout. Errors are
626 // expected during shutdown (e.g. `AlreadyClosed` if Chrome closed
627 // first) and are intentionally ignored.
628 let _ = sink.close().await;
629 Ok(())
630}
631
632/// Handler-facing read half of the split WebSocket connection.
633///
634/// Decoded CDP messages are produced by a dedicated background task
635/// (see [`ws_read_loop`]) and forwarded over a bounded mpsc. `WsReader`
636/// itself is a thin `Receiver` wrapper — calling `next_message()` does
637/// a single `rx.recv().await` with no per-message decoding work on the
638/// caller's task. This keeps the Handler's poll loop free of CPU-bound
639/// deserialize time, which matters for large (multi-MB) CDP responses
640/// such as screenshots and wide-header network events.
641#[derive(Debug)]
642pub struct WsReader<T: EventMessage> {
643 rx: mpsc::Receiver<Result<Box<Message<T>>>>,
644 _marker: PhantomData<T>,
645}
646
647impl<T: EventMessage + Unpin> WsReader<T> {
648 /// Read the next CDP message from the WebSocket.
649 ///
650 /// Returns `None` when the background reader task has exited
651 /// (connection closed or sender dropped). This call does only a
652 /// channel `recv` — the actual WS read + JSON decode happens on
653 /// the background `ws_read_loop` task.
654 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
655 self.rx.recv().await
656 }
657}
658
659impl<T: EventMessage + Unpin> Stream for Connection<T> {
660 type Item = Result<Box<Message<T>>>;
661
662 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
663 let pin = self.get_mut();
664
665 // Send and flush outgoing messages
666 if let Err(err) = pin.start_send_next(cx) {
667 return Poll::Ready(Some(Err(err)));
668 }
669
670 // Read from the websocket, skipping non-data frames (pings,
671 // pongs, malformed messages) without yielding back to the
672 // executor. This avoids a full round-trip per skipped frame.
673 //
674 // Cap consecutive skips so a flood of non-data frames (many
675 // pings, malformed/unexpected types) cannot starve the
676 // runtime — yield Pending after `MAX_SKIPS_PER_POLL` and
677 // self-wake so we resume on the next tick.
678 const MAX_SKIPS_PER_POLL: u32 = 16;
679 let mut skips: u32 = 0;
680 loop {
681 match ready!(pin.ws.poll_next_unpin(cx)) {
682 Some(Ok(WsMessage::Text(text))) => {
683 match decode_message::<T>(text.as_bytes(), Some(&text)) {
684 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
685 Err(err) => {
686 tracing::debug!(
687 target: "chromiumoxide::conn::raw_ws::parse_errors",
688 "Dropping malformed text WS frame: {err}",
689 );
690 skips += 1;
691 }
692 }
693 }
694 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
695 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
696 Err(err) => {
697 tracing::debug!(
698 target: "chromiumoxide::conn::raw_ws::parse_errors",
699 "Dropping malformed binary WS frame: {err}",
700 );
701 skips += 1;
702 }
703 },
704 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
705 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
706 skips += 1;
707 }
708 Some(Ok(msg)) => {
709 tracing::debug!(
710 target: "chromiumoxide::conn::raw_ws::parse_errors",
711 "Unexpected WS message type: {:?}",
712 msg
713 );
714 skips += 1;
715 }
716 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
717 None => return Poll::Ready(None),
718 }
719
720 if skips >= MAX_SKIPS_PER_POLL {
721 cx.waker().wake_by_ref();
722 return Poll::Pending;
723 }
724 }
725 }
726}
727
728/// Shared decode path for both text and binary WS frames.
729/// `raw_text_for_logging` is only provided for textual frames so we can log the original
730/// payload on parse failure if desired.
731#[cfg(not(feature = "serde_stacker"))]
732fn decode_message<T: EventMessage>(
733 bytes: &[u8],
734 raw_text_for_logging: Option<&str>,
735) -> Result<Box<Message<T>>> {
736 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
737 Ok(msg) => {
738 tracing::trace!("Received {:?}", msg);
739 Ok(msg)
740 }
741 Err(err) => {
742 if let Some(txt) = raw_text_for_logging {
743 let preview = &txt[..txt.len().min(512)];
744 tracing::debug!(
745 target: "chromiumoxide::conn::raw_ws::parse_errors",
746 msg_len = txt.len(),
747 "Skipping unrecognized WS message {err} preview={preview}",
748 );
749 } else {
750 tracing::debug!(
751 target: "chromiumoxide::conn::raw_ws::parse_errors",
752 "Skipping unrecognized binary WS message {err}",
753 );
754 }
755 Err(err.into())
756 }
757 }
758}
759
760/// Shared decode path for both text and binary WS frames.
761/// `raw_text_for_logging` is only provided for textual frames so we can log the original
762/// payload on parse failure if desired.
763#[cfg(feature = "serde_stacker")]
764fn decode_message<T: EventMessage>(
765 bytes: &[u8],
766 raw_text_for_logging: Option<&str>,
767) -> Result<Box<Message<T>>> {
768 use serde::Deserialize;
769 let mut de = serde_json::Deserializer::from_slice(bytes);
770
771 de.disable_recursion_limit();
772
773 let de = serde_stacker::Deserializer::new(&mut de);
774
775 match Box::<Message<T>>::deserialize(de) {
776 Ok(msg) => {
777 tracing::trace!("Received {:?}", msg);
778 Ok(msg)
779 }
780 Err(err) => {
781 if let Some(txt) = raw_text_for_logging {
782 let preview = &txt[..txt.len().min(512)];
783 tracing::debug!(
784 target: "chromiumoxide::conn::raw_ws::parse_errors",
785 msg_len = txt.len(),
786 "Skipping unrecognized WS message {err} preview={preview}",
787 );
788 } else {
789 tracing::debug!(
790 target: "chromiumoxide::conn::raw_ws::parse_errors",
791 "Skipping unrecognized binary WS message {err}",
792 );
793 }
794 Err(err.into())
795 }
796 }
797}
798
799#[cfg(test)]
800mod ws_read_loop_tests {
801 //! Unit tests for the `ws_read_loop` background reader task.
802 //!
803 //! These tests feed a synthetic `Stream<Item = Result<WsMessage, _>>`
804 //! into `ws_read_loop` — no real WebSocket, no Chrome — and observe
805 //! what comes out the other side of the mpsc channel.
806 //!
807 //! The properties under test are the ones that make the reader-task
808 //! decoupling safe: FIFO ordering, no-deadlock on a bounded channel
809 //! under back-pressure, silent drop of non-data frames, graceful
810 //! transport-error propagation, and clean exit on `Close`.
811 //!
812 //! The typed events are `chromiumoxide_cdp::cdp::CdpEventMessage` —
813 //! the same instantiation the real Handler uses — so these tests
814 //! exercise the actual decode path (`serde_json::from_slice`), not
815 //! a simplified fake.
816 use super::*;
817 use chromiumoxide_cdp::cdp::CdpEventMessage;
818 use chromiumoxide_types::CallId;
819 use futures_util::stream;
820 use tokio::sync::mpsc;
821 use tokio_tungstenite::tungstenite::Message as WsMessage;
822
823 /// Build a CDP `Response` WS frame as text — the smallest valid CDP
824 /// message. `id` tags the frame for ordering assertions.
825 fn response_frame(id: u64) -> WsMessage {
826 WsMessage::Text(
827 format!(r#"{{"id":{id},"result":{{"ok":true}}}}"#)
828 .to_string()
829 .into(),
830 )
831 }
832
833 /// Build a frame far larger than a typical socket chunk, to exercise
834 /// the "large message" path that motivated this refactor. The blob
835 /// field pushes serde_json through a big allocation even though the
836 /// envelope is tiny.
837 fn large_response_frame(id: u64, blob_bytes: usize) -> WsMessage {
838 let blob = "x".repeat(blob_bytes);
839 WsMessage::Text(
840 format!(r#"{{"id":{id},"result":{{"blob":"{blob}"}}}}"#)
841 .to_string()
842 .into(),
843 )
844 }
845
846 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
847 async fn forwards_messages_in_stream_order() {
848 let frames = vec![
849 Ok(response_frame(1)),
850 Ok(response_frame(2)),
851 Ok(response_frame(3)),
852 ];
853 let stream = stream::iter(frames);
854 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
855 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
856
857 for expected in [1u64, 2, 3] {
858 let msg = rx.recv().await.expect("msg").expect("decode ok");
859 if let Message::Response(resp) = *msg {
860 assert_eq!(resp.id, CallId::new(expected as usize));
861 } else {
862 panic!("expected Response");
863 }
864 }
865 assert!(rx.recv().await.is_none(), "channel must close on EOF");
866 task.await.expect("reader task join");
867 }
868
869 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
870 async fn pings_and_pongs_never_reach_the_handler() {
871 let frames = vec![
872 Ok(WsMessage::Ping(vec![1, 2, 3].into())),
873 Ok(response_frame(7)),
874 Ok(WsMessage::Pong(vec![].into())),
875 Ok(response_frame(8)),
876 ];
877 let stream = stream::iter(frames);
878 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
879 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
880
881 for expected in [7u64, 8] {
882 let msg = rx.recv().await.expect("msg").expect("decode ok");
883 if let Message::Response(resp) = *msg {
884 assert_eq!(resp.id, CallId::new(expected as usize));
885 }
886 }
887 assert!(rx.recv().await.is_none());
888 task.await.expect("reader task join");
889 }
890
891 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
892 async fn malformed_frames_do_not_block_subsequent_valid_frames() {
893 let frames = vec![
894 Ok(WsMessage::Text("{not valid json".to_string().into())),
895 Ok(response_frame(42)),
896 ];
897 let stream = stream::iter(frames);
898 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
899 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
900
901 let msg = rx.recv().await.expect("msg").expect("decode ok");
902 if let Message::Response(resp) = *msg {
903 assert_eq!(resp.id, CallId::new(42));
904 }
905 assert!(rx.recv().await.is_none());
906 task.await.expect("reader task join");
907 }
908
909 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
910 async fn close_frame_terminates_the_reader() {
911 let frames = vec![
912 Ok(response_frame(1)),
913 Ok(WsMessage::Close(None)),
914 Ok(response_frame(2)), // unreachable after Close
915 ];
916 let stream = stream::iter(frames);
917 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
918 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
919
920 let msg = rx.recv().await.expect("msg").expect("decode ok");
921 if let Message::Response(resp) = *msg {
922 assert_eq!(resp.id, CallId::new(1));
923 }
924 assert!(
925 rx.recv().await.is_none(),
926 "reader must exit on Close; frames after Close must not appear"
927 );
928 task.await.expect("reader task join");
929 }
930
931 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
932 async fn transport_error_is_forwarded_once_then_reader_exits() {
933 let frames = vec![
934 Ok(response_frame(1)),
935 Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed),
936 Ok(response_frame(2)),
937 ];
938 let stream = stream::iter(frames);
939 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
940 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
941
942 let msg = rx.recv().await.expect("msg").expect("ok");
943 assert!(matches!(*msg, Message::Response(_)));
944 match rx.recv().await {
945 Some(Err(CdpError::Ws(_))) => {}
946 other => panic!("expected forwarded Ws error, got {other:?}"),
947 }
948 assert!(rx.recv().await.is_none());
949 task.await.expect("reader task join");
950 }
951
952 /// Back-pressure property: with the smallest possible channel and
953 /// many frames, the reader task awaits capacity after each send and
954 /// never deadlocks. This is the core "no deadlock" proof for the
955 /// new design — if the reader held anything across its `.await` that
956 /// the consumer needed, the consumer's `recv().await` would block
957 /// forever. Completion under a 5s watchdog proves it doesn't.
958 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
959 async fn bounded_channel_does_not_deadlock_under_backpressure() {
960 const N: u64 = 512;
961 let frames: Vec<_> = (1..=N).map(|id| Ok(response_frame(id))).collect();
962 let stream = stream::iter(frames);
963
964 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(1);
965 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
966
967 let deadline = std::time::Duration::from_secs(5);
968 let collected = tokio::time::timeout(deadline, async {
969 let mut seen = 0u64;
970 while let Some(frame) = rx.recv().await {
971 let msg = frame.expect("decode ok");
972 if let Message::Response(resp) = *msg {
973 seen += 1;
974 assert_eq!(
975 resp.id,
976 CallId::new(seen as usize),
977 "back-pressure must preserve FIFO order"
978 );
979 }
980 }
981 seen
982 })
983 .await
984 .expect("reader must make forward progress despite cap-1 back-pressure");
985
986 assert_eq!(collected, N, "all frames must arrive");
987 task.await.expect("reader task join");
988 }
989
990 /// Large message (>1 MB) is decoded correctly on the background
991 /// task. This is the specific scenario the reader-task refactor
992 /// was built for — we don't measure time here (benches cover that),
993 /// we just prove the end-to-end path works without corruption or
994 /// deadlock.
995 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
996 async fn large_message_decodes_without_corruption() {
997 let big = 2 * 1024 * 1024; // 2 MB payload
998 let frames = vec![Ok(large_response_frame(100, big)), Ok(response_frame(101))];
999 let stream = stream::iter(frames);
1000 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(4);
1001 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
1002
1003 let first = rx.recv().await.expect("msg").expect("ok");
1004 if let Message::Response(resp) = *first {
1005 assert_eq!(resp.id, CallId::new(100));
1006 }
1007 let second = rx.recv().await.expect("msg").expect("ok");
1008 if let Message::Response(resp) = *second {
1009 assert_eq!(resp.id, CallId::new(101));
1010 }
1011 assert!(rx.recv().await.is_none());
1012 task.await.expect("reader task join");
1013 }
1014
1015 /// FIFO ordering under the pipelined reader when large-frame
1016 /// decodes run in parallel via `spawn_blocking`.
1017 ///
1018 /// This test submits an interleaved sequence of large and small
1019 /// frames. Large frames take the `spawn_blocking` path (decode
1020 /// on the blocking pool, variable completion order); small
1021 /// frames take the inline path (decode immediately). The
1022 /// pipeline's `FuturesOrdered` queue must emit them to the
1023 /// Handler in strict arrival order regardless of which
1024 /// blocking-pool thread finishes first.
1025 ///
1026 /// If the ordering guarantee were ever broken — e.g. by
1027 /// accidentally swapping `FuturesOrdered` for `FuturesUnordered`
1028 /// — id sequence checks here would catch it immediately.
1029 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1030 async fn pipelined_large_and_small_frames_keep_fifo_order() {
1031 let big = 2 * 1024 * 1024; // 2 MB payload — forces spawn_blocking
1032 let frames = vec![
1033 Ok(large_response_frame(1, big)),
1034 Ok(response_frame(2)),
1035 Ok(response_frame(3)),
1036 Ok(large_response_frame(4, big)),
1037 Ok(response_frame(5)),
1038 Ok(large_response_frame(6, big)),
1039 Ok(response_frame(7)),
1040 Ok(response_frame(8)),
1041 ];
1042 let expected: Vec<usize> = (1..=8).collect();
1043
1044 let stream = stream::iter(frames);
1045 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(16);
1046 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
1047
1048 let deadline = std::time::Duration::from_secs(10);
1049 let observed = tokio::time::timeout(deadline, async {
1050 let mut ids = Vec::with_capacity(expected.len());
1051 while let Some(frame) = rx.recv().await {
1052 let msg = frame.expect("decode ok");
1053 if let Message::Response(resp) = *msg {
1054 ids.push(CallId::new(ids.len() + 1));
1055 assert_eq!(
1056 resp.id,
1057 *ids.last().unwrap(),
1058 "pipelined reader must emit frames in strict arrival order \
1059 regardless of per-frame decode latency"
1060 );
1061 }
1062 }
1063 ids
1064 })
1065 .await
1066 .expect("pipelined reader should make forward progress within 10s");
1067
1068 assert_eq!(
1069 observed.len(),
1070 expected.len(),
1071 "all {} frames must reach the Handler",
1072 expected.len()
1073 );
1074 task.await.expect("reader task join");
1075 }
1076}