use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use crate::cdp::CDPClient;
use crate::error::{BrowserError, Result};
#[derive(Debug, Clone, Copy, Default)]
pub enum WaitUntil {
DomContentLoaded,
#[default]
Load,
NetworkIdle,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Cookie {
pub name: String,
pub value: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub domain: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires: Option<f64>,
#[serde(default)]
pub secure: bool,
#[serde(default, rename = "httpOnly")]
pub http_only: bool,
#[serde(skip_serializing_if = "Option::is_none", rename = "sameSite")]
pub same_site: Option<String>,
}
#[derive(Clone)]
pub struct Locator {
selector: String,
page: Page,
}
impl Locator {
fn new(selector: impl Into<String>, page: Page) -> Self {
Self {
selector: selector.into(),
page,
}
}
pub async fn click(&self) -> Result<()> {
self.page.click_selector(&self.selector).await
}
pub async fn type_text(&self, text: &str) -> Result<()> {
self.page.type_text_selector(&self.selector, text).await
}
pub async fn wait_for(&self) -> Result<()> {
self.page.wait_for_selector(&self.selector).await
}
pub async fn wait_for_timeout(&self, dur: Duration) -> Result<()> {
self.page
.wait_for_selector_with_timeout(&self.selector, dur)
.await
}
pub async fn inner_text(&self) -> Result<String> {
let expr = format!(
"document.querySelector('{}')?.innerText ?? ''",
escape_selector(&self.selector)
);
let result = self
.page
.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr, "returnByValue": true })),
)
.await?;
result
.get("result")
.and_then(|r| r.get("value"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| {
BrowserError::invalid_response(
format!("inner_text('{}')", self.selector),
"unexpected result shape",
)
})
}
pub async fn get_attribute(&self, name: &str) -> Result<Option<String>> {
let expr = format!(
"document.querySelector('{}')?.getAttribute('{}') ?? null",
escape_selector(&self.selector),
name,
);
let result = self
.page
.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr, "returnByValue": true })),
)
.await?;
let val = result.get("result").and_then(|r| r.get("value"));
match val {
Some(Value::String(s)) => Ok(Some(s.clone())),
Some(Value::Null) | None => Ok(None),
_ => Ok(val.map(|v| v.to_string())),
}
}
}
#[derive(Clone)]
pub struct Page {
pub target_id: String,
pub session_id: String,
cdp: Arc<CDPClient>,
}
impl Page {
#[doc(hidden)]
pub fn new(target_id: String, session_id: String, cdp: Arc<CDPClient>) -> Self {
Page {
target_id,
session_id,
cdp,
}
}
pub fn locator(&self, selector: &str) -> Locator {
Locator::new(selector, self.clone())
}
pub async fn goto(&self, url: &str, wait_until: WaitUntil) -> Result<()> {
const TIMEOUT_SECS: u64 = 30;
let url_owned = url.to_string();
let session_id = self.session_id.clone();
let event_method = match wait_until {
WaitUntil::DomContentLoaded => "Page.domContentEventFired",
WaitUntil::Load | WaitUntil::NetworkIdle => "Page.loadEventFired",
};
let mut event_rx = self.cdp.subscribe_events();
let _ = self.send_command("Page.enable".to_string(), None).await;
let response = self
.send_command("Page.navigate".to_string(), Some(json!({ "url": url })))
.await?;
if let Some(error_text) = response.get("errorText").and_then(|v| v.as_str()) {
return Err(BrowserError::navigation_failed(&url_owned, error_text));
}
let wait_result = timeout(Duration::from_secs(TIMEOUT_SECS), async {
match wait_until {
WaitUntil::NetworkIdle => {
let mut last_activity = tokio::time::Instant::now();
loop {
tokio::select! {
recv = event_rx.recv() => {
match recv {
Ok(msg)
if msg.session_id.as_deref() == Some(&session_id) =>
{
last_activity = tokio::time::Instant::now();
}
Ok(_) => {} Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
last_activity = tokio::time::Instant::now();
}
Err(_) => {}
}
}
_ = tokio::time::sleep(Duration::from_millis(50)) => {
if last_activity.elapsed() >= Duration::from_millis(500) {
return Ok::<(), BrowserError>(());
}
}
}
}
}
_ => loop {
match event_rx.recv().await {
Ok(msg)
if msg.method.as_deref() == Some(event_method)
&& msg.session_id.as_deref() == Some(&session_id) =>
{
return Ok(());
}
Ok(_) => {} Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
return Ok(()); }
Err(_) => tokio::time::sleep(Duration::from_millis(50)).await,
}
},
}
})
.await;
wait_result.map_err(|_| {
BrowserError::timeout(format!("navigating to '{}'", url_owned), TIMEOUT_SECS)
})?
}
pub async fn evaluate_handle(&self, expression: &str) -> Result<String> {
let result = self
.send_command(
"Runtime.evaluate".to_string(),
Some(json!({
"expression": expression,
"returnByValue": false
})),
)
.await?;
if let Some(exc) = result.get("exceptionDetails") {
let msg = exc
.get("exception")
.and_then(|e| e.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("unknown JS exception");
return Err(BrowserError::command_failed("Runtime.evaluate", msg));
}
result
.get("result")
.and_then(|v| v.get("objectId"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| {
BrowserError::invalid_response(
"evaluate_handle()",
"missing result.objectId — may have evaluated to a primitive",
)
})
}
pub async fn evaluate<T: DeserializeOwned>(&self, expression: &str) -> Result<T> {
let result = self
.send_command(
"Runtime.evaluate".to_string(),
Some(json!({
"expression": expression,
"returnByValue": true,
"awaitPromise": true,
})),
)
.await?;
if let Some(exc) = result.get("exceptionDetails") {
let msg = exc
.get("exception")
.and_then(|e| e.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("unknown JS exception");
return Err(BrowserError::command_failed("Runtime.evaluate", msg));
}
let value = result
.get("result")
.and_then(|r| r.get("value"))
.cloned()
.unwrap_or(Value::Null);
serde_json::from_value(value)
.map_err(|e| BrowserError::invalid_response("evaluate()", e.to_string()))
}
pub async fn wait_for_selector(&self, selector: &str) -> Result<()> {
self.wait_for_selector_with_timeout(selector, Duration::from_secs(30))
.await
}
pub async fn wait_for_selector_with_timeout(
&self,
selector: &str,
dur: Duration,
) -> Result<()> {
let selector = selector.to_string();
let timeout_secs = dur.as_secs();
let fut = async {
loop {
let expr = format!("!!document.querySelector('{}')", escape_selector(&selector),);
let result = self
.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr, "returnByValue": true })),
)
.await?;
if let Some(true) = result
.get("result")
.and_then(|r| r.get("value"))
.and_then(|v| v.as_bool())
{
return Ok::<(), BrowserError>(());
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
};
timeout(dur, fut).await.map_err(|_| {
BrowserError::timeout(format!("waiting for selector '{}'", selector), timeout_secs)
})?
}
pub(crate) async fn click_selector(&self, selector: &str) -> Result<()> {
let expr = format!(
"document.querySelector('{}').click()",
escape_selector(selector),
);
self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr })),
)
.await?;
Ok(())
}
pub(crate) async fn type_text_selector(&self, selector: &str, text: &str) -> Result<()> {
let focus_expr = format!(
"document.querySelector('{}').focus()",
escape_selector(selector)
);
self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": focus_expr })),
)
.await?;
for ch in text.chars() {
self.send_command(
"Input.dispatchKeyEvent".to_string(),
Some(json!({
"type": "char",
"text": ch.to_string(),
})),
)
.await?;
}
Ok(())
}
pub async fn click(&self, selector: &str) -> Result<()> {
self.click_selector(selector).await
}
pub async fn type_text(&self, selector: &str, text: &str) -> Result<()> {
self.type_text_selector(selector, text).await
}
pub async fn content(&self) -> Result<String> {
let result = self
.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": "document.documentElement.outerHTML" })),
)
.await?;
result
.get("result")
.and_then(|v| v.get("value"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| {
BrowserError::invalid_response("content()", "missing result.value string")
})
}
pub async fn screenshot(&self) -> Result<Vec<u8>> {
let result = self
.send_command("Page.captureScreenshot".to_string(), None)
.await?;
let base64_data = result
.get("data")
.and_then(|v| v.as_str())
.ok_or_else(|| BrowserError::invalid_response("screenshot()", "missing data field"))?;
base64_decode(base64_data)
}
pub async fn intercept_requests<F>(&self, callback: F) -> Result<()>
where
F: Fn(&str, &str) -> bool + Send + 'static,
{
let _ = self.send_command("Network.enable".to_string(), None).await;
let _ = self
.send_command(
"Network.setRequestInterception".to_string(),
Some(json!({ "patterns": [{ "urlPattern": "*" }] })),
)
.await;
let mut event_rx = self.cdp.subscribe_events();
let cdp = self.cdp.clone();
let session_id = self.session_id.clone();
tokio::spawn(async move {
while let Ok(msg) = event_rx.recv().await {
if msg.method.as_deref() != Some("Network.requestIntercepted") {
continue;
}
if msg.session_id.as_deref() != Some(&session_id) {
continue;
}
if let Some(params) = msg.params {
let url = params
.get("request")
.and_then(|r| r.get("url"))
.and_then(|u| u.as_str())
.unwrap_or("");
let resource_type = params
.get("request")
.and_then(|r| r.get("resourceType"))
.and_then(|r| r.as_str())
.unwrap_or("");
let request_id = params
.get("requestId")
.and_then(|r| r.as_str())
.unwrap_or("");
let should_abort = callback(url, resource_type);
let cdp_method = if should_abort {
"Network.abortRequest"
} else {
"Network.continueInterceptedRequest"
};
let _ = cdp
.send_command_with_session(
&session_id,
cdp_method.to_string(),
Some(json!({ "requestId": request_id })),
)
.await;
}
}
});
Ok(())
}
pub async fn cookies(&self) -> Result<Vec<Cookie>> {
let result = self
.send_command("Network.getCookies".to_string(), None)
.await?;
let cookies_array = result
.get("cookies")
.and_then(|v| v.as_array())
.ok_or_else(|| BrowserError::invalid_response("cookies()", "missing cookies array"))?;
let mut cookies = Vec::new();
for cookie_val in cookies_array {
if let Ok(cookie) = serde_json::from_value::<Cookie>(cookie_val.clone()) {
cookies.push(cookie);
}
}
Ok(cookies)
}
pub async fn set_cookies(&self, cookies: &[Cookie]) -> Result<()> {
let cookie_params: Vec<Value> = cookies
.iter()
.map(|c| {
let mut obj = json!({
"name": c.name,
"value": c.value,
});
if let Some(domain) = &c.domain {
obj["domain"] = json!(domain);
}
if let Some(path) = &c.path {
obj["path"] = json!(path);
}
if let Some(expires) = c.expires {
obj["expires"] = json!(expires);
}
if c.secure {
obj["secure"] = json!(true);
}
if c.http_only {
obj["httpOnly"] = json!(true);
}
if let Some(same_site) = &c.same_site {
obj["sameSite"] = json!(same_site);
}
obj
})
.collect();
self.send_command(
"Network.setCookies".to_string(),
Some(json!({ "cookies": cookie_params })),
)
.await?;
Ok(())
}
pub async fn pdf(&self) -> Result<Vec<u8>> {
self.pdf_with_options(None).await
}
pub async fn pdf_with_options(&self, options: Option<&Value>) -> Result<Vec<u8>> {
let mut params = json!({
"landscape": false,
"displayHeaderFooter": false,
"scale": 1.0,
"paperWidth": 8.5,
"paperHeight": 11.0,
"marginTop": 0.4,
"marginBottom": 0.4,
"marginLeft": 0.4,
"marginRight": 0.4,
"preferCSSPageSize": true,
"transferMode": "ReturnAsBase64",
});
if let Some(opts) = options {
if let Some(obj) = params.as_object_mut() {
if let Some(opts_obj) = opts.as_object() {
for (key, value) in opts_obj.iter() {
obj.insert(key.clone(), value.clone());
}
}
}
}
let result = self
.send_command("Page.printToPDF".to_string(), Some(params))
.await?;
let base64_data = result
.get("data")
.and_then(|v| v.as_str())
.ok_or_else(|| BrowserError::invalid_response("pdf()", "missing data field"))?;
base64_decode(base64_data)
}
pub(crate) async fn send_command(
&self,
method: String,
params: Option<Value>,
) -> Result<Value> {
self.cdp
.send_command_with_session(&self.session_id, method, params)
.await
}
}
fn escape_selector(s: &str) -> String {
s.replace('\'', "\\'")
}
fn base64_decode(s: &str) -> Result<Vec<u8>> {
use base64::Engine;
let engine = base64::engine::general_purpose::STANDARD;
engine.decode(s).map_err(|e| {
BrowserError::invalid_response("screenshot()", format!("base64 decode failed: {e}"))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wait_until_default() {
let w: WaitUntil = Default::default();
assert!(matches!(w, WaitUntil::Load));
}
#[test]
fn test_escape_selector_plain() {
assert_eq!(escape_selector("button#id"), "button#id");
}
#[test]
fn test_escape_selector_quotes() {
assert_eq!(escape_selector("input[name='q']"), "input[name=\\'q\\']");
}
}