1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25 pending_commands: VecDeque<MethodCall>,
27 ws: WebSocketStream<ConnectStream>,
29 next_id: usize,
31 needs_flush: bool,
33 _marker: PhantomData<T>,
35}
36
37lazy_static::lazy_static! {
38 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
40 Ok(disable_nagle) => disable_nagle == "true",
41 _ => true
42 };
43 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
45 Ok(d) => d == "true",
46 _ => false
47 };
48}
49
50pub const DEFAULT_CONNECTION_RETRIES: u32 = 4;
52
53const INITIAL_BACKOFF_MS: u64 = 50;
55
56impl<T: EventMessage + Unpin> Connection<T> {
57 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
58 Self::connect_with_retries(debug_ws_url, DEFAULT_CONNECTION_RETRIES).await
59 }
60
61 pub async fn connect_with_retries(
62 debug_ws_url: impl AsRef<str>,
63 retries: u32,
64 ) -> Result<Self> {
65 let mut config = WebSocketConfig::default();
66
67 if !*WEBSOCKET_DEFAULTS {
68 config.max_message_size = None;
69 config.max_frame_size = None;
70 }
71
72 let url = debug_ws_url.as_ref();
73 let use_uring = crate::uring_fs::is_enabled();
74 let mut last_err = None;
75
76 for attempt in 0..=retries {
77 let result = if use_uring {
78 Self::connect_uring(url, config).await
79 } else {
80 Self::connect_default(url, config).await
81 };
82
83 match result {
84 Ok(ws) => {
85 return Ok(Self {
86 pending_commands: Default::default(),
87 ws,
88 next_id: 0,
89 needs_flush: false,
90 _marker: Default::default(),
91 });
92 }
93 Err(e) => {
94 last_err = Some(e);
95 if attempt < retries {
96 let backoff_ms = INITIAL_BACKOFF_MS * 3u64.saturating_pow(attempt);
97 tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
98 }
99 }
100 }
101 }
102
103 Err(last_err.unwrap_or_else(|| CdpError::msg("connection failed")))
104 }
105
106 async fn connect_default(
108 url: &str,
109 config: WebSocketConfig,
110 ) -> Result<WebSocketStream<ConnectStream>> {
111 let (ws, _) =
112 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
113 Ok(ws)
114 }
115
116 async fn connect_uring(
119 url: &str,
120 config: WebSocketConfig,
121 ) -> Result<WebSocketStream<ConnectStream>> {
122 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
123
124 let request = url.into_client_request()?;
125 let host = request
126 .uri()
127 .host()
128 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
129 let port = request.uri().port_u16().unwrap_or(9222);
130
131 let addr_str = format!("{}:{}", host, port);
133 let addr: std::net::SocketAddr = match addr_str.parse() {
134 Ok(a) => a,
135 Err(_) => {
136 return Self::connect_default(url, config).await;
138 }
139 };
140
141 let std_stream = crate::uring_fs::tcp_connect(addr)
143 .await
144 .map_err(CdpError::Io)?;
145
146 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
148 if *DISABLE_NAGLE {
149 let _ = std_stream.set_nodelay(true);
150 }
151
152 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
154
155 let (ws, _) = tokio_tungstenite::client_async_with_config(
157 request,
158 MaybeTlsStream::Plain(tokio_stream),
159 Some(config),
160 )
161 .await?;
162
163 Ok(ws)
164 }
165}
166
167impl<T: EventMessage> Connection<T> {
168 fn next_call_id(&mut self) -> CallId {
169 let id = CallId::new(self.next_id);
170 self.next_id = self.next_id.wrapping_add(1);
171 id
172 }
173
174 pub fn submit_command(
177 &mut self,
178 method: MethodId,
179 session_id: Option<SessionId>,
180 params: serde_json::Value,
181 ) -> serde_json::Result<CallId> {
182 let id = self.next_call_id();
183 let call = MethodCall {
184 id,
185 method,
186 session_id: session_id.map(Into::into),
187 params,
188 };
189 self.pending_commands.push_back(call);
190 Ok(id)
191 }
192
193 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
198 if self.needs_flush {
200 match self.ws.poll_flush_unpin(cx) {
201 Poll::Ready(Ok(())) => self.needs_flush = false,
202 Poll::Ready(Err(e)) => return Err(e.into()),
203 Poll::Pending => return Ok(()),
204 }
205 }
206
207 let mut sent_any = false;
209 while !self.pending_commands.is_empty() {
210 match self.ws.poll_ready_unpin(cx) {
211 Poll::Ready(Ok(())) => {
212 let cmd = self.pending_commands.pop_front().unwrap();
213 tracing::trace!("Sending {:?}", cmd);
214 let msg = serde_json::to_string(&cmd)?;
215 self.ws.start_send_unpin(msg.into())?;
216 sent_any = true;
217 }
218 _ => break,
219 }
220 }
221
222 if sent_any {
224 match self.ws.poll_flush_unpin(cx) {
225 Poll::Ready(Ok(())) => {}
226 Poll::Ready(Err(e)) => return Err(e.into()),
227 Poll::Pending => self.needs_flush = true,
228 }
229 }
230
231 Ok(())
232 }
233}
234
235impl<T: EventMessage + Unpin> Stream for Connection<T> {
236 type Item = Result<Box<Message<T>>>;
237
238 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
239 let pin = self.get_mut();
240
241 if let Err(err) = pin.start_send_next(cx) {
243 return Poll::Ready(Some(Err(err)));
244 }
245
246 match ready!(pin.ws.poll_next_unpin(cx)) {
248 Some(Ok(WsMessage::Text(text))) => {
249 match decode_message::<T>(text.as_bytes(), Some(&text)) {
250 Ok(msg) => Poll::Ready(Some(Ok(msg))),
251 Err(err) => {
252 tracing::debug!(
253 target: "chromiumoxide::conn::raw_ws::parse_errors",
254 "Dropping malformed text WS frame: {err}",
255 );
256 cx.waker().wake_by_ref();
257 Poll::Pending
258 }
259 }
260 }
261 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
262 Ok(msg) => Poll::Ready(Some(Ok(msg))),
263 Err(err) => {
264 tracing::debug!(
265 target: "chromiumoxide::conn::raw_ws::parse_errors",
266 "Dropping malformed binary WS frame: {err}",
267 );
268 cx.waker().wake_by_ref();
269 Poll::Pending
270 }
271 },
272 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
273 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
275 cx.waker().wake_by_ref();
276 Poll::Pending
277 }
278 Some(Ok(msg)) => {
279 tracing::debug!(
281 target: "chromiumoxide::conn::raw_ws::parse_errors",
282 "Unexpected WS message type: {:?}",
283 msg
284 );
285 cx.waker().wake_by_ref();
286 Poll::Pending
287 }
288 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
289 None => {
290 Poll::Ready(None)
292 }
293 }
294 }
295}
296
297#[cfg(not(feature = "serde_stacker"))]
301fn decode_message<T: EventMessage>(
302 bytes: &[u8],
303 raw_text_for_logging: Option<&str>,
304) -> Result<Box<Message<T>>> {
305 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
306 Ok(msg) => {
307 tracing::trace!("Received {:?}", msg);
308 Ok(msg)
309 }
310 Err(err) => {
311 if let Some(txt) = raw_text_for_logging {
312 let preview = &txt[..txt.len().min(512)];
313 tracing::debug!(
314 target: "chromiumoxide::conn::raw_ws::parse_errors",
315 msg_len = txt.len(),
316 "Skipping unrecognized WS message {err} preview={preview}",
317 );
318 } else {
319 tracing::debug!(
320 target: "chromiumoxide::conn::raw_ws::parse_errors",
321 "Skipping unrecognized binary WS message {err}",
322 );
323 }
324 Err(err.into())
325 }
326 }
327}
328
329#[cfg(feature = "serde_stacker")]
333fn decode_message<T: EventMessage>(
334 bytes: &[u8],
335 raw_text_for_logging: Option<&str>,
336) -> Result<Box<Message<T>>> {
337 use serde::Deserialize;
338 let mut de = serde_json::Deserializer::from_slice(bytes);
339
340 de.disable_recursion_limit();
341
342 let de = serde_stacker::Deserializer::new(&mut de);
343
344 match Box::<Message<T>>::deserialize(de) {
345 Ok(msg) => {
346 tracing::trace!("Received {:?}", msg);
347 Ok(msg)
348 }
349 Err(err) => {
350 if let Some(txt) = raw_text_for_logging {
351 let preview = &txt[..txt.len().min(512)];
352 tracing::debug!(
353 target: "chromiumoxide::conn::raw_ws::parse_errors",
354 msg_len = txt.len(),
355 "Skipping unrecognized WS message {err} preview={preview}",
356 );
357 } else {
358 tracing::debug!(
359 target: "chromiumoxide::conn::raw_ws::parse_errors",
360 "Skipping unrecognized binary WS message {err}",
361 );
362 }
363 Err(err.into())
364 }
365 }
366}