Skip to main content

chaser_oxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use async_tungstenite::tungstenite::Message as WsMessage;
7use async_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream};
8use futures::stream::Stream;
9use futures::task::{Context, Poll};
10use futures::{SinkExt, StreamExt};
11
12use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
13use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
14
15use crate::error::CdpError;
16use crate::error::Result;
17
18cfg_if::cfg_if! {
19    if #[cfg(feature = "async-std-runtime")] {
20       use async_tungstenite::async_std::ConnectStream;
21    } else if #[cfg(feature = "tokio-runtime")] {
22        use async_tungstenite::tokio::ConnectStream;
23    }
24}
25/// Exchanges the messages with the websocket
26#[must_use = "streams do nothing unless polled"]
27#[derive(Debug)]
28pub struct Connection<T: EventMessage> {
29    /// Queue of commands to send.
30    pending_commands: VecDeque<MethodCall>,
31    /// The websocket of the chromium instance
32    ws: WebSocketStream<ConnectStream>,
33    /// The identifier for a specific command
34    next_id: usize,
35    needs_flush: bool,
36    /// The message that is currently being proceessed
37    pending_flush: Option<MethodCall>,
38    _marker: PhantomData<T>,
39}
40
41impl<T: EventMessage + Unpin> Connection<T> {
42    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
43        let config = WebSocketConfig::default()
44            .max_message_size(None)
45            .max_frame_size(None);
46
47        cfg_if::cfg_if! {
48            if #[cfg(feature = "async-std-runtime")] {
49               let (ws, _) = async_tungstenite::async_std::connect_async_with_config(debug_ws_url.as_ref(), Some(config)).await?;
50            } else if #[cfg(feature = "tokio-runtime")] {
51                 let (ws, _) = async_tungstenite::tokio::connect_async_with_config(debug_ws_url.as_ref(), Some(config)).await?;
52            }
53        }
54
55        Ok(Self {
56            pending_commands: Default::default(),
57            ws,
58            next_id: 0,
59            needs_flush: false,
60            pending_flush: None,
61            _marker: Default::default(),
62        })
63    }
64}
65
66impl<T: EventMessage> Connection<T> {
67    fn next_call_id(&mut self) -> CallId {
68        let id = CallId::new(self.next_id);
69        self.next_id = self.next_id.wrapping_add(1);
70        id
71    }
72
73    /// Queue in the command to send over the socket and return the id for this
74    /// command
75    pub fn submit_command(
76        &mut self,
77        method: MethodId,
78        session_id: Option<SessionId>,
79        params: serde_json::Value,
80    ) -> serde_json::Result<CallId> {
81        let id = self.next_call_id();
82        let call = MethodCall {
83            id,
84            method,
85            session_id: session_id.map(Into::into),
86            params,
87        };
88        self.pending_commands.push_back(call);
89        Ok(id)
90    }
91
92    /// flush any processed message and start sending the next over the conn
93    /// sink
94    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
95        if self.needs_flush {
96            if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
97                self.needs_flush = false;
98            }
99        }
100        if self.pending_flush.is_none() && !self.needs_flush {
101            if let Some(cmd) = self.pending_commands.pop_front() {
102                tracing::trace!("Sending {:?}", cmd);
103                let msg = serde_json::to_string(&cmd)?;
104                self.ws.start_send_unpin(msg.into())?;
105                self.pending_flush = Some(cmd);
106            }
107        }
108        Ok(())
109    }
110}
111
112impl<T: EventMessage + Unpin> Stream for Connection<T> {
113    type Item = Result<Message<T>>;
114
115    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
116        let pin = self.get_mut();
117
118        loop {
119            // queue in the next message if not currently flushing
120            if let Err(err) = pin.start_send_next(cx) {
121                return Poll::Ready(Some(Err(err)));
122            }
123
124            // send the message
125            if let Some(call) = pin.pending_flush.take() {
126                if pin.ws.poll_ready_unpin(cx).is_ready() {
127                    pin.needs_flush = true;
128                    // try another flush
129                    continue;
130                } else {
131                    pin.pending_flush = Some(call);
132                }
133            }
134
135            break;
136        }
137
138        // read from the ws
139        match ready!(pin.ws.poll_next_unpin(cx)) {
140            Some(Ok(WsMessage::Text(text))) => {
141                let ready = match serde_json::from_str::<Message<T>>(&text) {
142                    Ok(msg) => {
143                        tracing::trace!("Received {:?}", msg);
144                        Ok(msg)
145                    }
146                    Err(err) => {
147                        let msg = text.as_str().to_string();
148                        tracing::debug!(target: "chromiumoxide::conn::raw_ws::parse_errors", msg, "Failed to parse raw WS message {}", err);
149                        Err(CdpError::InvalidMessage(text.as_str().to_string(), err))
150                    }
151                };
152                Poll::Ready(Some(ready))
153            }
154            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
155            // ignore ping and pong
156            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
157                cx.waker().wake_by_ref();
158                Poll::Pending
159            }
160            Some(Ok(msg)) => Poll::Ready(Some(Err(CdpError::UnexpectedWsMessage(msg)))),
161            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
162            None => {
163                // ws connection closed
164                Poll::Ready(None)
165            }
166        }
167    }
168}