use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::net::TcpStream;
use tokio::sync::{oneshot, Mutex as AsyncMutex};
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::protocol::WebSocketConfig;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{connect_async_with_config, MaybeTlsStream, WebSocketStream};
use crate::CdpError;
const DEFAULT_COMMAND_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct CdpClientConfig {
pub max_message_size: Option<usize>,
pub max_frame_size: Option<usize>,
pub additional_headers: HashMap<String, String>,
pub command_timeout: Duration,
}
impl Default for CdpClientConfig {
fn default() -> Self {
Self {
max_message_size: Some(DEFAULT_MAX_MESSAGE_SIZE),
max_frame_size: None, additional_headers: HashMap::new(),
command_timeout: DEFAULT_COMMAND_TIMEOUT,
}
}
}
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WsSink = futures_util::stream::SplitSink<WsStream, Message>;
type WsSource = futures_util::stream::SplitStream<WsStream>;
type PendingRequests = HashMap<u64, oneshot::Sender<Result<Value, CdpError>>>;
pub type EventHandler = Arc<
dyn Fn(Value, Option<String>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync,
>;
pub struct EventRegistry {
handlers: std::sync::Mutex<HashMap<String, EventHandler>>,
}
impl EventRegistry {
pub fn new() -> Self {
Self {
handlers: std::sync::Mutex::new(HashMap::new()),
}
}
pub fn register(&self, method: &str, handler: EventHandler) {
self.handlers
.lock()
.unwrap()
.insert(method.to_string(), handler);
}
pub fn unregister(&self, method: &str) {
self.handlers.lock().unwrap().remove(method);
}
pub async fn handle_event(
&self,
method: &str,
params: Value,
session_id: Option<String>,
) -> bool {
let handler = {
let handlers = self.handlers.lock().unwrap();
handlers.get(method).cloned()
};
if let Some(handler) = handler {
handler(params, session_id).await;
true
} else {
false
}
}
pub fn clear(&self) {
self.handlers.lock().unwrap().clear();
}
}
impl Default for EventRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct CdpClient {
inner: Arc<ClientInner>,
}
struct ClientInner {
sink: AsyncMutex<WsSink>,
next_id: AtomicU64,
pending: Arc<AsyncMutex<PendingRequests>>,
event_registry: Arc<EventRegistry>,
closed: AtomicBool,
command_timeout: Duration,
message_loop_handle: std::sync::Mutex<Option<JoinHandle<()>>>,
}
impl Drop for ClientInner {
fn drop(&mut self) {
if let Some(handle) = self.message_loop_handle.get_mut().unwrap().take() {
handle.abort();
}
}
}
impl CdpClient {
pub async fn connect(url: &str) -> Result<Self, CdpError> {
Self::connect_with_config(url, CdpClientConfig::default()).await
}
pub async fn connect_with_config(
url: &str,
config: CdpClientConfig,
) -> Result<Self, CdpError> {
let mut request = url.into_client_request()?;
for (key, value) in &config.additional_headers {
request.headers_mut().insert(
key.parse::<tokio_tungstenite::tungstenite::http::HeaderName>()
.map_err(|e| CdpError::Protocol {
code: -1,
message: format!("Invalid header name '{key}': {e}"),
data: None,
})?,
value
.parse()
.map_err(|e| CdpError::Protocol {
code: -1,
message: format!("Invalid header value for '{key}': {e}"),
data: None,
})?,
);
}
let mut ws_config = WebSocketConfig::default();
ws_config.max_message_size = config.max_message_size;
ws_config.max_frame_size = config.max_frame_size;
let (ws_stream, _) =
connect_async_with_config(request, Some(ws_config), false).await?;
let (sink, stream) = ws_stream.split();
let pending = Arc::new(AsyncMutex::new(HashMap::new()));
let event_registry = Arc::new(EventRegistry::new());
let closed = Arc::new(AtomicBool::new(false));
let handle = tokio::spawn({
let pending = pending.clone();
let registry = event_registry.clone();
let closed = closed.clone();
async move {
message_loop(stream, pending, registry, closed).await;
}
});
Ok(Self {
inner: Arc::new(ClientInner {
sink: AsyncMutex::new(sink),
next_id: AtomicU64::new(0),
pending,
event_registry,
closed: AtomicBool::new(false),
command_timeout: config.command_timeout,
message_loop_handle: std::sync::Mutex::new(Some(handle)),
}),
})
}
pub async fn send_raw(
&self,
method: &str,
params: Value,
session_id: Option<&str>,
) -> Result<Value, CdpError> {
if self.inner.closed.load(Ordering::Acquire) {
return Err(CdpError::ConnectionClosed);
}
let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed) + 1;
let (tx, rx) = oneshot::channel();
self.inner.pending.lock().await.insert(id, tx);
let mut msg = serde_json::json!({
"id": id,
"method": method,
"params": params,
});
if let Some(sid) = session_id {
msg["sessionId"] = Value::String(sid.to_string());
}
let send_result = self
.inner
.sink
.lock()
.await
.send(Message::Text(msg.to_string().into()))
.await;
if let Err(e) = send_result {
self.inner.pending.lock().await.remove(&id);
return Err(e.into());
}
match tokio::time::timeout(self.inner.command_timeout, rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => {
Err(CdpError::ConnectionClosed)
}
Err(_elapsed) => {
self.inner.pending.lock().await.remove(&id);
Err(CdpError::Timeout)
}
}
}
pub async fn emit_event(
&self,
method: &str,
params: Value,
session_id: Option<&str>,
) -> bool {
self.inner
.event_registry
.handle_event(method, params, session_id.map(String::from))
.await
}
pub(crate) fn event_registry(&self) -> &Arc<EventRegistry> {
&self.inner.event_registry
}
pub async fn close(&self) -> Result<(), CdpError> {
self.inner.closed.store(true, Ordering::Release);
{
let mut pending = self.inner.pending.lock().await;
for (_, tx) in pending.drain() {
let _ = tx.send(Err(CdpError::ConnectionClosed));
}
}
if let Some(handle) = self.inner.message_loop_handle.lock().unwrap().take() {
handle.abort();
let _ = handle.await;
}
self.inner.sink.lock().await.close().await?;
Ok(())
}
}
async fn message_loop(
mut stream: WsSource,
pending: Arc<AsyncMutex<PendingRequests>>,
event_registry: Arc<EventRegistry>,
closed: Arc<AtomicBool>,
) {
while let Some(msg_result) = stream.next().await {
match msg_result {
Ok(Message::Text(text)) => {
let data: Value = match serde_json::from_str(&text) {
Ok(v) => v,
Err(e) => {
tracing::warn!("Failed to parse CDP message: {e}");
continue;
}
};
if let Some(id) = data.get("id").and_then(|v| v.as_u64()) {
let mut pending = pending.lock().await;
if let Some(tx) = pending.remove(&id) {
let result = if let Some(error) = data.get("error") {
let code = error.get("code").and_then(|v| v.as_i64()).unwrap_or(0);
let message = error
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error")
.to_string();
let err_data = error.get("data").map(|v| v.to_string());
Err(CdpError::Protocol {
code,
message,
data: err_data,
})
} else {
Ok(data
.get("result")
.cloned()
.unwrap_or(Value::Object(Default::default())))
};
let _ = tx.send(result);
}
} else if let Some(method) = data.get("method").and_then(|v| v.as_str()) {
let params = data.get("params").cloned().unwrap_or_default();
let session_id = data
.get("sessionId")
.and_then(|v| v.as_str())
.map(String::from);
let registry = event_registry.clone();
let method = method.to_string();
tokio::spawn(async move {
registry.handle_event(&method, params, session_id).await;
});
}
}
Ok(Message::Close(_)) | Err(_) => {
closed.store(true, Ordering::Release);
let mut pending = pending.lock().await;
for (_, tx) in pending.drain() {
let _ = tx.send(Err(CdpError::ConnectionClosed));
}
break;
}
_ => {} }
}
}