1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25 pending_commands: VecDeque<MethodCall>,
27 ws: WebSocketStream<ConnectStream>,
29 next_id: usize,
31 needs_flush: bool,
33 _marker: PhantomData<T>,
35}
36
37lazy_static::lazy_static! {
38 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
40 Ok(disable_nagle) => disable_nagle == "true",
41 _ => true
42 };
43 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
45 Ok(d) => d == "true",
46 _ => false
47 };
48}
49
50impl<T: EventMessage + Unpin> Connection<T> {
51 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
52 let mut config = WebSocketConfig::default();
53
54 if !*WEBSOCKET_DEFAULTS {
55 config.max_message_size = None;
56 config.max_frame_size = None;
57 }
58
59 let ws = if crate::uring_fs::is_enabled() {
60 Self::connect_uring(debug_ws_url.as_ref(), config).await?
61 } else {
62 Self::connect_default(debug_ws_url.as_ref(), config).await?
63 };
64
65 Ok(Self {
66 pending_commands: Default::default(),
67 ws,
68 next_id: 0,
69 needs_flush: false,
70 _marker: Default::default(),
71 })
72 }
73
74 async fn connect_default(
76 url: &str,
77 config: WebSocketConfig,
78 ) -> Result<WebSocketStream<ConnectStream>> {
79 let (ws, _) =
80 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
81 Ok(ws)
82 }
83
84 async fn connect_uring(
87 url: &str,
88 config: WebSocketConfig,
89 ) -> Result<WebSocketStream<ConnectStream>> {
90 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
91
92 let request = url.into_client_request()?;
93 let host = request
94 .uri()
95 .host()
96 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
97 let port = request.uri().port_u16().unwrap_or(9222);
98
99 let addr_str = format!("{}:{}", host, port);
101 let addr: std::net::SocketAddr = match addr_str.parse() {
102 Ok(a) => a,
103 Err(_) => {
104 return Self::connect_default(url, config).await;
106 }
107 };
108
109 let std_stream = crate::uring_fs::tcp_connect(addr)
111 .await
112 .map_err(CdpError::Io)?;
113
114 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
116 if *DISABLE_NAGLE {
117 let _ = std_stream.set_nodelay(true);
118 }
119
120 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
122
123 let (ws, _) = tokio_tungstenite::client_async_with_config(
125 request,
126 MaybeTlsStream::Plain(tokio_stream),
127 Some(config),
128 )
129 .await?;
130
131 Ok(ws)
132 }
133}
134
135impl<T: EventMessage> Connection<T> {
136 fn next_call_id(&mut self) -> CallId {
137 let id = CallId::new(self.next_id);
138 self.next_id = self.next_id.wrapping_add(1);
139 id
140 }
141
142 pub fn submit_command(
145 &mut self,
146 method: MethodId,
147 session_id: Option<SessionId>,
148 params: serde_json::Value,
149 ) -> serde_json::Result<CallId> {
150 let id = self.next_call_id();
151 let call = MethodCall {
152 id,
153 method,
154 session_id: session_id.map(Into::into),
155 params,
156 };
157 self.pending_commands.push_back(call);
158 Ok(id)
159 }
160
161 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
166 if self.needs_flush {
168 match self.ws.poll_flush_unpin(cx) {
169 Poll::Ready(Ok(())) => self.needs_flush = false,
170 Poll::Ready(Err(e)) => return Err(e.into()),
171 Poll::Pending => return Ok(()),
172 }
173 }
174
175 let mut sent_any = false;
177 while !self.pending_commands.is_empty() {
178 match self.ws.poll_ready_unpin(cx) {
179 Poll::Ready(Ok(())) => {
180 let cmd = self.pending_commands.pop_front().unwrap();
181 tracing::trace!("Sending {:?}", cmd);
182 let msg = serde_json::to_string(&cmd)?;
183 self.ws.start_send_unpin(msg.into())?;
184 sent_any = true;
185 }
186 _ => break,
187 }
188 }
189
190 if sent_any {
192 match self.ws.poll_flush_unpin(cx) {
193 Poll::Ready(Ok(())) => {}
194 Poll::Ready(Err(e)) => return Err(e.into()),
195 Poll::Pending => self.needs_flush = true,
196 }
197 }
198
199 Ok(())
200 }
201}
202
203impl<T: EventMessage + Unpin> Stream for Connection<T> {
204 type Item = Result<Box<Message<T>>>;
205
206 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207 let pin = self.get_mut();
208
209 if let Err(err) = pin.start_send_next(cx) {
211 return Poll::Ready(Some(Err(err)));
212 }
213
214 match ready!(pin.ws.poll_next_unpin(cx)) {
216 Some(Ok(WsMessage::Text(text))) => {
217 match decode_message::<T>(text.as_bytes(), Some(&text)) {
218 Ok(msg) => Poll::Ready(Some(Ok(msg))),
219 Err(err) => {
220 tracing::debug!(
221 target: "chromiumoxide::conn::raw_ws::parse_errors",
222 "Dropping malformed text WS frame: {err}",
223 );
224 cx.waker().wake_by_ref();
225 Poll::Pending
226 }
227 }
228 }
229 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
230 Ok(msg) => Poll::Ready(Some(Ok(msg))),
231 Err(err) => {
232 tracing::debug!(
233 target: "chromiumoxide::conn::raw_ws::parse_errors",
234 "Dropping malformed binary WS frame: {err}",
235 );
236 cx.waker().wake_by_ref();
237 Poll::Pending
238 }
239 },
240 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
241 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
243 cx.waker().wake_by_ref();
244 Poll::Pending
245 }
246 Some(Ok(msg)) => {
247 tracing::debug!(
249 target: "chromiumoxide::conn::raw_ws::parse_errors",
250 "Unexpected WS message type: {:?}",
251 msg
252 );
253 cx.waker().wake_by_ref();
254 Poll::Pending
255 }
256 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
257 None => {
258 Poll::Ready(None)
260 }
261 }
262 }
263}
264
265#[cfg(not(feature = "serde_stacker"))]
269fn decode_message<T: EventMessage>(
270 bytes: &[u8],
271 raw_text_for_logging: Option<&str>,
272) -> Result<Box<Message<T>>> {
273 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
274 Ok(msg) => {
275 tracing::trace!("Received {:?}", msg);
276 Ok(msg)
277 }
278 Err(err) => {
279 if let Some(txt) = raw_text_for_logging {
280 let preview = &txt[..txt.len().min(512)];
281 tracing::debug!(
282 target: "chromiumoxide::conn::raw_ws::parse_errors",
283 msg_len = txt.len(),
284 "Skipping unrecognized WS message {err} preview={preview}",
285 );
286 } else {
287 tracing::debug!(
288 target: "chromiumoxide::conn::raw_ws::parse_errors",
289 "Skipping unrecognized binary WS message {err}",
290 );
291 }
292 Err(err.into())
293 }
294 }
295}
296
297#[cfg(feature = "serde_stacker")]
301fn decode_message<T: EventMessage>(
302 bytes: &[u8],
303 raw_text_for_logging: Option<&str>,
304) -> Result<Box<Message<T>>> {
305 use serde::Deserialize;
306 let mut de = serde_json::Deserializer::from_slice(bytes);
307
308 de.disable_recursion_limit();
309
310 let de = serde_stacker::Deserializer::new(&mut de);
311
312 match Box::<Message<T>>::deserialize(de) {
313 Ok(msg) => {
314 tracing::trace!("Received {:?}", msg);
315 Ok(msg)
316 }
317 Err(err) => {
318 if let Some(txt) = raw_text_for_logging {
319 let preview = &txt[..txt.len().min(512)];
320 tracing::debug!(
321 target: "chromiumoxide::conn::raw_ws::parse_errors",
322 msg_len = txt.len(),
323 "Skipping unrecognized WS message {err} preview={preview}",
324 );
325 } else {
326 tracing::debug!(
327 target: "chromiumoxide::conn::raw_ws::parse_errors",
328 "Skipping unrecognized binary WS message {err}",
329 );
330 }
331 Err(err.into())
332 }
333 }
334}