Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::SplitSink;
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22/// Exchanges the messages with the websocket
23#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26    /// Queue of commands to send.
27    pending_commands: VecDeque<MethodCall>,
28    /// The websocket of the chromium instance
29    ws: WebSocketStream<ConnectStream>,
30    /// The identifier for a specific command
31    next_id: usize,
32    /// Whether the write buffer has unsent data that needs flushing.
33    needs_flush: bool,
34    /// The phantom marker.
35    _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39    /// Nagle's algorithm disabled?
40    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41        Ok(disable_nagle) => disable_nagle == "true",
42        _ => true
43    };
44    /// Websocket config defaults
45    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46        Ok(d) => d == "true",
47        _ => false
48    };
49}
50
51/// Default number of WebSocket connection retry attempts.
52pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54/// Initial backoff delay between connection retries (in milliseconds).
55const INITIAL_BACKOFF_MS: u64 = 50;
56
57/// Maximum backoff delay between connection retries (in milliseconds).
58const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63    }
64
65    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66        let mut config = WebSocketConfig::default();
67
68        // Cap the internal write buffer so a slow receiver cannot cause
69        // unbounded memory growth (default is usize::MAX).
70        config.max_write_buffer_size = 4 * 1024 * 1024;
71
72        if !*WEBSOCKET_DEFAULTS {
73            config.max_message_size = None;
74            config.max_frame_size = None;
75        }
76
77        let url = debug_ws_url.as_ref();
78        let use_uring = crate::uring_fs::is_enabled();
79        let mut last_err = None;
80
81        for attempt in 0..=retries {
82            let result = if use_uring {
83                Self::connect_uring(url, config).await
84            } else {
85                Self::connect_default(url, config).await
86            };
87
88            match result {
89                Ok(ws) => {
90                    return Ok(Self {
91                        pending_commands: Default::default(),
92                        ws,
93                        next_id: 0,
94                        needs_flush: false,
95                        _marker: Default::default(),
96                    });
97                }
98                Err(e) => {
99                    // Detect non-retriable errors early to avoid wasting time
100                    // on connections that will never succeed.
101                    let should_retry = match &e {
102                        // Connection refused — nothing is listening on this port.
103                        CdpError::Io(io_err)
104                            if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105                        {
106                            false
107                        }
108                        // HTTP response to a WebSocket upgrade (e.g. wrong path
109                        // returns 404 / redirect) — retrying the same URL won't help.
110                        CdpError::Ws(tungstenite_err) => !matches!(
111                            tungstenite_err,
112                            tokio_tungstenite::tungstenite::Error::Http(_)
113                                | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
114                        ),
115                        _ => true,
116                    };
117
118                    last_err = Some(e);
119
120                    if !should_retry {
121                        break;
122                    }
123
124                    if attempt < retries {
125                        let backoff_ms =
126                            (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
127                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
128                    }
129                }
130            }
131        }
132
133        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
134    }
135
136    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
137    async fn connect_default(
138        url: &str,
139        config: WebSocketConfig,
140    ) -> Result<WebSocketStream<ConnectStream>> {
141        let (ws, _) =
142            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
143        Ok(ws)
144    }
145
146    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
147    /// handshake over the pre-connected stream.
148    async fn connect_uring(
149        url: &str,
150        config: WebSocketConfig,
151    ) -> Result<WebSocketStream<ConnectStream>> {
152        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
153
154        let request = url.into_client_request()?;
155        let host = request
156            .uri()
157            .host()
158            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
159        let port = request.uri().port_u16().unwrap_or(9222);
160
161        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
162        let addr_str = format!("{}:{}", host, port);
163        let addr: std::net::SocketAddr = match addr_str.parse() {
164            Ok(a) => a,
165            Err(_) => {
166                // Hostname needs DNS — fall back to default path.
167                return Self::connect_default(url, config).await;
168            }
169        };
170
171        // TCP connect via io_uring.
172        let std_stream = crate::uring_fs::tcp_connect(addr)
173            .await
174            .map_err(CdpError::Io)?;
175
176        // Set non-blocking + Nagle.
177        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
178        if *DISABLE_NAGLE {
179            let _ = std_stream.set_nodelay(true);
180        }
181
182        // Wrap in tokio TcpStream.
183        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
184
185        // WebSocket handshake over the pre-connected stream.
186        let (ws, _) = tokio_tungstenite::client_async_with_config(
187            request,
188            MaybeTlsStream::Plain(tokio_stream),
189            Some(config),
190        )
191        .await?;
192
193        Ok(ws)
194    }
195}
196
197impl<T: EventMessage> Connection<T> {
198    fn next_call_id(&mut self) -> CallId {
199        let id = CallId::new(self.next_id);
200        self.next_id = self.next_id.wrapping_add(1);
201        id
202    }
203
204    /// Queue in the command to send over the socket and return the id for this
205    /// command
206    pub fn submit_command(
207        &mut self,
208        method: MethodId,
209        session_id: Option<SessionId>,
210        params: serde_json::Value,
211    ) -> serde_json::Result<CallId> {
212        let id = self.next_call_id();
213        let call = MethodCall {
214            id,
215            method,
216            session_id: session_id.map(Into::into),
217            params,
218        };
219        self.pending_commands.push_back(call);
220        Ok(id)
221    }
222
223    /// Buffer all queued commands into the WebSocket sink, then flush once.
224    ///
225    /// This batches multiple CDP commands into a single TCP write instead of
226    /// flushing after every individual message.
227    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
228        // Complete any pending flush from a previous poll first.
229        if self.needs_flush {
230            match self.ws.poll_flush_unpin(cx) {
231                Poll::Ready(Ok(())) => self.needs_flush = false,
232                Poll::Ready(Err(e)) => return Err(e.into()),
233                Poll::Pending => return Ok(()),
234            }
235        }
236
237        // Buffer as many queued commands as the sink will accept.
238        let mut sent_any = false;
239        while !self.pending_commands.is_empty() {
240            match self.ws.poll_ready_unpin(cx) {
241                Poll::Ready(Ok(())) => {
242                    let Some(cmd) = self.pending_commands.pop_front() else {
243                        break;
244                    };
245                    tracing::trace!("Sending {:?}", cmd);
246                    let msg = serde_json::to_string(&cmd)?;
247                    self.ws.start_send_unpin(msg.into())?;
248                    sent_any = true;
249                }
250                _ => break,
251            }
252        }
253
254        // Flush the entire batch in one write.
255        if sent_any {
256            match self.ws.poll_flush_unpin(cx) {
257                Poll::Ready(Ok(())) => {}
258                Poll::Ready(Err(e)) => return Err(e.into()),
259                Poll::Pending => self.needs_flush = true,
260            }
261        }
262
263        Ok(())
264    }
265}
266
267/// Capacity of the bounded channel feeding the background WS writer task.
268/// Large enough that bursts of CDP commands never block the handler, small
269/// enough to apply back-pressure before memory grows without bound.
270const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
271
272/// Capacity of the bounded channel from the background WS reader task to
273/// the Handler. Keeps decoded CDP messages buffered so the reader task
274/// can keep reading the socket while the Handler processes a backlog;
275/// applies TCP-level back-pressure on Chrome when the Handler is slow
276/// (the reader awaits channel capacity, stops draining the socket).
277const WS_READ_CHANNEL_CAPACITY: usize = 1024;
278
279/// Split parts returned by [`Connection::into_async`].
280#[derive(Debug)]
281pub struct AsyncConnection<T: EventMessage> {
282    /// Receive half for decoded CDP messages. Backed by a bounded mpsc
283    /// fed by a dedicated background reader task — decode runs on that
284    /// task, never on the Handler task, so large CDP responses (multi-MB
285    /// screenshots, huge event payloads) cannot stall the Handler's
286    /// event loop.
287    pub reader: WsReader<T>,
288    /// Sender half for submitting outgoing CDP commands.
289    pub cmd_tx: mpsc::Sender<MethodCall>,
290    /// Handle to the background writer task.
291    pub writer_handle: tokio::task::JoinHandle<Result<()>>,
292    /// Handle to the background reader task (reads + decodes WS frames).
293    pub reader_handle: tokio::task::JoinHandle<()>,
294    /// Next command-call-id counter (continue numbering from where Connection left off).
295    pub next_id: usize,
296}
297
298impl<T: EventMessage + Unpin + Send + 'static> Connection<T> {
299    /// Consume the connection and split into a background reader + writer
300    /// pair, exposing the Handler-facing ends via `AsyncConnection`.
301    ///
302    /// Two `tokio::spawn`'d tasks are created:
303    ///
304    /// * `ws_write_loop` — batches outgoing commands and flushes them in
305    ///   one write per wakeup.
306    /// * `ws_read_loop`  — reads WS frames, decodes them to typed
307    ///   `Message<T>`, and forwards them via a bounded mpsc to the
308    ///   Handler. Ping/pong/malformed frames are skipped on this task
309    ///   and never reach the Handler. Large-message decode (SerDe CPU
310    ///   work) runs here, **not** on the Handler task, so the Handler's
311    ///   poll loop never stalls for tens of milliseconds on a 10 MB
312    ///   screenshot response.
313    ///
314    /// The design uses only `tokio::spawn` (cooperative async) — no
315    /// `spawn_blocking` or blocking thread-pool — so it scales with the
316    /// tokio runtime's worker threads on multi-threaded runtimes, and
317    /// interleaves cleanly with the Handler task on single-threaded
318    /// runtimes.
319    pub fn into_async(self) -> AsyncConnection<T> {
320        let (ws_sink, ws_stream) = self.ws.split();
321        let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
322        let (msg_tx, msg_rx) = mpsc::channel::<Result<Box<Message<T>>>>(WS_READ_CHANNEL_CAPACITY);
323
324        let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
325        let reader_handle = tokio::spawn(ws_read_loop::<T, _>(ws_stream, msg_tx));
326
327        let reader = WsReader {
328            rx: msg_rx,
329            _marker: PhantomData,
330        };
331
332        AsyncConnection {
333            reader,
334            cmd_tx,
335            writer_handle,
336            reader_handle,
337            next_id: self.next_id,
338        }
339    }
340}
341
342/// Background task that reads frames from the WebSocket, decodes them to
343/// typed CDP `Message<T>`, and forwards them to the Handler over a
344/// bounded mpsc.
345///
346/// Runs on a `tokio::spawn`'d task — **not** `spawn_blocking` — so CPU
347/// time for JSON decode is charged to a regular tokio worker and not the
348/// blocking thread pool. On a multi-threaded runtime, the decode can run
349/// on a different worker than the Handler, giving true parallelism for
350/// large messages. On a single-threaded runtime, it cooperates with the
351/// Handler via `.await` points on the send channel.
352///
353/// Flow per frame:
354///
355/// * `Text` / `Binary` → `decode_message::<T>`; decoded `Ok(msg)` is
356///   sent to the Handler. Decode errors are logged and the frame is
357///   dropped (same behavior as the legacy inline decode path).
358/// * `Close` → loop exits cleanly, dropping `tx`. The Handler's
359///   `next_message().await` returns `None` on the next call.
360/// * `Ping` / `Pong` / unexpected frame types → skipped silently; they
361///   never cross the channel to the Handler.
362/// * Transport error → forwarded as `Err(CdpError::Ws(..))`, then the
363///   loop exits (the WS half is considered dead after an error).
364///
365/// Back-pressure: the outbound `tx` is bounded. If the Handler is busy
366/// and the channel fills, `tx.send(..).await` parks this task, which
367/// stops draining the WS socket. TCP flow control then applies
368/// back-pressure to Chrome instead of letting memory grow without bound.
369async fn ws_read_loop<T, S>(mut stream: S, tx: mpsc::Sender<Result<Box<Message<T>>>>)
370where
371    T: EventMessage,
372    S: Stream<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
373        + Unpin,
374{
375    while let Some(frame) = stream.next().await {
376        match frame {
377            Ok(WsMessage::Text(text)) => {
378                match decode_message::<T>(text.as_bytes(), Some(&text)) {
379                    Ok(msg) => {
380                        if tx.send(Ok(msg)).await.is_err() {
381                            return;
382                        }
383                    }
384                    Err(err) => {
385                        tracing::debug!(
386                            target: "chromiumoxide::conn::raw_ws::parse_errors",
387                            "Dropping malformed text WS frame: {err}",
388                        );
389                    }
390                }
391            }
392            Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
393                Ok(msg) => {
394                    if tx.send(Ok(msg)).await.is_err() {
395                        return;
396                    }
397                }
398                Err(err) => {
399                    tracing::debug!(
400                        target: "chromiumoxide::conn::raw_ws::parse_errors",
401                        "Dropping malformed binary WS frame: {err}",
402                    );
403                }
404            },
405            Ok(WsMessage::Close(_)) => return,
406            Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => {}
407            Ok(msg) => {
408                tracing::debug!(
409                    target: "chromiumoxide::conn::raw_ws::parse_errors",
410                    "Unexpected WS message type: {:?}",
411                    msg
412                );
413            }
414            Err(err) => {
415                // Forward the error once, then exit. The Handler will
416                // observe it on the next `next_message()` call.
417                let _ = tx.send(Err(CdpError::Ws(err))).await;
418                return;
419            }
420        }
421    }
422}
423
424/// Background task that batches and flushes outgoing CDP commands.
425async fn ws_write_loop(
426    mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
427    mut rx: mpsc::Receiver<MethodCall>,
428) -> Result<()> {
429    while let Some(call) = rx.recv().await {
430        let msg = crate::serde_json::to_string(&call)?;
431        sink.feed(WsMessage::Text(msg.into()))
432            .await
433            .map_err(CdpError::Ws)?;
434
435        // Batch: drain all buffered commands without waiting.
436        while let Ok(call) = rx.try_recv() {
437            let msg = crate::serde_json::to_string(&call)?;
438            sink.feed(WsMessage::Text(msg.into()))
439                .await
440                .map_err(CdpError::Ws)?;
441        }
442
443        // Flush the entire batch in one write.
444        sink.flush().await.map_err(CdpError::Ws)?;
445    }
446    Ok(())
447}
448
449/// Handler-facing read half of the split WebSocket connection.
450///
451/// Decoded CDP messages are produced by a dedicated background task
452/// (see [`ws_read_loop`]) and forwarded over a bounded mpsc. `WsReader`
453/// itself is a thin `Receiver` wrapper — calling `next_message()` does
454/// a single `rx.recv().await` with no per-message decoding work on the
455/// caller's task. This keeps the Handler's poll loop free of CPU-bound
456/// deserialize time, which matters for large (multi-MB) CDP responses
457/// such as screenshots and wide-header network events.
458#[derive(Debug)]
459pub struct WsReader<T: EventMessage> {
460    rx: mpsc::Receiver<Result<Box<Message<T>>>>,
461    _marker: PhantomData<T>,
462}
463
464impl<T: EventMessage + Unpin> WsReader<T> {
465    /// Read the next CDP message from the WebSocket.
466    ///
467    /// Returns `None` when the background reader task has exited
468    /// (connection closed or sender dropped). This call does only a
469    /// channel `recv` — the actual WS read + JSON decode happens on
470    /// the background `ws_read_loop` task.
471    pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
472        self.rx.recv().await
473    }
474}
475
476impl<T: EventMessage + Unpin> Stream for Connection<T> {
477    type Item = Result<Box<Message<T>>>;
478
479    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
480        let pin = self.get_mut();
481
482        // Send and flush outgoing messages
483        if let Err(err) = pin.start_send_next(cx) {
484            return Poll::Ready(Some(Err(err)));
485        }
486
487        // Read from the websocket, skipping non-data frames (pings,
488        // pongs, malformed messages) without yielding back to the
489        // executor.  This avoids a full round-trip per skipped frame.
490        //
491        // Cap consecutive skips so a flood of non-data frames (many
492        // pings, malformed/unexpected types) cannot starve the
493        // runtime — yield Pending after `MAX_SKIPS_PER_POLL` and
494        // self-wake so we resume on the next tick.
495        const MAX_SKIPS_PER_POLL: u32 = 16;
496        let mut skips: u32 = 0;
497        loop {
498            match ready!(pin.ws.poll_next_unpin(cx)) {
499                Some(Ok(WsMessage::Text(text))) => {
500                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
501                        Ok(msg) => return Poll::Ready(Some(Ok(msg))),
502                        Err(err) => {
503                            tracing::debug!(
504                                target: "chromiumoxide::conn::raw_ws::parse_errors",
505                                "Dropping malformed text WS frame: {err}",
506                            );
507                            skips += 1;
508                        }
509                    }
510                }
511                Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
512                    Ok(msg) => return Poll::Ready(Some(Ok(msg))),
513                    Err(err) => {
514                        tracing::debug!(
515                            target: "chromiumoxide::conn::raw_ws::parse_errors",
516                            "Dropping malformed binary WS frame: {err}",
517                        );
518                        skips += 1;
519                    }
520                },
521                Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
522                Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
523                    skips += 1;
524                }
525                Some(Ok(msg)) => {
526                    tracing::debug!(
527                        target: "chromiumoxide::conn::raw_ws::parse_errors",
528                        "Unexpected WS message type: {:?}",
529                        msg
530                    );
531                    skips += 1;
532                }
533                Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
534                None => return Poll::Ready(None),
535            }
536
537            if skips >= MAX_SKIPS_PER_POLL {
538                cx.waker().wake_by_ref();
539                return Poll::Pending;
540            }
541        }
542    }
543}
544
545/// Shared decode path for both text and binary WS frames.
546/// `raw_text_for_logging` is only provided for textual frames so we can log the original
547/// payload on parse failure if desired.
548#[cfg(not(feature = "serde_stacker"))]
549fn decode_message<T: EventMessage>(
550    bytes: &[u8],
551    raw_text_for_logging: Option<&str>,
552) -> Result<Box<Message<T>>> {
553    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
554        Ok(msg) => {
555            tracing::trace!("Received {:?}", msg);
556            Ok(msg)
557        }
558        Err(err) => {
559            if let Some(txt) = raw_text_for_logging {
560                let preview = &txt[..txt.len().min(512)];
561                tracing::debug!(
562                    target: "chromiumoxide::conn::raw_ws::parse_errors",
563                    msg_len = txt.len(),
564                    "Skipping unrecognized WS message {err} preview={preview}",
565                );
566            } else {
567                tracing::debug!(
568                    target: "chromiumoxide::conn::raw_ws::parse_errors",
569                    "Skipping unrecognized binary WS message {err}",
570                );
571            }
572            Err(err.into())
573        }
574    }
575}
576
577/// Shared decode path for both text and binary WS frames.
578/// `raw_text_for_logging` is only provided for textual frames so we can log the original
579/// payload on parse failure if desired.
580#[cfg(feature = "serde_stacker")]
581fn decode_message<T: EventMessage>(
582    bytes: &[u8],
583    raw_text_for_logging: Option<&str>,
584) -> Result<Box<Message<T>>> {
585    use serde::Deserialize;
586    let mut de = serde_json::Deserializer::from_slice(bytes);
587
588    de.disable_recursion_limit();
589
590    let de = serde_stacker::Deserializer::new(&mut de);
591
592    match Box::<Message<T>>::deserialize(de) {
593        Ok(msg) => {
594            tracing::trace!("Received {:?}", msg);
595            Ok(msg)
596        }
597        Err(err) => {
598            if let Some(txt) = raw_text_for_logging {
599                let preview = &txt[..txt.len().min(512)];
600                tracing::debug!(
601                    target: "chromiumoxide::conn::raw_ws::parse_errors",
602                    msg_len = txt.len(),
603                    "Skipping unrecognized WS message {err} preview={preview}",
604                );
605            } else {
606                tracing::debug!(
607                    target: "chromiumoxide::conn::raw_ws::parse_errors",
608                    "Skipping unrecognized binary WS message {err}",
609                );
610            }
611            Err(err.into())
612        }
613    }
614}
615
616#[cfg(test)]
617mod ws_read_loop_tests {
618    //! Unit tests for the `ws_read_loop` background reader task.
619    //!
620    //! These tests feed a synthetic `Stream<Item = Result<WsMessage, _>>`
621    //! into `ws_read_loop` — no real WebSocket, no Chrome — and observe
622    //! what comes out the other side of the mpsc channel.
623    //!
624    //! The properties under test are the ones that make the reader-task
625    //! decoupling safe: FIFO ordering, no-deadlock on a bounded channel
626    //! under back-pressure, silent drop of non-data frames, graceful
627    //! transport-error propagation, and clean exit on `Close`.
628    //!
629    //! The typed events are `chromiumoxide_cdp::cdp::CdpEventMessage` —
630    //! the same instantiation the real Handler uses — so these tests
631    //! exercise the actual decode path (`serde_json::from_slice`), not
632    //! a simplified fake.
633    use super::*;
634    use chromiumoxide_cdp::cdp::CdpEventMessage;
635    use chromiumoxide_types::CallId;
636    use futures_util::stream;
637    use tokio::sync::mpsc;
638    use tokio_tungstenite::tungstenite::Message as WsMessage;
639
640    /// Build a CDP `Response` WS frame as text — the smallest valid CDP
641    /// message. `id` tags the frame for ordering assertions.
642    fn response_frame(id: u64) -> WsMessage {
643        WsMessage::Text(
644            format!(r#"{{"id":{id},"result":{{"ok":true}}}}"#)
645                .to_string()
646                .into(),
647        )
648    }
649
650    /// Build a frame far larger than a typical socket chunk, to exercise
651    /// the "large message" path that motivated this refactor. The blob
652    /// field pushes serde_json through a big allocation even though the
653    /// envelope is tiny.
654    fn large_response_frame(id: u64, blob_bytes: usize) -> WsMessage {
655        let blob = "x".repeat(blob_bytes);
656        WsMessage::Text(
657            format!(r#"{{"id":{id},"result":{{"blob":"{blob}"}}}}"#)
658                .to_string()
659                .into(),
660        )
661    }
662
663    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
664    async fn forwards_messages_in_stream_order() {
665        let frames = vec![
666            Ok(response_frame(1)),
667            Ok(response_frame(2)),
668            Ok(response_frame(3)),
669        ];
670        let stream = stream::iter(frames);
671        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
672        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
673
674        for expected in [1u64, 2, 3] {
675            let msg = rx.recv().await.expect("msg").expect("decode ok");
676            if let Message::Response(resp) = *msg {
677                assert_eq!(resp.id, CallId::new(expected as usize));
678            } else {
679                panic!("expected Response");
680            }
681        }
682        assert!(rx.recv().await.is_none(), "channel must close on EOF");
683        task.await.expect("reader task join");
684    }
685
686    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
687    async fn pings_and_pongs_never_reach_the_handler() {
688        let frames = vec![
689            Ok(WsMessage::Ping(vec![1, 2, 3].into())),
690            Ok(response_frame(7)),
691            Ok(WsMessage::Pong(vec![].into())),
692            Ok(response_frame(8)),
693        ];
694        let stream = stream::iter(frames);
695        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
696        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
697
698        for expected in [7u64, 8] {
699            let msg = rx.recv().await.expect("msg").expect("decode ok");
700            if let Message::Response(resp) = *msg {
701                assert_eq!(resp.id, CallId::new(expected as usize));
702            }
703        }
704        assert!(rx.recv().await.is_none());
705        task.await.expect("reader task join");
706    }
707
708    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
709    async fn malformed_frames_do_not_block_subsequent_valid_frames() {
710        let frames = vec![
711            Ok(WsMessage::Text("{not valid json".to_string().into())),
712            Ok(response_frame(42)),
713        ];
714        let stream = stream::iter(frames);
715        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
716        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
717
718        let msg = rx.recv().await.expect("msg").expect("decode ok");
719        if let Message::Response(resp) = *msg {
720            assert_eq!(resp.id, CallId::new(42));
721        }
722        assert!(rx.recv().await.is_none());
723        task.await.expect("reader task join");
724    }
725
726    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
727    async fn close_frame_terminates_the_reader() {
728        let frames = vec![
729            Ok(response_frame(1)),
730            Ok(WsMessage::Close(None)),
731            Ok(response_frame(2)), // unreachable after Close
732        ];
733        let stream = stream::iter(frames);
734        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
735        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
736
737        let msg = rx.recv().await.expect("msg").expect("decode ok");
738        if let Message::Response(resp) = *msg {
739            assert_eq!(resp.id, CallId::new(1));
740        }
741        assert!(
742            rx.recv().await.is_none(),
743            "reader must exit on Close; frames after Close must not appear"
744        );
745        task.await.expect("reader task join");
746    }
747
748    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
749    async fn transport_error_is_forwarded_once_then_reader_exits() {
750        let frames = vec![
751            Ok(response_frame(1)),
752            Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed),
753            Ok(response_frame(2)),
754        ];
755        let stream = stream::iter(frames);
756        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
757        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
758
759        let msg = rx.recv().await.expect("msg").expect("ok");
760        assert!(matches!(*msg, Message::Response(_)));
761        match rx.recv().await {
762            Some(Err(CdpError::Ws(_))) => {}
763            other => panic!("expected forwarded Ws error, got {other:?}"),
764        }
765        assert!(rx.recv().await.is_none());
766        task.await.expect("reader task join");
767    }
768
769    /// Back-pressure property: with the smallest possible channel and
770    /// many frames, the reader task awaits capacity after each send and
771    /// never deadlocks. This is the core "no deadlock" proof for the
772    /// new design — if the reader held anything across its `.await` that
773    /// the consumer needed, the consumer's `recv().await` would block
774    /// forever. Completion under a 5s watchdog proves it doesn't.
775    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
776    async fn bounded_channel_does_not_deadlock_under_backpressure() {
777        const N: u64 = 512;
778        let frames: Vec<_> = (1..=N).map(|id| Ok(response_frame(id))).collect();
779        let stream = stream::iter(frames);
780
781        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(1);
782        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
783
784        let deadline = std::time::Duration::from_secs(5);
785        let collected = tokio::time::timeout(deadline, async {
786            let mut seen = 0u64;
787            while let Some(frame) = rx.recv().await {
788                let msg = frame.expect("decode ok");
789                if let Message::Response(resp) = *msg {
790                    seen += 1;
791                    assert_eq!(
792                        resp.id,
793                        CallId::new(seen as usize),
794                        "back-pressure must preserve FIFO order"
795                    );
796                }
797            }
798            seen
799        })
800        .await
801        .expect("reader must make forward progress despite cap-1 back-pressure");
802
803        assert_eq!(collected, N, "all frames must arrive");
804        task.await.expect("reader task join");
805    }
806
807    /// Large message (>1 MB) is decoded correctly on the background
808    /// task. This is the specific scenario the reader-task refactor
809    /// was built for — we don't measure time here (benches cover that),
810    /// we just prove the end-to-end path works without corruption or
811    /// deadlock.
812    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
813    async fn large_message_decodes_without_corruption() {
814        let big = 2 * 1024 * 1024; // 2 MB payload
815        let frames = vec![
816            Ok(large_response_frame(100, big)),
817            Ok(response_frame(101)),
818        ];
819        let stream = stream::iter(frames);
820        let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(4);
821        let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
822
823        let first = rx.recv().await.expect("msg").expect("ok");
824        if let Message::Response(resp) = *first {
825            assert_eq!(resp.id, CallId::new(100));
826        }
827        let second = rx.recv().await.expect("msg").expect("ok");
828        if let Message::Response(resp) = *second {
829            assert_eq!(resp.id, CallId::new(101));
830        }
831        assert!(rx.recv().await.is_none());
832        task.await.expect("reader task join");
833    }
834}