Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21/// Exchanges the messages with the websocket
22#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25    /// Queue of commands to send.
26    pending_commands: VecDeque<MethodCall>,
27    /// The websocket of the chromium instance
28    ws: WebSocketStream<ConnectStream>,
29    /// The identifier for a specific command
30    next_id: usize,
31    /// Whether the write buffer has unsent data that needs flushing.
32    needs_flush: bool,
33    /// The phantom marker.
34    _marker: PhantomData<T>,
35}
36
37lazy_static::lazy_static! {
38    /// Nagle's algorithm disabled?
39    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
40        Ok(disable_nagle) => disable_nagle == "true",
41        _ => true
42    };
43    /// Websocket config defaults
44    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
45        Ok(d) => d == "true",
46        _ => false
47    };
48}
49
50/// Default number of WebSocket connection retry attempts.
51pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
52
53/// Initial backoff delay between connection retries (in milliseconds).
54const INITIAL_BACKOFF_MS: u64 = 50;
55
56impl<T: EventMessage + Unpin> Connection<T> {
57    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
58        Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
59    }
60
61    pub async fn connect_with_retries(
62        debug_ws_url: impl AsRef<str>,
63        retries: u32,
64    ) -> Result<Self> {
65        let mut config = WebSocketConfig::default();
66
67        if !*WEBSOCKET_DEFAULTS {
68            config.max_message_size = None;
69            config.max_frame_size = None;
70        }
71
72        let url = debug_ws_url.as_ref();
73        let use_uring = crate::uring_fs::is_enabled();
74        let mut last_err = None;
75
76        for attempt in 0..=retries {
77            let result = if use_uring {
78                Self::connect_uring(url, config).await
79            } else {
80                Self::connect_default(url, config).await
81            };
82
83            match result {
84                Ok(ws) => {
85                    return Ok(Self {
86                        pending_commands: Default::default(),
87                        ws,
88                        next_id: 0,
89                        needs_flush: false,
90                        _marker: Default::default(),
91                    });
92                }
93                Err(e) => {
94                    last_err = Some(e);
95                    if attempt < retries {
96                        let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
97                        tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
98                    }
99                }
100            }
101        }
102
103        Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
104    }
105
106    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
107    async fn connect_default(
108        url: &str,
109        config: WebSocketConfig,
110    ) -> Result<WebSocketStream<ConnectStream>> {
111        let (ws, _) =
112            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
113        Ok(ws)
114    }
115
116    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
117    /// handshake over the pre-connected stream.
118    async fn connect_uring(
119        url: &str,
120        config: WebSocketConfig,
121    ) -> Result<WebSocketStream<ConnectStream>> {
122        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
123
124        let request = url.into_client_request()?;
125        let host = request
126            .uri()
127            .host()
128            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
129        let port = request.uri().port_u16().unwrap_or(9222);
130
131        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
132        let addr_str = format!("{}:{}", host, port);
133        let addr: std::net::SocketAddr = match addr_str.parse() {
134            Ok(a) => a,
135            Err(_) => {
136                // Hostname needs DNS — fall back to default path.
137                return Self::connect_default(url, config).await;
138            }
139        };
140
141        // TCP connect via io_uring.
142        let std_stream = crate::uring_fs::tcp_connect(addr)
143            .await
144            .map_err(CdpError::Io)?;
145
146        // Set non-blocking + Nagle.
147        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
148        if *DISABLE_NAGLE {
149            let _ = std_stream.set_nodelay(true);
150        }
151
152        // Wrap in tokio TcpStream.
153        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
154
155        // WebSocket handshake over the pre-connected stream.
156        let (ws, _) = tokio_tungstenite::client_async_with_config(
157            request,
158            MaybeTlsStream::Plain(tokio_stream),
159            Some(config),
160        )
161        .await?;
162
163        Ok(ws)
164    }
165}
166
167impl<T: EventMessage> Connection<T> {
168    fn next_call_id(&mut self) -> CallId {
169        let id = CallId::new(self.next_id);
170        self.next_id = self.next_id.wrapping_add(1);
171        id
172    }
173
174    /// Queue in the command to send over the socket and return the id for this
175    /// command
176    pub fn submit_command(
177        &mut self,
178        method: MethodId,
179        session_id: Option<SessionId>,
180        params: serde_json::Value,
181    ) -> serde_json::Result<CallId> {
182        let id = self.next_call_id();
183        let call = MethodCall {
184            id,
185            method,
186            session_id: session_id.map(Into::into),
187            params,
188        };
189        self.pending_commands.push_back(call);
190        Ok(id)
191    }
192
193    /// Buffer all queued commands into the WebSocket sink, then flush once.
194    ///
195    /// This batches multiple CDP commands into a single TCP write instead of
196    /// flushing after every individual message.
197    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
198        // Complete any pending flush from a previous poll first.
199        if self.needs_flush {
200            match self.ws.poll_flush_unpin(cx) {
201                Poll::Ready(Ok(())) => self.needs_flush = false,
202                Poll::Ready(Err(e)) => return Err(e.into()),
203                Poll::Pending => return Ok(()),
204            }
205        }
206
207        // Buffer as many queued commands as the sink will accept.
208        let mut sent_any = false;
209        while !self.pending_commands.is_empty() {
210            match self.ws.poll_ready_unpin(cx) {
211                Poll::Ready(Ok(())) => {
212                    let cmd = self.pending_commands.pop_front().unwrap();
213                    tracing::trace!("Sending {:?}", cmd);
214                    let msg = serde_json::to_string(&cmd)?;
215                    self.ws.start_send_unpin(msg.into())?;
216                    sent_any = true;
217                }
218                _ => break,
219            }
220        }
221
222        // Flush the entire batch in one write.
223        if sent_any {
224            match self.ws.poll_flush_unpin(cx) {
225                Poll::Ready(Ok(())) => {}
226                Poll::Ready(Err(e)) => return Err(e.into()),
227                Poll::Pending => self.needs_flush = true,
228            }
229        }
230
231        Ok(())
232    }
233}
234
235impl<T: EventMessage + Unpin> Stream for Connection<T> {
236    type Item = Result<Box<Message<T>>>;
237
238    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
239        let pin = self.get_mut();
240
241        // Send and flush outgoing messages
242        if let Err(err) = pin.start_send_next(cx) {
243            return Poll::Ready(Some(Err(err)));
244        }
245
246        // read from the websocket
247        match ready!(pin.ws.poll_next_unpin(cx)) {
248            Some(Ok(WsMessage::Text(text))) => {
249                match decode_message::<T>(text.as_bytes(), Some(&text)) {
250                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
251                    Err(err) => {
252                        tracing::debug!(
253                            target: "chromiumoxide::conn::raw_ws::parse_errors",
254                            "Dropping malformed text WS frame: {err}",
255                        );
256                        cx.waker().wake_by_ref();
257                        Poll::Pending
258                    }
259                }
260            }
261            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
262                Ok(msg) => Poll::Ready(Some(Ok(msg))),
263                Err(err) => {
264                    tracing::debug!(
265                        target: "chromiumoxide::conn::raw_ws::parse_errors",
266                        "Dropping malformed binary WS frame: {err}",
267                    );
268                    cx.waker().wake_by_ref();
269                    Poll::Pending
270                }
271            },
272            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
273            // ignore ping and pong
274            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
275                cx.waker().wake_by_ref();
276                Poll::Pending
277            }
278            Some(Ok(msg)) => {
279                // Unexpected WS message type, but not fatal.
280                tracing::debug!(
281                    target: "chromiumoxide::conn::raw_ws::parse_errors",
282                    "Unexpected WS message type: {:?}",
283                    msg
284                );
285                cx.waker().wake_by_ref();
286                Poll::Pending
287            }
288            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
289            None => {
290                // ws connection closed
291                Poll::Ready(None)
292            }
293        }
294    }
295}
296
297/// Shared decode path for both text and binary WS frames.
298/// `raw_text_for_logging` is only provided for textual frames so we can log the original
299/// payload on parse failure if desired.
300#[cfg(not(feature = "serde_stacker"))]
301fn decode_message<T: EventMessage>(
302    bytes: &[u8],
303    raw_text_for_logging: Option<&str>,
304) -> Result<Box<Message<T>>> {
305    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
306        Ok(msg) => {
307            tracing::trace!("Received {:?}", msg);
308            Ok(msg)
309        }
310        Err(err) => {
311            if let Some(txt) = raw_text_for_logging {
312                let preview = &txt[..txt.len().min(512)];
313                tracing::debug!(
314                    target: "chromiumoxide::conn::raw_ws::parse_errors",
315                    msg_len = txt.len(),
316                    "Skipping unrecognized WS message {err} preview={preview}",
317                );
318            } else {
319                tracing::debug!(
320                    target: "chromiumoxide::conn::raw_ws::parse_errors",
321                    "Skipping unrecognized binary WS message {err}",
322                );
323            }
324            Err(err.into())
325        }
326    }
327}
328
329/// Shared decode path for both text and binary WS frames.
330/// `raw_text_for_logging` is only provided for textual frames so we can log the original
331/// payload on parse failure if desired.
332#[cfg(feature = "serde_stacker")]
333fn decode_message<T: EventMessage>(
334    bytes: &[u8],
335    raw_text_for_logging: Option<&str>,
336) -> Result<Box<Message<T>>> {
337    use serde::Deserialize;
338    let mut de = serde_json::Deserializer::from_slice(bytes);
339
340    de.disable_recursion_limit();
341
342    let de = serde_stacker::Deserializer::new(&mut de);
343
344    match Box::<Message<T>>::deserialize(de) {
345        Ok(msg) => {
346            tracing::trace!("Received {:?}", msg);
347            Ok(msg)
348        }
349        Err(err) => {
350            if let Some(txt) = raw_text_for_logging {
351                let preview = &txt[..txt.len().min(512)];
352                tracing::debug!(
353                    target: "chromiumoxide::conn::raw_ws::parse_errors",
354                    msg_len = txt.len(),
355                    "Skipping unrecognized WS message {err} preview={preview}",
356                );
357            } else {
358                tracing::debug!(
359                    target: "chromiumoxide::conn::raw_ws::parse_errors",
360                    "Skipping unrecognized binary WS message {err}",
361                );
362            }
363            Err(err.into())
364        }
365    }
366}