1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26 pending_commands: VecDeque<MethodCall>,
28 ws: WebSocketStream<ConnectStream>,
30 next_id: usize,
32 needs_flush: bool,
34 _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41 Ok(disable_nagle) => disable_nagle == "true",
42 _ => true
43 };
44 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46 Ok(d) => d == "true",
47 _ => false
48 };
49}
50
51pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54const INITIAL_BACKOFF_MS: u64 = 50;
56
57const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63 }
64
65 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66 let mut config = WebSocketConfig::default();
67
68 config.max_write_buffer_size = 4 * 1024 * 1024;
71
72 if !*WEBSOCKET_DEFAULTS {
73 config.max_message_size = None;
74 config.max_frame_size = None;
75 }
76
77 let url = debug_ws_url.as_ref();
78 let use_uring = crate::uring_fs::is_enabled();
79 let mut last_err = None;
80
81 for attempt in 0..=retries {
82 let result = if use_uring {
83 Self::connect_uring(url, config).await
84 } else {
85 Self::connect_default(url, config).await
86 };
87
88 match result {
89 Ok(ws) => {
90 return Ok(Self {
91 pending_commands: Default::default(),
92 ws,
93 next_id: 0,
94 needs_flush: false,
95 _marker: Default::default(),
96 });
97 }
98 Err(e) => {
99 let should_retry = match &e {
102 CdpError::Io(io_err)
104 if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105 {
106 false
107 }
108 CdpError::Ws(tungstenite_err) => !matches!(
111 tungstenite_err,
112 tokio_tungstenite::tungstenite::Error::Http(_)
113 | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
114 ),
115 _ => true,
116 };
117
118 last_err = Some(e);
119
120 if !should_retry {
121 break;
122 }
123
124 if attempt < retries {
125 let backoff_ms =
126 (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
127 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
128 }
129 }
130 }
131 }
132
133 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
134 }
135
136 async fn connect_default(
138 url: &str,
139 config: WebSocketConfig,
140 ) -> Result<WebSocketStream<ConnectStream>> {
141 let (ws, _) =
142 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
143 Ok(ws)
144 }
145
146 async fn connect_uring(
149 url: &str,
150 config: WebSocketConfig,
151 ) -> Result<WebSocketStream<ConnectStream>> {
152 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
153
154 let request = url.into_client_request()?;
155 let host = request
156 .uri()
157 .host()
158 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
159 let port = request.uri().port_u16().unwrap_or(9222);
160
161 let addr_str = format!("{}:{}", host, port);
163 let addr: std::net::SocketAddr = match addr_str.parse() {
164 Ok(a) => a,
165 Err(_) => {
166 return Self::connect_default(url, config).await;
168 }
169 };
170
171 let std_stream = crate::uring_fs::tcp_connect(addr)
173 .await
174 .map_err(CdpError::Io)?;
175
176 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
178 if *DISABLE_NAGLE {
179 let _ = std_stream.set_nodelay(true);
180 }
181
182 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
184
185 let (ws, _) = tokio_tungstenite::client_async_with_config(
187 request,
188 MaybeTlsStream::Plain(tokio_stream),
189 Some(config),
190 )
191 .await?;
192
193 Ok(ws)
194 }
195}
196
197impl<T: EventMessage> Connection<T> {
198 fn next_call_id(&mut self) -> CallId {
199 let id = CallId::new(self.next_id);
200 self.next_id = self.next_id.wrapping_add(1);
201 id
202 }
203
204 pub fn submit_command(
207 &mut self,
208 method: MethodId,
209 session_id: Option<SessionId>,
210 params: serde_json::Value,
211 ) -> serde_json::Result<CallId> {
212 let id = self.next_call_id();
213 let call = MethodCall {
214 id,
215 method,
216 session_id: session_id.map(Into::into),
217 params,
218 };
219 self.pending_commands.push_back(call);
220 Ok(id)
221 }
222
223 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
228 if self.needs_flush {
230 match self.ws.poll_flush_unpin(cx) {
231 Poll::Ready(Ok(())) => self.needs_flush = false,
232 Poll::Ready(Err(e)) => return Err(e.into()),
233 Poll::Pending => return Ok(()),
234 }
235 }
236
237 let mut sent_any = false;
239 while !self.pending_commands.is_empty() {
240 match self.ws.poll_ready_unpin(cx) {
241 Poll::Ready(Ok(())) => {
242 let Some(cmd) = self.pending_commands.pop_front() else {
243 break;
244 };
245 tracing::trace!("Sending {:?}", cmd);
246 let msg = serde_json::to_string(&cmd)?;
247 self.ws.start_send_unpin(msg.into())?;
248 sent_any = true;
249 }
250 _ => break,
251 }
252 }
253
254 if sent_any {
256 match self.ws.poll_flush_unpin(cx) {
257 Poll::Ready(Ok(())) => {}
258 Poll::Ready(Err(e)) => return Err(e.into()),
259 Poll::Pending => self.needs_flush = true,
260 }
261 }
262
263 Ok(())
264 }
265}
266
267const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
271
272#[derive(Debug)]
274pub struct AsyncConnection<T: EventMessage> {
275 pub reader: WsReader<T>,
277 pub cmd_tx: mpsc::Sender<MethodCall>,
279 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
281 pub next_id: usize,
283}
284
285impl<T: EventMessage + Unpin> Connection<T> {
286 pub fn into_async(self) -> AsyncConnection<T> {
293 let (ws_sink, ws_stream) = self.ws.split();
294 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
295
296 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
297
298 let reader = WsReader {
299 inner: ws_stream,
300 _marker: PhantomData,
301 };
302
303 AsyncConnection {
304 reader,
305 cmd_tx,
306 writer_handle,
307 next_id: self.next_id,
308 }
309 }
310}
311
312async fn ws_write_loop(
314 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
315 mut rx: mpsc::Receiver<MethodCall>,
316) -> Result<()> {
317 while let Some(call) = rx.recv().await {
318 let msg = crate::serde_json::to_string(&call)?;
319 sink.feed(WsMessage::Text(msg.into()))
320 .await
321 .map_err(CdpError::Ws)?;
322
323 while let Ok(call) = rx.try_recv() {
325 let msg = crate::serde_json::to_string(&call)?;
326 sink.feed(WsMessage::Text(msg.into()))
327 .await
328 .map_err(CdpError::Ws)?;
329 }
330
331 sink.flush().await.map_err(CdpError::Ws)?;
333 }
334 Ok(())
335}
336
337#[derive(Debug)]
342pub struct WsReader<T: EventMessage> {
343 inner: SplitStream<WebSocketStream<ConnectStream>>,
344 _marker: PhantomData<T>,
345}
346
347impl<T: EventMessage + Unpin> WsReader<T> {
348 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
352 loop {
353 match self.inner.next().await? {
354 Ok(WsMessage::Text(text)) => {
355 match decode_message::<T>(text.as_bytes(), Some(&text)) {
356 Ok(msg) => return Some(Ok(msg)),
357 Err(err) => {
358 tracing::debug!(
359 target: "chromiumoxide::conn::raw_ws::parse_errors",
360 "Dropping malformed text WS frame: {err}",
361 );
362 continue;
363 }
364 }
365 }
366 Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
367 Ok(msg) => return Some(Ok(msg)),
368 Err(err) => {
369 tracing::debug!(
370 target: "chromiumoxide::conn::raw_ws::parse_errors",
371 "Dropping malformed binary WS frame: {err}",
372 );
373 continue;
374 }
375 },
376 Ok(WsMessage::Close(_)) => return None,
377 Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
378 Ok(msg) => {
379 tracing::debug!(
380 target: "chromiumoxide::conn::raw_ws::parse_errors",
381 "Unexpected WS message type: {:?}",
382 msg
383 );
384 continue;
385 }
386 Err(err) => return Some(Err(CdpError::Ws(err))),
387 }
388 }
389 }
390}
391
392impl<T: EventMessage + Unpin> Stream for Connection<T> {
393 type Item = Result<Box<Message<T>>>;
394
395 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
396 let pin = self.get_mut();
397
398 if let Err(err) = pin.start_send_next(cx) {
400 return Poll::Ready(Some(Err(err)));
401 }
402
403 loop {
407 match ready!(pin.ws.poll_next_unpin(cx)) {
408 Some(Ok(WsMessage::Text(text))) => {
409 match decode_message::<T>(text.as_bytes(), Some(&text)) {
410 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
411 Err(err) => {
412 tracing::debug!(
413 target: "chromiumoxide::conn::raw_ws::parse_errors",
414 "Dropping malformed text WS frame: {err}",
415 );
416 continue;
417 }
418 }
419 }
420 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
421 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
422 Err(err) => {
423 tracing::debug!(
424 target: "chromiumoxide::conn::raw_ws::parse_errors",
425 "Dropping malformed binary WS frame: {err}",
426 );
427 continue;
428 }
429 },
430 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
431 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => continue,
433 Some(Ok(msg)) => {
434 tracing::debug!(
435 target: "chromiumoxide::conn::raw_ws::parse_errors",
436 "Unexpected WS message type: {:?}",
437 msg
438 );
439 continue;
440 }
441 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
442 None => return Poll::Ready(None),
443 }
444 }
445 }
446}
447
448#[cfg(not(feature = "serde_stacker"))]
452fn decode_message<T: EventMessage>(
453 bytes: &[u8],
454 raw_text_for_logging: Option<&str>,
455) -> Result<Box<Message<T>>> {
456 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
457 Ok(msg) => {
458 tracing::trace!("Received {:?}", msg);
459 Ok(msg)
460 }
461 Err(err) => {
462 if let Some(txt) = raw_text_for_logging {
463 let preview = &txt[..txt.len().min(512)];
464 tracing::debug!(
465 target: "chromiumoxide::conn::raw_ws::parse_errors",
466 msg_len = txt.len(),
467 "Skipping unrecognized WS message {err} preview={preview}",
468 );
469 } else {
470 tracing::debug!(
471 target: "chromiumoxide::conn::raw_ws::parse_errors",
472 "Skipping unrecognized binary WS message {err}",
473 );
474 }
475 Err(err.into())
476 }
477 }
478}
479
480#[cfg(feature = "serde_stacker")]
484fn decode_message<T: EventMessage>(
485 bytes: &[u8],
486 raw_text_for_logging: Option<&str>,
487) -> Result<Box<Message<T>>> {
488 use serde::Deserialize;
489 let mut de = serde_json::Deserializer::from_slice(bytes);
490
491 de.disable_recursion_limit();
492
493 let de = serde_stacker::Deserializer::new(&mut de);
494
495 match Box::<Message<T>>::deserialize(de) {
496 Ok(msg) => {
497 tracing::trace!("Received {:?}", msg);
498 Ok(msg)
499 }
500 Err(err) => {
501 if let Some(txt) = raw_text_for_logging {
502 let preview = &txt[..txt.len().min(512)];
503 tracing::debug!(
504 target: "chromiumoxide::conn::raw_ws::parse_errors",
505 msg_len = txt.len(),
506 "Skipping unrecognized WS message {err} preview={preview}",
507 );
508 } else {
509 tracing::debug!(
510 target: "chromiumoxide::conn::raw_ws::parse_errors",
511 "Skipping unrecognized binary WS message {err}",
512 );
513 }
514 Err(err.into())
515 }
516 }
517}