use std::sync::Arc;
use std::time::Duration;
use futures::{Stream, StreamExt};
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_stream::wrappers::BroadcastStream;
use tokio_util::sync::CancellationToken;
use crate::actor::{EVENT_BUS_CAPACITY, OutboundCmd, run_actor};
use crate::error::{CallError, TransportError};
use crate::frame::RawEvent;
use crate::observer::TargetObserver;
pub(crate) const DEFAULT_OBSERVER_TIMEOUT: Duration = Duration::from_secs(5);
pub(crate) const SHUTDOWN_DRAIN_CODE: i32 = -32001;
#[derive(Clone, Debug)]
pub struct Connection {
pub(crate) inner: Arc<ConnectionInner>,
}
#[derive(Debug)]
pub(crate) struct ConnectionInner {
pub(crate) cmd_tx: mpsc::Sender<OutboundCmd>,
pub(crate) event_tx: broadcast::Sender<RawEvent>,
pub(crate) shutdown: CancellationToken,
pub(crate) observer_timeout: Duration,
}
impl Connection {
pub async fn call_raw(
&self,
method: impl Into<String>,
params: Value,
session_id: Option<String>,
) -> Result<Value, CallError> {
let (reply_tx, reply_rx) = oneshot::channel();
self.inner
.cmd_tx
.send(OutboundCmd {
method: method.into(),
params,
session_id,
reply: reply_tx,
})
.await
.map_err(|_| TransportError::Shutdown)?;
match reply_rx.await {
Ok(Ok(v)) => Ok(v),
Ok(Err(rpc_err)) => {
if rpc_err.code == SHUTDOWN_DRAIN_CODE {
Err(CallError::Transport(TransportError::Shutdown))
} else {
Err(CallError::Rpc(rpc_err.code, rpc_err.message, rpc_err.data))
}
}
Err(_) => Err(CallError::Transport(TransportError::Shutdown)),
}
}
pub fn subscribe_raw(&self) -> impl Stream<Item = RawEvent> + Send + Unpin + use<> {
Box::pin(
BroadcastStream::new(self.inner.event_tx.subscribe()).filter_map(|res| async move {
res.ok()
}),
)
}
pub fn subscribe<T>(
&self,
method: &'static str,
) -> impl Stream<Item = T> + Send + Unpin + use<T>
where
T: DeserializeOwned + Send + 'static,
{
Box::pin(
BroadcastStream::new(self.inner.event_tx.subscribe()).filter_map(
move |res| async move {
let ev = res.ok()?;
if ev.method == method {
serde_json::from_value(ev.params).ok()
} else {
None
}
},
),
)
}
pub fn shutdown(&self) {
self.inner.shutdown.cancel();
}
pub fn shutdown_token(&self) -> CancellationToken {
self.inner.shutdown.clone()
}
pub(crate) fn observer_timeout(&self) -> Duration {
self.inner.observer_timeout
}
}
pub async fn connect(ws_url: &str) -> Result<Connection, TransportError> {
connect_with_observers(ws_url, Vec::new()).await
}
pub async fn connect_with_observers(
ws_url: &str,
observers: Vec<Arc<dyn TargetObserver>>,
) -> Result<Connection, TransportError> {
use tokio_tungstenite::connect_async;
let (ws, _resp) = connect_async(ws_url).await?;
Ok(spawn_actor_with_observers(ws, observers))
}
pub fn spawn_actor<S>(ws: S) -> Connection
where
S: futures::Sink<
tokio_tungstenite::tungstenite::Message,
Error = tokio_tungstenite::tungstenite::Error,
> + futures::Stream<
Item = Result<
tokio_tungstenite::tungstenite::Message,
tokio_tungstenite::tungstenite::Error,
>,
> + Send
+ Unpin
+ 'static,
{
spawn_actor_with_observers(ws, Vec::new())
}
pub fn spawn_actor_with_observers<S>(ws: S, observers: Vec<Arc<dyn TargetObserver>>) -> Connection
where
S: futures::Sink<
tokio_tungstenite::tungstenite::Message,
Error = tokio_tungstenite::tungstenite::Error,
> + futures::Stream<
Item = Result<
tokio_tungstenite::tungstenite::Message,
tokio_tungstenite::tungstenite::Error,
>,
> + Send
+ Unpin
+ 'static,
{
spawn_actor_with_observers_and_timeout(ws, observers, DEFAULT_OBSERVER_TIMEOUT)
}
pub fn spawn_actor_with_observers_and_timeout<S>(
ws: S,
observers: Vec<Arc<dyn TargetObserver>>,
observer_timeout: Duration,
) -> Connection
where
S: futures::Sink<
tokio_tungstenite::tungstenite::Message,
Error = tokio_tungstenite::tungstenite::Error,
> + futures::Stream<
Item = Result<
tokio_tungstenite::tungstenite::Message,
tokio_tungstenite::tungstenite::Error,
>,
> + Send
+ Unpin
+ 'static,
{
let (cmd_tx, cmd_rx) = mpsc::channel::<OutboundCmd>(64);
let (event_tx, _event_rx) = broadcast::channel::<RawEvent>(EVENT_BUS_CAPACITY);
let shutdown = CancellationToken::new();
let inner = Arc::new(ConnectionInner {
cmd_tx,
event_tx: event_tx.clone(),
shutdown: shutdown.clone(),
observer_timeout,
});
let weak_inner = Arc::downgrade(&inner);
tokio::spawn(run_actor(
ws, cmd_rx, event_tx, shutdown, observers, weak_inner,
));
Connection { inner }
}
#[cfg(any(test, feature = "testing"))]
pub(crate) mod test_only {
use tokio::sync::mpsc;
use tokio_tungstenite::tungstenite::Message;
pub struct DriverStream {
pub tx: mpsc::Sender<Message>,
pub rx: mpsc::Receiver<Result<Message, tokio_tungstenite::tungstenite::Error>>,
}
impl futures::Sink<Message> for DriverStream {
type Error = tokio_tungstenite::tungstenite::Error;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn start_send(self: std::pin::Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.tx
.try_send(item)
.map_err(|_| tokio_tungstenite::tungstenite::Error::ConnectionClosed)
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
std::task::Poll::Ready(Ok(()))
}
}
impl futures::Stream for DriverStream {
type Item = Result<Message, tokio_tungstenite::tungstenite::Error>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
use crate::connection::test_only::DriverStream;
use serde_json::json;
use tokio_tungstenite::tungstenite::Message;
fn duplex_pair() -> (
DriverStream,
tokio::sync::mpsc::Sender<Result<Message, tokio_tungstenite::tungstenite::Error>>,
tokio::sync::mpsc::Receiver<Message>,
) {
let (tx_out, rx_out) = tokio::sync::mpsc::channel::<Message>(32);
let (tx_in, rx_in) = tokio::sync::mpsc::channel::<
Result<Message, tokio_tungstenite::tungstenite::Error>,
>(32);
(
DriverStream {
tx: tx_out,
rx: rx_in,
},
tx_in,
rx_out,
)
}
#[tokio::test]
async fn call_raw_round_trips_through_actor() {
let (ws, test_tx, mut test_rx) = duplex_pair();
let conn = spawn_actor(ws);
let call = tokio::spawn({
let c = conn.clone();
async move {
c.call_raw("Page.navigate", json!({ "url": "https://x.test" }), None)
.await
}
});
let sent = test_rx.recv().await.unwrap();
let id = serde_json::from_str::<Value>(match &sent {
Message::Text(t) => t,
_ => panic!("expected text frame"),
})
.unwrap()["id"]
.as_u64()
.unwrap();
test_tx
.send(Ok(Message::text(
json!({ "id": id, "result": { "frameId": "F1" } }).to_string(),
)))
.await
.unwrap();
let res = call.await.unwrap().unwrap();
assert_eq!(res["frameId"], "F1");
conn.shutdown();
}
}