1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{FuturesOrdered, SplitSink};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::future::Future;
9use std::task::{Context, Poll};
10use tokio::sync::mpsc;
11use tokio_tungstenite::tungstenite::Message as WsMessage;
12use tokio_tungstenite::MaybeTlsStream;
13use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
14
15use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
16use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
17
18use crate::error::CdpError;
19use crate::error::Result;
20
21type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
22
23#[must_use = "streams do nothing unless polled"]
25#[derive(Debug)]
26pub struct Connection<T: EventMessage> {
27 pending_commands: VecDeque<MethodCall>,
29 ws: WebSocketStream<ConnectStream>,
31 next_id: usize,
33 needs_flush: bool,
35 _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42 Ok(disable_nagle) => disable_nagle == "true",
43 _ => true
44 };
45 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47 Ok(d) => d == "true",
48 _ => false
49 };
50}
51
52pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
54
55const INITIAL_BACKOFF_MS: u64 = 50;
57
58const MAX_BACKOFF_MS: u64 = 2_000;
60
61impl<T: EventMessage + Unpin> Connection<T> {
62 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
63 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
64 }
65
66 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
67 let mut config = WebSocketConfig::default();
68
69 config.max_write_buffer_size = 4 * 1024 * 1024;
72
73 if !*WEBSOCKET_DEFAULTS {
74 config.max_message_size = None;
75 config.max_frame_size = None;
76 }
77
78 let url = debug_ws_url.as_ref();
79 let use_uring = crate::uring_fs::is_enabled();
80 let mut last_err = None;
81
82 for attempt in 0..=retries {
83 let result = if use_uring {
84 Self::connect_uring(url, config).await
85 } else {
86 Self::connect_default(url, config).await
87 };
88
89 match result {
90 Ok(ws) => {
91 return Ok(Self {
92 pending_commands: Default::default(),
93 ws,
94 next_id: 0,
95 needs_flush: false,
96 _marker: Default::default(),
97 });
98 }
99 Err(e) => {
100 let should_retry = match &e {
103 CdpError::Io(io_err)
105 if io_err.kind() == std::io::ErrorKind::ConnectionRefused =>
106 {
107 false
108 }
109 CdpError::Ws(tungstenite_err) => !matches!(
112 tungstenite_err,
113 tokio_tungstenite::tungstenite::Error::Http(_)
114 | tokio_tungstenite::tungstenite::Error::HttpFormat(_)
115 ),
116 _ => true,
117 };
118
119 last_err = Some(e);
120
121 if !should_retry {
122 break;
123 }
124
125 if attempt < retries {
126 let backoff_ms =
127 (INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt)).min(MAX_BACKOFF_MS);
128 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
129 }
130 }
131 }
132 }
133
134 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
135 }
136
137 async fn connect_default(
139 url: &str,
140 config: WebSocketConfig,
141 ) -> Result<WebSocketStream<ConnectStream>> {
142 let (ws, _) =
143 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
144 Ok(ws)
145 }
146
147 async fn connect_uring(
150 url: &str,
151 config: WebSocketConfig,
152 ) -> Result<WebSocketStream<ConnectStream>> {
153 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
154
155 let request = url.into_client_request()?;
156 let host = request
157 .uri()
158 .host()
159 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
160 let port = request.uri().port_u16().unwrap_or(9222);
161
162 let addr_str = format!("{}:{}", host, port);
164 let addr: std::net::SocketAddr = match addr_str.parse() {
165 Ok(a) => a,
166 Err(_) => {
167 return Self::connect_default(url, config).await;
169 }
170 };
171
172 let std_stream = crate::uring_fs::tcp_connect(addr)
174 .await
175 .map_err(CdpError::Io)?;
176
177 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
179 if *DISABLE_NAGLE {
180 let _ = std_stream.set_nodelay(true);
181 }
182
183 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
185
186 let (ws, _) = tokio_tungstenite::client_async_with_config(
188 request,
189 MaybeTlsStream::Plain(tokio_stream),
190 Some(config),
191 )
192 .await?;
193
194 Ok(ws)
195 }
196}
197
198impl<T: EventMessage> Connection<T> {
199 fn next_call_id(&mut self) -> CallId {
200 let id = CallId::new(self.next_id);
201 self.next_id = self.next_id.wrapping_add(1);
202 id
203 }
204
205 pub fn submit_command(
208 &mut self,
209 method: MethodId,
210 session_id: Option<SessionId>,
211 params: serde_json::Value,
212 ) -> serde_json::Result<CallId> {
213 let id = self.next_call_id();
214 let call = MethodCall {
215 id,
216 method,
217 session_id: session_id.map(Into::into),
218 params,
219 };
220 self.pending_commands.push_back(call);
221 Ok(id)
222 }
223
224 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
229 if self.needs_flush {
231 match self.ws.poll_flush_unpin(cx) {
232 Poll::Ready(Ok(())) => self.needs_flush = false,
233 Poll::Ready(Err(e)) => return Err(e.into()),
234 Poll::Pending => return Ok(()),
235 }
236 }
237
238 let mut sent_any = false;
240 while !self.pending_commands.is_empty() {
241 match self.ws.poll_ready_unpin(cx) {
242 Poll::Ready(Ok(())) => {
243 let Some(cmd) = self.pending_commands.pop_front() else {
244 break;
245 };
246 tracing::trace!("Sending {:?}", cmd);
247 let msg = serde_json::to_string(&cmd)?;
248 self.ws.start_send_unpin(msg.into())?;
249 sent_any = true;
250 }
251 _ => break,
252 }
253 }
254
255 if sent_any {
257 match self.ws.poll_flush_unpin(cx) {
258 Poll::Ready(Ok(())) => {}
259 Poll::Ready(Err(e)) => return Err(e.into()),
260 Poll::Pending => self.needs_flush = true,
261 }
262 }
263
264 Ok(())
265 }
266}
267
268const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
272
273const WS_READ_CHANNEL_CAPACITY: usize = 1024;
279
280const MAX_IN_FLIGHT_DECODES: usize = 32;
288
289const LARGE_FRAME_THRESHOLD: usize = 256 * 1024; #[derive(Debug)]
308pub struct AsyncConnection<T: EventMessage> {
309 pub reader: WsReader<T>,
315 pub cmd_tx: mpsc::Sender<MethodCall>,
317 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
319 pub reader_handle: tokio::task::JoinHandle<()>,
321 pub next_id: usize,
323}
324
325impl<T: EventMessage + Unpin + Send + 'static> Connection<T> {
326 pub fn into_async(self) -> AsyncConnection<T> {
347 let (ws_sink, ws_stream) = self.ws.split();
348 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
349 let (msg_tx, msg_rx) = mpsc::channel::<Result<Box<Message<T>>>>(WS_READ_CHANNEL_CAPACITY);
350
351 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
352 let reader_handle = tokio::spawn(ws_read_loop::<T, _>(ws_stream, msg_tx));
353
354 let reader = WsReader {
355 rx: msg_rx,
356 _marker: PhantomData,
357 };
358
359 AsyncConnection {
360 reader,
361 cmd_tx,
362 writer_handle,
363 reader_handle,
364 next_id: self.next_id,
365 }
366 }
367}
368
369enum InFlightDecode<T: EventMessage + Send + 'static> {
381 Ready(Option<Result<Box<Message<T>>>>),
384 Blocking(tokio::task::JoinHandle<Result<Box<Message<T>>>>),
386}
387
388impl<T: EventMessage + Send + 'static> Future for InFlightDecode<T> {
389 type Output = Result<Box<Message<T>>>;
390
391 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
392 match self.get_mut() {
397 InFlightDecode::Ready(slot) => Poll::Ready(
398 slot.take()
399 .expect("InFlightDecode::Ready polled after completion"),
400 ),
401 InFlightDecode::Blocking(handle) => match Pin::new(handle).poll(cx) {
402 Poll::Ready(Ok(res)) => Poll::Ready(res),
403 Poll::Ready(Err(join_err)) => Poll::Ready(Err(CdpError::msg(format!(
404 "WS decode blocking task join error: {join_err}"
405 )))),
406 Poll::Pending => Poll::Pending,
407 },
408 }
409 }
410}
411
412async fn emit_decoded<T>(
416 tx: &mpsc::Sender<Result<Box<Message<T>>>>,
417 res: Result<Box<Message<T>>>,
418) -> bool
419where
420 T: EventMessage + Send + 'static,
421{
422 match res {
423 Ok(msg) => tx.send(Ok(msg)).await.is_ok(),
424 Err(err) => {
425 tracing::debug!(
426 target: "chromiumoxide::conn::raw_ws::parse_errors",
427 "Dropping malformed WS frame: {err}",
428 );
429 true
430 }
431 }
432}
433
434async fn drain_in_flight<T>(
439 in_flight: &mut FuturesOrdered<InFlightDecode<T>>,
440 tx: &mpsc::Sender<Result<Box<Message<T>>>>,
441) where
442 T: EventMessage + Send + 'static,
443{
444 while let Some(res) = in_flight.next().await {
445 if !emit_decoded(tx, res).await {
446 return;
447 }
448 }
449}
450
451async fn ws_read_loop<T, S>(mut stream: S, tx: mpsc::Sender<Result<Box<Message<T>>>>)
479where
480 T: EventMessage + Send + 'static,
481 S: Stream<Item = std::result::Result<WsMessage, tokio_tungstenite::tungstenite::Error>>
482 + Unpin,
483{
484 let mut in_flight: FuturesOrdered<InFlightDecode<T>> = FuturesOrdered::new();
490
491 loop {
492 tokio::select! {
493 biased;
497
498 Some(res) = in_flight.next(), if !in_flight.is_empty() => {
503 if !emit_decoded(&tx, res).await {
504 return;
505 }
506 }
507
508 maybe_frame = stream.next(), if in_flight.len() < MAX_IN_FLIGHT_DECODES => {
513 match maybe_frame {
514 Some(Ok(WsMessage::Text(text))) => {
515 if text.len() >= LARGE_FRAME_THRESHOLD {
523 in_flight.push_back(InFlightDecode::Blocking(
524 tokio::task::spawn_blocking(move || {
525 decode_message::<T>(text.as_bytes(), None)
526 }),
527 ));
528 } else {
529 let res = decode_message::<T>(text.as_bytes(), Some(&text));
530 in_flight.push_back(InFlightDecode::Ready(Some(res)));
531 }
532 }
533 Some(Ok(WsMessage::Binary(buf))) => {
534 if buf.len() >= LARGE_FRAME_THRESHOLD {
539 in_flight.push_back(InFlightDecode::Blocking(
540 tokio::task::spawn_blocking(move || {
541 decode_message::<T>(&buf, None)
542 }),
543 ));
544 } else {
545 let res = decode_message::<T>(&buf, None);
546 in_flight.push_back(InFlightDecode::Ready(Some(res)));
547 }
548 }
549 Some(Ok(WsMessage::Close(_))) => {
550 drain_in_flight(&mut in_flight, &tx).await;
551 return;
552 }
553 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {}
554 Some(Ok(msg)) => {
555 tracing::debug!(
556 target: "chromiumoxide::conn::raw_ws::parse_errors",
557 "Unexpected WS message type: {:?}",
558 msg
559 );
560 }
561 Some(Err(err)) => {
562 drain_in_flight(&mut in_flight, &tx).await;
566 let _ = tx.send(Err(CdpError::Ws(err))).await;
567 return;
568 }
569 None => {
570 drain_in_flight(&mut in_flight, &tx).await;
573 return;
574 }
575 }
576 }
577 }
578 }
579}
580
581async fn ws_write_loop(
583 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
584 mut rx: mpsc::Receiver<MethodCall>,
585) -> Result<()> {
586 while let Some(call) = rx.recv().await {
587 let msg = crate::serde_json::to_string(&call)?;
588 sink.feed(WsMessage::Text(msg.into()))
589 .await
590 .map_err(CdpError::Ws)?;
591
592 while let Ok(call) = rx.try_recv() {
594 let msg = crate::serde_json::to_string(&call)?;
595 sink.feed(WsMessage::Text(msg.into()))
596 .await
597 .map_err(CdpError::Ws)?;
598 }
599
600 sink.flush().await.map_err(CdpError::Ws)?;
602 }
603 Ok(())
604}
605
606#[derive(Debug)]
616pub struct WsReader<T: EventMessage> {
617 rx: mpsc::Receiver<Result<Box<Message<T>>>>,
618 _marker: PhantomData<T>,
619}
620
621impl<T: EventMessage + Unpin> WsReader<T> {
622 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
629 self.rx.recv().await
630 }
631}
632
633impl<T: EventMessage + Unpin> Stream for Connection<T> {
634 type Item = Result<Box<Message<T>>>;
635
636 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
637 let pin = self.get_mut();
638
639 if let Err(err) = pin.start_send_next(cx) {
641 return Poll::Ready(Some(Err(err)));
642 }
643
644 const MAX_SKIPS_PER_POLL: u32 = 16;
653 let mut skips: u32 = 0;
654 loop {
655 match ready!(pin.ws.poll_next_unpin(cx)) {
656 Some(Ok(WsMessage::Text(text))) => {
657 match decode_message::<T>(text.as_bytes(), Some(&text)) {
658 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
659 Err(err) => {
660 tracing::debug!(
661 target: "chromiumoxide::conn::raw_ws::parse_errors",
662 "Dropping malformed text WS frame: {err}",
663 );
664 skips += 1;
665 }
666 }
667 }
668 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
669 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
670 Err(err) => {
671 tracing::debug!(
672 target: "chromiumoxide::conn::raw_ws::parse_errors",
673 "Dropping malformed binary WS frame: {err}",
674 );
675 skips += 1;
676 }
677 },
678 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
679 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
680 skips += 1;
681 }
682 Some(Ok(msg)) => {
683 tracing::debug!(
684 target: "chromiumoxide::conn::raw_ws::parse_errors",
685 "Unexpected WS message type: {:?}",
686 msg
687 );
688 skips += 1;
689 }
690 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
691 None => return Poll::Ready(None),
692 }
693
694 if skips >= MAX_SKIPS_PER_POLL {
695 cx.waker().wake_by_ref();
696 return Poll::Pending;
697 }
698 }
699 }
700}
701
702#[cfg(not(feature = "serde_stacker"))]
706fn decode_message<T: EventMessage>(
707 bytes: &[u8],
708 raw_text_for_logging: Option<&str>,
709) -> Result<Box<Message<T>>> {
710 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
711 Ok(msg) => {
712 tracing::trace!("Received {:?}", msg);
713 Ok(msg)
714 }
715 Err(err) => {
716 if let Some(txt) = raw_text_for_logging {
717 let preview = &txt[..txt.len().min(512)];
718 tracing::debug!(
719 target: "chromiumoxide::conn::raw_ws::parse_errors",
720 msg_len = txt.len(),
721 "Skipping unrecognized WS message {err} preview={preview}",
722 );
723 } else {
724 tracing::debug!(
725 target: "chromiumoxide::conn::raw_ws::parse_errors",
726 "Skipping unrecognized binary WS message {err}",
727 );
728 }
729 Err(err.into())
730 }
731 }
732}
733
734#[cfg(feature = "serde_stacker")]
738fn decode_message<T: EventMessage>(
739 bytes: &[u8],
740 raw_text_for_logging: Option<&str>,
741) -> Result<Box<Message<T>>> {
742 use serde::Deserialize;
743 let mut de = serde_json::Deserializer::from_slice(bytes);
744
745 de.disable_recursion_limit();
746
747 let de = serde_stacker::Deserializer::new(&mut de);
748
749 match Box::<Message<T>>::deserialize(de) {
750 Ok(msg) => {
751 tracing::trace!("Received {:?}", msg);
752 Ok(msg)
753 }
754 Err(err) => {
755 if let Some(txt) = raw_text_for_logging {
756 let preview = &txt[..txt.len().min(512)];
757 tracing::debug!(
758 target: "chromiumoxide::conn::raw_ws::parse_errors",
759 msg_len = txt.len(),
760 "Skipping unrecognized WS message {err} preview={preview}",
761 );
762 } else {
763 tracing::debug!(
764 target: "chromiumoxide::conn::raw_ws::parse_errors",
765 "Skipping unrecognized binary WS message {err}",
766 );
767 }
768 Err(err.into())
769 }
770 }
771}
772
773#[cfg(test)]
774mod ws_read_loop_tests {
775 use super::*;
791 use chromiumoxide_cdp::cdp::CdpEventMessage;
792 use chromiumoxide_types::CallId;
793 use futures_util::stream;
794 use tokio::sync::mpsc;
795 use tokio_tungstenite::tungstenite::Message as WsMessage;
796
797 fn response_frame(id: u64) -> WsMessage {
800 WsMessage::Text(
801 format!(r#"{{"id":{id},"result":{{"ok":true}}}}"#)
802 .to_string()
803 .into(),
804 )
805 }
806
807 fn large_response_frame(id: u64, blob_bytes: usize) -> WsMessage {
812 let blob = "x".repeat(blob_bytes);
813 WsMessage::Text(
814 format!(r#"{{"id":{id},"result":{{"blob":"{blob}"}}}}"#)
815 .to_string()
816 .into(),
817 )
818 }
819
820 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
821 async fn forwards_messages_in_stream_order() {
822 let frames = vec![
823 Ok(response_frame(1)),
824 Ok(response_frame(2)),
825 Ok(response_frame(3)),
826 ];
827 let stream = stream::iter(frames);
828 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
829 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
830
831 for expected in [1u64, 2, 3] {
832 let msg = rx.recv().await.expect("msg").expect("decode ok");
833 if let Message::Response(resp) = *msg {
834 assert_eq!(resp.id, CallId::new(expected as usize));
835 } else {
836 panic!("expected Response");
837 }
838 }
839 assert!(rx.recv().await.is_none(), "channel must close on EOF");
840 task.await.expect("reader task join");
841 }
842
843 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
844 async fn pings_and_pongs_never_reach_the_handler() {
845 let frames = vec![
846 Ok(WsMessage::Ping(vec![1, 2, 3].into())),
847 Ok(response_frame(7)),
848 Ok(WsMessage::Pong(vec![].into())),
849 Ok(response_frame(8)),
850 ];
851 let stream = stream::iter(frames);
852 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
853 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
854
855 for expected in [7u64, 8] {
856 let msg = rx.recv().await.expect("msg").expect("decode ok");
857 if let Message::Response(resp) = *msg {
858 assert_eq!(resp.id, CallId::new(expected as usize));
859 }
860 }
861 assert!(rx.recv().await.is_none());
862 task.await.expect("reader task join");
863 }
864
865 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
866 async fn malformed_frames_do_not_block_subsequent_valid_frames() {
867 let frames = vec![
868 Ok(WsMessage::Text("{not valid json".to_string().into())),
869 Ok(response_frame(42)),
870 ];
871 let stream = stream::iter(frames);
872 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
873 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
874
875 let msg = rx.recv().await.expect("msg").expect("decode ok");
876 if let Message::Response(resp) = *msg {
877 assert_eq!(resp.id, CallId::new(42));
878 }
879 assert!(rx.recv().await.is_none());
880 task.await.expect("reader task join");
881 }
882
883 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
884 async fn close_frame_terminates_the_reader() {
885 let frames = vec![
886 Ok(response_frame(1)),
887 Ok(WsMessage::Close(None)),
888 Ok(response_frame(2)), ];
890 let stream = stream::iter(frames);
891 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
892 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
893
894 let msg = rx.recv().await.expect("msg").expect("decode ok");
895 if let Message::Response(resp) = *msg {
896 assert_eq!(resp.id, CallId::new(1));
897 }
898 assert!(
899 rx.recv().await.is_none(),
900 "reader must exit on Close; frames after Close must not appear"
901 );
902 task.await.expect("reader task join");
903 }
904
905 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
906 async fn transport_error_is_forwarded_once_then_reader_exits() {
907 let frames = vec![
908 Ok(response_frame(1)),
909 Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed),
910 Ok(response_frame(2)),
911 ];
912 let stream = stream::iter(frames);
913 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(8);
914 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
915
916 let msg = rx.recv().await.expect("msg").expect("ok");
917 assert!(matches!(*msg, Message::Response(_)));
918 match rx.recv().await {
919 Some(Err(CdpError::Ws(_))) => {}
920 other => panic!("expected forwarded Ws error, got {other:?}"),
921 }
922 assert!(rx.recv().await.is_none());
923 task.await.expect("reader task join");
924 }
925
926 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
933 async fn bounded_channel_does_not_deadlock_under_backpressure() {
934 const N: u64 = 512;
935 let frames: Vec<_> = (1..=N).map(|id| Ok(response_frame(id))).collect();
936 let stream = stream::iter(frames);
937
938 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(1);
939 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
940
941 let deadline = std::time::Duration::from_secs(5);
942 let collected = tokio::time::timeout(deadline, async {
943 let mut seen = 0u64;
944 while let Some(frame) = rx.recv().await {
945 let msg = frame.expect("decode ok");
946 if let Message::Response(resp) = *msg {
947 seen += 1;
948 assert_eq!(
949 resp.id,
950 CallId::new(seen as usize),
951 "back-pressure must preserve FIFO order"
952 );
953 }
954 }
955 seen
956 })
957 .await
958 .expect("reader must make forward progress despite cap-1 back-pressure");
959
960 assert_eq!(collected, N, "all frames must arrive");
961 task.await.expect("reader task join");
962 }
963
964 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
970 async fn large_message_decodes_without_corruption() {
971 let big = 2 * 1024 * 1024; let frames = vec![
973 Ok(large_response_frame(100, big)),
974 Ok(response_frame(101)),
975 ];
976 let stream = stream::iter(frames);
977 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(4);
978 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
979
980 let first = rx.recv().await.expect("msg").expect("ok");
981 if let Message::Response(resp) = *first {
982 assert_eq!(resp.id, CallId::new(100));
983 }
984 let second = rx.recv().await.expect("msg").expect("ok");
985 if let Message::Response(resp) = *second {
986 assert_eq!(resp.id, CallId::new(101));
987 }
988 assert!(rx.recv().await.is_none());
989 task.await.expect("reader task join");
990 }
991
992 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
1007 async fn pipelined_large_and_small_frames_keep_fifo_order() {
1008 let big = 2 * 1024 * 1024; let frames = vec![
1010 Ok(large_response_frame(1, big)),
1011 Ok(response_frame(2)),
1012 Ok(response_frame(3)),
1013 Ok(large_response_frame(4, big)),
1014 Ok(response_frame(5)),
1015 Ok(large_response_frame(6, big)),
1016 Ok(response_frame(7)),
1017 Ok(response_frame(8)),
1018 ];
1019 let expected: Vec<usize> = (1..=8).collect();
1020
1021 let stream = stream::iter(frames);
1022 let (tx, mut rx) = mpsc::channel::<Result<Box<Message<CdpEventMessage>>>>(16);
1023 let task = tokio::spawn(ws_read_loop::<CdpEventMessage, _>(stream, tx));
1024
1025 let deadline = std::time::Duration::from_secs(10);
1026 let observed = tokio::time::timeout(deadline, async {
1027 let mut ids = Vec::with_capacity(expected.len());
1028 while let Some(frame) = rx.recv().await {
1029 let msg = frame.expect("decode ok");
1030 if let Message::Response(resp) = *msg {
1031 ids.push(CallId::new(ids.len() + 1));
1032 assert_eq!(
1033 resp.id,
1034 *ids.last().unwrap(),
1035 "pipelined reader must emit frames in strict arrival order \
1036 regardless of per-frame decode latency"
1037 );
1038 }
1039 }
1040 ids
1041 })
1042 .await
1043 .expect("pipelined reader should make forward progress within 10s");
1044
1045 assert_eq!(
1046 observed.len(),
1047 expected.len(),
1048 "all {} frames must reach the Handler",
1049 expected.len()
1050 );
1051 task.await.expect("reader task join");
1052 }
1053}