use crate::error::{BrowsingError, Result};
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{Mutex, mpsc};
use tokio::time::sleep;
use tokio_tungstenite::{connect_async, tungstenite::Message};
pub const MAX_RETRY_ATTEMPTS: u32 = 3;
pub const INITIAL_RETRY_DELAY_MS: u64 = 100;
pub const MAX_RETRY_DELAY_MS: u64 = 5000;
pub struct CdpClient {
url: String,
sender: Arc<Mutex<Option<mpsc::UnboundedSender<Message>>>>,
receiver: Arc<Mutex<Option<mpsc::UnboundedReceiver<Value>>>>,
request_id: Arc<Mutex<u64>>,
pending_requests: Arc<Mutex<HashMap<u64, mpsc::UnboundedSender<Value>>>>,
}
impl CdpClient {
pub fn new(url: String) -> Self {
Self {
url,
sender: Arc::new(Mutex::new(None)),
receiver: Arc::new(Mutex::new(None)),
request_id: Arc::new(Mutex::new(0)),
pending_requests: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn start(&mut self) -> Result<()> {
let (ws_stream, _) = connect_async(&self.url)
.await
.map_err(|e| BrowsingError::Cdp(format!("Failed to connect to CDP: {e}")))?;
let (mut write, mut read) = ws_stream.split();
let (tx, mut rx) = mpsc::unbounded_channel();
let (_tx_resp, rx_resp) = mpsc::unbounded_channel();
*self.sender.lock().await = Some(tx);
*self.receiver.lock().await = Some(rx_resp);
let pending_requests = Arc::clone(&self.pending_requests);
tokio::spawn(async move {
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(Message::Close(_)) => break, Some(m) => {
if let Err(e) = write.send(m).await {
tracing::debug!("WebSocket send error during shutdown: {}", e);
break;
}
}
None => break, }
}
msg = read.next() => {
match msg {
Some(Ok(Message::Text(text))) => {
if let Ok(value) = serde_json::from_str::<Value>(&text) {
if let Some(id_val) = value.get("id").and_then(|v| v.as_u64()) {
if let Some(tx) = pending_requests.lock().await.remove(&id_val) {
let _ = tx.send(value);
}
}
}
}
Some(Ok(Message::Close(_))) => break,
Some(Err(e)) => {
tracing::debug!("WebSocket closed: {}", e);
break;
}
Some(Ok(_)) => {}
None => break,
}
}
}
}
});
Ok(())
}
pub async fn send_command(&self, method: &str, params: Value) -> Result<Value> {
self.send_command_with_session(method, params, None).await
}
pub async fn send_command_with_session(
&self,
method: &str,
params: Value,
session_id: Option<&str>,
) -> Result<Value> {
self.send_command_with_session_retry(method, params, session_id, 0)
.await
}
async fn send_command_with_session_retry(
&self,
method: &str,
params: Value,
session_id: Option<&str>,
attempt: u32,
) -> Result<Value> {
match self
._send_command_internal(method, params.clone(), session_id)
.await
{
Ok(result) => Ok(result),
Err(e) => {
if attempt >= MAX_RETRY_ATTEMPTS {
return Err(BrowsingError::RetryLimitExceeded(
attempt + 1,
e.to_string(),
));
}
if self.is_retryable_error(&e) {
let delay_ms = Self::calculate_backoff_delay(attempt);
tracing::warn!(
"CDP command '{}' failed (attempt {}/{}), retrying in {}ms: {}",
method,
attempt + 1,
MAX_RETRY_ATTEMPTS,
delay_ms,
e
);
sleep(Duration::from_millis(delay_ms)).await;
Box::pin(self.send_command_with_session_retry(
method,
params,
session_id,
attempt + 1,
))
.await
} else {
Err(e)
}
}
}
}
async fn _send_command_internal(
&self,
method: &str,
params: Value,
session_id: Option<&str>,
) -> Result<Value> {
let mut request_id = self.request_id.lock().await;
let id = *request_id;
*request_id += 1;
drop(request_id);
let mut request = serde_json::json!({
"id": id,
"method": method,
"params": params
});
if let Some(sid) = session_id {
request["sessionId"] = serde_json::json!(sid);
}
let (tx, mut rx) = mpsc::unbounded_channel();
self.pending_requests.lock().await.insert(id, tx);
let sender_opt = self.sender.lock().await.as_ref().cloned();
let sender = sender_opt.ok_or_else(|| {
BrowsingError::Cdp("CDP client is not connected. Call start() first.".to_string())
})?;
sender
.send(Message::Text(request.to_string()))
.map_err(|e| BrowsingError::Cdp(format!("Failed to send command: {e}")))?;
if let Some(response) = rx.recv().await {
if let Some(error) = response.get("error") {
return Err(BrowsingError::Cdp(format!("CDP error: {error}")));
}
return Ok(response["result"].clone());
}
Err(BrowsingError::Cdp("No response received".to_string()))
}
pub fn is_retryable_error(&self, error: &BrowsingError) -> bool {
match error {
BrowsingError::Cdp(msg) if msg.contains("Failed to send command") => true,
BrowsingError::Cdp(msg) if msg.contains("No response received") => true,
BrowsingError::Cdp(msg) if msg.contains("connection") => true,
BrowsingError::Cdp(msg) if msg.contains("WebSocket") => true,
BrowsingError::Cdp(msg) if msg.contains("Target not found") => true,
BrowsingError::Cdp(msg) if msg.contains("Session not found") => true,
_ => false,
}
}
pub fn calculate_backoff_delay(attempt: u32) -> u64 {
let delay = INITIAL_RETRY_DELAY_MS * 2_u64.pow(attempt);
delay.min(MAX_RETRY_DELAY_MS)
}
pub async fn close(&self) {
if let Some(sender) = self.sender.lock().await.as_ref() {
let _ = sender.send(Message::Close(None));
}
}
}
pub struct CdpSession {
pub client: Arc<CdpClient>,
pub target_id: String,
pub session_id: String,
pub title: String,
pub url: String,
}
impl CdpSession {
pub async fn for_target(
client: Arc<CdpClient>,
target_id: String,
domains: Option<Vec<String>>,
) -> Result<Self> {
let params = serde_json::json!({
"targetId": target_id,
"flatten": true
});
let result = client.send_command("Target.attachToTarget", params).await?;
let session_id = result["sessionId"]
.as_str()
.ok_or_else(|| BrowsingError::Cdp("No sessionId in response".to_string()))?
.to_string();
let domains = domains.unwrap_or_else(|| {
vec![
"Page".to_string(),
"DOM".to_string(),
"DOMSnapshot".to_string(),
"Accessibility".to_string(),
"Runtime".to_string(),
"Inspector".to_string(),
]
});
for domain in &domains {
let method = format!("{domain}.enable");
let _ = client
.send_command_with_session(&method, serde_json::json!({}), Some(&session_id))
.await;
}
let target_info_params = serde_json::json!({"targetId": target_id});
let target_info = client
.send_command("Target.getTargetInfo", target_info_params)
.await?;
let title = target_info["targetInfo"]["title"]
.as_str()
.unwrap_or("Unknown title")
.to_string();
let url = target_info["targetInfo"]["url"]
.as_str()
.unwrap_or("about:blank")
.to_string();
Ok(Self {
client,
target_id,
session_id,
title,
url,
})
}
}