1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::task::{Context, Poll};
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tokio_tungstenite::MaybeTlsStream;
10use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
11
12use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
13use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
14
15use crate::error::CdpError;
16use crate::error::Result;
17
18type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
19
20#[must_use = "streams do nothing unless polled"]
22#[derive(Debug)]
23pub struct Connection<T: EventMessage> {
24 pending_commands: VecDeque<MethodCall>,
26 ws: WebSocketStream<ConnectStream>,
28 next_id: usize,
30 needs_flush: bool,
32 _marker: PhantomData<T>,
34}
35
36lazy_static::lazy_static! {
37 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
39 Ok(disable_nagle) => disable_nagle == "true",
40 _ => true
41 };
42 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
44 Ok(d) => d == "true",
45 _ => false
46 };
47}
48
49pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
51
52const INITIAL_BACKOFF_MS: u64 = 50;
54
55impl<T: EventMessage + Unpin> Connection<T> {
56 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
57 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
58 }
59
60 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
61 let mut config = WebSocketConfig::default();
62
63 if !*WEBSOCKET_DEFAULTS {
64 config.max_message_size = None;
65 config.max_frame_size = None;
66 }
67
68 let url = debug_ws_url.as_ref();
69 let use_uring = crate::uring_fs::is_enabled();
70 let mut last_err = None;
71
72 for attempt in 0..=retries {
73 let result = if use_uring {
74 Self::connect_uring(url, config).await
75 } else {
76 Self::connect_default(url, config).await
77 };
78
79 match result {
80 Ok(ws) => {
81 return Ok(Self {
82 pending_commands: Default::default(),
83 ws,
84 next_id: 0,
85 needs_flush: false,
86 _marker: Default::default(),
87 });
88 }
89 Err(e) => {
90 last_err = Some(e);
91 if attempt < retries {
92 let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
93 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
94 }
95 }
96 }
97 }
98
99 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
100 }
101
102 async fn connect_default(
104 url: &str,
105 config: WebSocketConfig,
106 ) -> Result<WebSocketStream<ConnectStream>> {
107 let (ws, _) =
108 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
109 Ok(ws)
110 }
111
112 async fn connect_uring(
115 url: &str,
116 config: WebSocketConfig,
117 ) -> Result<WebSocketStream<ConnectStream>> {
118 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
119
120 let request = url.into_client_request()?;
121 let host = request
122 .uri()
123 .host()
124 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
125 let port = request.uri().port_u16().unwrap_or(9222);
126
127 let addr_str = format!("{}:{}", host, port);
129 let addr: std::net::SocketAddr = match addr_str.parse() {
130 Ok(a) => a,
131 Err(_) => {
132 return Self::connect_default(url, config).await;
134 }
135 };
136
137 let std_stream = crate::uring_fs::tcp_connect(addr)
139 .await
140 .map_err(CdpError::Io)?;
141
142 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
144 if *DISABLE_NAGLE {
145 let _ = std_stream.set_nodelay(true);
146 }
147
148 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
150
151 let (ws, _) = tokio_tungstenite::client_async_with_config(
153 request,
154 MaybeTlsStream::Plain(tokio_stream),
155 Some(config),
156 )
157 .await?;
158
159 Ok(ws)
160 }
161}
162
163impl<T: EventMessage> Connection<T> {
164 fn next_call_id(&mut self) -> CallId {
165 let id = CallId::new(self.next_id);
166 self.next_id = self.next_id.wrapping_add(1);
167 id
168 }
169
170 pub fn submit_command(
173 &mut self,
174 method: MethodId,
175 session_id: Option<SessionId>,
176 params: serde_json::Value,
177 ) -> serde_json::Result<CallId> {
178 let id = self.next_call_id();
179 let call = MethodCall {
180 id,
181 method,
182 session_id: session_id.map(Into::into),
183 params,
184 };
185 self.pending_commands.push_back(call);
186 Ok(id)
187 }
188
189 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
194 if self.needs_flush {
196 match self.ws.poll_flush_unpin(cx) {
197 Poll::Ready(Ok(())) => self.needs_flush = false,
198 Poll::Ready(Err(e)) => return Err(e.into()),
199 Poll::Pending => return Ok(()),
200 }
201 }
202
203 let mut sent_any = false;
205 while !self.pending_commands.is_empty() {
206 match self.ws.poll_ready_unpin(cx) {
207 Poll::Ready(Ok(())) => {
208 let Some(cmd) = self.pending_commands.pop_front() else {
209 break;
210 };
211 tracing::trace!("Sending {:?}", cmd);
212 let msg = serde_json::to_string(&cmd)?;
213 self.ws.start_send_unpin(msg.into())?;
214 sent_any = true;
215 }
216 _ => break,
217 }
218 }
219
220 if sent_any {
222 match self.ws.poll_flush_unpin(cx) {
223 Poll::Ready(Ok(())) => {}
224 Poll::Ready(Err(e)) => return Err(e.into()),
225 Poll::Pending => self.needs_flush = true,
226 }
227 }
228
229 Ok(())
230 }
231}
232
233impl<T: EventMessage + Unpin> Stream for Connection<T> {
234 type Item = Result<Box<Message<T>>>;
235
236 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
237 let pin = self.get_mut();
238
239 if let Err(err) = pin.start_send_next(cx) {
241 return Poll::Ready(Some(Err(err)));
242 }
243
244 match ready!(pin.ws.poll_next_unpin(cx)) {
246 Some(Ok(WsMessage::Text(text))) => {
247 match decode_message::<T>(text.as_bytes(), Some(&text)) {
248 Ok(msg) => Poll::Ready(Some(Ok(msg))),
249 Err(err) => {
250 tracing::debug!(
251 target: "chromiumoxide::conn::raw_ws::parse_errors",
252 "Dropping malformed text WS frame: {err}",
253 );
254 cx.waker().wake_by_ref();
255 Poll::Pending
256 }
257 }
258 }
259 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
260 Ok(msg) => Poll::Ready(Some(Ok(msg))),
261 Err(err) => {
262 tracing::debug!(
263 target: "chromiumoxide::conn::raw_ws::parse_errors",
264 "Dropping malformed binary WS frame: {err}",
265 );
266 cx.waker().wake_by_ref();
267 Poll::Pending
268 }
269 },
270 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
271 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
273 cx.waker().wake_by_ref();
274 Poll::Pending
275 }
276 Some(Ok(msg)) => {
277 tracing::debug!(
279 target: "chromiumoxide::conn::raw_ws::parse_errors",
280 "Unexpected WS message type: {:?}",
281 msg
282 );
283 cx.waker().wake_by_ref();
284 Poll::Pending
285 }
286 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
287 None => {
288 Poll::Ready(None)
290 }
291 }
292 }
293}
294
295#[cfg(not(feature = "serde_stacker"))]
299fn decode_message<T: EventMessage>(
300 bytes: &[u8],
301 raw_text_for_logging: Option<&str>,
302) -> Result<Box<Message<T>>> {
303 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
304 Ok(msg) => {
305 tracing::trace!("Received {:?}", msg);
306 Ok(msg)
307 }
308 Err(err) => {
309 if let Some(txt) = raw_text_for_logging {
310 let preview = &txt[..txt.len().min(512)];
311 tracing::debug!(
312 target: "chromiumoxide::conn::raw_ws::parse_errors",
313 msg_len = txt.len(),
314 "Skipping unrecognized WS message {err} preview={preview}",
315 );
316 } else {
317 tracing::debug!(
318 target: "chromiumoxide::conn::raw_ws::parse_errors",
319 "Skipping unrecognized binary WS message {err}",
320 );
321 }
322 Err(err.into())
323 }
324 }
325}
326
327#[cfg(feature = "serde_stacker")]
331fn decode_message<T: EventMessage>(
332 bytes: &[u8],
333 raw_text_for_logging: Option<&str>,
334) -> Result<Box<Message<T>>> {
335 use serde::Deserialize;
336 let mut de = serde_json::Deserializer::from_slice(bytes);
337
338 de.disable_recursion_limit();
339
340 let de = serde_stacker::Deserializer::new(&mut de);
341
342 match Box::<Message<T>>::deserialize(de) {
343 Ok(msg) => {
344 tracing::trace!("Received {:?}", msg);
345 Ok(msg)
346 }
347 Err(err) => {
348 if let Some(txt) = raw_text_for_logging {
349 let preview = &txt[..txt.len().min(512)];
350 tracing::debug!(
351 target: "chromiumoxide::conn::raw_ws::parse_errors",
352 msg_len = txt.len(),
353 "Skipping unrecognized WS message {err} preview={preview}",
354 );
355 } else {
356 tracing::debug!(
357 target: "chromiumoxide::conn::raw_ws::parse_errors",
358 "Skipping unrecognized binary WS message {err}",
359 );
360 }
361 Err(err.into())
362 }
363 }
364}