1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26 pending_commands: VecDeque<MethodCall>,
28 ws: WebSocketStream<ConnectStream>,
30 next_id: usize,
32 needs_flush: bool,
34 _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41 Ok(disable_nagle) => disable_nagle == "true",
42 _ => true
43 };
44 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46 Ok(d) => d == "true",
47 _ => false
48 };
49}
50
51pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54const INITIAL_BACKOFF_MS: u64 = 50;
56
57const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63 }
64
65 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66 let mut config = WebSocketConfig::default();
67
68 config.max_write_buffer_size = 4 * 1024 * 1024;
71
72 if !*WEBSOCKET_DEFAULTS {
73 config.max_message_size = None;
74 config.max_frame_size = None;
75 }
76
77 let url = debug_ws_url.as_ref();
78 let use_uring = crate::uring_fs::is_enabled();
79 let mut last_err = None;
80
81 for attempt in 0..=retries {
82 let result = if use_uring {
83 Self::connect_uring(url, config).await
84 } else {
85 Self::connect_default(url, config).await
86 };
87
88 match result {
89 Ok(ws) => {
90 return Ok(Self {
91 pending_commands: Default::default(),
92 ws,
93 next_id: 0,
94 needs_flush: false,
95 _marker: Default::default(),
96 });
97 }
98 Err(e) => {
99 let should_retry = match &e {
102 CdpError::Io(io_err)
104 if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105 {
106 false
107 }
108 CdpError::Ws(tungstenite_err) => {
111 !matches!(
112 tungstenite_err,
113 tokio_tungstenite::tungstenite::Error::Http(_)
114 | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
115 )
116 }
117 _ => true,
118 };
119
120 last_err = Some(e);
121
122 if !should_retry {
123 break;
124 }
125
126 if attempt < retries {
127 let backoff_ms = (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt))
128 .min(MAX_BACKOFF_MS);
129 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
130 }
131 }
132 }
133 }
134
135 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
136 }
137
138 async fn connect_default(
140 url: &str,
141 config: WebSocketConfig,
142 ) -> Result<WebSocketStream<ConnectStream>> {
143 let (ws, _) =
144 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
145 Ok(ws)
146 }
147
148 async fn connect_uring(
151 url: &str,
152 config: WebSocketConfig,
153 ) -> Result<WebSocketStream<ConnectStream>> {
154 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
155
156 let request = url.into_client_request()?;
157 let host = request
158 .uri()
159 .host()
160 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
161 let port = request.uri().port_u16().unwrap_or(9222);
162
163 let addr_str = format!("{}:{}", host, port);
165 let addr: std::net::SocketAddr = match addr_str.parse() {
166 Ok(a) => a,
167 Err(_) => {
168 return Self::connect_default(url, config).await;
170 }
171 };
172
173 let std_stream = crate::uring_fs::tcp_connect(addr)
175 .await
176 .map_err(CdpError::Io)?;
177
178 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
180 if *DISABLE_NAGLE {
181 let _ = std_stream.set_nodelay(true);
182 }
183
184 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
186
187 let (ws, _) = tokio_tungstenite::client_async_with_config(
189 request,
190 MaybeTlsStream::Plain(tokio_stream),
191 Some(config),
192 )
193 .await?;
194
195 Ok(ws)
196 }
197}
198
199impl<T: EventMessage> Connection<T> {
200 fn next_call_id(&mut self) -> CallId {
201 let id = CallId::new(self.next_id);
202 self.next_id = self.next_id.wrapping_add(1);
203 id
204 }
205
206 pub fn submit_command(
209 &mut self,
210 method: MethodId,
211 session_id: Option<SessionId>,
212 params: serde_json::Value,
213 ) -> serde_json::Result<CallId> {
214 let id = self.next_call_id();
215 let call = MethodCall {
216 id,
217 method,
218 session_id: session_id.map(Into::into),
219 params,
220 };
221 self.pending_commands.push_back(call);
222 Ok(id)
223 }
224
225 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
230 if self.needs_flush {
232 match self.ws.poll_flush_unpin(cx) {
233 Poll::Ready(Ok(())) => self.needs_flush = false,
234 Poll::Ready(Err(e)) => return Err(e.into()),
235 Poll::Pending => return Ok(()),
236 }
237 }
238
239 let mut sent_any = false;
241 while !self.pending_commands.is_empty() {
242 match self.ws.poll_ready_unpin(cx) {
243 Poll::Ready(Ok(())) => {
244 let Some(cmd) = self.pending_commands.pop_front() else {
245 break;
246 };
247 tracing::trace!("Sending {:?}", cmd);
248 let msg = serde_json::to_string(&cmd)?;
249 self.ws.start_send_unpin(msg.into())?;
250 sent_any = true;
251 }
252 _ => break,
253 }
254 }
255
256 if sent_any {
258 match self.ws.poll_flush_unpin(cx) {
259 Poll::Ready(Ok(())) => {}
260 Poll::Ready(Err(e)) => return Err(e.into()),
261 Poll::Pending => self.needs_flush = true,
262 }
263 }
264
265 Ok(())
266 }
267}
268
269const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
273
274#[derive(Debug)]
276pub struct AsyncConnection<T: EventMessage> {
277 pub reader: WsReader<T>,
279 pub cmd_tx: mpsc::Sender<MethodCall>,
281 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
283 pub next_id: usize,
285}
286
287impl<T: EventMessage + Unpin> Connection<T> {
288 pub fn into_async(self) -> AsyncConnection<T> {
295 let (ws_sink, ws_stream) = self.ws.split();
296 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
297
298 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
299
300 let reader = WsReader {
301 inner: ws_stream,
302 _marker: PhantomData,
303 };
304
305 AsyncConnection {
306 reader,
307 cmd_tx,
308 writer_handle,
309 next_id: self.next_id,
310 }
311 }
312}
313
314async fn ws_write_loop(
316 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
317 mut rx: mpsc::Receiver<MethodCall>,
318) -> Result<()> {
319 while let Some(call) = rx.recv().await {
320 let msg = crate::serde_json::to_string(&call)?;
321 sink.feed(WsMessage::Text(msg.into()))
322 .await
323 .map_err(CdpError::Ws)?;
324
325 while let Ok(call) = rx.try_recv() {
327 let msg = crate::serde_json::to_string(&call)?;
328 sink.feed(WsMessage::Text(msg.into()))
329 .await
330 .map_err(CdpError::Ws)?;
331 }
332
333 sink.flush().await.map_err(CdpError::Ws)?;
335 }
336 Ok(())
337}
338
339#[derive(Debug)]
344pub struct WsReader<T: EventMessage> {
345 inner: SplitStream<WebSocketStream<ConnectStream>>,
346 _marker: PhantomData<T>,
347}
348
349impl<T: EventMessage + Unpin> WsReader<T> {
350 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
354 loop {
355 match self.inner.next().await? {
356 Ok(WsMessage::Text(text)) => {
357 match decode_message::<T>(text.as_bytes(), Some(&text)) {
358 Ok(msg) => return Some(Ok(msg)),
359 Err(err) => {
360 tracing::debug!(
361 target: "chromiumoxide::conn::raw_ws::parse_errors",
362 "Dropping malformed text WS frame: {err}",
363 );
364 continue;
365 }
366 }
367 }
368 Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
369 Ok(msg) => return Some(Ok(msg)),
370 Err(err) => {
371 tracing::debug!(
372 target: "chromiumoxide::conn::raw_ws::parse_errors",
373 "Dropping malformed binary WS frame: {err}",
374 );
375 continue;
376 }
377 },
378 Ok(WsMessage::Close(_)) => return None,
379 Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
380 Ok(msg) => {
381 tracing::debug!(
382 target: "chromiumoxide::conn::raw_ws::parse_errors",
383 "Unexpected WS message type: {:?}",
384 msg
385 );
386 continue;
387 }
388 Err(err) => return Some(Err(CdpError::Ws(err))),
389 }
390 }
391 }
392}
393
394impl<T: EventMessage + Unpin> Stream for Connection<T> {
395 type Item = Result<Box<Message<T>>>;
396
397 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
398 let pin = self.get_mut();
399
400 if let Err(err) = pin.start_send_next(cx) {
402 return Poll::Ready(Some(Err(err)));
403 }
404
405 loop {
409 match ready!(pin.ws.poll_next_unpin(cx)) {
410 Some(Ok(WsMessage::Text(text))) => {
411 match decode_message::<T>(text.as_bytes(), Some(&text)) {
412 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
413 Err(err) => {
414 tracing::debug!(
415 target: "chromiumoxide::conn::raw_ws::parse_errors",
416 "Dropping malformed text WS frame: {err}",
417 );
418 continue;
419 }
420 }
421 }
422 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
423 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
424 Err(err) => {
425 tracing::debug!(
426 target: "chromiumoxide::conn::raw_ws::parse_errors",
427 "Dropping malformed binary WS frame: {err}",
428 );
429 continue;
430 }
431 },
432 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
433 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => continue,
435 Some(Ok(msg)) => {
436 tracing::debug!(
437 target: "chromiumoxide::conn::raw_ws::parse_errors",
438 "Unexpected WS message type: {:?}",
439 msg
440 );
441 continue;
442 }
443 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
444 None => return Poll::Ready(None),
445 }
446 }
447 }
448}
449
450#[cfg(not(feature = "serde_stacker"))]
454fn decode_message<T: EventMessage>(
455 bytes: &[u8],
456 raw_text_for_logging: Option<&str>,
457) -> Result<Box<Message<T>>> {
458 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
459 Ok(msg) => {
460 tracing::trace!("Received {:?}", msg);
461 Ok(msg)
462 }
463 Err(err) => {
464 if let Some(txt) = raw_text_for_logging {
465 let preview = &txt[..txt.len().min(512)];
466 tracing::debug!(
467 target: "chromiumoxide::conn::raw_ws::parse_errors",
468 msg_len = txt.len(),
469 "Skipping unrecognized WS message {err} preview={preview}",
470 );
471 } else {
472 tracing::debug!(
473 target: "chromiumoxide::conn::raw_ws::parse_errors",
474 "Skipping unrecognized binary WS message {err}",
475 );
476 }
477 Err(err.into())
478 }
479 }
480}
481
482#[cfg(feature = "serde_stacker")]
486fn decode_message<T: EventMessage>(
487 bytes: &[u8],
488 raw_text_for_logging: Option<&str>,
489) -> Result<Box<Message<T>>> {
490 use serde::Deserialize;
491 let mut de = serde_json::Deserializer::from_slice(bytes);
492
493 de.disable_recursion_limit();
494
495 let de = serde_stacker::Deserializer::new(&mut de);
496
497 match Box::<Message<T>>::deserialize(de) {
498 Ok(msg) => {
499 tracing::trace!("Received {:?}", msg);
500 Ok(msg)
501 }
502 Err(err) => {
503 if let Some(txt) = raw_text_for_logging {
504 let preview = &txt[..txt.len().min(512)];
505 tracing::debug!(
506 target: "chromiumoxide::conn::raw_ws::parse_errors",
507 msg_len = txt.len(),
508 "Skipping unrecognized WS message {err} preview={preview}",
509 );
510 } else {
511 tracing::debug!(
512 target: "chromiumoxide::conn::raw_ws::parse_errors",
513 "Skipping unrecognized binary WS message {err}",
514 );
515 }
516 Err(err.into())
517 }
518 }
519}