use std::collections::HashMap;
use std::sync::{Arc, Mutex as StdMutex};
use std::time::Duration;
use crate::input::{CdpKeyboard, CdpMouse, CdpTouchscreen, HumanMouse};
use rustenium_cdp_definitions::Command;
use rustenium_cdp_definitions::base::CommandResponse;
use rustenium_core::error::CdpCommandResultError;
use rustenium_cdp_definitions::browser_protocol::dom::commands::{
DescribeNode, GetDocument, QuerySelector, QuerySelectorAll,
};
use rustenium_cdp_definitions::browser_protocol::dom::results::{
DescribeNodeResult, GetDocumentResult, QuerySelectorResult, QuerySelectorAllResult,
};
use rustenium_cdp_definitions::browser_protocol::dom::types::Node as DomNode;
use rustenium_cdp_definitions::browser_protocol::emulation::commands::SetDeviceMetricsOverride;
use rustenium_cdp_definitions::browser_protocol::page::commands::Navigate;
use rustenium_cdp_definitions::browser_protocol::page::results::NavigateResult;
use rustenium_cdp_definitions::browser_protocol::page::commands::CaptureScreenshot;
use rustenium_cdp_definitions::browser_protocol::page::results::CaptureScreenshotResult;
use rustenium_cdp_definitions::browser_protocol::target::command_builders::SetDiscoverTargetsBuilder;
use rustenium_cdp_definitions::browser_protocol::target::commands::CreateTarget;
use rustenium_cdp_definitions::browser_protocol::target::events::{TargetCreated, TargetDestroyed};
use rustenium_cdp_definitions::browser_protocol::target::results::CreateTargetResult;
use rustenium_cdp_definitions::browser_protocol::target::types::{TargetId, TargetInfo};
use rustenium_core::WebsocketConnectionTransport;
use rustenium_core::error::CdpSessionSendError;
use rustenium_core::session::CdpSession;
use rustenium_core::transport::{ConnectionTransport, ConnectionTransportConfig};
use rustenium_core::CdpEventManagement;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpStream;
use tokio::sync::Mutex as TokioMutex;
use tokio::time::sleep;
#[derive(Clone)]
pub struct CdpAdapter<T: ConnectionTransport + Send + Sync> {
pub session: Arc<TokioMutex<CdpSession<T>>>,
pub page_targets: Arc<StdMutex<HashMap<TargetId, TargetInfo>>>,
pub mouse: Arc<CdpMouse>,
pub human_mouse: Arc<HumanMouse<CdpMouse>>,
pub keyboard: Arc<CdpKeyboard<T>>,
pub touchscreen: Arc<CdpTouchscreen>,
}
impl CdpAdapter<WebsocketConnectionTransport> {
pub fn new(session: Arc<TokioMutex<CdpSession<WebsocketConnectionTransport>>>) -> Self {
let modifiers = Arc::new(StdMutex::new(0i64));
let mouse = CdpMouse::new(session.clone(), modifiers.clone());
let human_mouse = Arc::new(HumanMouse::new(mouse.clone()));
let mouse = Arc::new(mouse);
let keyboard = Arc::new(CdpKeyboard::new(session.clone(), modifiers));
let touchscreen = Arc::new(CdpTouchscreen::new(session.clone()));
Self {
session,
page_targets: Arc::new(StdMutex::new(HashMap::new())),
mouse,
human_mouse,
keyboard,
touchscreen,
}
}
}
impl<T: ConnectionTransport + Send + Sync> CdpAdapter<T> {
pub async fn listen_to_target_creation(&mut self) -> Result<(), CdpSessionSendError> {
let page_targets = self.page_targets.clone();
self.session.lock().await.add_event_handler(
[TargetCreated::IDENTIFIER],
move |event| {
let page_targets = page_targets.clone();
async move {
if let Ok(target) = event.try_into_event::<TargetCreated>() {
let info = target.params.target_info;
tracing::debug!("[CdpAdapter] Target created: id={}, type={}, url={}", info.target_id.as_ref(), info.r#type, info.url);
if info.r#type == "page" {
page_targets.lock().unwrap().insert(info.target_id.clone(), info);
}
}
}
},
);
let page_targets = self.page_targets.clone();
self.session.lock().await.add_event_handler(
[TargetDestroyed::IDENTIFIER],
move |event| {
let page_targets = page_targets.clone();
async move {
if let Ok(destroyed) = event.try_into_event::<TargetDestroyed>() {
let id = &destroyed.params.target_id;
tracing::debug!("[CdpAdapter] Target destroyed: id={}", id.as_ref());
page_targets.lock().unwrap().remove(id);
}
}
},
);
let command = SetDiscoverTargetsBuilder::default().discover(true).build().unwrap();
self.send_command(command).await?;
Ok(())
}
pub async fn send_command(
&mut self,
command: impl Into<Command>,
) -> Result<CommandResponse, CdpSessionSendError> {
self.session.lock().await.send(command).await
}
pub async fn navigate(
&mut self,
command: Navigate,
) -> Result<NavigateResult, crate::error::cdp::NavigateError> {
let result_value = self
.send_command(command)
.await
.map_err(|err| {
crate::error::cdp::NavigateError::CommandResultError(
CdpCommandResultError::SessionSendError(err),
)
})?
.result;
NavigateResult::try_from(result_value.clone()).map_err(|_| {
crate::error::cdp::NavigateError::CommandResultError(
CdpCommandResultError::InvalidResultTypeError(result_value),
)
})
}
pub async fn create_target(
&mut self,
command: CreateTarget,
) -> Result<TargetId, crate::error::cdp::CreateTargetError> {
let result_value = self
.send_command(command)
.await
.map_err(|err| {
crate::error::cdp::CreateTargetError::CommandResultError(
CdpCommandResultError::SessionSendError(err),
)
})?
.result;
let result = CreateTargetResult::try_from(result_value.clone()).map_err(|_| {
crate::error::cdp::CreateTargetError::CommandResultError(
CdpCommandResultError::InvalidResultTypeError(result_value),
)
})?;
Ok(result.target_id)
}
pub async fn fetch_node(
&mut self,
command: DescribeNode,
) -> Result<DomNode, crate::error::cdp::NodesFetchError> {
let result_value = self
.send_command(command)
.await
.map_err(|err| {
crate::error::cdp::NodesFetchError::CommandResultError(
CdpCommandResultError::SessionSendError(err),
)
})?
.result;
let result = DescribeNodeResult::try_from(result_value)
.map_err(|e| crate::error::cdp::NodesFetchError::ParseError(e.to_string()))?;
Ok(*result.node)
}
pub async fn emulate_device_metrics(
&mut self,
command: SetDeviceMetricsOverride,
) -> Result<(), crate::error::cdp::EmulateDeviceMetricsError> {
self.send_command(command)
.await
.map_err(|err| {
crate::error::cdp::EmulateDeviceMetricsError::CommandResultError(
CdpCommandResultError::SessionSendError(err),
)
})?;
Ok(())
}
async fn get_root_node_id(&mut self) -> Result<rustenium_cdp_definitions::browser_protocol::dom::types::NodeId, crate::error::cdp::LocateError> {
let result_value = self
.send_command(GetDocument::builder().depth(0).build())
.await
.map_err(|e| crate::error::cdp::LocateError::CommandResultError(CdpCommandResultError::SessionSendError(e)))?
.result;
let doc = GetDocumentResult::try_from(result_value)
.map_err(|e| crate::error::cdp::LocateError::ParseError(e.to_string()))?;
Ok(*doc.root.node_id)
}
async fn describe_by_id(
&mut self,
node_id: rustenium_cdp_definitions::browser_protocol::dom::types::NodeId,
) -> Result<DomNode, crate::error::cdp::LocateError> {
let result_value = self
.send_command(
DescribeNode::builder()
.node_id(node_id)
.depth(-1)
.build(),
)
.await
.map_err(|e| crate::error::cdp::LocateError::CommandResultError(CdpCommandResultError::SessionSendError(e)))?
.result;
DescribeNodeResult::try_from(result_value)
.map(|r| *r.node)
.map_err(|e| crate::error::cdp::LocateError::ParseError(e.to_string()))
}
pub async fn locate(&mut self, selector: &str) -> Result<Option<DomNode>, crate::error::cdp::LocateError> {
let root_id = self.get_root_node_id().await?;
let cmd = QuerySelector::builder()
.node_id(root_id)
.selector(selector)
.build()
.map_err(|e| crate::error::cdp::LocateError::ParseError(e))?;
let result_value = self
.send_command(cmd)
.await
.map_err(|e| crate::error::cdp::LocateError::CommandResultError(CdpCommandResultError::SessionSendError(e)))?
.result;
let qs = QuerySelectorResult::try_from(result_value)
.map_err(|e| crate::error::cdp::LocateError::ParseError(e.to_string()))?;
let node_id = *qs.node_id;
if *node_id.inner() == 0 {
return Ok(None);
}
Ok(Some(self.describe_by_id(node_id).await?))
}
pub async fn locate_all(&mut self, selector: &str) -> Result<Vec<DomNode>, crate::error::cdp::LocateError> {
let root_id = self.get_root_node_id().await?;
let cmd = QuerySelectorAll::builder()
.node_id(root_id)
.selector(selector)
.build()
.map_err(|e| crate::error::cdp::LocateError::ParseError(e))?;
let result_value = self
.send_command(cmd)
.await
.map_err(|e| crate::error::cdp::LocateError::CommandResultError(CdpCommandResultError::SessionSendError(e)))?
.result;
let ids = QuerySelectorAllResult::try_from(result_value)
.map_err(|e| crate::error::cdp::LocateError::ParseError(e.to_string()))?
.node_ids;
let mut nodes = Vec::with_capacity(ids.len());
for id in ids {
nodes.push(self.describe_by_id(id).await?);
}
Ok(nodes)
}
pub async fn wait_for(&mut self, selector: &str, timeout: Duration) -> Result<DomNode, crate::error::cdp::LocateError> {
let interval = Duration::from_millis(100);
let start = tokio::time::Instant::now();
loop {
if let Some(node) = self.locate(selector).await? {
return Ok(node);
}
if start.elapsed() >= timeout {
return Err(crate::error::cdp::LocateError::Timeout(selector.to_string()));
}
sleep(interval).await;
}
}
pub async fn wait_for_all(&mut self, selector: &str, timeout: Duration) -> Result<Vec<DomNode>, crate::error::cdp::LocateError> {
let interval = Duration::from_millis(100);
let start = tokio::time::Instant::now();
loop {
let nodes = self.locate_all(selector).await?;
if !nodes.is_empty() {
return Ok(nodes);
}
if start.elapsed() >= timeout {
return Err(crate::error::cdp::LocateError::Timeout(selector.to_string()));
}
sleep(interval).await;
}
}
pub async fn screenshot(&mut self) -> Result<String, crate::error::cdp::ScreenshotError> {
let cmd = CaptureScreenshot::builder().build();
let result_value = self
.send_command(cmd)
.await
.map_err(|e| crate::error::cdp::ScreenshotError::CommandResultError(
CdpCommandResultError::SessionSendError(e),
))?
.result;
let result = CaptureScreenshotResult::try_from(result_value)
.map_err(|e| crate::error::cdp::ScreenshotError::ParseError(e.to_string()))?;
Ok(String::from(result.data))
}
pub async fn close(&mut self) {
self.session.lock().await.close().await;
}
}
pub async fn fetch_ws_debugger_url_with_retry(
host: &str,
chrome_port: u16,
) -> Result<String, String> {
let mut last_err = None;
for attempt in 0..3 {
match fetch_ws_debugger_url(host, chrome_port).await {
Ok(url) => return Ok(url),
Err(e) => {
last_err = Some(e);
if attempt < 2 {
sleep(Duration::from_millis(500)).await;
}
}
}
}
Err(last_err.expect("retry loop should always set an error"))
}
pub async fn fetch_ws_debugger_url(host: &str, port: u16) -> Result<String, String> {
let addr = format!("{}:{}", host, port);
let mut stream = TcpStream::connect(&addr)
.await
.map_err(|e| format!("connect to {}: {}", addr, e))?;
let request = format!(
"GET /json HTTP/1.1\r\nHost: {}\r\nConnection: close\r\n\r\n",
addr
);
stream
.write_all(request.as_bytes())
.await
.map_err(|e| format!("write request: {}", e))?;
let mut reader = BufReader::new(stream);
let mut headers = String::new();
let mut content_length: Option<usize> = None;
loop {
let mut line = String::new();
reader
.read_line(&mut line)
.await
.map_err(|e| format!("read header: {}", e))?;
if line == "\r\n" || line.is_empty() {
break;
}
if let Some(val) = line
.strip_prefix("Content-Length:")
.or_else(|| line.strip_prefix("content-length:"))
{
content_length = val.trim().parse().ok();
}
headers.push_str(&line);
}
let body = if let Some(len) = content_length {
let mut buf = vec![0u8; len];
tokio::io::AsyncReadExt::read_exact(&mut reader, &mut buf)
.await
.map_err(|e| format!("read body: {}", e))?;
String::from_utf8(buf).map_err(|e| format!("invalid utf8: {}", e))?
} else {
let mut buf = String::new();
tokio::time::timeout(
std::time::Duration::from_secs(5),
tokio::io::AsyncReadExt::read_to_string(&mut reader, &mut buf),
)
.await
.map_err(|_| "timeout reading response body".to_string())?
.map_err(|e| format!("read body: {}", e))?;
buf
};
let targets: serde_json::Value =
serde_json::from_str(&body).map_err(|e| format!("parse JSON: {}", e))?;
targets[0]["webSocketDebuggerUrl"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| "webSocketDebuggerUrl not found in /json response".to_string())
}
pub async fn start_cdp_session(
connection_transport_config: &ConnectionTransportConfig,
) -> Arc<TokioMutex<CdpSession<WebsocketConnectionTransport>>> {
let session = CdpSession::<WebsocketConnectionTransport>::ws_new(connection_transport_config).await;
Arc::new(TokioMutex::new(session))
}