use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::sync::{broadcast, oneshot, Mutex};
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
use super::types::{CdpCommand, CdpEvent, CdpMessage};
type PendingMap = Arc<Mutex<HashMap<u64, oneshot::Sender<CdpMessage>>>>;
#[derive(Debug, Clone)]
pub struct RawCdpMessage {
pub text: String,
pub session_id: Option<String>,
}
pub struct CdpClient {
ws_tx: Arc<
Mutex<
futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
Message,
>,
>,
>,
next_id: AtomicU64,
pending: PendingMap,
event_tx: broadcast::Sender<CdpEvent>,
raw_tx: broadcast::Sender<RawCdpMessage>,
_reader_handle: tokio::task::JoinHandle<()>,
}
impl CdpClient {
pub async fn connect(url: &str) -> Result<Self, String> {
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| format!("CDP WebSocket connect failed: {}", e))?;
let (ws_tx, mut ws_rx) = ws_stream.split();
let ws_tx = Arc::new(Mutex::new(ws_tx));
let pending: PendingMap = Arc::new(Mutex::new(HashMap::new()));
let (event_tx, _) = broadcast::channel(256);
let (raw_tx, _) = broadcast::channel(512);
let pending_clone = pending.clone();
let event_tx_clone = event_tx.clone();
let raw_tx_clone = raw_tx.clone();
let reader_handle = tokio::spawn(async move {
while let Some(msg) = ws_rx.next().await {
let msg = match msg {
Ok(Message::Text(text)) => text,
Ok(Message::Close(_)) => break,
Ok(_) => continue,
Err(_) => break,
};
if raw_tx_clone.receiver_count() > 0 {
let session_id = serde_json::from_str::<serde_json::Value>(&msg)
.ok()
.and_then(|v| v.get("sessionId")?.as_str().map(String::from));
let _ = raw_tx_clone.send(RawCdpMessage {
text: msg.clone(),
session_id,
});
}
let parsed: CdpMessage = match serde_json::from_str(&msg) {
Ok(m) => m,
Err(_) => continue,
};
if let Some(id) = parsed.id {
let mut pending = pending_clone.lock().await;
if let Some(tx) = pending.remove(&id) {
let _ = tx.send(parsed);
}
} else if let Some(ref method) = parsed.method {
let event = CdpEvent {
method: method.clone(),
params: parsed.params.clone().unwrap_or(Value::Null),
session_id: parsed.session_id.clone(),
};
let _ = event_tx_clone.send(event);
}
}
});
Ok(Self {
ws_tx,
next_id: AtomicU64::new(1),
pending,
event_tx,
raw_tx,
_reader_handle: reader_handle,
})
}
pub async fn send_command(
&self,
method: &str,
params: Option<Value>,
session_id: Option<&str>,
) -> Result<Value, String> {
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let cmd = CdpCommand {
id,
method: method.to_string(),
params,
session_id: session_id.map(|s| s.to_string()),
};
let json = serde_json::to_string(&cmd)
.map_err(|e| format!("Failed to serialize CDP command: {}", e))?;
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(id, tx);
}
{
let mut ws_tx = self.ws_tx.lock().await;
ws_tx
.send(Message::Text(json))
.await
.map_err(|e| format!("Failed to send CDP command: {}", e))?;
}
let response = match tokio::time::timeout(std::time::Duration::from_secs(30), rx).await {
Ok(Ok(resp)) => resp,
Ok(Err(_)) => return Err("CDP response channel closed".to_string()),
Err(_) => {
self.pending.lock().await.remove(&id);
return Err(format!("CDP command timed out: {}", method));
}
};
if let Some(error) = response.error {
return Err(format!("CDP error ({}): {}", method, error));
}
Ok(response.result.unwrap_or(Value::Null))
}
pub fn subscribe(&self) -> broadcast::Receiver<CdpEvent> {
self.event_tx.subscribe()
}
pub fn subscribe_raw(&self) -> broadcast::Receiver<RawCdpMessage> {
self.raw_tx.subscribe()
}
pub fn inspect_handle(&self) -> InspectProxyHandle {
InspectProxyHandle {
ws_tx: self.ws_tx.clone(),
raw_tx: self.raw_tx.clone(),
}
}
pub async fn send_command_typed<P: serde::Serialize, R: serde::de::DeserializeOwned>(
&self,
method: &str,
params: &P,
session_id: Option<&str>,
) -> Result<R, String> {
let params_value = serde_json::to_value(params)
.map_err(|e| format!("Failed to serialize params: {}", e))?;
let result = self
.send_command(method, Some(params_value), session_id)
.await?;
serde_json::from_value(result)
.map_err(|e| format!("Failed to deserialize CDP response for {}: {}", method, e))
}
pub async fn send_command_no_params(
&self,
method: &str,
session_id: Option<&str>,
) -> Result<Value, String> {
self.send_command(method, None, session_id).await
}
pub async fn send_raw(&self, json: String) -> Result<(), String> {
let mut ws_tx = self.ws_tx.lock().await;
ws_tx
.send(Message::Text(json))
.await
.map_err(|e| format!("Failed to send raw CDP message: {}", e))
}
}
type WsTx = Arc<
Mutex<
futures_util::stream::SplitSink<
tokio_tungstenite::WebSocketStream<
tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
>,
Message,
>,
>,
>;
pub struct InspectProxyHandle {
ws_tx: WsTx,
raw_tx: broadcast::Sender<RawCdpMessage>,
}
impl InspectProxyHandle {
pub async fn send_raw(&self, json: String) -> Result<(), String> {
let mut ws_tx = self.ws_tx.lock().await;
ws_tx
.send(Message::Text(json))
.await
.map_err(|e| format!("Failed to send raw CDP message: {}", e))
}
pub fn subscribe_raw(&self) -> broadcast::Receiver<RawCdpMessage> {
self.raw_tx.subscribe()
}
}