1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::stream::{SplitSink, SplitStream};
7use futures_util::{SinkExt, Stream, StreamExt};
8use std::task::{Context, Poll};
9use tokio::sync::mpsc;
10use tokio_tungstenite::tungstenite::Message as WsMessage;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
13
14use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
15use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
16
17use crate::error::CdpError;
18use crate::error::Result;
19
20type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
21
22#[must_use = "streams do nothing unless polled"]
24#[derive(Debug)]
25pub struct Connection<T: EventMessage> {
26 pending_commands: VecDeque<MethodCall>,
28 ws: WebSocketStream<ConnectStream>,
30 next_id: usize,
32 needs_flush: bool,
34 _marker: PhantomData<T>,
36}
37
38lazy_static::lazy_static! {
39 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
41 Ok(disable_nagle) => disable_nagle == "true",
42 _ => true
43 };
44 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
46 Ok(d) => d == "true",
47 _ => false
48 };
49}
50
51pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
53
54const INITIAL_BACKOFF_MS: u64 = 50;
56
57impl<T: EventMessage + Unpin> Connection<T> {
58 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
59 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
60 }
61
62 pub async fn connect_with_retries(debug_ws_url: impl AsRef<str>, retries: u32) -> Result<Self> {
63 let mut config = WebSocketConfig::default();
64
65 config.max_write_buffer_size = 4 * 1024 * 1024;
68
69 if !*WEBSOCKET_DEFAULTS {
70 config.max_message_size = None;
71 config.max_frame_size = None;
72 }
73
74 let url = debug_ws_url.as_ref();
75 let use_uring = crate::uring_fs::is_enabled();
76 let mut last_err = None;
77
78 for attempt in 0..=retries {
79 let result = if use_uring {
80 Self::connect_uring(url, config).await
81 } else {
82 Self::connect_default(url, config).await
83 };
84
85 match result {
86 Ok(ws) => {
87 return Ok(Self {
88 pending_commands: Default::default(),
89 ws,
90 next_id: 0,
91 needs_flush: false,
92 _marker: Default::default(),
93 });
94 }
95 Err(e) => {
96 last_err = Some(e);
97 if attempt < retries {
98 let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
99 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
100 }
101 }
102 }
103 }
104
105 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
106 }
107
108 async fn connect_default(
110 url: &str,
111 config: WebSocketConfig,
112 ) -> Result<WebSocketStream<ConnectStream>> {
113 let (ws, _) =
114 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
115 Ok(ws)
116 }
117
118 async fn connect_uring(
121 url: &str,
122 config: WebSocketConfig,
123 ) -> Result<WebSocketStream<ConnectStream>> {
124 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
125
126 let request = url.into_client_request()?;
127 let host = request
128 .uri()
129 .host()
130 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
131 let port = request.uri().port_u16().unwrap_or(9222);
132
133 let addr_str = format!("{}:{}", host, port);
135 let addr: std::net::SocketAddr = match addr_str.parse() {
136 Ok(a) => a,
137 Err(_) => {
138 return Self::connect_default(url, config).await;
140 }
141 };
142
143 let std_stream = crate::uring_fs::tcp_connect(addr)
145 .await
146 .map_err(CdpError::Io)?;
147
148 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
150 if *DISABLE_NAGLE {
151 let _ = std_stream.set_nodelay(true);
152 }
153
154 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
156
157 let (ws, _) = tokio_tungstenite::client_async_with_config(
159 request,
160 MaybeTlsStream::Plain(tokio_stream),
161 Some(config),
162 )
163 .await?;
164
165 Ok(ws)
166 }
167}
168
169impl<T: EventMessage> Connection<T> {
170 fn next_call_id(&mut self) -> CallId {
171 let id = CallId::new(self.next_id);
172 self.next_id = self.next_id.wrapping_add(1);
173 id
174 }
175
176 pub fn submit_command(
179 &mut self,
180 method: MethodId,
181 session_id: Option<SessionId>,
182 params: serde_json::Value,
183 ) -> serde_json::Result<CallId> {
184 let id = self.next_call_id();
185 let call = MethodCall {
186 id,
187 method,
188 session_id: session_id.map(Into::into),
189 params,
190 };
191 self.pending_commands.push_back(call);
192 Ok(id)
193 }
194
195 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
200 if self.needs_flush {
202 match self.ws.poll_flush_unpin(cx) {
203 Poll::Ready(Ok(())) => self.needs_flush = false,
204 Poll::Ready(Err(e)) => return Err(e.into()),
205 Poll::Pending => return Ok(()),
206 }
207 }
208
209 let mut sent_any = false;
211 while !self.pending_commands.is_empty() {
212 match self.ws.poll_ready_unpin(cx) {
213 Poll::Ready(Ok(())) => {
214 let Some(cmd) = self.pending_commands.pop_front() else {
215 break;
216 };
217 tracing::trace!("Sending {:?}", cmd);
218 let msg = serde_json::to_string(&cmd)?;
219 self.ws.start_send_unpin(msg.into())?;
220 sent_any = true;
221 }
222 _ => break,
223 }
224 }
225
226 if sent_any {
228 match self.ws.poll_flush_unpin(cx) {
229 Poll::Ready(Ok(())) => {}
230 Poll::Ready(Err(e)) => return Err(e.into()),
231 Poll::Pending => self.needs_flush = true,
232 }
233 }
234
235 Ok(())
236 }
237}
238
239const WS_CMD_CHANNEL_CAPACITY: usize = 2048;
243
244#[derive(Debug)]
246pub struct AsyncConnection<T: EventMessage> {
247 pub reader: WsReader<T>,
249 pub cmd_tx: mpsc::Sender<MethodCall>,
251 pub writer_handle: tokio::task::JoinHandle<Result<()>>,
253 pub next_id: usize,
255}
256
257impl<T: EventMessage + Unpin> Connection<T> {
258 pub fn into_async(self) -> AsyncConnection<T> {
265 let (ws_sink, ws_stream) = self.ws.split();
266 let (cmd_tx, cmd_rx) = mpsc::channel(WS_CMD_CHANNEL_CAPACITY);
267
268 let writer_handle = tokio::spawn(ws_write_loop(ws_sink, cmd_rx));
269
270 let reader = WsReader {
271 inner: ws_stream,
272 _marker: PhantomData,
273 };
274
275 AsyncConnection {
276 reader,
277 cmd_tx,
278 writer_handle,
279 next_id: self.next_id,
280 }
281 }
282}
283
284async fn ws_write_loop(
286 mut sink: SplitSink<WebSocketStream<ConnectStream>, WsMessage>,
287 mut rx: mpsc::Receiver<MethodCall>,
288) -> Result<()> {
289 while let Some(call) = rx.recv().await {
290 let msg = crate::serde_json::to_string(&call)?;
291 sink.feed(WsMessage::Text(msg.into()))
292 .await
293 .map_err(CdpError::Ws)?;
294
295 while let Ok(call) = rx.try_recv() {
297 let msg = crate::serde_json::to_string(&call)?;
298 sink.feed(WsMessage::Text(msg.into()))
299 .await
300 .map_err(CdpError::Ws)?;
301 }
302
303 sink.flush().await.map_err(CdpError::Ws)?;
305 }
306 Ok(())
307}
308
309#[derive(Debug)]
314pub struct WsReader<T: EventMessage> {
315 inner: SplitStream<WebSocketStream<ConnectStream>>,
316 _marker: PhantomData<T>,
317}
318
319impl<T: EventMessage + Unpin> WsReader<T> {
320 pub async fn next_message(&mut self) -> Option<Result<Box<Message<T>>>> {
324 loop {
325 match self.inner.next().await? {
326 Ok(WsMessage::Text(text)) => {
327 match decode_message::<T>(text.as_bytes(), Some(&text)) {
328 Ok(msg) => return Some(Ok(msg)),
329 Err(err) => {
330 tracing::debug!(
331 target: "chromiumoxide::conn::raw_ws::parse_errors",
332 "Dropping malformed text WS frame: {err}",
333 );
334 continue;
335 }
336 }
337 }
338 Ok(WsMessage::Binary(buf)) => match decode_message::<T>(&buf, None) {
339 Ok(msg) => return Some(Ok(msg)),
340 Err(err) => {
341 tracing::debug!(
342 target: "chromiumoxide::conn::raw_ws::parse_errors",
343 "Dropping malformed binary WS frame: {err}",
344 );
345 continue;
346 }
347 },
348 Ok(WsMessage::Close(_)) => return None,
349 Ok(WsMessage::Ping(_)) | Ok(WsMessage::Pong(_)) => continue,
350 Ok(msg) => {
351 tracing::debug!(
352 target: "chromiumoxide::conn::raw_ws::parse_errors",
353 "Unexpected WS message type: {:?}",
354 msg
355 );
356 continue;
357 }
358 Err(err) => return Some(Err(CdpError::Ws(err))),
359 }
360 }
361 }
362}
363
364impl<T: EventMessage + Unpin> Stream for Connection<T> {
365 type Item = Result<Box<Message<T>>>;
366
367 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
368 let pin = self.get_mut();
369
370 if let Err(err) = pin.start_send_next(cx) {
372 return Poll::Ready(Some(Err(err)));
373 }
374
375 loop {
379 match ready!(pin.ws.poll_next_unpin(cx)) {
380 Some(Ok(WsMessage::Text(text))) => {
381 match decode_message::<T>(text.as_bytes(), Some(&text)) {
382 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
383 Err(err) => {
384 tracing::debug!(
385 target: "chromiumoxide::conn::raw_ws::parse_errors",
386 "Dropping malformed text WS frame: {err}",
387 );
388 continue;
389 }
390 }
391 }
392 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
393 Ok(msg) => return Poll::Ready(Some(Ok(msg))),
394 Err(err) => {
395 tracing::debug!(
396 target: "chromiumoxide::conn::raw_ws::parse_errors",
397 "Dropping malformed binary WS frame: {err}",
398 );
399 continue;
400 }
401 },
402 Some(Ok(WsMessage::Close(_))) => return Poll::Ready(None),
403 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => continue,
405 Some(Ok(msg)) => {
406 tracing::debug!(
407 target: "chromiumoxide::conn::raw_ws::parse_errors",
408 "Unexpected WS message type: {:?}",
409 msg
410 );
411 continue;
412 }
413 Some(Err(err)) => return Poll::Ready(Some(Err(CdpError::Ws(err)))),
414 None => return Poll::Ready(None),
415 }
416 }
417 }
418}
419
420#[cfg(not(feature = "serde_stacker"))]
424fn decode_message<T: EventMessage>(
425 bytes: &[u8],
426 raw_text_for_logging: Option<&str>,
427) -> Result<Box<Message<T>>> {
428 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
429 Ok(msg) => {
430 tracing::trace!("Received {:?}", msg);
431 Ok(msg)
432 }
433 Err(err) => {
434 if let Some(txt) = raw_text_for_logging {
435 let preview = &txt[..txt.len().min(512)];
436 tracing::debug!(
437 target: "chromiumoxide::conn::raw_ws::parse_errors",
438 msg_len = txt.len(),
439 "Skipping unrecognized WS message {err} preview={preview}",
440 );
441 } else {
442 tracing::debug!(
443 target: "chromiumoxide::conn::raw_ws::parse_errors",
444 "Skipping unrecognized binary WS message {err}",
445 );
446 }
447 Err(err.into())
448 }
449 }
450}
451
452#[cfg(feature = "serde_stacker")]
456fn decode_message<T: EventMessage>(
457 bytes: &[u8],
458 raw_text_for_logging: Option<&str>,
459) -> Result<Box<Message<T>>> {
460 use serde::Deserialize;
461 let mut de = serde_json::Deserializer::from_slice(bytes);
462
463 de.disable_recursion_limit();
464
465 let de = serde_stacker::Deserializer::new(&mut de);
466
467 match Box::<Message<T>>::deserialize(de) {
468 Ok(msg) => {
469 tracing::trace!("Received {:?}", msg);
470 Ok(msg)
471 }
472 Err(err) => {
473 if let Some(txt) = raw_text_for_logging {
474 let preview = &txt[..txt.len().min(512)];
475 tracing::debug!(
476 target: "chromiumoxide::conn::raw_ws::parse_errors",
477 msg_len = txt.len(),
478 "Skipping unrecognized WS message {err} preview={preview}",
479 );
480 } else {
481 tracing::debug!(
482 target: "chromiumoxide::conn::raw_ws::parse_errors",
483 "Skipping unrecognized binary WS message {err}",
484 );
485 }
486 Err(err.into())
487 }
488 }
489}