Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::task::{Context, Poll};
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tokio_tungstenite::MaybeTlsStream;
10use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
11
12use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
13use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
14
15use crate::error::CdpError;
16use crate::error::Result;
17
18type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
19
20/// Exchanges the messages with the websocket
21#[must_use = "streams do nothing unless polled"]
22#[derive(Debug)]
23pub struct Connection<T: EventMessage> {
24    /// Queue of commands to send.
25    pending_commands: VecDeque<MethodCall>,
26    /// The websocket of the chromium instance
27    ws: WebSocketStream<ConnectStream>,
28    /// The identifier for a specific command
29    next_id: usize,
30    /// Whether the write buffer has unsent data that needs flushing.
31    needs_flush: bool,
32    /// The phantom marker.
33    _marker: PhantomData<T>,
34}
35
36lazy_static::lazy_static! {
37    /// Nagle's algorithm disabled?
38    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
39        Ok(disable_nagle) => disable_nagle == "true",
40        _ => true
41    };
42    /// Websocket config defaults
43    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
44        Ok(d) => d == "true",
45        _ => false
46    };
47}
48
49/// Default number of WebSocket connection retry attempts.
50pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
51
52/// Initial backoff delay between connection retries (in milliseconds).
53const INITIAL_BACKOFF_MS: u64 = 50;
54
55impl<T: EventMessage + Unpin> Connection<T> {
56    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
57        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
58    }
59
60    pub async fn connect_with_retries(
61        debug_ws_url: impl AsRef<str>,
62        retries: u32,
63    ) -> Result<Self> {
64        let mut config = WebSocketConfig::default();
65
66        if !*WEBSOCKET_DEFAULTS {
67            config.max_message_size = None;
68            config.max_frame_size = None;
69        }
70
71        let url = debug_ws_url.as_ref();
72        let use_uring = crate::uring_fs::is_enabled();
73        let mut last_err = None;
74
75        for attempt in 0..=retries {
76            let result = if use_uring {
77                Self::connect_uring(url, config).await
78            } else {
79                Self::connect_default(url, config).await
80            };
81
82            match result {
83                Ok(ws) => {
84                    return Ok(Self {
85                        pending_commands: Default::default(),
86                        ws,
87                        next_id: 0,
88                        needs_flush: false,
89                        _marker: Default::default(),
90                    });
91                }
92                Err(e) => {
93                    last_err = Some(e);
94                    if attempt < retries {
95                        let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
96                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
97                    }
98                }
99            }
100        }
101
102        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
103    }
104
105    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
106    async fn connect_default(
107        url: &str,
108        config: WebSocketConfig,
109    ) -> Result<WebSocketStream<ConnectStream>> {
110        let (ws, _) =
111            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
112        Ok(ws)
113    }
114
115    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
116    /// handshake over the pre-connected stream.
117    async fn connect_uring(
118        url: &str,
119        config: WebSocketConfig,
120    ) -> Result<WebSocketStream<ConnectStream>> {
121        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
122
123        let request = url.into_client_request()?;
124        let host = request
125            .uri()
126            .host()
127            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
128        let port = request.uri().port_u16().unwrap_or(9222);
129
130        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
131        let addr_str = format!("{}:{}", host, port);
132        let addr: std::net::SocketAddr = match addr_str.parse() {
133            Ok(a) => a,
134            Err(_) => {
135                // Hostname needs DNS — fall back to default path.
136                return Self::connect_default(url, config).await;
137            }
138        };
139
140        // TCP connect via io_uring.
141        let std_stream = crate::uring_fs::tcp_connect(addr)
142            .await
143            .map_err(CdpError::Io)?;
144
145        // Set non-blocking + Nagle.
146        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
147        if *DISABLE_NAGLE {
148            let _ = std_stream.set_nodelay(true);
149        }
150
151        // Wrap in tokio TcpStream.
152        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
153
154        // WebSocket handshake over the pre-connected stream.
155        let (ws, _) = tokio_tungstenite::client_async_with_config(
156            request,
157            MaybeTlsStream::Plain(tokio_stream),
158            Some(config),
159        )
160        .await?;
161
162        Ok(ws)
163    }
164}
165
166impl<T: EventMessage> Connection<T> {
167    fn next_call_id(&mut self) -> CallId {
168        let id = CallId::new(self.next_id);
169        self.next_id = self.next_id.wrapping_add(1);
170        id
171    }
172
173    /// Queue in the command to send over the socket and return the id for this
174    /// command
175    pub fn submit_command(
176        &mut self,
177        method: MethodId,
178        session_id: Option<SessionId>,
179        params: serde_json::Value,
180    ) -> serde_json::Result<CallId> {
181        let id = self.next_call_id();
182        let call = MethodCall {
183            id,
184            method,
185            session_id: session_id.map(Into::into),
186            params,
187        };
188        self.pending_commands.push_back(call);
189        Ok(id)
190    }
191
192    /// Buffer all queued commands into the WebSocket sink, then flush once.
193    ///
194    /// This batches multiple CDP commands into a single TCP write instead of
195    /// flushing after every individual message.
196    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
197        // Complete any pending flush from a previous poll first.
198        if self.needs_flush {
199            match self.ws.poll_flush_unpin(cx) {
200                Poll::Ready(Ok(())) => self.needs_flush = false,
201                Poll::Ready(Err(e)) => return Err(e.into()),
202                Poll::Pending => return Ok(()),
203            }
204        }
205
206        // Buffer as many queued commands as the sink will accept.
207        let mut sent_any = false;
208        while !self.pending_commands.is_empty() {
209            match self.ws.poll_ready_unpin(cx) {
210                Poll::Ready(Ok(())) => {
211                    let cmd = self.pending_commands.pop_front().unwrap();
212                    tracing::trace!("Sending {:?}", cmd);
213                    let msg = serde_json::to_string(&cmd)?;
214                    self.ws.start_send_unpin(msg.into())?;
215                    sent_any = true;
216                }
217                _ => break,
218            }
219        }
220
221        // Flush the entire batch in one write.
222        if sent_any {
223            match self.ws.poll_flush_unpin(cx) {
224                Poll::Ready(Ok(())) => {}
225                Poll::Ready(Err(e)) => return Err(e.into()),
226                Poll::Pending => self.needs_flush = true,
227            }
228        }
229
230        Ok(())
231    }
232}
233
234impl<T: EventMessage + Unpin> Stream for Connection<T> {
235    type Item = Result<Box<Message<T>>>;
236
237    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
238        let pin = self.get_mut();
239
240        // Send and flush outgoing messages
241        if let Err(err) = pin.start_send_next(cx) {
242            return Poll::Ready(Some(Err(err)));
243        }
244
245        // read from the websocket
246        match ready!(pin.ws.poll_next_unpin(cx)) {
247            Some(Ok(WsMessage::Text(text))) => {
248                match decode_message::<T>(text.as_bytes(), Some(&text)) {
249                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
250                    Err(err) => {
251                        tracing::debug!(
252                            target: "chromiumoxide::conn::raw_ws::parse_errors",
253                            "Dropping malformed text WS frame: {err}",
254                        );
255                        cx.waker().wake_by_ref();
256                        Poll::Pending
257                    }
258                }
259            }
260            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
261                Ok(msg) => Poll::Ready(Some(Ok(msg))),
262                Err(err) => {
263                    tracing::debug!(
264                        target: "chromiumoxide::conn::raw_ws::parse_errors",
265                        "Dropping malformed binary WS frame: {err}",
266                    );
267                    cx.waker().wake_by_ref();
268                    Poll::Pending
269                }
270            },
271            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
272            // ignore ping and pong
273            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
274                cx.waker().wake_by_ref();
275                Poll::Pending
276            }
277            Some(Ok(msg)) => {
278                // Unexpected WS message type, but not fatal.
279                tracing::debug!(
280                    target: "chromiumoxide::conn::raw_ws::parse_errors",
281                    "Unexpected WS message type: {:?}",
282                    msg
283                );
284                cx.waker().wake_by_ref();
285                Poll::Pending
286            }
287            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
288            None => {
289                // ws connection closed
290                Poll::Ready(None)
291            }
292        }
293    }
294}
295
296/// Shared decode path for both text and binary WS frames.
297/// `raw_text_for_logging` is only provided for textual frames so we can log the original
298/// payload on parse failure if desired.
299#[cfg(not(feature = "serde_stacker"))]
300fn decode_message<T: EventMessage>(
301    bytes: &[u8],
302    raw_text_for_logging: Option<&str>,
303) -> Result<Box<Message<T>>> {
304    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
305        Ok(msg) => {
306            tracing::trace!("Received {:?}", msg);
307            Ok(msg)
308        }
309        Err(err) => {
310            if let Some(txt) = raw_text_for_logging {
311                let preview = &txt[..txt.len().min(512)];
312                tracing::debug!(
313                    target: "chromiumoxide::conn::raw_ws::parse_errors",
314                    msg_len = txt.len(),
315                    "Skipping unrecognized WS message {err} preview={preview}",
316                );
317            } else {
318                tracing::debug!(
319                    target: "chromiumoxide::conn::raw_ws::parse_errors",
320                    "Skipping unrecognized binary WS message {err}",
321                );
322            }
323            Err(err.into())
324        }
325    }
326}
327
328/// Shared decode path for both text and binary WS frames.
329/// `raw_text_for_logging` is only provided for textual frames so we can log the original
330/// payload on parse failure if desired.
331#[cfg(feature = "serde_stacker")]
332fn decode_message<T: EventMessage>(
333    bytes: &[u8],
334    raw_text_for_logging: Option<&str>,
335) -> Result<Box<Message<T>>> {
336    use serde::Deserialize;
337    let mut de = serde_json::Deserializer::from_slice(bytes);
338
339    de.disable_recursion_limit();
340
341    let de = serde_stacker::Deserializer::new(&mut de);
342
343    match Box::<Message<T>>::deserialize(de) {
344        Ok(msg) => {
345            tracing::trace!("Received {:?}", msg);
346            Ok(msg)
347        }
348        Err(err) => {
349            if let Some(txt) = raw_text_for_logging {
350                let preview = &txt[..txt.len().min(512)];
351                tracing::debug!(
352                    target: "chromiumoxide::conn::raw_ws::parse_errors",
353                    msg_len = txt.len(),
354                    "Skipping unrecognized WS message {err} preview={preview}",
355                );
356            } else {
357                tracing::debug!(
358                    target: "chromiumoxide::conn::raw_ws::parse_errors",
359                    "Skipping unrecognized binary WS message {err}",
360                );
361            }
362            Err(err.into())
363        }
364    }
365}