Skip to main content

chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21/// Exchanges the messages with the websocket
22#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25    /// Queue of commands to send.
26    pending_commands: VecDeque<MethodCall>,
27    /// The websocket of the chromium instance
28    ws: WebSocketStream<ConnectStream>,
29    /// The identifier for a specific command
30    next_id: usize,
31    /// A flush is required.
32    needs_flush: bool,
33    /// The message that is currently being proceessed
34    pending_flush: Option<MethodCall>,
35    /// The phantom marker.
36    _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40    /// Nagle's algorithm disabled?
41    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42        Ok(disable_nagle) => disable_nagle == "true",
43        _ => true
44    };
45    /// Websocket config defaults
46    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47        Ok(d) => d == "true",
48        _ => false
49    };
50}
51
52impl<T: EventMessage + Unpin> Connection<T> {
53    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
54        let mut config = WebSocketConfig::default();
55
56        if !*WEBSOCKET_DEFAULTS {
57            config.max_message_size = None;
58            config.max_frame_size = None;
59        }
60
61        let ws = if crate::uring_fs::is_enabled() {
62            Self::connect_uring(debug_ws_url.as_ref(), config).await?
63        } else {
64            Self::connect_default(debug_ws_url.as_ref(), config).await?
65        };
66
67        Ok(Self {
68            pending_commands: Default::default(),
69            ws,
70            next_id: 0,
71            needs_flush: false,
72            pending_flush: None,
73            _marker: Default::default(),
74        })
75    }
76
77    /// Default path: let tokio-tungstenite handle TCP connect + WS handshake.
78    async fn connect_default(
79        url: &str,
80        config: WebSocketConfig,
81    ) -> Result<WebSocketStream<ConnectStream>> {
82        let (ws, _) =
83            tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
84        Ok(ws)
85    }
86
87    /// io_uring path: pre-connect the TCP socket via io_uring, then do WS
88    /// handshake over the pre-connected stream.
89    async fn connect_uring(
90        url: &str,
91        config: WebSocketConfig,
92    ) -> Result<WebSocketStream<ConnectStream>> {
93        use tokio_tungstenite::tungstenite::client::IntoClientRequest;
94
95        let request = url.into_client_request()?;
96        let host = request
97            .uri()
98            .host()
99            .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
100        let port = request.uri().port_u16().unwrap_or(9222);
101
102        // Resolve host → SocketAddr (CDP is always localhost, so this is fast).
103        let addr_str = format!("{}:{}", host, port);
104        let addr: std::net::SocketAddr = match addr_str.parse() {
105            Ok(a) => a,
106            Err(_) => {
107                // Hostname needs DNS — fall back to default path.
108                return Self::connect_default(url, config).await;
109            }
110        };
111
112        // TCP connect via io_uring.
113        let std_stream = crate::uring_fs::tcp_connect(addr)
114            .await
115            .map_err(CdpError::Io)?;
116
117        // Set non-blocking + Nagle.
118        std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
119        if *DISABLE_NAGLE {
120            let _ = std_stream.set_nodelay(true);
121        }
122
123        // Wrap in tokio TcpStream.
124        let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
125
126        // WebSocket handshake over the pre-connected stream.
127        let (ws, _) = tokio_tungstenite::client_async_with_config(
128            request,
129            MaybeTlsStream::Plain(tokio_stream),
130            Some(config),
131        )
132        .await?;
133
134        Ok(ws)
135    }
136}
137
138impl<T: EventMessage> Connection<T> {
139    fn next_call_id(&mut self) -> CallId {
140        let id = CallId::new(self.next_id);
141        self.next_id = self.next_id.wrapping_add(1);
142        id
143    }
144
145    /// Queue in the command to send over the socket and return the id for this
146    /// command
147    pub fn submit_command(
148        &mut self,
149        method: MethodId,
150        session_id: Option<SessionId>,
151        params: serde_json::Value,
152    ) -> serde_json::Result<CallId> {
153        let id = self.next_call_id();
154        let call = MethodCall {
155            id,
156            method,
157            session_id: session_id.map(Into::into),
158            params,
159        };
160        self.pending_commands.push_back(call);
161        Ok(id)
162    }
163
164    /// flush any processed message and start sending the next over the conn
165    /// sink
166    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
167        if self.needs_flush {
168            if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
169                self.needs_flush = false;
170            }
171        }
172        if self.pending_flush.is_none() && !self.needs_flush {
173            if let Some(cmd) = self.pending_commands.pop_front() {
174                tracing::trace!("Sending {:?}", cmd);
175                let msg = serde_json::to_string(&cmd)?;
176                self.ws.start_send_unpin(msg.into())?;
177                self.pending_flush = Some(cmd);
178            }
179        }
180        Ok(())
181    }
182}
183
184impl<T: EventMessage + Unpin> Stream for Connection<T> {
185    type Item = Result<Box<Message<T>>>;
186
187    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
188        let pin = self.get_mut();
189
190        // flush pending outgoing messages
191        loop {
192            if let Err(err) = pin.start_send_next(cx) {
193                return Poll::Ready(Some(Err(err)));
194            }
195
196            if let Some(call) = pin.pending_flush.take() {
197                if pin.ws.poll_ready_unpin(cx).is_ready() {
198                    pin.needs_flush = true;
199                    // try another flush in this same poll
200                    continue;
201                } else {
202                    pin.pending_flush = Some(call);
203                }
204            }
205
206            break;
207        }
208
209        // read from the websocket
210        match ready!(pin.ws.poll_next_unpin(cx)) {
211            Some(Ok(WsMessage::Text(text))) => {
212                match decode_message::<T>(text.as_bytes(), Some(&text)) {
213                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
214                    Err(err) => {
215                        tracing::debug!(
216                            target: "chromiumoxide::conn::raw_ws::parse_errors",
217                            "Dropping malformed text WS frame: {err}",
218                        );
219                        cx.waker().wake_by_ref();
220                        Poll::Pending
221                    }
222                }
223            }
224            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
225                Ok(msg) => Poll::Ready(Some(Ok(msg))),
226                Err(err) => {
227                    tracing::debug!(
228                        target: "chromiumoxide::conn::raw_ws::parse_errors",
229                        "Dropping malformed binary WS frame: {err}",
230                    );
231                    cx.waker().wake_by_ref();
232                    Poll::Pending
233                }
234            },
235            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
236            // ignore ping and pong
237            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
238                cx.waker().wake_by_ref();
239                Poll::Pending
240            }
241            Some(Ok(msg)) => {
242                // Unexpected WS message type, but not fatal.
243                tracing::debug!(
244                    target: "chromiumoxide::conn::raw_ws::parse_errors",
245                    "Unexpected WS message type: {:?}",
246                    msg
247                );
248                cx.waker().wake_by_ref();
249                Poll::Pending
250            }
251            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
252            None => {
253                // ws connection closed
254                Poll::Ready(None)
255            }
256        }
257    }
258}
259
260/// Shared decode path for both text and binary WS frames.
261/// `raw_text_for_logging` is only provided for textual frames so we can log the original
262/// payload on parse failure if desired.
263#[cfg(not(feature = "serde_stacker"))]
264fn decode_message<T: EventMessage>(
265    bytes: &[u8],
266    raw_text_for_logging: Option<&str>,
267) -> Result<Box<Message<T>>> {
268    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
269        Ok(msg) => {
270            tracing::trace!("Received {:?}", msg);
271            Ok(msg)
272        }
273        Err(err) => {
274            if let Some(txt) = raw_text_for_logging {
275                let preview = &txt[..txt.len().min(512)];
276                tracing::debug!(
277                    target: "chromiumoxide::conn::raw_ws::parse_errors",
278                    msg_len = txt.len(),
279                    "Skipping unrecognized WS message {err} preview={preview}",
280                );
281            } else {
282                tracing::debug!(
283                    target: "chromiumoxide::conn::raw_ws::parse_errors",
284                    "Skipping unrecognized binary WS message {err}",
285                );
286            }
287            Err(err.into())
288        }
289    }
290}
291
292/// Shared decode path for both text and binary WS frames.
293/// `raw_text_for_logging` is only provided for textual frames so we can log the original
294/// payload on parse failure if desired.
295#[cfg(feature = "serde_stacker")]
296fn decode_message<T: EventMessage>(
297    bytes: &[u8],
298    raw_text_for_logging: Option<&str>,
299) -> Result<Box<Message<T>>> {
300    use serde::Deserialize;
301    let mut de = serde_json::Deserializer::from_slice(bytes);
302
303    de.disable_recursion_limit();
304
305    let de = serde_stacker::Deserializer::new(&mut de);
306
307    match Box::<Message<T>>::deserialize(de) {
308        Ok(msg) => {
309            tracing::trace!("Received {:?}", msg);
310            Ok(msg)
311        }
312        Err(err) => {
313            if let Some(txt) = raw_text_for_logging {
314                let preview = &txt[..txt.len().min(512)];
315                tracing::debug!(
316                    target: "chromiumoxide::conn::raw_ws::parse_errors",
317                    msg_len = txt.len(),
318                    "Skipping unrecognized WS message {err} preview={preview}",
319                );
320            } else {
321                tracing::debug!(
322                    target: "chromiumoxide::conn::raw_ws::parse_errors",
323                    "Skipping unrecognized binary WS message {err}",
324                );
325            }
326            Err(err.into())
327        }
328    }
329}