mod discovery;
pub use discovery::{BrowserVersion, CdpConnectionOptions, discover_websocket_url};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::{Mutex, broadcast, mpsc, oneshot};
use tokio::time::timeout;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tracing::{debug, error, info, instrument, trace, warn};
use crate::error::CdpError;
use crate::transport::{CdpEvent, CdpMessage, CdpRequest, CdpResponse};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const EVENT_CHANNEL_SIZE: usize = 256;
#[derive(Debug)]
pub struct CdpConnection {
tx: mpsc::Sender<CdpRequest>,
event_rx: broadcast::Sender<CdpEvent>,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<CdpResponse>>>>,
message_id: AtomicU64,
_read_handle: tokio::task::JoinHandle<()>,
_write_handle: tokio::task::JoinHandle<()>,
}
impl CdpConnection {
#[instrument(level = "info", skip(ws_url), fields(ws_url = %ws_url))]
pub async fn connect(ws_url: &str) -> Result<Self, CdpError> {
Self::connect_with_options(ws_url, &CdpConnectionOptions::default()).await
}
#[instrument(level = "info", skip(ws_url, options), fields(ws_url = %ws_url))]
pub async fn connect_with_options(
ws_url: &str,
options: &CdpConnectionOptions,
) -> Result<Self, CdpError> {
info!("Connecting to CDP WebSocket endpoint");
let mut request =
ws_url
.into_client_request()
.map_err(|e: tokio_tungstenite::tungstenite::Error| {
CdpError::InvalidUrl(format!("{ws_url}: {e}"))
})?;
for (name, value) in &options.headers {
let header_name = name
.parse::<tokio_tungstenite::tungstenite::http::HeaderName>()
.map_err(|e| CdpError::ConnectionFailed(format!("invalid header name: {e}")))?;
let header_value = value
.parse::<tokio_tungstenite::tungstenite::http::HeaderValue>()
.map_err(|e| CdpError::ConnectionFailed(format!("invalid header value: {e}")))?;
request.headers_mut().insert(header_name, header_value);
}
type WsStream = tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
let connect_future = tokio_tungstenite::connect_async(request);
let (ws_stream, response): (WsStream, _) = if let Some(timeout_duration) = options.timeout {
timeout(timeout_duration, connect_future)
.await
.map_err(|_| CdpError::ConnectionTimeout(timeout_duration))?
.map_err(CdpError::from)?
} else {
connect_future.await?
};
info!(status = %response.status(), "WebSocket connection established");
let (write, read) = ws_stream.split();
let (tx, rx) = mpsc::channel::<CdpRequest>(64);
let (event_tx, _) = broadcast::channel::<CdpEvent>(EVENT_CHANNEL_SIZE);
let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<CdpResponse>>>> =
Arc::new(Mutex::new(HashMap::new()));
let write_handle = tokio::spawn(Self::write_loop(rx, write));
debug!("Spawned CDP write loop");
let read_pending = pending.clone();
let read_event_tx = event_tx.clone();
let read_handle = tokio::spawn(Self::read_loop(read, read_pending, read_event_tx));
debug!("Spawned CDP read loop");
info!("CDP connection ready");
Ok(Self {
tx,
event_rx: event_tx,
pending,
message_id: AtomicU64::new(1),
_read_handle: read_handle,
_write_handle: write_handle,
})
}
pub async fn connect_via_http(endpoint_url: &str) -> Result<Self, CdpError> {
Self::connect_via_http_with_options(endpoint_url, CdpConnectionOptions::default()).await
}
#[instrument(level = "info", skip(options), fields(endpoint_url = %endpoint_url))]
pub async fn connect_via_http_with_options(
endpoint_url: &str,
options: CdpConnectionOptions,
) -> Result<Self, CdpError> {
let ws_url = discover_websocket_url(endpoint_url, &options).await?;
Self::connect_with_options(&ws_url, &options).await
}
async fn write_loop<S>(mut rx: mpsc::Receiver<CdpRequest>, mut sink: S)
where
S: futures_util::Sink<Message, Error = tokio_tungstenite::tungstenite::Error> + Unpin,
{
debug!("CDP write loop started");
while let Some(request) = rx.recv().await {
let method = request.method.clone();
let id = request.id;
let json = match serde_json::to_string(&request) {
Ok(j) => j,
Err(e) => {
error!(error = %e, method = %method, "Failed to serialize CDP request");
continue;
}
};
trace!(id = id, method = %method, json_len = json.len(), "Sending CDP request");
if sink.send(Message::Text(json.into())).await.is_err() {
warn!("WebSocket sink closed, ending write loop");
break;
}
debug!(id = id, method = %method, "CDP request sent");
}
debug!("CDP write loop ended");
}
async fn read_loop<S>(
mut stream: S,
pending: Arc<Mutex<HashMap<u64, oneshot::Sender<CdpResponse>>>>,
event_tx: broadcast::Sender<CdpEvent>,
) where
S: futures_util::Stream<Item = Result<Message, tokio_tungstenite::tungstenite::Error>>
+ Unpin,
{
debug!("CDP read loop started");
while let Some(msg) = stream.next().await {
let msg = match msg {
Ok(Message::Text(text)) => text,
Ok(Message::Close(frame)) => {
info!(?frame, "WebSocket closed by remote");
break;
}
Err(e) => {
warn!(error = %e, "WebSocket error, ending read loop");
break;
}
Ok(_) => continue,
};
trace!(json_len = msg.len(), "Received CDP message");
let cdp_msg: CdpMessage = match serde_json::from_str(&msg) {
Ok(m) => m,
Err(e) => {
error!(error = %e, "Failed to parse CDP message");
continue;
}
};
match cdp_msg {
CdpMessage::Response(resp) => {
let id = resp.id;
let has_error = resp.error.is_some();
debug!(id = id, has_error = has_error, "Received CDP response");
let mut pending = pending.lock().await;
if let Some(sender) = pending.remove(&id) {
let _ = sender.send(resp);
} else {
warn!(id = id, "Received response for unknown request ID");
}
}
CdpMessage::Event(ref event) => {
trace!(method = %event.method, session_id = ?event.session_id, "Received CDP event");
let _ = event_tx.send(event.clone());
}
}
}
debug!("CDP read loop ended");
}
pub async fn send_command<P, R>(
&self,
method: &str,
params: Option<P>,
session_id: Option<&str>,
) -> Result<R, CdpError>
where
P: Serialize,
R: DeserializeOwned,
{
self.send_command_with_timeout(method, params, session_id, DEFAULT_TIMEOUT)
.await
}
#[instrument(level = "debug", skip(self, params), fields(method = %method, session_id = ?session_id))]
pub async fn send_command_with_timeout<P, R>(
&self,
method: &str,
params: Option<P>,
session_id: Option<&str>,
timeout_duration: Duration,
) -> Result<R, CdpError>
where
P: Serialize,
R: DeserializeOwned,
{
let id = self.message_id.fetch_add(1, Ordering::Relaxed);
debug!(
id = id,
timeout_ms = timeout_duration.as_millis(),
"Preparing CDP command"
);
let params_value = params.map(|p| serde_json::to_value(p)).transpose()?;
let request = CdpRequest {
id,
method: method.to_string(),
params: params_value,
session_id: session_id.map(ToString::to_string),
};
let (resp_tx, resp_rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, resp_tx);
trace!(
id = id,
pending_count = pending.len(),
"Registered pending response"
);
}
self.tx
.send(request)
.await
.map_err(|_| CdpError::ConnectionLost)?;
trace!(id = id, "Request queued for sending");
let response = timeout(timeout_duration, resp_rx)
.await
.map_err(|_| {
warn!(id = id, method = %method, "CDP command timed out");
CdpError::Timeout(timeout_duration)
})?
.map_err(|_| CdpError::ConnectionLost)?;
if let Some(ref error) = response.error {
warn!(id = id, method = %method, code = error.code, error_msg = %error.message, "CDP protocol error");
return Err(CdpError::Protocol {
code: error.code,
message: error.message.clone(),
});
}
debug!(id = id, "CDP command completed successfully");
let result = response.result.unwrap_or(Value::Null);
serde_json::from_value(result).map_err(CdpError::from)
}
pub fn subscribe_events(&self) -> broadcast::Receiver<CdpEvent> {
debug!("New CDP event subscription created");
self.event_rx.subscribe()
}
}
#[cfg(test)]
mod tests;