use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::Duration;
use async_tungstenite::tokio::{ConnectStream, connect_async};
use async_tungstenite::tungstenite::Message;
use async_tungstenite::tungstenite::error::Error as WsError;
use async_tungstenite::{WebSocketReceiver, WebSocketSender};
use futures::StreamExt as _;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::{Mutex as AsyncMutex, broadcast, oneshot};
use tokio::task::JoinHandle;
type Sink = WebSocketSender<ConnectStream>;
type Stream = WebSocketReceiver<ConnectStream>;
const EVENT_BUFFER: usize = 256;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum CdpError {
#[error("websocket: {0}")]
WebSocket(String),
#[error("CDP {code}: {message}")]
Remote {
code: i64,
message: String,
},
#[error("decode response: {0}")]
Decode(String),
#[error("CDP {what} timed out after {elapsed:?}")]
Timeout {
elapsed: Duration,
what: &'static str,
},
#[error("CDP client is closed")]
Closed,
}
impl CdpError {
fn ws(e: &WsError) -> Self {
Self::WebSocket(e.to_string())
}
}
#[derive(Debug, Clone)]
pub struct CdpEvent {
pub method: String,
pub params: Value,
pub session_id: Option<String>,
}
#[derive(Serialize)]
struct Request<'a, P> {
id: u64,
method: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
params: Option<P>,
#[serde(rename = "sessionId", skip_serializing_if = "Option::is_none")]
session_id: Option<&'a str>,
}
#[derive(Deserialize)]
struct RemoteError {
code: i64,
message: String,
}
#[derive(Deserialize)]
struct Frame {
id: Option<u64>,
method: Option<String>,
params: Option<Value>,
result: Option<Value>,
error: Option<RemoteError>,
#[serde(rename = "sessionId")]
session_id: Option<String>,
}
type PendingMap = std::sync::Mutex<HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>>;
struct Inner {
sink: AsyncMutex<Sink>,
pending: PendingMap,
next_id: AtomicU64,
events: broadcast::Sender<CdpEvent>,
closed: AtomicBool,
}
impl Inner {
fn mark_closed(&self) {
if self.closed.swap(true, Ordering::AcqRel) {
return;
}
let drained: Vec<_> = {
let mut g = self
.pending
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
g.drain().collect()
};
for (_, tx) in drained {
let _ = tx.send(Err(CdpError::Closed));
}
}
}
pub struct CdpClient {
inner: Arc<Inner>,
read_loop: JoinHandle<()>,
}
impl fmt::Debug for CdpClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CdpClient")
.field("closed", &self.inner.closed.load(Ordering::Acquire))
.field(
"pending",
&self
.inner
.pending
.lock()
.map(|g| g.len())
.unwrap_or_default(),
)
.field("read_loop_finished", &self.read_loop.is_finished())
.finish()
}
}
impl Drop for CdpClient {
fn drop(&mut self) {
self.inner.mark_closed();
self.read_loop.abort();
}
}
impl CdpClient {
pub async fn connect(url: &str) -> Result<Self, CdpError> {
let (ws, _resp) = connect_async(url).await.map_err(|e| CdpError::ws(&e))?;
let (sink, stream) = ws.split();
let (events_tx, _) = broadcast::channel(EVENT_BUFFER);
let inner = Arc::new(Inner {
sink: AsyncMutex::new(sink),
pending: std::sync::Mutex::new(HashMap::new()),
next_id: AtomicU64::new(1),
events: events_tx,
closed: AtomicBool::new(false),
});
let read_loop = tokio::spawn(read_loop(Arc::clone(&inner), stream));
Ok(Self { inner, read_loop })
}
pub async fn execute<P, R>(
&self,
method: &'static str,
params: P,
session_id: Option<&str>,
timeout: Duration,
) -> Result<R, CdpError>
where
P: Serialize,
R: DeserializeOwned,
{
if self.inner.closed.load(Ordering::Acquire) {
return Err(CdpError::Closed);
}
let id = self.inner.next_id.fetch_add(1, Ordering::AcqRel);
let req = Request {
id,
method,
params: Some(params),
session_id,
};
let json = serde_json::to_string(&req).map_err(|e| CdpError::Decode(e.to_string()))?;
let (tx, rx) = oneshot::channel();
{
let mut g = self
.inner
.pending
.lock()
.map_err(|_| CdpError::WebSocket("pending mutex poisoned".into()))?;
g.insert(id, tx);
}
let send = {
let mut sink = self.inner.sink.lock().await;
sink.send(Message::Text(json.into())).await
};
if let Err(e) = send {
let _ = self
.inner
.pending
.lock()
.map(|mut g| g.remove(&id))
.unwrap_or_default();
return Err(CdpError::ws(&e));
}
let wait = async {
rx.await.map_err(|_| CdpError::Closed)?.and_then(|value| {
serde_json::from_value::<R>(value).map_err(|e| CdpError::Decode(e.to_string()))
})
};
tokio::time::timeout(timeout, wait).await.map_err(|_| {
let _ = self
.inner
.pending
.lock()
.map(|mut g| g.remove(&id))
.unwrap_or_default();
CdpError::Timeout {
elapsed: timeout,
what: method,
}
})?
}
#[must_use]
pub fn subscribe_events(&self) -> broadcast::Receiver<CdpEvent> {
self.inner.events.subscribe()
}
pub async fn wait_for_event<F>(
&self,
predicate: F,
timeout: Duration,
what: &'static str,
) -> Result<CdpEvent, CdpError>
where
F: Fn(&CdpEvent) -> bool + Send + Sync,
{
let mut rx = self.subscribe_events();
Self::wait_for_event_on(&mut rx, predicate, timeout, what).await
}
pub async fn wait_for_event_on<F>(
rx: &mut broadcast::Receiver<CdpEvent>,
predicate: F,
timeout: Duration,
what: &'static str,
) -> Result<CdpEvent, CdpError>
where
F: Fn(&CdpEvent) -> bool + Send + Sync,
{
let wait = async {
loop {
match rx.recv().await {
Ok(evt) if predicate(&evt) => return Ok::<CdpEvent, CdpError>(evt),
Ok(_) | Err(broadcast::error::RecvError::Lagged(_)) => {}
Err(broadcast::error::RecvError::Closed) => return Err(CdpError::Closed),
}
}
};
tokio::time::timeout(timeout, wait)
.await
.map_err(|_| CdpError::Timeout {
elapsed: timeout,
what,
})?
}
pub async fn close(self) {
self.inner.mark_closed();
let _ = self.inner.sink.lock().await.close(None).await;
self.read_loop.abort();
}
}
async fn read_loop(inner: Arc<Inner>, mut stream: Stream) {
while let Some(msg) = stream.next().await {
if inner.closed.load(Ordering::Acquire) {
break;
}
let text = match msg {
Ok(Message::Text(t)) => t,
Ok(Message::Binary(b)) => {
let Ok(decoded) = String::from_utf8(b.into()) else {
tracing::warn!("CDP: non-UTF8 binary frame, dropped");
continue;
};
decoded.into()
}
Ok(Message::Close(_)) => {
tracing::debug!("CDP: peer closed");
break;
}
Ok(_) => continue, Err(e) => {
tracing::warn!(error = %e, "CDP: stream error, closing");
break;
}
};
let frame: Frame = match serde_json::from_str(&text) {
Ok(f) => f,
Err(e) => {
tracing::warn!(error = %e, "CDP: malformed frame, dropped");
continue;
}
};
match (frame.id, frame.method) {
(Some(id), _) => {
let tx = inner.pending.lock().ok().and_then(|mut g| g.remove(&id));
if let Some(tx) = tx {
let result = if let Some(err) = frame.error {
Err(CdpError::Remote {
code: err.code,
message: err.message,
})
} else {
Ok(frame.result.unwrap_or(Value::Null))
};
let _ = tx.send(result);
} else {
tracing::debug!(id, "CDP: response for unknown / cancelled id");
}
}
(None, Some(method)) => {
let evt = CdpEvent {
method,
params: frame.params.unwrap_or(Value::Null),
session_id: frame.session_id,
};
let _ = inner.events.send(evt);
}
(None, None) => {
tracing::warn!("CDP: frame has neither id nor method, dropped");
}
}
}
inner.mark_closed();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_serialises_with_optional_fields() {
let r: Request<'_, Value> = Request {
id: 42,
method: "Page.enable",
params: None,
session_id: None,
};
let s = serde_json::to_value(&r).unwrap();
assert_eq!(s, serde_json::json!({ "id": 42, "method": "Page.enable" }));
}
#[test]
fn request_serialises_with_session_id() {
let r = Request {
id: 7,
method: "Page.navigate",
params: Some(serde_json::json!({ "url": "https://example.com" })),
session_id: Some("abc-123"),
};
let s = serde_json::to_value(&r).unwrap();
assert_eq!(
s,
serde_json::json!({
"id": 7,
"method": "Page.navigate",
"params": {"url": "https://example.com"},
"sessionId": "abc-123",
})
);
}
#[test]
fn frame_parses_a_response() {
let txt = r#"{"id": 1, "result": {"targetId": "T1"}}"#;
let f: Frame = serde_json::from_str(txt).unwrap();
assert_eq!(f.id, Some(1));
assert!(f.method.is_none());
assert_eq!(f.result.unwrap(), serde_json::json!({"targetId": "T1"}));
}
#[test]
fn frame_parses_a_remote_error() {
let txt = r#"{"id": 9, "error": {"code": -32601, "message": "Method not found"}}"#;
let f: Frame = serde_json::from_str(txt).unwrap();
let err = f.error.unwrap();
assert_eq!(err.code, -32601);
assert_eq!(err.message, "Method not found");
}
#[test]
fn frame_parses_an_event_with_session_id() {
let txt =
r#"{"method": "Page.loadEventFired", "params": {"timestamp": 1.0}, "sessionId": "S1"}"#;
let f: Frame = serde_json::from_str(txt).unwrap();
assert!(f.id.is_none());
assert_eq!(f.method.as_deref(), Some("Page.loadEventFired"));
assert_eq!(f.session_id.as_deref(), Some("S1"));
assert!(f.params.is_some());
}
}