chromiumoxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use futures::stream::Stream;
7use futures::task::{Context, Poll};
8use futures::{SinkExt, StreamExt};
9use tokio_tungstenite::tungstenite::Message as WsMessage;
10use tokio_tungstenite::MaybeTlsStream;
11use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
12
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19type ConnectStream = MaybeTlsStream<tokio::net::TcpStream>;
20
21/// Exchanges the messages with the websocket
22#[must_use = "streams do nothing unless polled"]
23#[derive(Debug)]
24pub struct Connection<T: EventMessage> {
25    /// Queue of commands to send.
26    pending_commands: VecDeque<MethodCall>,
27    /// The websocket of the chromium instance
28    ws: WebSocketStream<ConnectStream>,
29    /// The identifier for a specific command
30    next_id: usize,
31    /// A flush is required.
32    needs_flush: bool,
33    /// The message that is currently being proceessed
34    pending_flush: Option<MethodCall>,
35    /// The phantom marker.
36    _marker: PhantomData<T>,
37}
38
39lazy_static::lazy_static! {
40    /// Nagle's algorithm disabled?
41    static ref DISABLE_NAGLE: bool = match std::env::var("DISABLE_NAGLE") {
42        Ok(disable_nagle) => disable_nagle == "true",
43        _ => true
44    };
45    /// Websocket config defaults
46    static ref WEBSOCKET_DEFAULTS: bool = match std::env::var("WEBSOCKET_DEFAULTS") {
47        Ok(d) => d == "true",
48        _ => false
49    };
50}
51
52impl<T: EventMessage + Unpin> Connection<T> {
53    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
54        let mut config = WebSocketConfig::default();
55
56        if *WEBSOCKET_DEFAULTS == false {
57            config.max_message_size = None;
58            config.max_frame_size = None;
59        }
60
61        let (ws, _) = tokio_tungstenite::connect_async_with_config(
62            debug_ws_url.as_ref(),
63            Some(config),
64            *DISABLE_NAGLE,
65        )
66        .await?;
67
68        Ok(Self {
69            pending_commands: Default::default(),
70            ws,
71            next_id: 0,
72            needs_flush: false,
73            pending_flush: None,
74            _marker: Default::default(),
75        })
76    }
77}
78
79impl<T: EventMessage> Connection<T> {
80    fn next_call_id(&mut self) -> CallId {
81        let id = CallId::new(self.next_id);
82        self.next_id = self.next_id.wrapping_add(1);
83        id
84    }
85
86    /// Queue in the command to send over the socket and return the id for this
87    /// command
88    pub fn submit_command(
89        &mut self,
90        method: MethodId,
91        session_id: Option<SessionId>,
92        params: serde_json::Value,
93    ) -> serde_json::Result<CallId> {
94        let id = self.next_call_id();
95        let call = MethodCall {
96            id,
97            method,
98            session_id: session_id.map(Into::into),
99            params,
100        };
101        self.pending_commands.push_back(call);
102        Ok(id)
103    }
104
105    /// flush any processed message and start sending the next over the conn
106    /// sink
107    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
108        if self.needs_flush {
109            if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
110                self.needs_flush = false;
111            }
112        }
113        if self.pending_flush.is_none() && !self.needs_flush {
114            if let Some(cmd) = self.pending_commands.pop_front() {
115                tracing::trace!("Sending {:?}", cmd);
116                let msg = serde_json::to_string(&cmd)?;
117                self.ws.start_send_unpin(msg.into())?;
118                self.pending_flush = Some(cmd);
119            }
120        }
121        Ok(())
122    }
123}
124
125impl<T: EventMessage + Unpin> Stream for Connection<T> {
126    type Item = Result<Box<Message<T>>>;
127
128    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
129        let pin = self.get_mut();
130
131        // flush pending outgoing messages
132        loop {
133            if let Err(err) = pin.start_send_next(cx) {
134                return Poll::Ready(Some(Err(err)));
135            }
136
137            if let Some(call) = pin.pending_flush.take() {
138                if pin.ws.poll_ready_unpin(cx).is_ready() {
139                    pin.needs_flush = true;
140                    // try another flush in this same poll
141                    continue;
142                } else {
143                    pin.pending_flush = Some(call);
144                }
145            }
146
147            break;
148        }
149
150        // read from the websocket
151        match ready!(pin.ws.poll_next_unpin(cx)) {
152            Some(Ok(WsMessage::Text(text))) => {
153                match decode_message::<T>(text.as_bytes(), Some(&text)) {
154                    Ok(msg) => Poll::Ready(Some(Ok(msg))),
155                    Err(err) => {
156                        tracing::debug!(
157                            target: "chromiumoxide::conn::raw_ws::parse_errors",
158                            "Dropping malformed text WS frame: {err}",
159                        );
160                        cx.waker().wake_by_ref();
161                        Poll::Pending
162                    }
163                }
164            }
165            Some(Ok(WsMessage::Binary(buf))) => match decode_message::<T>(&buf, None) {
166                Ok(msg) => Poll::Ready(Some(Ok(msg))),
167                Err(err) => {
168                    tracing::debug!(
169                        target: "chromiumoxide::conn::raw_ws::parse_errors",
170                        "Dropping malformed binary WS frame: {err}",
171                    );
172                    cx.waker().wake_by_ref();
173                    Poll::Pending
174                }
175            },
176            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
177            // ignore ping and pong
178            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
179                cx.waker().wake_by_ref();
180                Poll::Pending
181            }
182            Some(Ok(msg)) => {
183                // Unexpected WS message type, but not fatal.
184                tracing::debug!(
185                    target: "chromiumoxide::conn::raw_ws::parse_errors",
186                    "Unexpected WS message type: {:?}",
187                    msg
188                );
189                cx.waker().wake_by_ref();
190                Poll::Pending
191            }
192            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
193            None => {
194                // ws connection closed
195                Poll::Ready(None)
196            }
197        }
198    }
199}
200
201/// Shared decode path for both text and binary WS frames.
202/// `raw_text_for_logging` is only provided for textual frames so we can log the original
203/// payload on parse failure if desired.
204#[cfg(not(feature = "serde_stacker"))]
205fn decode_message<T: EventMessage>(
206    bytes: &[u8],
207    raw_text_for_logging: Option<&str>,
208) -> Result<Box<Message<T>>> {
209    match serde_json::from_slice::<Box<Message<T>>>(bytes) {
210        Ok(msg) => {
211            tracing::trace!("Received {:?}", msg);
212            Ok(msg)
213        }
214        Err(err) => {
215            if let Some(txt) = raw_text_for_logging {
216                tracing::error!(
217                    target: "chromiumoxide::conn::raw_ws::parse_errors",
218                    msg_len = txt.len(),
219                    "Failed to parse raw WS message {err}",
220                );
221            } else {
222                tracing::error!(
223                    target: "chromiumoxide::conn::raw_ws::parse_errors",
224                    "Failed to parse binary WS message {err}",
225                );
226            }
227            Err(err.into())
228        }
229    }
230}
231
232/// Shared decode path for both text and binary WS frames.
233/// `raw_text_for_logging` is only provided for textual frames so we can log the original
234/// payload on parse failure if desired.
235#[cfg(feature = "serde_stacker")]
236fn decode_message<T: EventMessage>(
237    bytes: &[u8],
238    raw_text_for_logging: Option<&str>,
239) -> Result<Box<Message<T>>> {
240    use serde::Deserialize;
241    let mut de = serde_json::Deserializer::from_slice(bytes);
242
243    de.disable_recursion_limit();
244
245    let de = serde_stacker::Deserializer::new(&mut de);
246
247    match Box::<Message<T>>::deserialize(de) {
248        Ok(msg) => {
249            tracing::trace!("Received {:?}", msg);
250            Ok(msg)
251        }
252        Err(err) => {
253            if let Some(txt) = raw_text_for_logging {
254                tracing::error!(
255                    target: "chromiumoxide::conn::raw_ws::parse_errors",
256                    msg_len = txt.len(),
257                    "Failed to parse raw WS message {err}",
258                );
259            } else {
260                tracing::error!(
261                    target: "chromiumoxide::conn::raw_ws::parse_errors",
262                    "Failed to parse binary WS message {err}",
263                );
264            }
265            Err(err.into())
266        }
267    }
268}