1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25 pending_commands: VecDeque<MethodCall>,
27 ws: WebSocketStream<ConnectStream>,
29 next_id: usize,
31 needs_flush: bool,
33 pending_flush: Option<MethodCall>,
35 _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42 Ok(disable_nagle) => disable_nagle == "true",
43 _ => true
44 };
45 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47 Ok(d) => d == "true",
48 _ => false
49 };
50}
51
52impl<T: EventMessage + Unpin> Connection<T> {
53 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
54 let mut config = WebSocketConfig::default();
55
56 if !*WEBSOCKET_DEFAULTS {
57 config.max_message_size = None;
58 config.max_frame_size = None;
59 }
60
61 let ws = if crate::uring_fs::is_enabled() {
62 Self::connect_uring(debug_ws_url.as_ref(), config).await?
63 } else {
64 Self::connect_default(debug_ws_url.as_ref(), config).await?
65 };
66
67 Ok(Self {
68 pending_commands: Default::default(),
69 ws,
70 next_id: 0,
71 needs_flush: false,
72 pending_flush: None,
73 _marker: Default::default(),
74 })
75 }
76
77 async fn connect_default(
79 url: &str,
80 config: WebSocketConfig,
81 ) -> Result<WebSocketStream<ConnectStream>> {
82 let (ws, _) =
83 tokio_tungstenite::connect_async_with_config(url, Some(config), *DISABLE_NAGLE).await?;
84 Ok(ws)
85 }
86
87 async fn connect_uring(
90 url: &str,
91 config: WebSocketConfig,
92 ) -> Result<WebSocketStream<ConnectStream>> {
93 use tokio_tungstenite::tungstenite::client::IntoClientRequest;
94
95 let request = url.into_client_request()?;
96 let host = request
97 .uri()
98 .host()
99 .ok_or_else(|| CdpError::msg("no host in CDP WebSocket URL"))?;
100 let port = request.uri().port_u16().unwrap_or(9222);
101
102 let addr_str = format!("{}:{}", host, port);
104 let addr: std::net::SocketAddr = match addr_str.parse() {
105 Ok(a) => a,
106 Err(_) => {
107 return Self::connect_default(url, config).await;
109 }
110 };
111
112 let std_stream = crate::uring_fs::tcp_connect(addr)
114 .await
115 .map_err(CdpError::Io)?;
116
117 std_stream.set_nonblocking(true).map_err(CdpError::Io)?;
119 if *DISABLE_NAGLE {
120 let _ = std_stream.set_nodelay(true);
121 }
122
123 let tokio_stream = tokio::net::TcpStream::from_std(std_stream).map_err(CdpError::Io)?;
125
126 let (ws, _) = tokio_tungstenite::client_async_with_config(
128 request,
129 MaybeTlsStream::Plain(tokio_stream),
130 Some(config),
131 )
132 .await?;
133
134 Ok(ws)
135 }
136}
137
138impl<T: EventMessage> Connection<T> {
139 fn next_call_id(&mut self) -> CallId {
140 let id = CallId::new(self.next_id);
141 self.next_id = self.next_id.wrapping_add(1);
142 id
143 }
144
145 pub fn submit_command(
148 &mut self,
149 method: MethodId,
150 session_id: Option<SessionId>,
151 params: serde_json::Value,
152 ) -> serde_json::Result<CallId> {
153 let id = self.next_call_id();
154 let call = MethodCall {
155 id,
156 method,
157 session_id: session_id.map(Into::into),
158 params,
159 };
160 self.pending_commands.push_back(call);
161 Ok(id)
162 }
163
164 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
167 if self.needs_flush {
168 if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
169 self.needs_flush = false;
170 }
171 }
172 if self.pending_flush.is_none() && !self.needs_flush {
173 if let Some(cmd) = self.pending_commands.pop_front() {
174 tracing::trace!("Sending {:?}", cmd);
175 let msg = serde_json::to_string(&cmd)?;
176 self.ws.start_send_unpin(msg.into())?;
177 self.pending_flush = Some(cmd);
178 }
179 }
180 Ok(())
181 }
182}
183
184impl<T: EventMessage + Unpin> Stream for Connection<T> {
185 type Item = Result<Box<Message<T>>>;
186
187 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
188 let pin = self.get_mut();
189
190 loop {
192 if let Err(err) = pin.start_send_next(cx) {
193 return Poll::Ready(Some(Err(err)));
194 }
195
196 if let Some(call) = pin.pending_flush.take() {
197 if pin.ws.poll_ready_unpin(cx).is_ready() {
198 pin.needs_flush = true;
199 continue;
201 } else {
202 pin.pending_flush = Some(call);
203 }
204 }
205
206 break;
207 }
208
209 match ready!(pin.ws.poll_next_unpin(cx)) {
211 Some(Ok(WsMessage::Text(text))) => {
212 match decode_message::<T>(text.as_bytes(), Some(&text)) {
213 Ok(msg) => Poll::Ready(Some(Ok(msg))),
214 Err(err) => {
215 tracing::debug!(
216 target: "chromiumoxide::conn::raw_ws::parse_errors",
217 "Dropping malformed text WS frame: {err}",
218 );
219 cx.waker().wake_by_ref();
220 Poll::Pending
221 }
222 }
223 }
224 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
225 Ok(msg) => Poll::Ready(Some(Ok(msg))),
226 Err(err) => {
227 tracing::debug!(
228 target: "chromiumoxide::conn::raw_ws::parse_errors",
229 "Dropping malformed binary WS frame: {err}",
230 );
231 cx.waker().wake_by_ref();
232 Poll::Pending
233 }
234 },
235 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
236 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
238 cx.waker().wake_by_ref();
239 Poll::Pending
240 }
241 Some(Ok(msg)) => {
242 tracing::debug!(
244 target: "chromiumoxide::conn::raw_ws::parse_errors",
245 "Unexpected WS message type: {:?}",
246 msg
247 );
248 cx.waker().wake_by_ref();
249 Poll::Pending
250 }
251 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
252 None => {
253 Poll::Ready(None)
255 }
256 }
257 }
258}
259
260#[cfg(not(feature = "serde_stacker"))]
264fn decode_message<T: EventMessage>(
265 bytes: &[u8],
266 raw_text_for_logging: Option<&str>,
267) -> Result<Box<Message<T>>> {
268 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
269 Ok(msg) => {
270 tracing::trace!("Received {:?}", msg);
271 Ok(msg)
272 }
273 Err(err) => {
274 if let Some(txt) = raw_text_for_logging {
275 let preview = &txt[..txt.len().min(512)];
276 tracing::debug!(
277 target: "chromiumoxide::conn::raw_ws::parse_errors",
278 msg_len = txt.len(),
279 "Skipping unrecognized WS message {err} preview={preview}",
280 );
281 } else {
282 tracing::debug!(
283 target: "chromiumoxide::conn::raw_ws::parse_errors",
284 "Skipping unrecognized binary WS message {err}",
285 );
286 }
287 Err(err.into())
288 }
289 }
290}
291
292#[cfg(feature = "serde_stacker")]
296fn decode_message<T: EventMessage>(
297 bytes: &[u8],
298 raw_text_for_logging: Option<&str>,
299) -> Result<Box<Message<T>>> {
300 use serde::Deserialize;
301 let mut de = serde_json::Deserializer::from_slice(bytes);
302
303 de.disable_recursion_limit();
304
305 let de = serde_stacker::Deserializer::new(&mut de);
306
307 match Box::<Message<T>>::deserialize(de) {
308 Ok(msg) => {
309 tracing::trace!("Received {:?}", msg);
310 Ok(msg)
311 }
312 Err(err) => {
313 if let Some(txt) = raw_text_for_logging {
314 let preview = &txt[..txt.len().min(512)];
315 tracing::debug!(
316 target: "chromiumoxide::conn::raw_ws::parse_errors",
317 msg_len = txt.len(),
318 "Skipping unrecognized WS message {err} preview={preview}",
319 );
320 } else {
321 tracing::debug!(
322 target: "chromiumoxide::conn::raw_ws::parse_errors",
323 "Skipping unrecognized binary WS message {err}",
324 );
325 }
326 Err(err.into())
327 }
328 }
329}