use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{anyhow, Result};
use futures_util::{SinkExt, StreamExt};
use serde_json::{json, Value};
use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
use tokio_tungstenite::tungstenite::Message;
pub mod protocol;
use protocol::{CdpError, Request, Response};
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const EVENT_CHANNEL_CAPACITY: usize = 256;
type PendingMap = HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>;
#[derive(Debug, Clone)]
pub struct CdpEvent {
pub method: String,
pub params: Value,
pub session_id: Option<String>,
}
pub struct CdpClient {
next_id: Mutex<u64>,
pending: Arc<Mutex<PendingMap>>,
events_tx: broadcast::Sender<CdpEvent>,
write_tx: mpsc::UnboundedSender<String>,
reader_handle: tokio::task::JoinHandle<()>,
writer_handle: tokio::task::JoinHandle<()>,
}
impl CdpClient {
pub async fn connect(ws_url: &str) -> Result<Self> {
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url).await?;
let (mut ws_sink, mut ws_stream) = ws_stream.split();
let pending: Arc<Mutex<PendingMap>> = Arc::new(Mutex::new(HashMap::new()));
let (events_tx, _) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
let (write_tx, mut write_rx) = mpsc::unbounded_channel::<String>();
let writer_handle = tokio::spawn(async move {
while let Some(text) = write_rx.recv().await {
if ws_sink.send(Message::Text(text)).await.is_err() {
break;
}
}
let _ = ws_sink.close().await;
});
let pending_r = pending.clone();
let events_r = events_tx.clone();
let reader_handle = tokio::spawn(async move {
while let Some(msg) = ws_stream.next().await {
let text = match msg {
Ok(Message::Text(t)) => t,
Ok(Message::Binary(b)) => match String::from_utf8(b) {
Ok(s) => s,
Err(_) => continue,
},
Ok(Message::Close(_)) | Err(_) => break,
Ok(_) => continue,
};
let resp: Response = match serde_json::from_str(&text) {
Ok(r) => r,
Err(_) => continue,
};
if let Some(id) = resp.id {
let mut p = pending_r.lock().await;
if let Some(tx) = p.remove(&id) {
let res = if let Some(err) = resp.error {
Err(err)
} else {
Ok(resp.result)
};
let _ = tx.send(res);
}
} else if let Some(method) = resp.method {
let _ = events_r.send(CdpEvent {
method,
params: resp.params,
session_id: resp.session_id,
});
}
}
let mut p = pending_r.lock().await;
for (_, tx) in p.drain() {
let _ = tx.send(Err(CdpError {
code: -1,
message: "connection closed".into(),
}));
}
});
Ok(Self {
next_id: Mutex::new(1),
pending,
events_tx,
write_tx,
reader_handle,
writer_handle,
})
}
pub async fn connect_http(base_url: &str) -> Result<Self> {
let base = base_url.trim_end_matches('/');
let url = format!("{base}/json/version");
let resp: Value = reqwest::get(&url).await?.json().await?;
let ws_url = resp
.get("webSocketDebuggerUrl")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("webSocketDebuggerUrl missing from {url}"))?
.to_string();
Self::connect(&ws_url).await
}
pub async fn send(&self, method: &str, params: Value) -> Result<Value> {
self.send_with_session(method, params, None).await
}
pub async fn send_with_session(
&self,
method: &str,
params: Value,
session_id: Option<&str>,
) -> Result<Value> {
let id = {
let mut n = self.next_id.lock().await;
let id = *n;
*n += 1;
id
};
let req = Request {
id,
method,
params,
session_id: session_id.map(|s| s.to_string()),
};
let text = serde_json::to_string(&req)?;
let (tx, rx) = oneshot::channel();
{
let mut p = self.pending.lock().await;
p.insert(id, tx);
}
if self.write_tx.send(text).is_err() {
let mut p = self.pending.lock().await;
p.remove(&id);
return Err(anyhow!("writer task closed"));
}
match tokio::time::timeout(REQUEST_TIMEOUT, rx).await {
Ok(Ok(Ok(v))) => Ok(v),
Ok(Ok(Err(e))) => Err(anyhow!(e)),
Ok(Err(_)) => Err(anyhow!("response channel dropped")),
Err(_) => {
let mut p = self.pending.lock().await;
p.remove(&id);
Err(anyhow!("CDP request timed out after {:?}", REQUEST_TIMEOUT))
}
}
}
pub fn subscribe(&self) -> broadcast::Receiver<CdpEvent> {
self.events_tx.subscribe()
}
pub async fn attach_to_target(&self, target_id: &str) -> Result<String> {
let v = self
.send(
"Target.attachToTarget",
json!({ "targetId": target_id, "flatten": true }),
)
.await?;
v.get("sessionId")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| anyhow!("sessionId missing from Target.attachToTarget response"))
}
pub async fn list_targets(&self) -> Result<Vec<Value>> {
let v = self.send("Target.getTargets", Value::Null).await?;
match v.get("targetInfos") {
Some(Value::Array(a)) => Ok(a.clone()),
_ => Ok(vec![]),
}
}
pub async fn close(self) {
drop(self.write_tx);
let _ = self.writer_handle.await;
self.reader_handle.abort();
let _ = self.reader_handle.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::{SinkExt, StreamExt};
use tokio_tungstenite::tungstenite::Message;
#[tokio::test]
async fn round_trip_request_response() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
while let Some(Ok(msg)) = ws.next().await {
if let Message::Text(t) = msg {
let req: Value = serde_json::from_str(&t).unwrap();
let id = req["id"].as_u64().unwrap();
let resp = json!({"id": id, "result": {"ok": true, "echo": req["method"]}});
ws.send(Message::Text(resp.to_string())).await.unwrap();
}
}
});
let url = format!("ws://{}", addr);
let client = CdpClient::connect(&url).await.unwrap();
let v = client
.send("Page.navigate", json!({"url": "about:blank"}))
.await
.unwrap();
assert_eq!(v["ok"], true);
assert_eq!(v["echo"], "Page.navigate");
client.close().await;
}
#[tokio::test]
async fn broadcast_event_to_subscriber() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (ready_tx, ready_rx) = oneshot::channel::<()>();
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut ws = tokio_tungstenite::accept_async(stream).await.unwrap();
let _ = ready_rx.await;
let evt = json!({
"method": "Target.targetCreated",
"params": {"targetInfo": {"targetId": "abc"}},
"sessionId": "S1"
});
ws.send(Message::Text(evt.to_string())).await.unwrap();
while let Some(Ok(_)) = ws.next().await {}
});
let url = format!("ws://{}", addr);
let client = CdpClient::connect(&url).await.unwrap();
let mut rx = client.subscribe();
ready_tx.send(()).unwrap();
let evt = tokio::time::timeout(Duration::from_secs(5), rx.recv())
.await
.expect("event timeout")
.expect("event recv");
assert_eq!(evt.method, "Target.targetCreated");
assert_eq!(evt.session_id.as_deref(), Some("S1"));
assert_eq!(evt.params["targetInfo"]["targetId"], "abc");
client.close().await;
}
}