Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22/// Exchanges the messages with the websocket
23#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26    /// Queue of commands to send.
27    pending_commands: VecDeque<MethodCall>,
28    /// The websocket of the chromium instance
29    ws: WebSocketStream<ConnectStream>,
30    /// The identifier for a specific command
31    next_id: usize,
32    /// Whether the write buffer has unsent data that needs flushing.
33    needs_flush: bool,
34    /// The phantom marker.
35    _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39    /// Nagle's algorithm disabled?
40    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41        Ok(disable_nagle) => disable_nagle == "true",
42        _ => true
43    };
44    /// Websocket config defaults
45    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46        Ok(d) => d == "true",
47        _ => false
48    };
49}
50
51/// Default number of WebSocket connection retry attempts.
52pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54/// Initial backoff delay between connection retries (in milliseconds).
55const INITIAL_BACKOFF_MS: u64 = 50;
56
57/// Maximum backoff delay between connection retries (in milliseconds).
58const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63    }
64
65    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66        let mut config = WebSocketConfig::default();
67
68        // Cap the internal write buffer so a slow receiver cannot cause
69        // unbounded memory growth (default is usize::MAX).
70        config.max_write_buffer_size = 4 * 1024 * 1024;
71
72        if !*WEBSOCKET_DEFAULTS {
73            config.max_message_size = None;
74            config.max_frame_size = None;
75        }
76
77        let url = debug_ws_url.as_ref();
78        let use_uring = crate::uring_fs::is_enabled();
79        let mut last_err = None;
80
81        for attempt in 0..=retries {
82            let result = if use_uring {
83                Self::connect_uring(url, config).await
84            } else {
85                Self::connect_default(url, config).await
86            };
87
88            match result {
89                Ok(ws) => {
90                    return Ok(Self {
91                        pending_commands: Default::default(),
92                        ws,
93                        next_id: 0,
94                        needs_flush: false,
95                        _marker: Default::default(),
96                    });
97                }
98                Err(e) => {
99                    // Detect non-retriable errors early to avoid wasting time
100                    // on connections that will never succeed.
101                    let should_retry = match &e {
102                        // Connection refused — nothing is listening on this port.
103                        CdpError::Io(io_err)
104                            if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105                        {
106                            false
107                        }
108                        // HTTP response to a WebSocket upgrade (e.g. wrong path
109                        // returns 404 / redirect) — retrying the same URL won't help.
110                        CdpError::Ws(tungstenite_err) => {
111                            !matches!(
112                                tungstenite_err,
113                                tokio_tungstenite::tungstenite::Error::Http(_)
114                                    | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
115                            )
116                        }
117                        _ => true,
118                    };
119
120                    last_err = Some(e);
121
122                    if !should_retry {
123                        break;
124                    }
125
126                    if attempt < retries {
127                        let backoff_ms = (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt))
128                            .min(MAX_BACKOFF_MS);
129                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
130                    }
131                }
132            }
133        }
134
135        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
136    }
137
138    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
139    async fn connect_default(
140        url: &str,
141        config: WebSocketConfig,
142    ) -> Result<WebSocketStream<ConnectStream>> {
143        let (ws, _) =
144            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
145        Ok(ws)
146    }
147
148    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
149    /// handshake over the pre-connected stream.
150    async fn connect_uring(
151        url: &str,
152        config: WebSocketConfig,
153    ) -> Result<WebSocketStream<ConnectStream>> {
154        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
155
156        let request = url.into_client_request()?;
157        let host = request
158            .uri()
159            .host()
160            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
161        let port = request.uri().port_u16().unwrap_or(9222);
162
163        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
164        let addr_str = format!("{}:{}", host, port);
165        let addr: std::net::SocketAddr = match addr_str.parse() {
166            Ok(a) => a,
167            Err(_) => {
168                // Hostname needs DNS — fall back to default path.
169                return Self::connect_default(url, config).await;
170            }
171        };
172
173        // TCP connect via io_uring.
174        let std_stream = crate::uring_fs::tcp_connect(addr)
175            .await
176            .map_err(CdpError::Io)?;
177
178        // Set non-blocking + Nagle.
179        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
180        if *DISABLE_NAGLE {
181            let _ = std_stream.set_nodelay(true);
182        }
183
184        // Wrap in tokio TcpStream.
185        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
186
187        // WebSocket handshake over the pre-connected stream.
188        let (ws, _) = tokio_tungstenite::client_async_with_config(
189            request,
190            MaybeTlsStream::Plain(tokio_stream),
191            Some(config),
192        )
193        .await?;
194
195        Ok(ws)
196    }
197}
198
199impl<T: EventMessage> Connection<T> {
200    fn next_call_id(&mut self) -> CallId {
201        let id = CallId::new(self.next_id);
202        self.next_id = self.next_id.wrapping_add(1);
203        id
204    }
205
206    /// Queue in the command to send over the socket and return the id for this
207    /// command
208    pub fn submit_command(
209        &mut self,
210        method: MethodId,
211        session_id: Option<SessionId>,
212        params: serde_json::Value,
213    ) -> serde_json::Result<CallId> {
214        let id = self.next_call_id();
215        let call = MethodCall {
216            id,
217            method,
218            session_id: session_id.map(Into::into),
219            params,
220        };
221        self.pending_commands.push_back(call);
222        Ok(id)
223    }
224
225    /// Buffer all queued commands into the WebSocket sink, then flush once.
226    ///
227    /// This batches multiple CDP commands into a single TCP write instead of
228    /// flushing after every individual message.
229    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
230        // Complete any pending flush from a previous poll first.
231        if self.needs_flush {
232            match self.ws.poll_flush_unpin(cx) {
233                Poll::Ready(Ok(())) => self.needs_flush = false,
234                Poll::Ready(Err(e)) => return Err(e.into()),
235                Poll::Pending => return Ok(()),
236            }
237        }
238
239        // Buffer as many queued commands as the sink will accept.
240        let mut sent_any = false;
241        while !self.pending_commands.is_empty() {
242            match self.ws.poll_ready_unpin(cx) {
243                Poll::Ready(Ok(())) => {
244                    let Some(cmd) = self.pending_commands.pop_front() else {
245                        break;
246                    };
247                    tracing::trace!("Sending {:?}", cmd);
248                    let msg = serde_json::to_string(&cmd)?;
249                    self.ws.start_send_unpin(msg.into())?;
250                    sent_any = true;
251                }
252                _ => break,
253            }
254        }
255
256        // Flush the entire batch in one write.
257        if sent_any {
258            match self.ws.poll_flush_unpin(cx) {
259                Poll::Ready(Ok(())) => {}
260                Poll::Ready(Err(e)) => return Err(e.into()),
261                Poll::Pending => self.needs_flush = true,
262            }
263        }
264
265        Ok(())
266    }
267}
268
269/// Capacity of the bounded channel feeding the background WS writer task.
270/// Large enough that bursts of CDP commands never block the handler, small
271/// enough to apply back-pressure before memory grows without bound.
272const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
273
274/// Split parts returned by [`Connection::into_async`].
275#[derive(Debug)]
276pub struct AsyncConnection<T: EventMessage> {
277    /// WebSocket read stream — yields decoded CDP messages.
278    pub reader: WsReader<T>,
279    /// Sender half for submitting outgoing CDP commands.
280    pub cmd_tx: mpsc::Sender<MethodCall>,
281    /// Handle to the background writer task.
282    pub writer_handle: tokio::task::JoinHandle<Result<()>>,
283    /// Next command-call-id counter (continue numbering from where Connection left off).
284    pub next_id: usize,
285}
286
287impl<T: EventMessage + Unpin> Connection<T> {
288    /// Consume the connection and split into an async reader + background writer.
289    ///
290    /// The writer task batches outgoing commands: it `recv()`s the first
291    /// command, then drains all immediately-available commands via
292    /// `try_recv()` before flushing the batch to the WebSocket in one
293    /// write.
294    pub fn into_async(self) -> AsyncConnection<T> {
295        let (ws_sink, ws_stream) = self.ws.split();
296        let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
297
298        let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
299
300        let reader = WsReader {
301            inner: ws_stream,
302            _marker: PhantomData,
303        };
304
305        AsyncConnection {
306            reader,
307            cmd_tx,
308            writer_handle,
309            next_id: self.next_id,
310        }
311    }
312}
313
314/// Background task that batches and flushes outgoing CDP commands.
315async fn ws_write_loop(
316    mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
317    mut rx: mpsc::Receiver<MethodCall>,
318) -> Result<()> {
319    while let Some(call) = rx.recv().await {
320        let msg = crate::serde_json::to_string(&call)?;
321        sink.feed(WsMessage::Text(msg.into()))
322            .await
323            .map_err(CdpError::Ws)?;
324
325        // Batch: drain all buffered commands without waiting.
326        while let Ok(call) = rx.try_recv() {
327            let msg = crate::serde_json::to_string(&call)?;
328            sink.feed(WsMessage::Text(msg.into()))
329                .await
330                .map_err(CdpError::Ws)?;
331        }
332
333        // Flush the entire batch in one write.
334        sink.flush().await.map_err(CdpError::Ws)?;
335    }
336    Ok(())
337}
338
339/// Read half of a split WebSocket connection.
340///
341/// Decodes incoming WS frames into typed CDP messages, skipping pings/pongs
342/// and malformed data frames.
343#[derive(Debug)]
344pub struct WsReader<T: EventMessage> {
345    inner: SplitStream<WebSocketStream<ConnectStream>>,
346    _marker: PhantomData<T>,
347}
348
349impl<T: EventMessage + Unpin> WsReader<T> {
350    /// Read the next CDP message from the WebSocket.
351    ///
352    /// Returns `None` when the connection is closed.
353    pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
354        loop {
355            match self.inner.next().await? {
356                Ok(WsMessage::Text(text)) => {
357                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
358                        Ok(msg) => return Some(Ok(msg)),
359                        Err(err) => {
360                            tracing::debug!(
361                                target: "chromiumoxide::conn::raw_ws::parse_errors",
362                                "Dropping malformed text WS frame: {err}",
363                            );
364                            continue;
365                        }
366                    }
367                }
368                Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
369                    Ok(msg) => return Some(Ok(msg)),
370                    Err(err) => {
371                        tracing::debug!(
372                            target: "chromiumoxide::conn::raw_ws::parse_errors",
373                            "Dropping malformed binary WS frame: {err}",
374                        );
375                        continue;
376                    }
377                },
378                Ok(WsMessage::Close(_)) => return None,
379                Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
380                Ok(msg) => {
381                    tracing::debug!(
382                        target: "chromiumoxide::conn::raw_ws::parse_errors",
383                        "Unexpected WS message type: {:?}",
384                        msg
385                    );
386                    continue;
387                }
388                Err(err) => return Some(Err(CdpError::Ws(err))),
389            }
390        }
391    }
392}
393
394impl<T: EventMessage + Unpin> Stream for Connection<T> {
395    type Item = Result<Box<Message<T>>>;
396
397    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398        let pin = self.get_mut();
399
400        // Send and flush outgoing messages
401        if let Err(err) = pin.start_send_next(cx) {
402            return Poll::Ready(Some(Err(err)));
403        }
404
405        // Read from the websocket, skipping non-data frames (pings,
406        // pongs, malformed messages) without yielding back to the
407        // executor.  This avoids a full round-trip per skipped frame.
408        loop {
409            match ready!(pin.ws.poll_next_unpin(cx)) {
410                Some(Ok(WsMessage::Text(text))) => {
411                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
412                        Ok(msg) => return Poll::Ready(Some(Ok(msg))),
413                        Err(err) => {
414                            tracing::debug!(
415                                target: "chromiumoxide::conn::raw_ws::parse_errors",
416                                "Dropping malformed text WS frame: {err}",
417                            );
418                            continue;
419                        }
420                    }
421                }
422                Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
423                    Ok(msg) => return Poll::Ready(Some(Ok(msg))),
424                    Err(err) => {
425                        tracing::debug!(
426                            target: "chromiumoxide::conn::raw_ws::parse_errors",
427                            "Dropping malformed binary WS frame: {err}",
428                        );
429                        continue;
430                    }
431                },
432                Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
433                // skip ping, pong, and unexpected types without yielding
434                Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => continue,
435                Some(Ok(msg)) => {
436                    tracing::debug!(
437                        target: "chromiumoxide::conn::raw_ws::parse_errors",
438                        "Unexpected WS message type: {:?}",
439                        msg
440                    );
441                    continue;
442                }
443                Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
444                None => return Poll::Ready(None),
445            }
446        }
447    }
448}
449
450/// Shared decode path for both text and binary WS frames.
451/// `raw_text_for_logging` is only provided for textual frames so we can log the original
452/// payload on parse failure if desired.
453#[cfg(not(feature = "serde_stacker"))]
454fn decode_message<T: EventMessage>(
455    bytes: &[u8],
456    raw_text_for_logging: Option<&str>,
457) -> Result<Box<Message<T>>> {
458    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
459        Ok(msg) => {
460            tracing::trace!("Received {:?}", msg);
461            Ok(msg)
462        }
463        Err(err) => {
464            if let Some(txt) = raw_text_for_logging {
465                let preview = &txt[..txt.len().min(512)];
466                tracing::debug!(
467                    target: "chromiumoxide::conn::raw_ws::parse_errors",
468                    msg_len = txt.len(),
469                    "Skipping unrecognized WS message {err} preview={preview}",
470                );
471            } else {
472                tracing::debug!(
473                    target: "chromiumoxide::conn::raw_ws::parse_errors",
474                    "Skipping unrecognized binary WS message {err}",
475                );
476            }
477            Err(err.into())
478        }
479    }
480}
481
482/// Shared decode path for both text and binary WS frames.
483/// `raw_text_for_logging` is only provided for textual frames so we can log the original
484/// payload on parse failure if desired.
485#[cfg(feature = "serde_stacker")]
486fn decode_message<T: EventMessage>(
487    bytes: &[u8],
488    raw_text_for_logging: Option<&str>,
489) -> Result<Box<Message<T>>> {
490    use serde::Deserialize;
491    let mut de = serde_json::Deserializer::from_slice(bytes);
492
493    de.disable_recursion_limit();
494
495    let de = serde_stacker::Deserializer::new(&mut de);
496
497    match Box::<Message<T>>::deserialize(de) {
498        Ok(msg) => {
499            tracing::trace!("Received {:?}", msg);
500            Ok(msg)
501        }
502        Err(err) => {
503            if let Some(txt) = raw_text_for_logging {
504                let preview = &txt[..txt.len().min(512)];
505                tracing::debug!(
506                    target: "chromiumoxide::conn::raw_ws::parse_errors",
507                    msg_len = txt.len(),
508                    "Skipping unrecognized WS message {err} preview={preview}",
509                );
510            } else {
511                tracing::debug!(
512                    target: "chromiumoxide::conn::raw_ws::parse_errors",
513                    "Skipping unrecognized binary WS message {err}",
514                );
515            }
516            Err(err.into())
517        }
518    }
519}