Skip to main content

chaser_oxide/
conn.rs

1use std::collections::VecDeque;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::task::ready;
5
6use async_tungstenite::tungstenite::Message as WsMessage;
7use async_tungstenite::{WebSocketStream, tungstenite::protocol::WebSocketConfig};
8use futures::stream::Stream;
9use futures::task::{Context, Poll};
10use futures::{SinkExt, StreamExt};
11
12use async_tungstenite::tokio::ConnectStream;
13use chromiumoxide_cdp::cdp::browser_protocol::target::SessionId;
14use chromiumoxide_types::{CallId, EventMessage, Message, MethodCall, MethodId};
15
16use crate::error::CdpError;
17use crate::error::Result;
18
19/// Exchanges the messages with the websocket
20#[must_use = "streams do nothing unless polled"]
21#[derive(Debug)]
22pub struct Connection<T: EventMessage> {
23    /// Queue of commands to send.
24    pending_commands: VecDeque<MethodCall>,
25    /// The websocket of the chromium instance
26    ws: WebSocketStream<ConnectStream>,
27    /// The identifier for a specific command
28    next_id: usize,
29    needs_flush: bool,
30    /// The message that is currently being proceessed
31    pending_flush: Option<MethodCall>,
32    _marker: PhantomData<T>,
33}
34
35impl<T: EventMessage + Unpin> Connection<T> {
36    pub async fn connect(debug_ws_url: impl AsRef<str>) -> Result<Self> {
37        let config = WebSocketConfig::default()
38            .max_message_size(None)
39            .max_frame_size(None);
40
41        let (ws, _) = async_tungstenite::tokio::connect_async_with_config(
42            debug_ws_url.as_ref(),
43            Some(config),
44        )
45        .await?;
46
47        Ok(Self {
48            pending_commands: Default::default(),
49            ws,
50            next_id: 0,
51            needs_flush: false,
52            pending_flush: None,
53            _marker: Default::default(),
54        })
55    }
56}
57
58impl<T: EventMessage> Connection<T> {
59    fn next_call_id(&mut self) -> CallId {
60        let id = CallId::new(self.next_id);
61        self.next_id = self.next_id.wrapping_add(1);
62        id
63    }
64
65    /// Queue in the command to send over the socket and return the id for this
66    /// command
67    pub fn submit_command(
68        &mut self,
69        method: MethodId,
70        session_id: Option<SessionId>,
71        params: serde_json::Value,
72    ) -> serde_json::Result<CallId> {
73        let id = self.next_call_id();
74        let call = MethodCall {
75            id,
76            method,
77            session_id: session_id.map(Into::into),
78            params,
79        };
80        self.pending_commands.push_back(call);
81        Ok(id)
82    }
83
84    /// flush any processed message and start sending the next over the conn
85    /// sink
86    fn start_send_next(&mut self, cx: &mut Context<'_>) -> Result<()> {
87        if self.needs_flush {
88            if let Poll::Ready(Ok(())) = self.ws.poll_flush_unpin(cx) {
89                self.needs_flush = false;
90            }
91        }
92        if self.pending_flush.is_none() && !self.needs_flush {
93            if let Some(cmd) = self.pending_commands.pop_front() {
94                tracing::trace!("Sending {:?}", cmd);
95                let msg = serde_json::to_string(&cmd)?;
96                self.ws.start_send_unpin(msg.into())?;
97                self.pending_flush = Some(cmd);
98            }
99        }
100        Ok(())
101    }
102}
103
104impl<T: EventMessage + Unpin> Stream for Connection<T> {
105    type Item = Result<Message<T>>;
106
107    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108        let pin = self.get_mut();
109
110        loop {
111            // queue in the next message if not currently flushing
112            if let Err(err) = pin.start_send_next(cx) {
113                return Poll::Ready(Some(Err(err)));
114            }
115
116            // send the message
117            if let Some(call) = pin.pending_flush.take() {
118                if pin.ws.poll_ready_unpin(cx).is_ready() {
119                    pin.needs_flush = true;
120                    // try another flush
121                    continue;
122                } else {
123                    pin.pending_flush = Some(call);
124                }
125            }
126
127            break;
128        }
129
130        // read from the ws
131        match ready!(pin.ws.poll_next_unpin(cx)) {
132            Some(Ok(WsMessage::Text(text))) => {
133                let ready = match serde_json::from_str::<Message<T>>(&text) {
134                    Ok(msg) => {
135                        tracing::trace!("Received {:?}", msg);
136                        Ok(msg)
137                    }
138                    Err(err) => {
139                        let msg = text.as_str().to_string();
140                        tracing::debug!(target: "chromiumoxide::conn::raw_ws::parse_errors", msg, "Failed to parse raw WS message {}", err);
141                        Err(CdpError::InvalidMessage(text.as_str().to_string(), err))
142                    }
143                };
144                Poll::Ready(Some(ready))
145            }
146            Some(Ok(WsMessage::Close(_))) => Poll::Ready(None),
147            // ignore ping and pong
148            Some(Ok(WsMessage::Ping(_))) | Some(Ok(WsMessage::Pong(_))) => {
149                cx.waker().wake_by_ref();
150                Poll::Pending
151            }
152            Some(Ok(msg)) => Poll::Ready(Some(Err(CdpError::UnexpectedWsMessage(msg)))),
153            Some(Err(err)) => Poll::Ready(Some(Err(CdpError::Ws(err)))),
154            None => {
155                // ws connection closed
156                Poll::Ready(None)
157            }
158        }
159    }
160}