use std::{
collections::HashMap,
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
};
use chromiumoxide_types::{CallId, Command, MethodCall, Response};
use futures::{SinkExt, StreamExt, stream::SplitSink};
use tokio::{
net::TcpStream,
sync::{
Mutex,
oneshot::{self, Sender},
},
};
use tokio_websockets::{ClientBuilder, MaybeTlsStream, Message, WebSocketStream};
use tracing::{debug, error, instrument, warn};
use crate::{Error, error::Result};
type CdpSink = Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>;
type ChannelMap = Arc<Mutex<HashMap<CallId, Sender<Response>>>>;
#[derive(Debug, Clone)]
pub struct Connection {
sink: CdpSink,
channels: ChannelMap,
id: Arc<AtomicUsize>,
}
impl Connection {
#[instrument]
pub async fn connect(uri: &str) -> Result<Self> {
let uri = uri.parse()?;
let (socket, _) = ClientBuilder::from_uri(uri).connect().await?;
let (sink, mut stream) = socket.split();
let sink = Arc::new(Mutex::new(sink));
let channels = ChannelMap::default();
{
let channels = channels.clone();
tokio::spawn(async move {
loop {
let Some(message) = stream.next().await else {
debug!("CDP WS stream ends");
break;
};
let message = match message {
Ok(message) => message,
Err(err) => {
error!(%err, "Unable to read CDP WS stream");
continue;
}
};
let text = match message.as_text() {
Some(text) => text,
None => {
error!(?message, "CDP WS message is not text");
continue;
}
};
let response = match serde_json::from_str::<Response>(text) {
Ok(response) => response,
Err(err) => {
error!(%err, "Unable to deserialize CDP response");
continue;
}
};
let Some(tx) = channels.lock().await.remove(&response.id) else {
warn!(id = %response.id, "tx missing for response");
continue;
};
if let Err(response) = tx.send(response) {
error!(?response, "Unable to send response back");
}
}
});
}
Ok(Self {
sink,
channels,
id: Default::default(),
})
}
pub async fn send<T: Command>(
&self,
session_id: Option<String>,
command: T,
) -> Result<T::Response> {
let id = CallId::new(self.id.fetch_add(1, Ordering::SeqCst));
let method_call = MethodCall {
id,
session_id,
method: command.identifier(),
params: serde_json::to_value(command)?,
};
let text = serde_json::to_string(&method_call)?;
let (tx, rx) = oneshot::channel::<Response>();
self.channels.lock().await.insert(id, tx);
self.sink.lock().await.send(Message::text(text)).await?;
let response = rx.await?;
match response.result {
Some(response) => Ok(serde_json::from_value::<T::Response>(response)?),
None => Err(Error::Response(response.error)),
}
}
}