Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21/// Exchanges the messages with the websocket
22#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25    /// Queue of commands to send.
26    pending_commands: VecDeque<MethodCall>,
27    /// The websocket of the chromium instance
28    ws: WebSocketStream<ConnectStream>,
29    /// The identifier for a specific command
30    next_id: usize,
31    /// Whether the write buffer has unsent data that needs flushing.
32    needs_flush: bool,
33    /// The phantom marker.
34    _marker: PhantomData<T>,
35}
36
37lazy_static::lazy_static! {
38    /// Nagle's algorithm disabled?
39    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
40        Ok(disable_nagle) => disable_nagle == "true",
41        _ => true
42    };
43    /// Websocket config defaults
44    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
45        Ok(d) => d == "true",
46        _ => false
47    };
48}
49
50impl<T: EventMessage + Unpin> Connection<T> {
51    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
52        let mut config = WebSocketConfig::default();
53
54        if !*WEBSOCKET_DEFAULTS {
55            config.max_message_size = None;
56            config.max_frame_size = None;
57        }
58
59        let ws = if crate::uring_fs::is_enabled() {
60            Self::connect_uring(debug_ws_url.as_ref(), config).await?
61        } else {
62            Self::connect_default(debug_ws_url.as_ref(), config).await?
63        };
64
65        Ok(Self {
66            pending_commands: Default::default(),
67            ws,
68            next_id: 0,
69            needs_flush: false,
70            _marker: Default::default(),
71        })
72    }
73
74    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
75    async fn connect_default(
76        url: &str,
77        config: WebSocketConfig,
78    ) -> Result<WebSocketStream<ConnectStream>> {
79        let (ws, _) =
80            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
81        Ok(ws)
82    }
83
84    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
85    /// handshake over the pre-connected stream.
86    async fn connect_uring(
87        url: &str,
88        config: WebSocketConfig,
89    ) -> Result<WebSocketStream<ConnectStream>> {
90        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
91
92        let request = url.into_client_request()?;
93        let host = request
94            .uri()
95            .host()
96            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
97        let port = request.uri().port_u16().unwrap_or(9222);
98
99        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
100        let addr_str = format!("{}:{}", host, port);
101        let addr: std::net::SocketAddr = match addr_str.parse() {
102            Ok(a) => a,
103            Err(_) => {
104                // Hostname needs DNS — fall back to default path.
105                return Self::connect_default(url, config).await;
106            }
107        };
108
109        // TCP connect via io_uring.
110        let std_stream = crate::uring_fs::tcp_connect(addr)
111            .await
112            .map_err(CdpError::Io)?;
113
114        // Set non-blocking + Nagle.
115        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
116        if *DISABLE_NAGLE {
117            let _ = std_stream.set_nodelay(true);
118        }
119
120        // Wrap in tokio TcpStream.
121        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
122
123        // WebSocket handshake over the pre-connected stream.
124        let (ws, _) = tokio_tungstenite::client_async_with_config(
125            request,
126            MaybeTlsStream::Plain(tokio_stream),
127            Some(config),
128        )
129        .await?;
130
131        Ok(ws)
132    }
133}
134
135impl<T: EventMessage> Connection<T> {
136    fn next_call_id(&mut self) -> CallId {
137        let id = CallId::new(self.next_id);
138        self.next_id = self.next_id.wrapping_add(1);
139        id
140    }
141
142    /// Queue in the command to send over the socket and return the id for this
143    /// command
144    pub fn submit_command(
145        &mut self,
146        method: MethodId,
147        session_id: Option<SessionId>,
148        params: serde_json::Value,
149    ) -> serde_json::Result<CallId> {
150        let id = self.next_call_id();
151        let call = MethodCall {
152            id,
153            method,
154            session_id: session_id.map(Into::into),
155            params,
156        };
157        self.pending_commands.push_back(call);
158        Ok(id)
159    }
160
161    /// Buffer all queued commands into the WebSocket sink, then flush once.
162    ///
163    /// This batches multiple CDP commands into a single TCP write instead of
164    /// flushing after every individual message.
165    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
166        // Complete any pending flush from a previous poll first.
167        if self.needs_flush {
168            match self.ws.poll_flush_unpin(cx) {
169                Poll::Ready(Ok(())) => self.needs_flush = false,
170                Poll::Ready(Err(e)) => return Err(e.into()),
171                Poll::Pending => return Ok(()),
172            }
173        }
174
175        // Buffer as many queued commands as the sink will accept.
176        let mut sent_any = false;
177        while !self.pending_commands.is_empty() {
178            match self.ws.poll_ready_unpin(cx) {
179                Poll::Ready(Ok(())) => {
180                    let cmd = self.pending_commands.pop_front().unwrap();
181                    tracing::trace!("Sending {:?}", cmd);
182                    let msg = serde_json::to_string(&cmd)?;
183                    self.ws.start_send_unpin(msg.into())?;
184                    sent_any = true;
185                }
186                _ => break,
187            }
188        }
189
190        // Flush the entire batch in one write.
191        if sent_any {
192            match self.ws.poll_flush_unpin(cx) {
193                Poll::Ready(Ok(())) => {}
194                Poll::Ready(Err(e)) => return Err(e.into()),
195                Poll::Pending => self.needs_flush = true,
196            }
197        }
198
199        Ok(())
200    }
201}
202
203impl<T: EventMessage + Unpin> Stream for Connection<T> {
204    type Item = Result<Box<Message<T>>>;
205
206    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207        let pin = self.get_mut();
208
209        // Send and flush outgoing messages
210        if let Err(err) = pin.start_send_next(cx) {
211            return Poll::Ready(Some(Err(err)));
212        }
213
214        // read from the websocket
215        match ready!(pin.ws.poll_next_unpin(cx)) {
216            Some(Ok(WsMessage::Text(text))) => {
217                match decode_message::<T>(text.as_bytes(), Some(&text)) {
218                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
219                    Err(err) => {
220                        tracing::debug!(
221                            target: "chromiumoxide::conn::raw_ws::parse_errors",
222                            "Dropping malformed text WS frame: {err}",
223                        );
224                        cx.waker().wake_by_ref();
225                        Poll::Pending
226                    }
227                }
228            }
229            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
230                Ok(msg) => Poll::Ready(Some(Ok(msg))),
231                Err(err) => {
232                    tracing::debug!(
233                        target: "chromiumoxide::conn::raw_ws::parse_errors",
234                        "Dropping malformed binary WS frame: {err}",
235                    );
236                    cx.waker().wake_by_ref();
237                    Poll::Pending
238                }
239            },
240            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
241            // ignore ping and pong
242            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
243                cx.waker().wake_by_ref();
244                Poll::Pending
245            }
246            Some(Ok(msg)) => {
247                // Unexpected WS message type, but not fatal.
248                tracing::debug!(
249                    target: "chromiumoxide::conn::raw_ws::parse_errors",
250                    "Unexpected WS message type: {:?}",
251                    msg
252                );
253                cx.waker().wake_by_ref();
254                Poll::Pending
255            }
256            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
257            None => {
258                // ws connection closed
259                Poll::Ready(None)
260            }
261        }
262    }
263}
264
265/// Shared decode path for both text and binary WS frames.
266/// `raw_text_for_logging` is only provided for textual frames so we can log the original
267/// payload on parse failure if desired.
268#[cfg(not(feature = "serde_stacker"))]
269fn decode_message<T: EventMessage>(
270    bytes: &[u8],
271    raw_text_for_logging: Option<&str>,
272) -> Result<Box<Message<T>>> {
273    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
274        Ok(msg) => {
275            tracing::trace!("Received {:?}", msg);
276            Ok(msg)
277        }
278        Err(err) => {
279            if let Some(txt) = raw_text_for_logging {
280                let preview = &txt[..txt.len().min(512)];
281                tracing::debug!(
282                    target: "chromiumoxide::conn::raw_ws::parse_errors",
283                    msg_len = txt.len(),
284                    "Skipping unrecognized WS message {err} preview={preview}",
285                );
286            } else {
287                tracing::debug!(
288                    target: "chromiumoxide::conn::raw_ws::parse_errors",
289                    "Skipping unrecognized binary WS message {err}",
290                );
291            }
292            Err(err.into())
293        }
294    }
295}
296
297/// Shared decode path for both text and binary WS frames.
298/// `raw_text_for_logging` is only provided for textual frames so we can log the original
299/// payload on parse failure if desired.
300#[cfg(feature = "serde_stacker")]
301fn decode_message<T: EventMessage>(
302    bytes: &[u8],
303    raw_text_for_logging: Option<&str>,
304) -> Result<Box<Message<T>>> {
305    use serde::Deserialize;
306    let mut de = serde_json::Deserializer::from_slice(bytes);
307
308    de.disable_recursion_limit();
309
310    let de = serde_stacker::Deserializer::new(&mut de);
311
312    match Box::<Message<T>>::deserialize(de) {
313        Ok(msg) => {
314            tracing::trace!("Received {:?}", msg);
315            Ok(msg)
316        }
317        Err(err) => {
318            if let Some(txt) = raw_text_for_logging {
319                let preview = &txt[..txt.len().min(512)];
320                tracing::debug!(
321                    target: "chromiumoxide::conn::raw_ws::parse_errors",
322                    msg_len = txt.len(),
323                    "Skipping unrecognized WS message {err} preview={preview}",
324                );
325            } else {
326                tracing::debug!(
327                    target: "chromiumoxide::conn::raw_ws::parse_errors",
328                    "Skipping unrecognized binary WS message {err}",
329                );
330            }
331            Err(err.into())
332        }
333    }
334}