1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::SplitSink;
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26 pending_commands: VecDeque<MethodCall>,
28 ws: WebSocketStream<ConnectStream>,
30 next_id: usize,
32 needs_flush: bool,
34 _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41 Ok(disable_nagle) => disable_nagle == "true",
42 _ => true
43 };
44 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46 Ok(d) => d == "true",
47 _ => false
48 };
49}
50
51pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54const INITIAL_BACKOFF_MS: u64 = 50;
56
57const MAX_BACKOFF_MS: u64 = 2_000;
59
60impl<T: EventMessage + Unpin> Connection<T> {
61 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
62 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
63 }
64
65 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
66 let mut config = WebSocketConfig::default();
67
68 config.max_write_buffer_size = 4 * 1024 * 1024;
71
72 if !*WEBSOCKET_DEFAULTS {
73 config.max_message_size = None;
74 config.max_frame_size = None;
75 }
76
77 let url = debug_ws_url.as_ref();
78 let use_uring = crate::uring_fs::is_enabled();
79 let mut last_err = None;
80
81 for attempt in 0..=retries {
82 let result = if use_uring {
83 Self::connect_uring(url, config).await
84 } else {
85 Self::connect_default(url, config).await
86 };
87
88 match result {
89 Ok(ws) => {
90 return Ok(Self {
91 pending_commands: Default::default(),
92 ws,
93 next_id: 0,
94 needs_flush: false,
95 _marker: Default::default(),
96 });
97 }
98 Err(e) => {
99 let should_retry = match &e {
102 CdpError::Io(io_err)
104 if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
105 {
106 false
107 }
108 CdpError::Ws(tungstenite_err) => !matches!(
111 tungstenite_err,
112 tokio_tungstenite::tungstenite::Error::Http(_)
113 | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
114 ),
115 _ => true,
116 };
117
118 last_err = Some(e);
119
120 if !should_retry {
121 break;
122 }
123
124 if attempt < retries {
125 let backoff_ms =
126 (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
127 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
128 }
129 }
130 }
131 }
132
133 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
134 }
135
136 async fn connect_default(
138 url: &str,
139 config: WebSocketConfig,
140 ) -> Result<WebSocketStream<ConnectStream>> {
141 let (ws, _) =
142 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
143 Ok(ws)
144 }
145
146 async fn connect_uring(
149 url: &str,
150 config: WebSocketConfig,
151 ) -> Result<WebSocketStream<ConnectStream>> {
152 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
153
154 let request = url.into_client_request()?;
155 let host = request
156 .uri()
157 .host()
158 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
159 let port = request.uri().port_u16().unwrap_or(9222);
160
161 let addr_str = format!("{}:{}", host, port);
163 let addr: std::net::SocketAddr = match addr_str.parse() {
164 Ok(a) => a,
165 Err(_) => {
166 return Self::connect_default(url, config).await;
168 }
169 };
170
171 let std_stream = crate::uring_fs::tcp_connect(addr)
173 .await
174 .map_err(CdpError::Io)?;
175
176 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
178 if *DISABLE_NAGLE {
179 let _ = std_stream.set_nodelay(true);
180 }
181
182 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
184
185 let (ws, _) = tokio_tungstenite::client_async_with_config(
187 request,
188 MaybeTlsStream::Plain(tokio_stream),
189 Some(config),
190 )
191 .await?;
192
193 Ok(ws)
194 }
195}
196
197impl<T: EventMessage> Connection<T> {
198 fn next_call_id(&mut self) -> CallId {
199 let id = CallId::new(self.next_id);
200 self.next_id = self.next_id.wrapping_add(1);
201 id
202 }
203
204 pub fn submit_command(
207 &mut self,
208 method: MethodId,
209 session_id: Option<SessionId>,
210 params: serde_json::Value,
211 ) -> serde_json::Result<CallId> {
212 let id = self.next_call_id();
213 let call = MethodCall {
214 id,
215 method,
216 session_id: session_id.map(Into::into),
217 params,
218 };
219 self.pending_commands.push_back(call);
220 Ok(id)
221 }
222
223 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
228 if self.needs_flush {
230 match self.ws.poll_flush_unpin(cx) {
231 Poll::Ready(Ok(())) => self.needs_flush = false,
232 Poll::Ready(Err(e)) => return Err(e.into()),
233 Poll::Pending => return Ok(()),
234 }
235 }
236
237 let mut sent_any = false;
239 while !self.pending_commands.is_empty() {
240 match self.ws.poll_ready_unpin(cx) {
241 Poll::Ready(Ok(())) => {
242 let Some(cmd) = self.pending_commands.pop_front() else {
243 break;
244 };
245 tracing::trace!("Sending {:?}", cmd);
246 let msg = serde_json::to_string(&cmd)?;
247 self.ws.start_send_unpin(msg.into())?;
248 sent_any = true;
249 }
250 _ => break,
251 }
252 }
253
254 if sent_any {
256 match self.ws.poll_flush_unpin(cx) {
257 Poll::Ready(Ok(())) => {}
258 Poll::Ready(Err(e)) => return Err(e.into()),
259 Poll::Pending => self.needs_flush = true,
260 }
261 }
262
263 Ok(())
264 }
265}
266
267const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
271
272const WS_READ_CHANNEL_CAPACITY: usize = 1024;
278
279#[derive(Debug)]
281pub struct AsyncConnection<T: EventMessage> {
282 pub reader: WsReader<T>,
288 pub cmd_tx: mpsc::Sender<MethodCall>,
290 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
292 pub reader_handle: tokio::task::JoinHandle<()>,
294 pub next_id: usize,
296}
297
298impl<T: EventMessage + Unpin + Send + 'static> Connection<T> {
299 pub fn into_async(self) -> AsyncConnection<T> {
320 let (ws_sink, ws_stream) = self.ws.split();
321 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
322 let (msg_tx, msg_rx) = mpsc::channel::<Result<Box<Message<T>>>>(WS_READ_CHANNEL_CAPACITY);
323
324 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
325 let reader_handle = tokio::spawn(ws_read_loop::<T, _>(ws_stream, msg_tx));
326
327 let reader = WsReader {
328 rx: msg_rx,
329 _marker: PhantomData,
330 };
331
332 AsyncConnection {
333 reader,
334 cmd_tx,
335 writer_handle,
336 reader_handle,
337 next_id: self.next_id,
338 }
339 }
340}
341
342async fn ws_read_loop<T, S>(mut stream: S, tx: mpsc::Sender<Result<Box<Message<T>>>>)
370where
371 T: EventMessage,
372 S: Stream<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
373 + Unpin,
374{
375 while let Some(frame) = stream.next().await {
376 match frame {
377 Ok(WsMessage::Text(text)) => {
378 match decode_message::<T>(text.as_bytes(), Some(&text)) {
379 Ok(msg) => {
380 if tx.send(Ok(msg)).await.is_err() {
381 return;
382 }
383 }
384 Err(err) => {
385 tracing::debug!(
386 target: "chromiumoxide::conn::raw_ws::parse_errors",
387 "Dropping malformed text WS frame: {err}",
388 );
389 }
390 }
391 }
392 Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
393 Ok(msg) => {
394 if tx.send(Ok(msg)).await.is_err() {
395 return;
396 }
397 }
398 Err(err) => {
399 tracing::debug!(
400 target: "chromiumoxide::conn::raw_ws::parse_errors",
401 "Dropping malformed binary WS frame: {err}",
402 );
403 }
404 },
405 Ok(WsMessage::Close(_)) => return,
406 Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => {}
407 Ok(msg) => {
408 tracing::debug!(
409 target: "chromiumoxide::conn::raw_ws::parse_errors",
410 "Unexpected WS message type: {:?}",
411 msg
412 );
413 }
414 Err(err) => {
415 let _ = tx.send(Err(CdpError::Ws(err))).await;
418 return;
419 }
420 }
421 }
422}
423
424async fn ws_write_loop(
426 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
427 mut rx: mpsc::Receiver<MethodCall>,
428) -> Result<()> {
429 while let Some(call) = rx.recv().await {
430 let msg = crate::serde_json::to_string(&call)?;
431 sink.feed(WsMessage::Text(msg.into()))
432 .await
433 .map_err(CdpError::Ws)?;
434
435 while let Ok(call) = rx.try_recv() {
437 let msg = crate::serde_json::to_string(&call)?;
438 sink.feed(WsMessage::Text(msg.into()))
439 .await
440 .map_err(CdpError::Ws)?;
441 }
442
443 sink.flush().await.map_err(CdpError::Ws)?;
445 }
446 Ok(())
447}
448
449#[derive(Debug)]
459pub struct WsReader<T: EventMessage> {
460 rx: mpsc::Receiver<Result<Box<Message<T>>>>,
461 _marker: PhantomData<T>,
462}
463
464impl<T: EventMessage + Unpin> WsReader<T> {
465 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
472 self.rx.recv().await
473 }
474}
475
476impl<T: EventMessage + Unpin> Stream for Connection<T> {
477 type Item = Result<Box<Message<T>>>;
478
479 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
480 let pin = self.get_mut();
481
482 if let Err(err) = pin.start_send_next(cx) {
484 return Poll::Ready(Some(Err(err)));
485 }
486
487 const MAX_SKIPS_PER_POLL: u32 = 16;
496 let mut skips: u32 = 0;
497 loop {
498 match ready!(pin.ws.poll_next_unpin(cx)) {
499 Some(Ok(WsMessage::Text(text))) => {
500 match decode_message::<T>(text.as_bytes(), Some(&text)) {
501 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
502 Err(err) => {
503 tracing::debug!(
504 target: "chromiumoxide::conn::raw_ws::parse_errors",
505 "Dropping malformed text WS frame: {err}",
506 );
507 skips += 1;
508 }
509 }
510 }
511 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
512 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
513 Err(err) => {
514 tracing::debug!(
515 target: "chromiumoxide::conn::raw_ws::parse_errors",
516 "Dropping malformed binary WS frame: {err}",
517 );
518 skips += 1;
519 }
520 },
521 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
522 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
523 skips += 1;
524 }
525 Some(Ok(msg)) => {
526 tracing::debug!(
527 target: "chromiumoxide::conn::raw_ws::parse_errors",
528 "Unexpected WS message type: {:?}",
529 msg
530 );
531 skips += 1;
532 }
533 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
534 None => return Poll::Ready(None),
535 }
536
537 if skips >= MAX_SKIPS_PER_POLL {
538 cx.waker().wake_by_ref();
539 return Poll::Pending;
540 }
541 }
542 }
543}
544
545#[cfg(not(feature = "serde_stacker"))]
549fn decode_message<T: EventMessage>(
550 bytes: &[u8],
551 raw_text_for_logging: Option<&str>,
552) -> Result<Box<Message<T>>> {
553 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
554 Ok(msg) => {
555 tracing::trace!("Received {:?}", msg);
556 Ok(msg)
557 }
558 Err(err) => {
559 if let Some(txt) = raw_text_for_logging {
560 let preview = &txt[..txt.len().min(512)];
561 tracing::debug!(
562 target: "chromiumoxide::conn::raw_ws::parse_errors",
563 msg_len = txt.len(),
564 "Skipping unrecognized WS message {err} preview={preview}",
565 );
566 } else {
567 tracing::debug!(
568 target: "chromiumoxide::conn::raw_ws::parse_errors",
569 "Skipping unrecognized binary WS message {err}",
570 );
571 }
572 Err(err.into())
573 }
574 }
575}
576
577#[cfg(feature = "serde_stacker")]
581fn decode_message<T: EventMessage>(
582 bytes: &[u8],
583 raw_text_for_logging: Option<&str>,
584) -> Result<Box<Message<T>>> {
585 use serde::Deserialize;
586 let mut de = serde_json::Deserializer::from_slice(bytes);
587
588 de.disable_recursion_limit();
589
590 let de = serde_stacker::Deserializer::new(&mut de);
591
592 match Box::<Message<T>>::deserialize(de) {
593 Ok(msg) => {
594 tracing::trace!("Received {:?}", msg);
595 Ok(msg)
596 }
597 Err(err) => {
598 if let Some(txt) = raw_text_for_logging {
599 let preview = &txt[..txt.len().min(512)];
600 tracing::debug!(
601 target: "chromiumoxide::conn::raw_ws::parse_errors",
602 msg_len = txt.len(),
603 "Skipping unrecognized WS message {err} preview={preview}",
604 );
605 } else {
606 tracing::debug!(
607 target: "chromiumoxide::conn::raw_ws::parse_errors",
608 "Skipping unrecognized binary WS message {err}",
609 );
610 }
611 Err(err.into())
612 }
613 }
614}
615
616#[cfg(test)]
617mod ws_read_loop_tests {
618 use super::*;
634 use chromiumoxide_cdp::cdp::CdpEventMessage;
635 use chromiumoxide_types::CallId;
636 use futures_util::stream;
637 use tokio::sync::mpsc;
638 use tokio_tungstenite::tungstenite::Message as WsMessage;
639
640 fn response_frame(id: u64) -> WsMessage {
643 WsMessage::Text(
644 format!(r#"{{"id":{id},"result":{{"ok":true}}}}"#)
645 .to_string()
646 .into(),
647 )
648 }
649
650 fn large_response_frame(id: u64, blob_bytes: usize) -> WsMessage {
655 let blob = "x".repeat(blob_bytes);
656 WsMessage::Text(
657 format!(r#"{{"id":{id},"result":{{"blob":"{blob}"}}}}"#)
658 .to_string()
659 .into(),
660 )
661 }
662
663 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
664 async fn forwards_messages_in_stream_order() {
665 let frames = vec![
666 Ok(response_frame(1)),
667 Ok(response_frame(2)),
668 Ok(response_frame(3)),
669 ];
670 let stream = stream::iter(frames);
671 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
672 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
673
674 for expected in [1u64, 2, 3] {
675 let msg = rx.recv().await.expect("msg").expect("decode ok");
676 if let Message::Response(resp) = *msg {
677 assert_eq!(resp.id, CallId::new(expected as usize));
678 } else {
679 panic!("expected Response");
680 }
681 }
682 assert!(rx.recv().await.is_none(), "channel must close on EOF");
683 task.await.expect("reader task join");
684 }
685
686 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
687 async fn pings_and_pongs_never_reach_the_handler() {
688 let frames = vec![
689 Ok(WsMessage::Ping(vec![1, 2, 3].into())),
690 Ok(response_frame(7)),
691 Ok(WsMessage::Pong(vec![].into())),
692 Ok(response_frame(8)),
693 ];
694 let stream = stream::iter(frames);
695 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
696 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
697
698 for expected in [7u64, 8] {
699 let msg = rx.recv().await.expect("msg").expect("decode ok");
700 if let Message::Response(resp) = *msg {
701 assert_eq!(resp.id, CallId::new(expected as usize));
702 }
703 }
704 assert!(rx.recv().await.is_none());
705 task.await.expect("reader task join");
706 }
707
708 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
709 async fn malformed_frames_do_not_block_subsequent_valid_frames() {
710 let frames = vec![
711 Ok(WsMessage::Text("{not valid json".to_string().into())),
712 Ok(response_frame(42)),
713 ];
714 let stream = stream::iter(frames);
715 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
716 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
717
718 let msg = rx.recv().await.expect("msg").expect("decode ok");
719 if let Message::Response(resp) = *msg {
720 assert_eq!(resp.id, CallId::new(42));
721 }
722 assert!(rx.recv().await.is_none());
723 task.await.expect("reader task join");
724 }
725
726 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
727 async fn close_frame_terminates_the_reader() {
728 let frames = vec![
729 Ok(response_frame(1)),
730 Ok(WsMessage::Close(None)),
731 Ok(response_frame(2)), ];
733 let stream = stream::iter(frames);
734 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
735 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
736
737 let msg = rx.recv().await.expect("msg").expect("decode ok");
738 if let Message::Response(resp) = *msg {
739 assert_eq!(resp.id, CallId::new(1));
740 }
741 assert!(
742 rx.recv().await.is_none(),
743 "reader must exit on Close; frames after Close must not appear"
744 );
745 task.await.expect("reader task join");
746 }
747
748 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
749 async fn transport_error_is_forwarded_once_then_reader_exits() {
750 let frames = vec![
751 Ok(response_frame(1)),
752 Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed),
753 Ok(response_frame(2)),
754 ];
755 let stream = stream::iter(frames);
756 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
757 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
758
759 let msg = rx.recv().await.expect("msg").expect("ok");
760 assert!(matches!(*msg, Message::Response(_)));
761 match rx.recv().await {
762 Some(Err(CdpError::Ws(_))) => {}
763 other => panic!("expected forwarded Ws error, got {other:?}"),
764 }
765 assert!(rx.recv().await.is_none());
766 task.await.expect("reader task join");
767 }
768
769 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
776 async fn bounded_channel_does_not_deadlock_under_backpressure() {
777 const N: u64 = 512;
778 let frames: Vec<_> = (1..=N).map(|id| Ok(response_frame(id))).collect();
779 let stream = stream::iter(frames);
780
781 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(1);
782 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
783
784 let deadline = std::time::Duration::from_secs(5);
785 let collected = tokio::time::timeout(deadline, async {
786 let mut seen = 0u64;
787 while let Some(frame) = rx.recv().await {
788 let msg = frame.expect("decode ok");
789 if let Message::Response(resp) = *msg {
790 seen += 1;
791 assert_eq!(
792 resp.id,
793 CallId::new(seen as usize),
794 "back-pressure must preserve FIFO order"
795 );
796 }
797 }
798 seen
799 })
800 .await
801 .expect("reader must make forward progress despite cap-1 back-pressure");
802
803 assert_eq!(collected, N, "all frames must arrive");
804 task.await.expect("reader task join");
805 }
806
807 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
813 async fn large_message_decodes_without_corruption() {
814 let big = 2 * 1024 * 1024; let frames = vec![
816 Ok(large_response_frame(100, big)),
817 Ok(response_frame(101)),
818 ];
819 let stream = stream::iter(frames);
820 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(4);
821 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
822
823 let first = rx.recv().await.expect("msg").expect("ok");
824 if let Message::Response(resp) = *first {
825 assert_eq!(resp.id, CallId::new(100));
826 }
827 let second = rx.recv().await.expect("msg").expect("ok");
828 if let Message::Response(resp) = *second {
829 assert_eq!(resp.id, CallId::new(101));
830 }
831 assert!(rx.recv().await.is_none());
832 task.await.expect("reader task join");
833 }
834}