1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26 pending_commands: VecDeque<MethodCall>,
28 ws: WebSocketStream<ConnectStream>,
30 next_id: usize,
32 needs_flush: bool,
34 _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41 Ok(disable_nagle) => disable_nagle == "true",
42 _ => true
43 };
44 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46 Ok(d) => d == "true",
47 _ => false
48 };
49}
50
51pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54const INITIAL_BACKOFF_MS: u64 = 50;
56
57const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63 }
64
65 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66 let mut config = WebSocketConfig::default();
67
68 config.max_write_buffer_size = 4 * 1024 * 1024;
71
72 if !*WEBSOCKET_DEFAULTS {
73 config.max_message_size = None;
74 config.max_frame_size = None;
75 }
76
77 let url = debug_ws_url.as_ref();
78 let use_uring = crate::uring_fs::is_enabled();
79 let mut last_err = None;
80
81 for attempt in 0..=retries {
82 let result = if use_uring {
83 Self::connect_uring(url, config).await
84 } else {
85 Self::connect_default(url, config).await
86 };
87
88 match result {
89 Ok(ws) => {
90 return Ok(Self {
91 pending_commands: Default::default(),
92 ws,
93 next_id: 0,
94 needs_flush: false,
95 _marker: Default::default(),
96 });
97 }
98 Err(e) => {
99 let should_retry = match &e {
102 CdpError::Io(io_err)
104 if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105 {
106 false
107 }
108 CdpError::Ws(tungstenite_err) => !matches!(
111 tungstenite_err,
112 tokio_tungstenite::tungstenite::Error::Http(_)
113 | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
114 ),
115 _ => true,
116 };
117
118 last_err = Some(e);
119
120 if !should_retry {
121 break;
122 }
123
124 if attempt < retries {
125 let backoff_ms =
126 (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
127 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
128 }
129 }
130 }
131 }
132
133 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
134 }
135
136 async fn connect_default(
138 url: &str,
139 config: WebSocketConfig,
140 ) -> Result<WebSocketStream<ConnectStream>> {
141 let (ws, _) =
142 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
143 Ok(ws)
144 }
145
146 async fn connect_uring(
149 url: &str,
150 config: WebSocketConfig,
151 ) -> Result<WebSocketStream<ConnectStream>> {
152 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
153
154 let request = url.into_client_request()?;
155 let host = request
156 .uri()
157 .host()
158 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
159 let port = request.uri().port_u16().unwrap_or(9222);
160
161 let addr_str = format!("{}:{}", host, port);
163 let addr: std::net::SocketAddr = match addr_str.parse() {
164 Ok(a) => a,
165 Err(_) => {
166 return Self::connect_default(url, config).await;
168 }
169 };
170
171 let std_stream = crate::uring_fs::tcp_connect(addr)
173 .await
174 .map_err(CdpError::Io)?;
175
176 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
178 if *DISABLE_NAGLE {
179 let _ = std_stream.set_nodelay(true);
180 }
181
182 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
184
185 let (ws, _) = tokio_tungstenite::client_async_with_config(
187 request,
188 MaybeTlsStream::Plain(tokio_stream),
189 Some(config),
190 )
191 .await?;
192
193 Ok(ws)
194 }
195}
196
197impl<T: EventMessage> Connection<T> {
198 fn next_call_id(&mut self) -> CallId {
199 let id = CallId::new(self.next_id);
200 self.next_id = self.next_id.wrapping_add(1);
201 id
202 }
203
204 pub fn submit_command(
207 &mut self,
208 method: MethodId,
209 session_id: Option<SessionId>,
210 params: serde_json::Value,
211 ) -> serde_json::Result<CallId> {
212 let id = self.next_call_id();
213 let call = MethodCall {
214 id,
215 method,
216 session_id: session_id.map(Into::into),
217 params,
218 };
219 self.pending_commands.push_back(call);
220 Ok(id)
221 }
222
223 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
228 if self.needs_flush {
230 match self.ws.poll_flush_unpin(cx) {
231 Poll::Ready(Ok(())) => self.needs_flush = false,
232 Poll::Ready(Err(e)) => return Err(e.into()),
233 Poll::Pending => return Ok(()),
234 }
235 }
236
237 let mut sent_any = false;
239 while !self.pending_commands.is_empty() {
240 match self.ws.poll_ready_unpin(cx) {
241 Poll::Ready(Ok(())) => {
242 let Some(cmd) = self.pending_commands.pop_front() else {
243 break;
244 };
245 tracing::trace!("Sending {:?}", cmd);
246 let msg = serde_json::to_string(&cmd)?;
247 self.ws.start_send_unpin(msg.into())?;
248 sent_any = true;
249 }
250 _ => break,
251 }
252 }
253
254 if sent_any {
256 match self.ws.poll_flush_unpin(cx) {
257 Poll::Ready(Ok(())) => {}
258 Poll::Ready(Err(e)) => return Err(e.into()),
259 Poll::Pending => self.needs_flush = true,
260 }
261 }
262
263 Ok(())
264 }
265}
266
267const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
271
272#[derive(Debug)]
274pub struct AsyncConnection<T: EventMessage> {
275 pub reader: WsReader<T>,
277 pub cmd_tx: mpsc::Sender<MethodCall>,
279 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
281 pub next_id: usize,
283}
284
285impl<T: EventMessage + Unpin> Connection<T> {
286 pub fn into_async(self) -> AsyncConnection<T> {
293 let (ws_sink, ws_stream) = self.ws.split();
294 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
295
296 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
297
298 let reader = WsReader {
299 inner: ws_stream,
300 _marker: PhantomData,
301 };
302
303 AsyncConnection {
304 reader,
305 cmd_tx,
306 writer_handle,
307 next_id: self.next_id,
308 }
309 }
310}
311
312async fn ws_write_loop(
314 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
315 mut rx: mpsc::Receiver<MethodCall>,
316) -> Result<()> {
317 while let Some(call) = rx.recv().await {
318 let msg = crate::serde_json::to_string(&call)?;
319 sink.feed(WsMessage::Text(msg.into()))
320 .await
321 .map_err(CdpError::Ws)?;
322
323 while let Ok(call) = rx.try_recv() {
325 let msg = crate::serde_json::to_string(&call)?;
326 sink.feed(WsMessage::Text(msg.into()))
327 .await
328 .map_err(CdpError::Ws)?;
329 }
330
331 sink.flush().await.map_err(CdpError::Ws)?;
333 }
334 Ok(())
335}
336
337#[derive(Debug)]
342pub struct WsReader<T: EventMessage> {
343 inner: SplitStream<WebSocketStream<ConnectStream>>,
344 _marker: PhantomData<T>,
345}
346
347impl<T: EventMessage + Unpin> WsReader<T> {
348 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
352 loop {
353 match self.inner.next().await? {
354 Ok(WsMessage::Text(text)) => {
355 match decode_message::<T>(text.as_bytes(), Some(&text)) {
356 Ok(msg) => return Some(Ok(msg)),
357 Err(err) => {
358 tracing::debug!(
359 target: "chromiumoxide::conn::raw_ws::parse_errors",
360 "Dropping malformed text WS frame: {err}",
361 );
362 continue;
363 }
364 }
365 }
366 Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
367 Ok(msg) => return Some(Ok(msg)),
368 Err(err) => {
369 tracing::debug!(
370 target: "chromiumoxide::conn::raw_ws::parse_errors",
371 "Dropping malformed binary WS frame: {err}",
372 );
373 continue;
374 }
375 },
376 Ok(WsMessage::Close(_)) => return None,
377 Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
378 Ok(msg) => {
379 tracing::debug!(
380 target: "chromiumoxide::conn::raw_ws::parse_errors",
381 "Unexpected WS message type: {:?}",
382 msg
383 );
384 continue;
385 }
386 Err(err) => return Some(Err(CdpError::Ws(err))),
387 }
388 }
389 }
390}
391
392impl<T: EventMessage + Unpin> Stream for Connection<T> {
393 type Item = Result<Box<Message<T>>>;
394
395 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
396 let pin = self.get_mut();
397
398 if let Err(err) = pin.start_send_next(cx) {
400 return Poll::Ready(Some(Err(err)));
401 }
402
403 const MAX_SKIPS_PER_POLL: u32 = 16;
412 let mut skips: u32 = 0;
413 loop {
414 match ready!(pin.ws.poll_next_unpin(cx)) {
415 Some(Ok(WsMessage::Text(text))) => {
416 match decode_message::<T>(text.as_bytes(), Some(&text)) {
417 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
418 Err(err) => {
419 tracing::debug!(
420 target: "chromiumoxide::conn::raw_ws::parse_errors",
421 "Dropping malformed text WS frame: {err}",
422 );
423 skips += 1;
424 }
425 }
426 }
427 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
428 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
429 Err(err) => {
430 tracing::debug!(
431 target: "chromiumoxide::conn::raw_ws::parse_errors",
432 "Dropping malformed binary WS frame: {err}",
433 );
434 skips += 1;
435 }
436 },
437 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
438 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
439 skips += 1;
440 }
441 Some(Ok(msg)) => {
442 tracing::debug!(
443 target: "chromiumoxide::conn::raw_ws::parse_errors",
444 "Unexpected WS message type: {:?}",
445 msg
446 );
447 skips += 1;
448 }
449 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
450 None => return Poll::Ready(None),
451 }
452
453 if skips >= MAX_SKIPS_PER_POLL {
454 cx.waker().wake_by_ref();
455 return Poll::Pending;
456 }
457 }
458 }
459}
460
461#[cfg(not(feature = "serde_stacker"))]
465fn decode_message<T: EventMessage>(
466 bytes: &[u8],
467 raw_text_for_logging: Option<&str>,
468) -> Result<Box<Message<T>>> {
469 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
470 Ok(msg) => {
471 tracing::trace!("Received {:?}", msg);
472 Ok(msg)
473 }
474 Err(err) => {
475 if let Some(txt) = raw_text_for_logging {
476 let preview = &txt[..txt.len().min(512)];
477 tracing::debug!(
478 target: "chromiumoxide::conn::raw_ws::parse_errors",
479 msg_len = txt.len(),
480 "Skipping unrecognized WS message {err} preview={preview}",
481 );
482 } else {
483 tracing::debug!(
484 target: "chromiumoxide::conn::raw_ws::parse_errors",
485 "Skipping unrecognized binary WS message {err}",
486 );
487 }
488 Err(err.into())
489 }
490 }
491}
492
493#[cfg(feature = "serde_stacker")]
497fn decode_message<T: EventMessage>(
498 bytes: &[u8],
499 raw_text_for_logging: Option<&str>,
500) -> Result<Box<Message<T>>> {
501 use serde::Deserialize;
502 let mut de = serde_json::Deserializer::from_slice(bytes);
503
504 de.disable_recursion_limit();
505
506 let de = serde_stacker::Deserializer::new(&mut de);
507
508 match Box::<Message<T>>::deserialize(de) {
509 Ok(msg) => {
510 tracing::trace!("Received {:?}", msg);
511 Ok(msg)
512 }
513 Err(err) => {
514 if let Some(txt) = raw_text_for_logging {
515 let preview = &txt[..txt.len().min(512)];
516 tracing::debug!(
517 target: "chromiumoxide::conn::raw_ws::parse_errors",
518 msg_len = txt.len(),
519 "Skipping unrecognized WS message {err} preview={preview}",
520 );
521 } else {
522 tracing::debug!(
523 target: "chromiumoxide::conn::raw_ws::parse_errors",
524 "Skipping unrecognized binary WS message {err}",
525 );
526 }
527 Err(err.into())
528 }
529 }
530}