Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22/// Exchanges the messages with the websocket
23#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26    /// Queue of commands to send.
27    pending_commands: VecDeque<MethodCall>,
28    /// The websocket of the chromium instance
29    ws: WebSocketStream<ConnectStream>,
30    /// The identifier for a specific command
31    next_id: usize,
32    /// Whether the write buffer has unsent data that needs flushing.
33    needs_flush: bool,
34    /// The phantom marker.
35    _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39    /// Nagle's algorithm disabled?
40    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41        Ok(disable_nagle) => disable_nagle == "true",
42        _ => true
43    };
44    /// Websocket config defaults
45    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46        Ok(d) => d == "true",
47        _ => false
48    };
49}
50
51/// Default number of WebSocket connection retry attempts.
52pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54/// Initial backoff delay between connection retries (in milliseconds).
55const INITIAL_BACKOFF_MS: u64 = 50;
56
57/// Maximum backoff delay between connection retries (in milliseconds).
58const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63    }
64
65    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66        let mut config = WebSocketConfig::default();
67
68        // Cap the internal write buffer so a slow receiver cannot cause
69        // unbounded memory growth (default is usize::MAX).
70        config.max_write_buffer_size = 4 * 1024 * 1024;
71
72        if !*WEBSOCKET_DEFAULTS {
73            config.max_message_size = None;
74            config.max_frame_size = None;
75        }
76
77        let url = debug_ws_url.as_ref();
78        let use_uring = crate::uring_fs::is_enabled();
79        let mut last_err = None;
80
81        for attempt in 0..=retries {
82            let result = if use_uring {
83                Self::connect_uring(url, config).await
84            } else {
85                Self::connect_default(url, config).await
86            };
87
88            match result {
89                Ok(ws) => {
90                    return Ok(Self {
91                        pending_commands: Default::default(),
92                        ws,
93                        next_id: 0,
94                        needs_flush: false,
95                        _marker: Default::default(),
96                    });
97                }
98                Err(e) => {
99                    // Detect non-retriable errors early to avoid wasting time
100                    // on connections that will never succeed.
101                    let should_retry = match &e {
102                        // Connection refused — nothing is listening on this port.
103                        CdpError::Io(io_err)
104                            if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105                        {
106                            false
107                        }
108                        // HTTP response to a WebSocket upgrade (e.g. wrong path
109                        // returns 404 / redirect) — retrying the same URL won't help.
110                        CdpError::Ws(tungstenite_err) => !matches!(
111                            tungstenite_err,
112                            tokio_tungstenite::tungstenite::Error::Http(_)
113                                | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
114                        ),
115                        _ => true,
116                    };
117
118                    last_err = Some(e);
119
120                    if !should_retry {
121                        break;
122                    }
123
124                    if attempt < retries {
125                        let backoff_ms =
126                            (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
127                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
128                    }
129                }
130            }
131        }
132
133        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
134    }
135
136    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
137    async fn connect_default(
138        url: &str,
139        config: WebSocketConfig,
140    ) -> Result<WebSocketStream<ConnectStream>> {
141        let (ws, _) =
142            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
143        Ok(ws)
144    }
145
146    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
147    /// handshake over the pre-connected stream.
148    async fn connect_uring(
149        url: &str,
150        config: WebSocketConfig,
151    ) -> Result<WebSocketStream<ConnectStream>> {
152        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
153
154        let request = url.into_client_request()?;
155        let host = request
156            .uri()
157            .host()
158            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
159        let port = request.uri().port_u16().unwrap_or(9222);
160
161        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
162        let addr_str = format!("{}:{}", host, port);
163        let addr: std::net::SocketAddr = match addr_str.parse() {
164            Ok(a) => a,
165            Err(_) => {
166                // Hostname needs DNS — fall back to default path.
167                return Self::connect_default(url, config).await;
168            }
169        };
170
171        // TCP connect via io_uring.
172        let std_stream = crate::uring_fs::tcp_connect(addr)
173            .await
174            .map_err(CdpError::Io)?;
175
176        // Set non-blocking + Nagle.
177        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
178        if *DISABLE_NAGLE {
179            let _ = std_stream.set_nodelay(true);
180        }
181
182        // Wrap in tokio TcpStream.
183        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
184
185        // WebSocket handshake over the pre-connected stream.
186        let (ws, _) = tokio_tungstenite::client_async_with_config(
187            request,
188            MaybeTlsStream::Plain(tokio_stream),
189            Some(config),
190        )
191        .await?;
192
193        Ok(ws)
194    }
195}
196
197impl<T: EventMessage> Connection<T> {
198    fn next_call_id(&mut self) -> CallId {
199        let id = CallId::new(self.next_id);
200        self.next_id = self.next_id.wrapping_add(1);
201        id
202    }
203
204    /// Queue in the command to send over the socket and return the id for this
205    /// command
206    pub fn submit_command(
207        &mut self,
208        method: MethodId,
209        session_id: Option<SessionId>,
210        params: serde_json::Value,
211    ) -> serde_json::Result<CallId> {
212        let id = self.next_call_id();
213        let call = MethodCall {
214            id,
215            method,
216            session_id: session_id.map(Into::into),
217            params,
218        };
219        self.pending_commands.push_back(call);
220        Ok(id)
221    }
222
223    /// Buffer all queued commands into the WebSocket sink, then flush once.
224    ///
225    /// This batches multiple CDP commands into a single TCP write instead of
226    /// flushing after every individual message.
227    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
228        // Complete any pending flush from a previous poll first.
229        if self.needs_flush {
230            match self.ws.poll_flush_unpin(cx) {
231                Poll::Ready(Ok(())) => self.needs_flush = false,
232                Poll::Ready(Err(e)) => return Err(e.into()),
233                Poll::Pending => return Ok(()),
234            }
235        }
236
237        // Buffer as many queued commands as the sink will accept.
238        let mut sent_any = false;
239        while !self.pending_commands.is_empty() {
240            match self.ws.poll_ready_unpin(cx) {
241                Poll::Ready(Ok(())) => {
242                    let Some(cmd) = self.pending_commands.pop_front() else {
243                        break;
244                    };
245                    tracing::trace!("Sending {:?}", cmd);
246                    let msg = serde_json::to_string(&cmd)?;
247                    self.ws.start_send_unpin(msg.into())?;
248                    sent_any = true;
249                }
250                _ => break,
251            }
252        }
253
254        // Flush the entire batch in one write.
255        if sent_any {
256            match self.ws.poll_flush_unpin(cx) {
257                Poll::Ready(Ok(())) => {}
258                Poll::Ready(Err(e)) => return Err(e.into()),
259                Poll::Pending => self.needs_flush = true,
260            }
261        }
262
263        Ok(())
264    }
265}
266
267/// Capacity of the bounded channel feeding the background WS writer task.
268/// Large enough that bursts of CDP commands never block the handler, small
269/// enough to apply back-pressure before memory grows without bound.
270const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
271
272/// Split parts returned by [`Connection::into_async`].
273#[derive(Debug)]
274pub struct AsyncConnection<T: EventMessage> {
275    /// WebSocket read stream — yields decoded CDP messages.
276    pub reader: WsReader<T>,
277    /// Sender half for submitting outgoing CDP commands.
278    pub cmd_tx: mpsc::Sender<MethodCall>,
279    /// Handle to the background writer task.
280    pub writer_handle: tokio::task::JoinHandle<Result<()>>,
281    /// Next command-call-id counter (continue numbering from where Connection left off).
282    pub next_id: usize,
283}
284
285impl<T: EventMessage + Unpin> Connection<T> {
286    /// Consume the connection and split into an async reader + background writer.
287    ///
288    /// The writer task batches outgoing commands: it `recv()`s the first
289    /// command, then drains all immediately-available commands via
290    /// `try_recv()` before flushing the batch to the WebSocket in one
291    /// write.
292    pub fn into_async(self) -> AsyncConnection<T> {
293        let (ws_sink, ws_stream) = self.ws.split();
294        let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
295
296        let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
297
298        let reader = WsReader {
299            inner: ws_stream,
300            _marker: PhantomData,
301        };
302
303        AsyncConnection {
304            reader,
305            cmd_tx,
306            writer_handle,
307            next_id: self.next_id,
308        }
309    }
310}
311
312/// Background task that batches and flushes outgoing CDP commands.
313async fn ws_write_loop(
314    mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
315    mut rx: mpsc::Receiver<MethodCall>,
316) -> Result<()> {
317    while let Some(call) = rx.recv().await {
318        let msg = crate::serde_json::to_string(&call)?;
319        sink.feed(WsMessage::Text(msg.into()))
320            .await
321            .map_err(CdpError::Ws)?;
322
323        // Batch: drain all buffered commands without waiting.
324        while let Ok(call) = rx.try_recv() {
325            let msg = crate::serde_json::to_string(&call)?;
326            sink.feed(WsMessage::Text(msg.into()))
327                .await
328                .map_err(CdpError::Ws)?;
329        }
330
331        // Flush the entire batch in one write.
332        sink.flush().await.map_err(CdpError::Ws)?;
333    }
334    Ok(())
335}
336
337/// Read half of a split WebSocket connection.
338///
339/// Decodes incoming WS frames into typed CDP messages, skipping pings/pongs
340/// and malformed data frames.
341#[derive(Debug)]
342pub struct WsReader<T: EventMessage> {
343    inner: SplitStream<WebSocketStream<ConnectStream>>,
344    _marker: PhantomData<T>,
345}
346
347impl<T: EventMessage + Unpin> WsReader<T> {
348    /// Read the next CDP message from the WebSocket.
349    ///
350    /// Returns `None` when the connection is closed.
351    pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
352        loop {
353            match self.inner.next().await? {
354                Ok(WsMessage::Text(text)) => {
355                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
356                        Ok(msg) => return Some(Ok(msg)),
357                        Err(err) => {
358                            tracing::debug!(
359                                target: "chromiumoxide::conn::raw_ws::parse_errors",
360                                "Dropping malformed text WS frame: {err}",
361                            );
362                            continue;
363                        }
364                    }
365                }
366                Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
367                    Ok(msg) => return Some(Ok(msg)),
368                    Err(err) => {
369                        tracing::debug!(
370                            target: "chromiumoxide::conn::raw_ws::parse_errors",
371                            "Dropping malformed binary WS frame: {err}",
372                        );
373                        continue;
374                    }
375                },
376                Ok(WsMessage::Close(_)) => return None,
377                Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
378                Ok(msg) => {
379                    tracing::debug!(
380                        target: "chromiumoxide::conn::raw_ws::parse_errors",
381                        "Unexpected WS message type: {:?}",
382                        msg
383                    );
384                    continue;
385                }
386                Err(err) => return Some(Err(CdpError::Ws(err))),
387            }
388        }
389    }
390}
391
392impl<T: EventMessage + Unpin> Stream for Connection<T> {
393    type Item = Result<Box<Message<T>>>;
394
395    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
396        let pin = self.get_mut();
397
398        // Send and flush outgoing messages
399        if let Err(err) = pin.start_send_next(cx) {
400            return Poll::Ready(Some(Err(err)));
401        }
402
403        // Read from the websocket, skipping non-data frames (pings,
404        // pongs, malformed messages) without yielding back to the
405        // executor.  This avoids a full round-trip per skipped frame.
406        loop {
407            match ready!(pin.ws.poll_next_unpin(cx)) {
408                Some(Ok(WsMessage::Text(text))) => {
409                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
410                        Ok(msg) => return Poll::Ready(Some(Ok(msg))),
411                        Err(err) => {
412                            tracing::debug!(
413                                target: "chromiumoxide::conn::raw_ws::parse_errors",
414                                "Dropping malformed text WS frame: {err}",
415                            );
416                            continue;
417                        }
418                    }
419                }
420                Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
421                    Ok(msg) => return Poll::Ready(Some(Ok(msg))),
422                    Err(err) => {
423                        tracing::debug!(
424                            target: "chromiumoxide::conn::raw_ws::parse_errors",
425                            "Dropping malformed binary WS frame: {err}",
426                        );
427                        continue;
428                    }
429                },
430                Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
431                // skip ping, pong, and unexpected types without yielding
432                Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => continue,
433                Some(Ok(msg)) => {
434                    tracing::debug!(
435                        target: "chromiumoxide::conn::raw_ws::parse_errors",
436                        "Unexpected WS message type: {:?}",
437                        msg
438                    );
439                    continue;
440                }
441                Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
442                None => return Poll::Ready(None),
443            }
444        }
445    }
446}
447
448/// Shared decode path for both text and binary WS frames.
449/// `raw_text_for_logging` is only provided for textual frames so we can log the original
450/// payload on parse failure if desired.
451#[cfg(not(feature = "serde_stacker"))]
452fn decode_message<T: EventMessage>(
453    bytes: &[u8],
454    raw_text_for_logging: Option<&str>,
455) -> Result<Box<Message<T>>> {
456    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
457        Ok(msg) => {
458            tracing::trace!("Received {:?}", msg);
459            Ok(msg)
460        }
461        Err(err) => {
462            if let Some(txt) = raw_text_for_logging {
463                let preview = &txt[..txt.len().min(512)];
464                tracing::debug!(
465                    target: "chromiumoxide::conn::raw_ws::parse_errors",
466                    msg_len = txt.len(),
467                    "Skipping unrecognized WS message {err} preview={preview}",
468                );
469            } else {
470                tracing::debug!(
471                    target: "chromiumoxide::conn::raw_ws::parse_errors",
472                    "Skipping unrecognized binary WS message {err}",
473                );
474            }
475            Err(err.into())
476        }
477    }
478}
479
480/// Shared decode path for both text and binary WS frames.
481/// `raw_text_for_logging` is only provided for textual frames so we can log the original
482/// payload on parse failure if desired.
483#[cfg(feature = "serde_stacker")]
484fn decode_message<T: EventMessage>(
485    bytes: &[u8],
486    raw_text_for_logging: Option<&str>,
487) -> Result<Box<Message<T>>> {
488    use serde::Deserialize;
489    let mut de = serde_json::Deserializer::from_slice(bytes);
490
491    de.disable_recursion_limit();
492
493    let de = serde_stacker::Deserializer::new(&mut de);
494
495    match Box::<Message<T>>::deserialize(de) {
496        Ok(msg) => {
497            tracing::trace!("Received {:?}", msg);
498            Ok(msg)
499        }
500        Err(err) => {
501            if let Some(txt) = raw_text_for_logging {
502                let preview = &txt[..txt.len().min(512)];
503                tracing::debug!(
504                    target: "chromiumoxide::conn::raw_ws::parse_errors",
505                    msg_len = txt.len(),
506                    "Skipping unrecognized WS message {err} preview={preview}",
507                );
508            } else {
509                tracing::debug!(
510                    target: "chromiumoxide::conn::raw_ws::parse_errors",
511                    "Skipping unrecognized binary WS message {err}",
512                );
513            }
514            Err(err.into())
515        }
516    }
517}