Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::task::{Context, Poll};
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tokio_tungstenite::MaybeTlsStream;
10use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
11
12use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
13use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
14
15use crate::error::CdpError;
16use crate::error::Result;
17
18type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
19
20/// Exchanges the messages with the websocket
21#[must_use = "streams do nothing unless polled"]
22#[derive(Debug)]
23pub struct Connection<T: EventMessage> {
24    /// Queue of commands to send.
25    pending_commands: VecDeque<MethodCall>,
26    /// The websocket of the chromium instance
27    ws: WebSocketStream<ConnectStream>,
28    /// The identifier for a specific command
29    next_id: usize,
30    /// Whether the write buffer has unsent data that needs flushing.
31    needs_flush: bool,
32    /// The phantom marker.
33    _marker: PhantomData<T>,
34}
35
36lazy_static::lazy_static! {
37    /// Nagle's algorithm disabled?
38    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
39        Ok(disable_nagle) => disable_nagle == "true",
40        _ => true
41    };
42    /// Websocket config defaults
43    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
44        Ok(d) => d == "true",
45        _ => false
46    };
47}
48
49/// Default number of WebSocket connection retry attempts.
50pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
51
52/// Initial backoff delay between connection retries (in milliseconds).
53const INITIAL_BACKOFF_MS: u64 = 50;
54
55impl<T: EventMessage + Unpin> Connection<T> {
56    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
57        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
58    }
59
60    pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
61        let mut config = WebSocketConfig::default();
62
63        if !*WEBSOCKET_DEFAULTS {
64            config.max_message_size = None;
65            config.max_frame_size = None;
66        }
67
68        let url = debug_ws_url.as_ref();
69        let use_uring = crate::uring_fs::is_enabled();
70        let mut last_err = None;
71
72        for attempt in 0..=retries {
73            let result = if use_uring {
74                Self::connect_uring(url, config).await
75            } else {
76                Self::connect_default(url, config).await
77            };
78
79            match result {
80                Ok(ws) => {
81                    return Ok(Self {
82                        pending_commands: Default::default(),
83                        ws,
84                        next_id: 0,
85                        needs_flush: false,
86                        _marker: Default::default(),
87                    });
88                }
89                Err(e) => {
90                    last_err = Some(e);
91                    if attempt < retries {
92                        let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
93                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
94                    }
95                }
96            }
97        }
98
99        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
100    }
101
102    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
103    async fn connect_default(
104        url: &str,
105        config: WebSocketConfig,
106    ) -> Result<WebSocketStream<ConnectStream>> {
107        let (ws, _) =
108            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
109        Ok(ws)
110    }
111
112    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
113    /// handshake over the pre-connected stream.
114    async fn connect_uring(
115        url: &str,
116        config: WebSocketConfig,
117    ) -> Result<WebSocketStream<ConnectStream>> {
118        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
119
120        let request = url.into_client_request()?;
121        let host = request
122            .uri()
123            .host()
124            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
125        let port = request.uri().port_u16().unwrap_or(9222);
126
127        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
128        let addr_str = format!("{}:{}", host, port);
129        let addr: std::net::SocketAddr = match addr_str.parse() {
130            Ok(a) => a,
131            Err(_) => {
132                // Hostname needs DNS — fall back to default path.
133                return Self::connect_default(url, config).await;
134            }
135        };
136
137        // TCP connect via io_uring.
138        let std_stream = crate::uring_fs::tcp_connect(addr)
139            .await
140            .map_err(CdpError::Io)?;
141
142        // Set non-blocking + Nagle.
143        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
144        if *DISABLE_NAGLE {
145            let _ = std_stream.set_nodelay(true);
146        }
147
148        // Wrap in tokio TcpStream.
149        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
150
151        // WebSocket handshake over the pre-connected stream.
152        let (ws, _) = tokio_tungstenite::client_async_with_config(
153            request,
154            MaybeTlsStream::Plain(tokio_stream),
155            Some(config),
156        )
157        .await?;
158
159        Ok(ws)
160    }
161}
162
163impl<T: EventMessage> Connection<T> {
164    fn next_call_id(&mut self) -> CallId {
165        let id = CallId::new(self.next_id);
166        self.next_id = self.next_id.wrapping_add(1);
167        id
168    }
169
170    /// Queue in the command to send over the socket and return the id for this
171    /// command
172    pub fn submit_command(
173        &mut self,
174        method: MethodId,
175        session_id: Option<SessionId>,
176        params: serde_json::Value,
177    ) -> serde_json::Result<CallId> {
178        let id = self.next_call_id();
179        let call = MethodCall {
180            id,
181            method,
182            session_id: session_id.map(Into::into),
183            params,
184        };
185        self.pending_commands.push_back(call);
186        Ok(id)
187    }
188
189    /// Buffer all queued commands into the WebSocket sink, then flush once.
190    ///
191    /// This batches multiple CDP commands into a single TCP write instead of
192    /// flushing after every individual message.
193    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
194        // Complete any pending flush from a previous poll first.
195        if self.needs_flush {
196            match self.ws.poll_flush_unpin(cx) {
197                Poll::Ready(Ok(())) => self.needs_flush = false,
198                Poll::Ready(Err(e)) => return Err(e.into()),
199                Poll::Pending => return Ok(()),
200            }
201        }
202
203        // Buffer as many queued commands as the sink will accept.
204        let mut sent_any = false;
205        while !self.pending_commands.is_empty() {
206            match self.ws.poll_ready_unpin(cx) {
207                Poll::Ready(Ok(())) => {
208                    let cmd = self.pending_commands.pop_front().unwrap();
209                    tracing::trace!("Sending {:?}", cmd);
210                    let msg = serde_json::to_string(&cmd)?;
211                    self.ws.start_send_unpin(msg.into())?;
212                    sent_any = true;
213                }
214                _ => break,
215            }
216        }
217
218        // Flush the entire batch in one write.
219        if sent_any {
220            match self.ws.poll_flush_unpin(cx) {
221                Poll::Ready(Ok(())) => {}
222                Poll::Ready(Err(e)) => return Err(e.into()),
223                Poll::Pending => self.needs_flush = true,
224            }
225        }
226
227        Ok(())
228    }
229}
230
231impl<T: EventMessage + Unpin> Stream for Connection<T> {
232    type Item = Result<Box<Message<T>>>;
233
234    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
235        let pin = self.get_mut();
236
237        // Send and flush outgoing messages
238        if let Err(err) = pin.start_send_next(cx) {
239            return Poll::Ready(Some(Err(err)));
240        }
241
242        // read from the websocket
243        match ready!(pin.ws.poll_next_unpin(cx)) {
244            Some(Ok(WsMessage::Text(text))) => {
245                match decode_message::<T>(text.as_bytes(), Some(&text)) {
246                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
247                    Err(err) => {
248                        tracing::debug!(
249                            target: "chromiumoxide::conn::raw_ws::parse_errors",
250                            "Dropping malformed text WS frame: {err}",
251                        );
252                        cx.waker().wake_by_ref();
253                        Poll::Pending
254                    }
255                }
256            }
257            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
258                Ok(msg) => Poll::Ready(Some(Ok(msg))),
259                Err(err) => {
260                    tracing::debug!(
261                        target: "chromiumoxide::conn::raw_ws::parse_errors",
262                        "Dropping malformed binary WS frame: {err}",
263                    );
264                    cx.waker().wake_by_ref();
265                    Poll::Pending
266                }
267            },
268            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
269            // ignore ping and pong
270            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
271                cx.waker().wake_by_ref();
272                Poll::Pending
273            }
274            Some(Ok(msg)) => {
275                // Unexpected WS message type, but not fatal.
276                tracing::debug!(
277                    target: "chromiumoxide::conn::raw_ws::parse_errors",
278                    "Unexpected WS message type: {:?}",
279                    msg
280                );
281                cx.waker().wake_by_ref();
282                Poll::Pending
283            }
284            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
285            None => {
286                // ws connection closed
287                Poll::Ready(None)
288            }
289        }
290    }
291}
292
293/// Shared decode path for both text and binary WS frames.
294/// `raw_text_for_logging` is only provided for textual frames so we can log the original
295/// payload on parse failure if desired.
296#[cfg(not(feature = "serde_stacker"))]
297fn decode_message<T: EventMessage>(
298    bytes: &[u8],
299    raw_text_for_logging: Option<&str>,
300) -> Result<Box<Message<T>>> {
301    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
302        Ok(msg) => {
303            tracing::trace!("Received {:?}", msg);
304            Ok(msg)
305        }
306        Err(err) => {
307            if let Some(txt) = raw_text_for_logging {
308                let preview = &txt[..txt.len().min(512)];
309                tracing::debug!(
310                    target: "chromiumoxide::conn::raw_ws::parse_errors",
311                    msg_len = txt.len(),
312                    "Skipping unrecognized WS message {err} preview={preview}",
313                );
314            } else {
315                tracing::debug!(
316                    target: "chromiumoxide::conn::raw_ws::parse_errors",
317                    "Skipping unrecognized binary WS message {err}",
318                );
319            }
320            Err(err.into())
321        }
322    }
323}
324
325/// Shared decode path for both text and binary WS frames.
326/// `raw_text_for_logging` is only provided for textual frames so we can log the original
327/// payload on parse failure if desired.
328#[cfg(feature = "serde_stacker")]
329fn decode_message<T: EventMessage>(
330    bytes: &[u8],
331    raw_text_for_logging: Option<&str>,
332) -> Result<Box<Message<T>>> {
333    use serde::Deserialize;
334    let mut de = serde_json::Deserializer::from_slice(bytes);
335
336    de.disable_recursion_limit();
337
338    let de = serde_stacker::Deserializer::new(&mut de);
339
340    match Box::<Message<T>>::deserialize(de) {
341        Ok(msg) => {
342            tracing::trace!("Received {:?}", msg);
343            Ok(msg)
344        }
345        Err(err) => {
346            if let Some(txt) = raw_text_for_logging {
347                let preview = &txt[..txt.len().min(512)];
348                tracing::debug!(
349                    target: "chromiumoxide::conn::raw_ws::parse_errors",
350                    msg_len = txt.len(),
351                    "Skipping unrecognized WS message {err} preview={preview}",
352                );
353            } else {
354                tracing::debug!(
355                    target: "chromiumoxide::conn::raw_ws::parse_errors",
356                    "Skipping unrecognized binary WS message {err}",
357                );
358            }
359            Err(err.into())
360        }
361    }
362}