Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::task::{Context, Poll};
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tokio_tungstenite::MaybeTlsStream;
10use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
11
12use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
13use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
14
15use crate::error::CdpError;
16use crate::error::Result;
17
18type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
19
20/// Exchanges the messages with the websocket
21#[must_use = "streams do nothing unless polled"]
22#[derive(Debug)]
23pub struct Connection<T: EventMessage> {
24    /// Queue of commands to send.
25    pending_commands: VecDeque<MethodCall>,
26    /// The websocket of the chromium instance
27    ws: WebSocketStream<ConnectStream>,
28    /// The identifier for a specific command
29    next_id: usize,
30    /// Whether the write buffer has unsent data that needs flushing.
31    needs_flush: bool,
32    /// The phantom marker.
33    _marker: PhantomData<T>,
34}
35
36lazy_static::lazy_static! {
37    /// Nagle's algorithm disabled?
38    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
39        Ok(disable_nagle) => disable_nagle == "true",
40        _ => true
41    };
42    /// Websocket config defaults
43    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
44        Ok(d) => d == "true",
45        _ => false
46    };
47}
48
49/// Default number of WebSocket connection retry attempts.
50pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
51
52/// Initial backoff delay between connection retries (in milliseconds).
53const INITIAL_BACKOFF_MS: u64 = 50;
54
55impl<T: EventMessage + Unpin> Connection<T> {
56    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
57        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
58    }
59
60    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
61        let mut config = WebSocketConfig::default();
62
63        if !*WEBSOCKET_DEFAULTS {
64            config.max_message_size = None;
65            config.max_frame_size = None;
66        }
67
68        let url = debug_ws_url.as_ref();
69        let use_uring = crate::uring_fs::is_enabled();
70        let mut last_err = None;
71
72        for attempt in 0..=retries {
73            let result = if use_uring {
74                Self::connect_uring(url, config).await
75            } else {
76                Self::connect_default(url, config).await
77            };
78
79            match result {
80                Ok(ws) => {
81                    return Ok(Self {
82                        pending_commands: Default::default(),
83                        ws,
84                        next_id: 0,
85                        needs_flush: false,
86                        _marker: Default::default(),
87                    });
88                }
89                Err(e) => {
90                    last_err = Some(e);
91                    if attempt < retries {
92                        let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
93                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
94                    }
95                }
96            }
97        }
98
99        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
100    }
101
102    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
103    async fn connect_default(
104        url: &str,
105        config: WebSocketConfig,
106    ) -> Result<WebSocketStream<ConnectStream>> {
107        let (ws, _) =
108            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
109        Ok(ws)
110    }
111
112    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
113    /// handshake over the pre-connected stream.
114    async fn connect_uring(
115        url: &str,
116        config: WebSocketConfig,
117    ) -> Result<WebSocketStream<ConnectStream>> {
118        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
119
120        let request = url.into_client_request()?;
121        let host = request
122            .uri()
123            .host()
124            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
125        let port = request.uri().port_u16().unwrap_or(9222);
126
127        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
128        let addr_str = format!("{}:{}", host, port);
129        let addr: std::net::SocketAddr = match addr_str.parse() {
130            Ok(a) => a,
131            Err(_) => {
132                // Hostname needs DNS — fall back to default path.
133                return Self::connect_default(url, config).await;
134            }
135        };
136
137        // TCP connect via io_uring.
138        let std_stream = crate::uring_fs::tcp_connect(addr)
139            .await
140            .map_err(CdpError::Io)?;
141
142        // Set non-blocking + Nagle.
143        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
144        if *DISABLE_NAGLE {
145            let _ = std_stream.set_nodelay(true);
146        }
147
148        // Wrap in tokio TcpStream.
149        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
150
151        // WebSocket handshake over the pre-connected stream.
152        let (ws, _) = tokio_tungstenite::client_async_with_config(
153            request,
154            MaybeTlsStream::Plain(tokio_stream),
155            Some(config),
156        )
157        .await?;
158
159        Ok(ws)
160    }
161}
162
163impl<T: EventMessage> Connection<T> {
164    fn next_call_id(&mut self) -> CallId {
165        let id = CallId::new(self.next_id);
166        self.next_id = self.next_id.wrapping_add(1);
167        id
168    }
169
170    /// Queue in the command to send over the socket and return the id for this
171    /// command
172    pub fn submit_command(
173        &mut self,
174        method: MethodId,
175        session_id: Option<SessionId>,
176        params: serde_json::Value,
177    ) -> serde_json::Result<CallId> {
178        let id = self.next_call_id();
179        let call = MethodCall {
180            id,
181            method,
182            session_id: session_id.map(Into::into),
183            params,
184        };
185        self.pending_commands.push_back(call);
186        Ok(id)
187    }
188
189    /// Buffer all queued commands into the WebSocket sink, then flush once.
190    ///
191    /// This batches multiple CDP commands into a single TCP write instead of
192    /// flushing after every individual message.
193    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
194        // Complete any pending flush from a previous poll first.
195        if self.needs_flush {
196            match self.ws.poll_flush_unpin(cx) {
197                Poll::Ready(Ok(())) => self.needs_flush = false,
198                Poll::Ready(Err(e)) => return Err(e.into()),
199                Poll::Pending => return Ok(()),
200            }
201        }
202
203        // Buffer as many queued commands as the sink will accept.
204        let mut sent_any = false;
205        while !self.pending_commands.is_empty() {
206            match self.ws.poll_ready_unpin(cx) {
207                Poll::Ready(Ok(())) => {
208                    let Some(cmd) = self.pending_commands.pop_front() else {
209                        break;
210                    };
211                    tracing::trace!("Sending {:?}", cmd);
212                    let msg = serde_json::to_string(&cmd)?;
213                    self.ws.start_send_unpin(msg.into())?;
214                    sent_any = true;
215                }
216                _ => break,
217            }
218        }
219
220        // Flush the entire batch in one write.
221        if sent_any {
222            match self.ws.poll_flush_unpin(cx) {
223                Poll::Ready(Ok(())) => {}
224                Poll::Ready(Err(e)) => return Err(e.into()),
225                Poll::Pending => self.needs_flush = true,
226            }
227        }
228
229        Ok(())
230    }
231}
232
233impl<T: EventMessage + Unpin> Stream for Connection<T> {
234    type Item = Result<Box<Message<T>>>;
235
236    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237        let pin = self.get_mut();
238
239        // Send and flush outgoing messages
240        if let Err(err) = pin.start_send_next(cx) {
241            return Poll::Ready(Some(Err(err)));
242        }
243
244        // read from the websocket
245        match ready!(pin.ws.poll_next_unpin(cx)) {
246            Some(Ok(WsMessage::Text(text))) => {
247                match decode_message::<T>(text.as_bytes(), Some(&text)) {
248                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
249                    Err(err) => {
250                        tracing::debug!(
251                            target: "chromiumoxide::conn::raw_ws::parse_errors",
252                            "Dropping malformed text WS frame: {err}",
253                        );
254                        cx.waker().wake_by_ref();
255                        Poll::Pending
256                    }
257                }
258            }
259            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
260                Ok(msg) => Poll::Ready(Some(Ok(msg))),
261                Err(err) => {
262                    tracing::debug!(
263                        target: "chromiumoxide::conn::raw_ws::parse_errors",
264                        "Dropping malformed binary WS frame: {err}",
265                    );
266                    cx.waker().wake_by_ref();
267                    Poll::Pending
268                }
269            },
270            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
271            // ignore ping and pong
272            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
273                cx.waker().wake_by_ref();
274                Poll::Pending
275            }
276            Some(Ok(msg)) => {
277                // Unexpected WS message type, but not fatal.
278                tracing::debug!(
279                    target: "chromiumoxide::conn::raw_ws::parse_errors",
280                    "Unexpected WS message type: {:?}",
281                    msg
282                );
283                cx.waker().wake_by_ref();
284                Poll::Pending
285            }
286            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
287            None => {
288                // ws connection closed
289                Poll::Ready(None)
290            }
291        }
292    }
293}
294
295/// Shared decode path for both text and binary WS frames.
296/// `raw_text_for_logging` is only provided for textual frames so we can log the original
297/// payload on parse failure if desired.
298#[cfg(not(feature = "serde_stacker"))]
299fn decode_message<T: EventMessage>(
300    bytes: &[u8],
301    raw_text_for_logging: Option<&str>,
302) -> Result<Box<Message<T>>> {
303    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
304        Ok(msg) => {
305            tracing::trace!("Received {:?}", msg);
306            Ok(msg)
307        }
308        Err(err) => {
309            if let Some(txt) = raw_text_for_logging {
310                let preview = &txt[..txt.len().min(512)];
311                tracing::debug!(
312                    target: "chromiumoxide::conn::raw_ws::parse_errors",
313                    msg_len = txt.len(),
314                    "Skipping unrecognized WS message {err} preview={preview}",
315                );
316            } else {
317                tracing::debug!(
318                    target: "chromiumoxide::conn::raw_ws::parse_errors",
319                    "Skipping unrecognized binary WS message {err}",
320                );
321            }
322            Err(err.into())
323        }
324    }
325}
326
327/// Shared decode path for both text and binary WS frames.
328/// `raw_text_for_logging` is only provided for textual frames so we can log the original
329/// payload on parse failure if desired.
330#[cfg(feature = "serde_stacker")]
331fn decode_message<T: EventMessage>(
332    bytes: &[u8],
333    raw_text_for_logging: Option<&str>,
334) -> Result<Box<Message<T>>> {
335    use serde::Deserialize;
336    let mut de = serde_json::Deserializer::from_slice(bytes);
337
338    de.disable_recursion_limit();
339
340    let de = serde_stacker::Deserializer::new(&mut de);
341
342    match Box::<Message<T>>::deserialize(de) {
343        Ok(msg) => {
344            tracing::trace!("Received {:?}", msg);
345            Ok(msg)
346        }
347        Err(err) => {
348            if let Some(txt) = raw_text_for_logging {
349                let preview = &txt[..txt.len().min(512)];
350                tracing::debug!(
351                    target: "chromiumoxide::conn::raw_ws::parse_errors",
352                    msg_len = txt.len(),
353                    "Skipping unrecognized WS message {err} preview={preview}",
354                );
355            } else {
356                tracing::debug!(
357                    target: "chromiumoxide::conn::raw_ws::parse_errors",
358                    "Skipping unrecognized binary WS message {err}",
359                );
360            }
361            Err(err.into())
362        }
363    }
364}