1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25 pending_commands: VecDeque<MethodCall>,
27 ws: WebSocketStream<ConnectStream>,
29 next_id: usize,
31 needs_flush: bool,
33 pending_flush: Option<MethodCall>,
35 _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40 static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42 Ok(disable_nagle) => disable_nagle == "true",
43 _ => true
44 };
45 static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47 Ok(d) => d == "true",
48 _ => false
49 };
50}
51
52impl<T: EventMessage + Unpin> Connection<T> {
53 pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
54 let mut config = WebSocketConfig::default();
55
56 if *WEBSOCKET_DEFAULTS == false {
57 config.max_message_size = None;
58 config.max_frame_size = None;
59 }
60
61 let (ws, _) = tokio_tungstenite::connect_async_with_config(
62 debug_ws_url.as_ref(),
63 Some(config),
64 *DISABLE_NAGLE,
65 )
66 .await?;
67
68 Ok(Self {
69 pending_commands: Default::default(),
70 ws,
71 next_id: 0,
72 needs_flush: false,
73 pending_flush: None,
74 _marker: Default::default(),
75 })
76 }
77}
78
79impl<T: EventMessage> Connection<T> {
80 fn next_call_id(&mut self) -> CallId {
81 let id = CallId::new(self.next_id);
82 self.next_id = self.next_id.wrapping_add(1);
83 id
84 }
85
86 pub fn submit_command(
89 &mut self,
90 method: MethodId,
91 session_id: Option<SessionId>,
92 params: serde_json::Value,
93 ) -> serde_json::Result<CallId> {
94 let id = self.next_call_id();
95 let call = MethodCall {
96 id,
97 method,
98 session_id: session_id.map(Into::into),
99 params,
100 };
101 self.pending_commands.push_back(call);
102 Ok(id)
103 }
104
105 fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
108 if self.needs_flush {
109 if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
110 self.needs_flush = false;
111 }
112 }
113 if self.pending_flush.is_none() && !self.needs_flush {
114 if let Some(cmd) = self.pending_commands.pop_front() {
115 tracing::trace!("Sending {:?}", cmd);
116 let msg = serde_json::to_string(&cmd)?;
117 self.ws.start_send_unpin(msg.into())?;
118 self.pending_flush = Some(cmd);
119 }
120 }
121 Ok(())
122 }
123}
124
125impl<T: EventMessage + Unpin> Stream for Connection<T> {
126 type Item = Result<Box<Message<T>>>;
127
128 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129 let pin = self.get_mut();
130
131 loop {
133 if let Err(err) = pin.start_send_next(cx) {
134 return Poll::Ready(Some(Err(err)));
135 }
136
137 if let Some(call) = pin.pending_flush.take() {
138 if pin.ws.poll_ready_unpin(cx).is_ready() {
139 pin.needs_flush = true;
140 continue;
142 } else {
143 pin.pending_flush = Some(call);
144 }
145 }
146
147 break;
148 }
149
150 match ready!(pin.ws.poll_next_unpin(cx)) {
152 Some(Ok(WsMessage::Text(text))) => {
153 match decode_message::<T>(text.as_bytes(), Some(&text)) {
154 Ok(msg) => Poll::Ready(Some(Ok(msg))),
155 Err(err) => {
156 tracing::debug!(
157 target: "chromiumoxide::conn::raw_ws::parse_errors",
158 "Dropping malformed text WS frame: {err}",
159 );
160 cx.waker().wake_by_ref();
161 Poll::Pending
162 }
163 }
164 }
165 Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
166 Ok(msg) => Poll::Ready(Some(Ok(msg))),
167 Err(err) => {
168 tracing::debug!(
169 target: "chromiumoxide::conn::raw_ws::parse_errors",
170 "Dropping malformed binary WS frame: {err}",
171 );
172 cx.waker().wake_by_ref();
173 Poll::Pending
174 }
175 },
176 Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
177 Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
179 cx.waker().wake_by_ref();
180 Poll::Pending
181 }
182 Some(Ok(msg)) => {
183 tracing::debug!(
185 target: "chromiumoxide::conn::raw_ws::parse_errors",
186 "Unexpected WS message type: {:?}",
187 msg
188 );
189 cx.waker().wake_by_ref();
190 Poll::Pending
191 }
192 Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
193 None => {
194 Poll::Ready(None)
196 }
197 }
198 }
199}
200
201#[cfg(not(feature = "serde_stacker"))]
205fn decode_message<T: EventMessage>(
206 bytes: &[u8],
207 raw_text_for_logging: Option<&str>,
208) -> Result<Box<Message<T>>> {
209 match serde_json::from_slice::<Box<Message<T>>>(bytes) {
210 Ok(msg) => {
211 tracing::trace!("Received {:?}", msg);
212 Ok(msg)
213 }
214 Err(err) => {
215 if let Some(txt) = raw_text_for_logging {
216 tracing::error!(
217 target: "chromiumoxide::conn::raw_ws::parse_errors",
218 msg_len = txt.len(),
219 "Failed to parse raw WS message {err}",
220 );
221 } else {
222 tracing::error!(
223 target: "chromiumoxide::conn::raw_ws::parse_errors",
224 "Failed to parse binary WS message {err}",
225 );
226 }
227 Err(err.into())
228 }
229 }
230}
231
232#[cfg(feature = "serde_stacker")]
236fn decode_message<T: EventMessage>(
237 bytes: &[u8],
238 raw_text_for_logging: Option<&str>,
239) -> Result<Box<Message<T>>> {
240 use serde::Deserialize;
241 let mut de = serde_json::Deserializer::from_slice(bytes);
242
243 de.disable_recursion_limit();
244
245 let de = serde_stacker::Deserializer::new(&mut de);
246
247 match Box::<Message<T>>::deserialize(de) {
248 Ok(msg) => {
249 tracing::trace!("Received {:?}", msg);
250 Ok(msg)
251 }
252 Err(err) => {
253 if let Some(txt) = raw_text_for_logging {
254 tracing::error!(
255 target: "chromiumoxide::conn::raw_ws::parse_errors",
256 msg_len = txt.len(),
257 "Failed to parse raw WS message {err}",
258 );
259 } else {
260 tracing::error!(
261 target: "chromiumoxide::conn::raw_ws::parse_errors",
262 "Failed to parse binary WS message {err}",
263 );
264 }
265 Err(err.into())
266 }
267 }
268}