Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22/// Exchanges the messages with the websocket
23#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26    /// Queue of commands to send.
27    pending_commands: VecDeque<MethodCall>,
28    /// The websocket of the chromium instance
29    ws: WebSocketStream<ConnectStream>,
30    /// The identifier for a specific command
31    next_id: usize,
32    /// Whether the write buffer has unsent data that needs flushing.
33    needs_flush: bool,
34    /// The phantom marker.
35    _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39    /// Nagle's algorithm disabled?
40    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41        Ok(disable_nagle) => disable_nagle == "true",
42        _ => true
43    };
44    /// Websocket config defaults
45    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46        Ok(d) => d == "true",
47        _ => false
48    };
49}
50
51/// Default number of WebSocket connection retry attempts.
52pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54/// Initial backoff delay between connection retries (in milliseconds).
55const INITIAL_BACKOFF_MS: u64 = 50;
56
57impl<T: EventMessage + Unpin> Connection<T> {
58    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
59        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
60    }
61
62    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
63        let mut config = WebSocketConfig::default();
64
65        // Cap the internal write buffer so a slow receiver cannot cause
66        // unbounded memory growth (default is usize::MAX).
67        config.max_write_buffer_size = 4 * 1024 * 1024;
68
69        if !*WEBSOCKET_DEFAULTS {
70            config.max_message_size = None;
71            config.max_frame_size = None;
72        }
73
74        let url = debug_ws_url.as_ref();
75        let use_uring = crate::uring_fs::is_enabled();
76        let mut last_err = None;
77
78        for attempt in 0..=retries {
79            let result = if use_uring {
80                Self::connect_uring(url, config).await
81            } else {
82                Self::connect_default(url, config).await
83            };
84
85            match result {
86                Ok(ws) => {
87                    return Ok(Self {
88                        pending_commands: Default::default(),
89                        ws,
90                        next_id: 0,
91                        needs_flush: false,
92                        _marker: Default::default(),
93                    });
94                }
95                Err(e) => {
96                    last_err = Some(e);
97                    if attempt < retries {
98                        let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
99                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
100                    }
101                }
102            }
103        }
104
105        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
106    }
107
108    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
109    async fn connect_default(
110        url: &str,
111        config: WebSocketConfig,
112    ) -> Result<WebSocketStream<ConnectStream>> {
113        let (ws, _) =
114            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
115        Ok(ws)
116    }
117
118    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
119    /// handshake over the pre-connected stream.
120    async fn connect_uring(
121        url: &str,
122        config: WebSocketConfig,
123    ) -> Result<WebSocketStream<ConnectStream>> {
124        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
125
126        let request = url.into_client_request()?;
127        let host = request
128            .uri()
129            .host()
130            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
131        let port = request.uri().port_u16().unwrap_or(9222);
132
133        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
134        let addr_str = format!("{}:{}", host, port);
135        let addr: std::net::SocketAddr = match addr_str.parse() {
136            Ok(a) => a,
137            Err(_) => {
138                // Hostname needs DNS — fall back to default path.
139                return Self::connect_default(url, config).await;
140            }
141        };
142
143        // TCP connect via io_uring.
144        let std_stream = crate::uring_fs::tcp_connect(addr)
145            .await
146            .map_err(CdpError::Io)?;
147
148        // Set non-blocking + Nagle.
149        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
150        if *DISABLE_NAGLE {
151            let _ = std_stream.set_nodelay(true);
152        }
153
154        // Wrap in tokio TcpStream.
155        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
156
157        // WebSocket handshake over the pre-connected stream.
158        let (ws, _) = tokio_tungstenite::client_async_with_config(
159            request,
160            MaybeTlsStream::Plain(tokio_stream),
161            Some(config),
162        )
163        .await?;
164
165        Ok(ws)
166    }
167}
168
169impl<T: EventMessage> Connection<T> {
170    fn next_call_id(&mut self) -> CallId {
171        let id = CallId::new(self.next_id);
172        self.next_id = self.next_id.wrapping_add(1);
173        id
174    }
175
176    /// Queue in the command to send over the socket and return the id for this
177    /// command
178    pub fn submit_command(
179        &mut self,
180        method: MethodId,
181        session_id: Option<SessionId>,
182        params: serde_json::Value,
183    ) -> serde_json::Result<CallId> {
184        let id = self.next_call_id();
185        let call = MethodCall {
186            id,
187            method,
188            session_id: session_id.map(Into::into),
189            params,
190        };
191        self.pending_commands.push_back(call);
192        Ok(id)
193    }
194
195    /// Buffer all queued commands into the WebSocket sink, then flush once.
196    ///
197    /// This batches multiple CDP commands into a single TCP write instead of
198    /// flushing after every individual message.
199    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
200        // Complete any pending flush from a previous poll first.
201        if self.needs_flush {
202            match self.ws.poll_flush_unpin(cx) {
203                Poll::Ready(Ok(())) => self.needs_flush = false,
204                Poll::Ready(Err(e)) => return Err(e.into()),
205                Poll::Pending => return Ok(()),
206            }
207        }
208
209        // Buffer as many queued commands as the sink will accept.
210        let mut sent_any = false;
211        while !self.pending_commands.is_empty() {
212            match self.ws.poll_ready_unpin(cx) {
213                Poll::Ready(Ok(())) => {
214                    let Some(cmd) = self.pending_commands.pop_front() else {
215                        break;
216                    };
217                    tracing::trace!("Sending {:?}", cmd);
218                    let msg = serde_json::to_string(&cmd)?;
219                    self.ws.start_send_unpin(msg.into())?;
220                    sent_any = true;
221                }
222                _ => break,
223            }
224        }
225
226        // Flush the entire batch in one write.
227        if sent_any {
228            match self.ws.poll_flush_unpin(cx) {
229                Poll::Ready(Ok(())) => {}
230                Poll::Ready(Err(e)) => return Err(e.into()),
231                Poll::Pending => self.needs_flush = true,
232            }
233        }
234
235        Ok(())
236    }
237}
238
239/// Capacity of the bounded channel feeding the background WS writer task.
240/// Large enough that bursts of CDP commands never block the handler, small
241/// enough to apply back-pressure before memory grows without bound.
242const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
243
244/// Split parts returned by [`Connection::into_async`].
245#[derive(Debug)]
246pub struct AsyncConnection<T: EventMessage> {
247    /// WebSocket read stream — yields decoded CDP messages.
248    pub reader: WsReader<T>,
249    /// Sender half for submitting outgoing CDP commands.
250    pub cmd_tx: mpsc::Sender<MethodCall>,
251    /// Handle to the background writer task.
252    pub writer_handle: tokio::task::JoinHandle<Result<()>>,
253    /// Next command-call-id counter (continue numbering from where Connection left off).
254    pub next_id: usize,
255}
256
257impl<T: EventMessage + Unpin> Connection<T> {
258    /// Consume the connection and split into an async reader + background writer.
259    ///
260    /// The writer task batches outgoing commands: it `recv()`s the first
261    /// command, then drains all immediately-available commands via
262    /// `try_recv()` before flushing the batch to the WebSocket in one
263    /// write.
264    pub fn into_async(self) -> AsyncConnection<T> {
265        let (ws_sink, ws_stream) = self.ws.split();
266        let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
267
268        let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
269
270        let reader = WsReader {
271            inner: ws_stream,
272            _marker: PhantomData,
273        };
274
275        AsyncConnection {
276            reader,
277            cmd_tx,
278            writer_handle,
279            next_id: self.next_id,
280        }
281    }
282}
283
284/// Background task that batches and flushes outgoing CDP commands.
285async fn ws_write_loop(
286    mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
287    mut rx: mpsc::Receiver<MethodCall>,
288) -> Result<()> {
289    while let Some(call) = rx.recv().await {
290        let msg = crate::serde_json::to_string(&call)?;
291        sink.feed(WsMessage::Text(msg.into()))
292            .await
293            .map_err(CdpError::Ws)?;
294
295        // Batch: drain all buffered commands without waiting.
296        while let Ok(call) = rx.try_recv() {
297            let msg = crate::serde_json::to_string(&call)?;
298            sink.feed(WsMessage::Text(msg.into()))
299                .await
300                .map_err(CdpError::Ws)?;
301        }
302
303        // Flush the entire batch in one write.
304        sink.flush().await.map_err(CdpError::Ws)?;
305    }
306    Ok(())
307}
308
309/// Read half of a split WebSocket connection.
310///
311/// Decodes incoming WS frames into typed CDP messages, skipping pings/pongs
312/// and malformed data frames.
313#[derive(Debug)]
314pub struct WsReader<T: EventMessage> {
315    inner: SplitStream<WebSocketStream<ConnectStream>>,
316    _marker: PhantomData<T>,
317}
318
319impl<T: EventMessage + Unpin> WsReader<T> {
320    /// Read the next CDP message from the WebSocket.
321    ///
322    /// Returns `None` when the connection is closed.
323    pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
324        loop {
325            match self.inner.next().await? {
326                Ok(WsMessage::Text(text)) => {
327                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
328                        Ok(msg) => return Some(Ok(msg)),
329                        Err(err) => {
330                            tracing::debug!(
331                                target: "chromiumoxide::conn::raw_ws::parse_errors",
332                                "Dropping malformed text WS frame: {err}",
333                            );
334                            continue;
335                        }
336                    }
337                }
338                Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
339                    Ok(msg) => return Some(Ok(msg)),
340                    Err(err) => {
341                        tracing::debug!(
342                            target: "chromiumoxide::conn::raw_ws::parse_errors",
343                            "Dropping malformed binary WS frame: {err}",
344                        );
345                        continue;
346                    }
347                },
348                Ok(WsMessage::Close(_)) => return None,
349                Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
350                Ok(msg) => {
351                    tracing::debug!(
352                        target: "chromiumoxide::conn::raw_ws::parse_errors",
353                        "Unexpected WS message type: {:?}",
354                        msg
355                    );
356                    continue;
357                }
358                Err(err) => return Some(Err(CdpError::Ws(err))),
359            }
360        }
361    }
362}
363
364impl<T: EventMessage + Unpin> Stream for Connection<T> {
365    type Item = Result<Box<Message<T>>>;
366
367    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
368        let pin = self.get_mut();
369
370        // Send and flush outgoing messages
371        if let Err(err) = pin.start_send_next(cx) {
372            return Poll::Ready(Some(Err(err)));
373        }
374
375        // Read from the websocket, skipping non-data frames (pings,
376        // pongs, malformed messages) without yielding back to the
377        // executor.  This avoids a full round-trip per skipped frame.
378        loop {
379            match ready!(pin.ws.poll_next_unpin(cx)) {
380                Some(Ok(WsMessage::Text(text))) => {
381                    match decode_message::<T>(text.as_bytes(), Some(&text)) {
382                        Ok(msg) => return Poll::Ready(Some(Ok(msg))),
383                        Err(err) => {
384                            tracing::debug!(
385                                target: "chromiumoxide::conn::raw_ws::parse_errors",
386                                "Dropping malformed text WS frame: {err}",
387                            );
388                            continue;
389                        }
390                    }
391                }
392                Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
393                    Ok(msg) => return Poll::Ready(Some(Ok(msg))),
394                    Err(err) => {
395                        tracing::debug!(
396                            target: "chromiumoxide::conn::raw_ws::parse_errors",
397                            "Dropping malformed binary WS frame: {err}",
398                        );
399                        continue;
400                    }
401                },
402                Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
403                // skip ping, pong, and unexpected types without yielding
404                Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => continue,
405                Some(Ok(msg)) => {
406                    tracing::debug!(
407                        target: "chromiumoxide::conn::raw_ws::parse_errors",
408                        "Unexpected WS message type: {:?}",
409                        msg
410                    );
411                    continue;
412                }
413                Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
414                None => return Poll::Ready(None),
415            }
416        }
417    }
418}
419
420/// Shared decode path for both text and binary WS frames.
421/// `raw_text_for_logging` is only provided for textual frames so we can log the original
422/// payload on parse failure if desired.
423#[cfg(not(feature = "serde_stacker"))]
424fn decode_message<T: EventMessage>(
425    bytes: &[u8],
426    raw_text_for_logging: Option<&str>,
427) -> Result<Box<Message<T>>> {
428    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
429        Ok(msg) => {
430            tracing::trace!("Received {:?}", msg);
431            Ok(msg)
432        }
433        Err(err) => {
434            if let Some(txt) = raw_text_for_logging {
435                let preview = &txt[..txt.len().min(512)];
436                tracing::debug!(
437                    target: "chromiumoxide::conn::raw_ws::parse_errors",
438                    msg_len = txt.len(),
439                    "Skipping unrecognized WS message {err} preview={preview}",
440                );
441            } else {
442                tracing::debug!(
443                    target: "chromiumoxide::conn::raw_ws::parse_errors",
444                    "Skipping unrecognized binary WS message {err}",
445                );
446            }
447            Err(err.into())
448        }
449    }
450}
451
452/// Shared decode path for both text and binary WS frames.
453/// `raw_text_for_logging` is only provided for textual frames so we can log the original
454/// payload on parse failure if desired.
455#[cfg(feature = "serde_stacker")]
456fn decode_message<T: EventMessage>(
457    bytes: &[u8],
458    raw_text_for_logging: Option<&str>,
459) -> Result<Box<Message<T>>> {
460    use serde::Deserialize;
461    let mut de = serde_json::Deserializer::from_slice(bytes);
462
463    de.disable_recursion_limit();
464
465    let de = serde_stacker::Deserializer::new(&mut de);
466
467    match Box::<Message<T>>::deserialize(de) {
468        Ok(msg) => {
469            tracing::trace!("Received {:?}", msg);
470            Ok(msg)
471        }
472        Err(err) => {
473            if let Some(txt) = raw_text_for_logging {
474                let preview = &txt[..txt.len().min(512)];
475                tracing::debug!(
476                    target: "chromiumoxide::conn::raw_ws::parse_errors",
477                    msg_len = txt.len(),
478                    "Skipping unrecognized WS message {err} preview={preview}",
479                );
480            } else {
481                tracing::debug!(
482                    target: "chromiumoxide::conn::raw_ws::parse_errors",
483                    "Skipping unrecognized binary WS message {err}",
484                );
485            }
486            Err(err.into())
487        }
488    }
489}