1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures_util::{SinkExt, Stream, StreamExt};
7use std::task::{Context, Poll};
8use tokio_tungstenite::tungstenite::Message as WsMessage;
9use tokio_tungstenite::MaybeTlsStream;
10use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
11
12use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
13use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
14
15use crate::error::CdpError;
16use crate::error::Result;
17
18type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
19
20#[must_use = "streams do nothing unless polled"]
22#[derive(Debug)]
23pub struct Connection<T: EventMessage> {
24 pending_commands: VecDeque<MethodCall>,
26 ws: WebSocketStream<ConnectStream>,
28 next_id: usize,
30 needs_flush: bool,
32 _marker: PhantomData<T>,
34}
35
36lazy_static::lazy_static! {
37 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
39 Ok(disable_nagle) => disable_nagle == "true",
40 _ => true
41 };
42 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
44 Ok(d) => d == "true",
45 _ => false
46 };
47}
48
49pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
51
52const INITIAL_BACKOFF_MS: u64 = 50;
54
55impl<T: EventMessage + Unpin> Connection<T> {
56 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
57 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
58 }
59
60 pub async fn connect_with_retries(
61 debug_ws_url: impl AsRef<str>,
62 retries: u32,
63 ) -> Result<Self> {
64 let mut config = WebSocketConfig::default();
65
66 if !*WEBSOCKET_DEFAULTS {
67 config.max_message_size = None;
68 config.max_frame_size = None;
69 }
70
71 let url = debug_ws_url.as_ref();
72 let use_uring = crate::uring_fs::is_enabled();
73 let mut last_err = None;
74
75 for attempt in 0..=retries {
76 let result = if use_uring {
77 Self::connect_uring(url, config).await
78 } else {
79 Self::connect_default(url, config).await
80 };
81
82 match result {
83 Ok(ws) => {
84 return Ok(Self {
85 pending_commands: Default::default(),
86 ws,
87 next_id: 0,
88 needs_flush: false,
89 _marker: Default::default(),
90 });
91 }
92 Err(e) => {
93 last_err = Some(e);
94 if attempt < retries {
95 let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
96 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
97 }
98 }
99 }
100 }
101
102 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
103 }
104
105 async fn connect_default(
107 url: &str,
108 config: WebSocketConfig,
109 ) -> Result<WebSocketStream<ConnectStream>> {
110 let (ws, _) =
111 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
112 Ok(ws)
113 }
114
115 async fn connect_uring(
118 url: &str,
119 config: WebSocketConfig,
120 ) -> Result<WebSocketStream<ConnectStream>> {
121 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
122
123 let request = url.into_client_request()?;
124 let host = request
125 .uri()
126 .host()
127 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
128 let port = request.uri().port_u16().unwrap_or(9222);
129
130 let addr_str = format!("{}:{}", host, port);
132 let addr: std::net::SocketAddr = match addr_str.parse() {
133 Ok(a) => a,
134 Err(_) => {
135 return Self::connect_default(url, config).await;
137 }
138 };
139
140 let std_stream = crate::uring_fs::tcp_connect(addr)
142 .await
143 .map_err(CdpError::Io)?;
144
145 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
147 if *DISABLE_NAGLE {
148 let _ = std_stream.set_nodelay(true);
149 }
150
151 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
153
154 let (ws, _) = tokio_tungstenite::client_async_with_config(
156 request,
157 MaybeTlsStream::Plain(tokio_stream),
158 Some(config),
159 )
160 .await?;
161
162 Ok(ws)
163 }
164}
165
166impl<T: EventMessage> Connection<T> {
167 fn next_call_id(&mut self) -> CallId {
168 let id = CallId::new(self.next_id);
169 self.next_id = self.next_id.wrapping_add(1);
170 id
171 }
172
173 pub fn submit_command(
176 &mut self,
177 method: MethodId,
178 session_id: Option<SessionId>,
179 params: serde_json::Value,
180 ) -> serde_json::Result<CallId> {
181 let id = self.next_call_id();
182 let call = MethodCall {
183 id,
184 method,
185 session_id: session_id.map(Into::into),
186 params,
187 };
188 self.pending_commands.push_back(call);
189 Ok(id)
190 }
191
192 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
197 if self.needs_flush {
199 match self.ws.poll_flush_unpin(cx) {
200 Poll::Ready(Ok(())) => self.needs_flush = false,
201 Poll::Ready(Err(e)) => return Err(e.into()),
202 Poll::Pending => return Ok(()),
203 }
204 }
205
206 let mut sent_any = false;
208 while !self.pending_commands.is_empty() {
209 match self.ws.poll_ready_unpin(cx) {
210 Poll::Ready(Ok(())) => {
211 let cmd = self.pending_commands.pop_front().unwrap();
212 tracing::trace!("Sending {:?}", cmd);
213 let msg = serde_json::to_string(&cmd)?;
214 self.ws.start_send_unpin(msg.into())?;
215 sent_any = true;
216 }
217 _ => break,
218 }
219 }
220
221 if sent_any {
223 match self.ws.poll_flush_unpin(cx) {
224 Poll::Ready(Ok(())) => {}
225 Poll::Ready(Err(e)) => return Err(e.into()),
226 Poll::Pending => self.needs_flush = true,
227 }
228 }
229
230 Ok(())
231 }
232}
233
234impl<T: EventMessage + Unpin> Stream for Connection<T> {
235 type Item = Result<Box<Message<T>>>;
236
237 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
238 let pin = self.get_mut();
239
240 if let Err(err) = pin.start_send_next(cx) {
242 return Poll::Ready(Some(Err(err)));
243 }
244
245 match ready!(pin.ws.poll_next_unpin(cx)) {
247 Some(Ok(WsMessage::Text(text))) => {
248 match decode_message::<T>(text.as_bytes(), Some(&text)) {
249 Ok(msg) => Poll::Ready(Some(Ok(msg))),
250 Err(err) => {
251 tracing::debug!(
252 target: "chromiumoxide::conn::raw_ws::parse_errors",
253 "Dropping malformed text WS frame: {err}",
254 );
255 cx.waker().wake_by_ref();
256 Poll::Pending
257 }
258 }
259 }
260 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
261 Ok(msg) => Poll::Ready(Some(Ok(msg))),
262 Err(err) => {
263 tracing::debug!(
264 target: "chromiumoxide::conn::raw_ws::parse_errors",
265 "Dropping malformed binary WS frame: {err}",
266 );
267 cx.waker().wake_by_ref();
268 Poll::Pending
269 }
270 },
271 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
272 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
274 cx.waker().wake_by_ref();
275 Poll::Pending
276 }
277 Some(Ok(msg)) => {
278 tracing::debug!(
280 target: "chromiumoxide::conn::raw_ws::parse_errors",
281 "Unexpected WS message type: {:?}",
282 msg
283 );
284 cx.waker().wake_by_ref();
285 Poll::Pending
286 }
287 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
288 None => {
289 Poll::Ready(None)
291 }
292 }
293 }
294}
295
296#[cfg(not(feature = "serde_stacker"))]
300fn decode_message<T: EventMessage>(
301 bytes: &[u8],
302 raw_text_for_logging: Option<&str>,
303) -> Result<Box<Message<T>>> {
304 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
305 Ok(msg) => {
306 tracing::trace!("Received {:?}", msg);
307 Ok(msg)
308 }
309 Err(err) => {
310 if let Some(txt) = raw_text_for_logging {
311 let preview = &txt[..txt.len().min(512)];
312 tracing::debug!(
313 target: "chromiumoxide::conn::raw_ws::parse_errors",
314 msg_len = txt.len(),
315 "Skipping unrecognized WS message {err} preview={preview}",
316 );
317 } else {
318 tracing::debug!(
319 target: "chromiumoxide::conn::raw_ws::parse_errors",
320 "Skipping unrecognized binary WS message {err}",
321 );
322 }
323 Err(err.into())
324 }
325 }
326}
327
328#[cfg(feature = "serde_stacker")]
332fn decode_message<T: EventMessage>(
333 bytes: &[u8],
334 raw_text_for_logging: Option<&str>,
335) -> Result<Box<Message<T>>> {
336 use serde::Deserialize;
337 let mut de = serde_json::Deserializer::from_slice(bytes);
338
339 de.disable_recursion_limit();
340
341 let de = serde_stacker::Deserializer::new(&mut de);
342
343 match Box::<Message<T>>::deserialize(de) {
344 Ok(msg) => {
345 tracing::trace!("Received {:?}", msg);
346 Ok(msg)
347 }
348 Err(err) => {
349 if let Some(txt) = raw_text_for_logging {
350 let preview = &txt[..txt.len().min(512)];
351 tracing::debug!(
352 target: "chromiumoxide::conn::raw_ws::parse_errors",
353 msg_len = txt.len(),
354 "Skipping unrecognized WS message {err} preview={preview}",
355 );
356 } else {
357 tracing::debug!(
358 target: "chromiumoxide::conn::raw_ws::parse_errors",
359 "Skipping unrecognized binary WS message {err}",
360 );
361 }
362 Err(err.into())
363 }
364 }
365}