use futures_util::SinkExt;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex as StdMutex};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::time::{timeout, Duration};
use tokio_tungstenite::tungstenite::Message;
use tracing::Instrument;
use crate::error::{BrowserError, Result};
#[derive(Debug, Clone)]
pub struct CDPRequest {
pub id: u32,
pub method: String,
pub params: Option<Value>,
pub session_id: Option<String>,
}
impl CDPRequest {
pub fn new(id: u32, method: String, params: Option<Value>) -> Self {
Self {
id,
method,
params,
session_id: None,
}
}
pub fn with_session(
id: u32,
method: String,
params: Option<Value>,
session_id: String,
) -> Self {
Self {
id,
method,
params,
session_id: Some(session_id),
}
}
pub fn to_json(&self) -> Value {
let mut obj = json!({
"id": self.id,
"method": self.method,
});
if let Some(session_id) = &self.session_id {
obj["sessionId"] = json!(session_id);
}
if let Some(params) = &self.params {
obj["params"] = params.clone();
}
obj
}
}
#[derive(Debug, Clone)]
pub struct CDPMessage {
pub id: Option<u32>,
pub method: Option<String>,
pub params: Option<Value>,
pub result: Option<Value>,
pub error: Option<Value>,
pub session_id: Option<String>,
}
impl CDPMessage {
pub fn from_json(value: Value) -> Result<Self> {
Ok(CDPMessage {
id: value.get("id").and_then(|v| v.as_u64()).map(|v| v as u32),
method: value
.get("method")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
params: value.get("params").cloned(),
result: value.get("result").cloned(),
error: value.get("error").cloned(),
session_id: value
.get("sessionId")
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
})
}
}
pub type WebSocketSink = futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
Message,
>;
pub struct CDPClient {
ws_url: String,
message_id_counter: Arc<AtomicU32>,
pending_responses: Arc<StdMutex<HashMap<u32, oneshot::Sender<Value>>>>,
event_broadcast: broadcast::Sender<CDPMessage>,
ws_tx: Arc<StdMutex<Option<mpsc::UnboundedSender<Message>>>>,
}
impl CDPClient {
pub fn new(ws_url: String) -> Self {
let (event_broadcast, _) = broadcast::channel(1024);
Self {
ws_url,
message_id_counter: Arc::new(AtomicU32::new(1)),
pending_responses: Arc::new(StdMutex::new(HashMap::new())),
event_broadcast,
ws_tx: Arc::new(StdMutex::new(None)),
}
}
pub fn set_writer(&self, tx: mpsc::UnboundedSender<Message>) {
*self.ws_tx.lock().expect("ws_tx mutex poisoned") = Some(tx);
}
pub fn next_id(&self) -> u32 {
self.message_id_counter.fetch_add(1, Ordering::SeqCst)
}
pub async fn connect(
&self,
) -> Result<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
> {
let (ws_stream, _) = tokio_tungstenite::connect_async(&self.ws_url)
.await
.map_err(|e| BrowserError::connection_failed(&self.ws_url, e.to_string()))?;
Ok(ws_stream)
}
pub fn send_raw(&self, msg: String) -> Result<()> {
let tx_guard = self.ws_tx.lock().expect("ws_tx mutex poisoned");
let tx = tx_guard.as_ref().ok_or_else(|| {
BrowserError::websocket("send_raw", "WebSocket writer not initialised")
})?;
tx.send(Message::Text(msg))
.map_err(|_| BrowserError::websocket("send_raw", "WebSocket writer task ended"))
}
pub fn subscribe_events(&self) -> broadcast::Receiver<CDPMessage> {
self.event_broadcast.subscribe()
}
#[tracing::instrument(level = "info", skip(self, params), fields(method = %method, id))]
pub async fn send_command(&self, method: String, params: Option<Value>) -> Result<Value> {
let id = self.next_id();
tracing::Span::current().record("id", id);
let request = CDPRequest::new(id, method.clone(), params);
let (tx, rx) = oneshot::channel();
self.register_response_handler(id, tx);
let json_str = tracing::info_span!("serialize").in_scope(|| request.to_json().to_string());
let bytes = json_str.len();
tracing::info_span!("ws_send", bytes).in_scope(|| self.send_raw(json_str))?;
const TIMEOUT_SECS: u64 = 30;
let wait = async {
match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(_)) => Err(BrowserError::command_failed(
&method,
"response channel closed unexpectedly",
)),
Err(_) => {
self.pending_responses
.lock()
.expect("pending_responses mutex poisoned")
.remove(&id);
Err(BrowserError::timeout(
format!("waiting for response to '{method}'"),
TIMEOUT_SECS,
))
}
}
};
wait.instrument(tracing::info_span!("await_response")).await
}
#[tracing::instrument(level = "info", skip(self, params), fields(method = %method, id, session_id = %session_id))]
pub async fn send_command_with_session(
&self,
session_id: &str,
method: String,
params: Option<Value>,
) -> Result<Value> {
let id = self.next_id();
tracing::Span::current().record("id", id);
let request = CDPRequest::with_session(id, method.clone(), params, session_id.to_string());
let (tx, rx) = oneshot::channel();
self.register_response_handler(id, tx);
let json_str = tracing::info_span!("serialize").in_scope(|| request.to_json().to_string());
let bytes = json_str.len();
tracing::info_span!("ws_send", bytes).in_scope(|| self.send_raw(json_str))?;
const TIMEOUT_SECS: u64 = 30;
let wait = async {
match timeout(Duration::from_secs(TIMEOUT_SECS), rx).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(_)) => Err(BrowserError::command_failed(
&method,
"response channel closed unexpectedly",
)),
Err(_) => {
self.pending_responses
.lock()
.expect("pending_responses mutex poisoned")
.remove(&id);
Err(BrowserError::timeout(
format!("waiting for response to '{method}'"),
TIMEOUT_SECS,
))
}
}
};
wait.instrument(tracing::info_span!("await_response")).await
}
pub fn register_response_handler(&self, id: u32, tx: oneshot::Sender<Value>) {
self.pending_responses
.lock()
.expect("pending_responses mutex poisoned")
.insert(id, tx);
}
pub fn fail_all_pending(&self, reason: &str) {
let mut pending = self
.pending_responses
.lock()
.expect("pending_responses mutex poisoned");
let count = pending.len();
pending.clear(); drop(pending);
if count > 0 {
tracing::warn!(
pending_count = count,
reason = reason,
"WebSocket terminated; failing in-flight CDP requests"
);
}
}
#[tracing::instrument(level = "debug", skip_all, fields(method = ?msg.method, id = ?msg.id))]
pub fn handle_message(&self, msg: CDPMessage) -> Result<()> {
if let Some(id) = msg.id {
let tx = self
.pending_responses
.lock()
.expect("pending_responses mutex poisoned")
.remove(&id);
if let Some(tx) = tx {
if let Some(error) = msg.error {
let _ = tx.send(json!({ "error": error }));
} else if let Some(result) = msg.result {
let _ = tx.send(result);
} else {
let _ = tx.send(json!({}));
}
}
} else if msg.method.is_some() {
let _ = self.event_broadcast.send(msg);
}
Ok(())
}
}
pub fn spawn_writer_task(
mut sink: WebSocketSink,
mut rx: mpsc::UnboundedReceiver<Message>,
cdp: Arc<CDPClient>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
if let Err(e) = sink.send(msg).await {
tracing::error!(error = %e, "WebSocket write error; terminating writer");
cdp.fail_all_pending(&format!("write error: {e}"));
return;
}
}
tracing::debug!("WebSocket writer task exiting (channel closed)");
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cdp_request_creation() {
let req = CDPRequest::new(
1,
"Page.navigate".to_string(),
Some(json!({"url": "https://example.com"})),
);
assert_eq!(req.id, 1);
assert_eq!(req.method, "Page.navigate");
assert_eq!(req.params.as_ref().unwrap()["url"], "https://example.com");
}
#[test]
fn test_cdp_request_to_json() {
let req = CDPRequest::new(
1,
"Page.navigate".to_string(),
Some(json!({"url": "https://example.com"})),
);
let json = req.to_json();
assert_eq!(json["id"], 1);
assert_eq!(json["method"], "Page.navigate");
assert_eq!(json["params"]["url"], "https://example.com");
}
#[test]
fn test_cdp_message_from_json() {
let json_val = json!({
"id": 1,
"result": {"url": "https://example.com"},
"sessionId": "SES001"
});
let msg = CDPMessage::from_json(json_val).unwrap();
assert_eq!(msg.id, Some(1));
assert_eq!(msg.result.as_ref().unwrap()["url"], "https://example.com");
assert_eq!(msg.session_id.as_deref(), Some("SES001"));
}
#[test]
fn test_cdp_message_session_id_parsed() {
let event = json!({
"method": "Page.loadEventFired",
"params": {},
"sessionId": "ABC123"
});
let msg = CDPMessage::from_json(event).unwrap();
assert_eq!(msg.method.as_deref(), Some("Page.loadEventFired"));
assert_eq!(msg.session_id.as_deref(), Some("ABC123"));
}
#[test]
fn test_cdp_request_with_session() {
let req = CDPRequest::with_session(
2,
"Runtime.evaluate".to_string(),
Some(json!({"expression": "1+1"})),
"SES001".to_string(),
);
let json = req.to_json();
assert_eq!(json["sessionId"], "SES001");
assert_eq!(json["method"], "Runtime.evaluate");
}
}