use serde::de::DeserializeOwned;
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::time::{timeout, Duration};
use crate::cdp::CDPClient;
use crate::error::{BrowserError, Result};
#[derive(Debug, Clone, Copy, Default)]
pub enum WaitUntil {
DomContentLoaded,
#[default]
Load,
NetworkIdle,
}
#[derive(Clone)]
pub struct Locator {
selector: String,
page: Page,
}
impl Locator {
fn new(selector: impl Into<String>, page: Page) -> Self {
Self {
selector: selector.into(),
page,
}
}
pub async fn click(&self) -> Result<()> {
self.page.click_selector(&self.selector).await
}
pub async fn type_text(&self, text: &str) -> Result<()> {
self.page.type_text_selector(&self.selector, text).await
}
pub async fn wait_for(&self) -> Result<()> {
self.page.wait_for_selector(&self.selector).await
}
pub async fn wait_for_timeout(&self, dur: Duration) -> Result<()> {
self.page.wait_for_selector_with_timeout(&self.selector, dur).await
}
pub async fn inner_text(&self) -> Result<String> {
let expr = format!("document.querySelector('{}')?.innerText ?? ''", escape_selector(&self.selector));
let result = self.page.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr, "returnByValue": true })),
).await?;
result
.get("result")
.and_then(|r| r.get("value"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| BrowserError::invalid_response(
format!("inner_text('{}')", self.selector),
"unexpected result shape",
))
}
pub async fn get_attribute(&self, name: &str) -> Result<Option<String>> {
let expr = format!(
"document.querySelector('{}')?.getAttribute('{}') ?? null",
escape_selector(&self.selector),
name,
);
let result = self.page.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr, "returnByValue": true })),
).await?;
let val = result
.get("result")
.and_then(|r| r.get("value"));
match val {
Some(Value::String(s)) => Ok(Some(s.clone())),
Some(Value::Null) | None => Ok(None),
_ => Ok(val.map(|v| v.to_string())),
}
}
}
#[derive(Clone)]
pub struct Page {
pub target_id: String,
pub session_id: String,
cdp: Arc<CDPClient>,
}
impl Page {
#[doc(hidden)]
pub fn new(target_id: String, session_id: String, cdp: Arc<CDPClient>) -> Self {
Page {
target_id,
session_id,
cdp,
}
}
pub fn locator(&self, selector: &str) -> Locator {
Locator::new(selector, self.clone())
}
pub async fn goto(&self, url: &str, wait_until: WaitUntil) -> Result<()> {
const TIMEOUT_SECS: u64 = 30;
let url_owned = url.to_string();
let session_id = self.session_id.clone();
let event_method = match wait_until {
WaitUntil::DomContentLoaded => "Page.domContentEventFired",
WaitUntil::Load | WaitUntil::NetworkIdle => "Page.loadEventFired",
};
let mut event_rx = self.cdp.subscribe_events();
let _ = self.send_command("Page.enable".to_string(), None).await;
let response = self.send_command(
"Page.navigate".to_string(),
Some(json!({ "url": url })),
).await?;
if let Some(error_text) = response.get("errorText").and_then(|v| v.as_str()) {
return Err(BrowserError::navigation_failed(&url_owned, error_text));
}
let wait_result = timeout(Duration::from_secs(TIMEOUT_SECS), async {
match wait_until {
WaitUntil::NetworkIdle => {
let mut last_activity = tokio::time::Instant::now();
loop {
tokio::select! {
recv = event_rx.recv() => {
match recv {
Ok(msg)
if msg.session_id.as_deref() == Some(&session_id) =>
{
last_activity = tokio::time::Instant::now();
}
Ok(_) => {} Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
last_activity = tokio::time::Instant::now();
}
Err(_) => {}
}
}
_ = tokio::time::sleep(Duration::from_millis(50)) => {
if last_activity.elapsed() >= Duration::from_millis(500) {
return Ok::<(), BrowserError>(());
}
}
}
}
}
_ => loop {
match event_rx.recv().await {
Ok(msg)
if msg.method.as_deref() == Some(event_method)
&& msg.session_id.as_deref() == Some(&session_id) =>
{
return Ok(());
}
Ok(_) => {} Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
return Ok(()); }
Err(_) => tokio::time::sleep(Duration::from_millis(50)).await,
}
},
}
})
.await;
wait_result.map_err(|_| BrowserError::timeout(
format!("navigating to '{}'", url_owned),
TIMEOUT_SECS,
))?
}
pub async fn evaluate<T: DeserializeOwned>(&self, expression: &str) -> Result<T> {
let result = self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({
"expression": expression,
"returnByValue": true,
"awaitPromise": true,
})),
).await?;
if let Some(exc) = result.get("exceptionDetails") {
let msg = exc
.get("exception")
.and_then(|e| e.get("description"))
.and_then(|d| d.as_str())
.unwrap_or("unknown JS exception");
return Err(BrowserError::command_failed("Runtime.evaluate", msg));
}
let value = result
.get("result")
.and_then(|r| r.get("value"))
.cloned()
.unwrap_or(Value::Null);
serde_json::from_value(value)
.map_err(|e| BrowserError::invalid_response("evaluate()", e.to_string()))
}
pub async fn wait_for_selector(&self, selector: &str) -> Result<()> {
self.wait_for_selector_with_timeout(selector, Duration::from_secs(30)).await
}
pub async fn wait_for_selector_with_timeout(&self, selector: &str, dur: Duration) -> Result<()> {
let selector = selector.to_string();
let timeout_secs = dur.as_secs();
let fut = async {
loop {
let expr = format!(
"!!document.querySelector('{}')",
escape_selector(&selector),
);
let result = self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr, "returnByValue": true })),
).await?;
if let Some(true) = result
.get("result")
.and_then(|r| r.get("value"))
.and_then(|v| v.as_bool())
{
return Ok::<(), BrowserError>(());
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
};
timeout(dur, fut).await.map_err(|_| BrowserError::timeout(
format!("waiting for selector '{}'", selector),
timeout_secs,
))?
}
pub(crate) async fn click_selector(&self, selector: &str) -> Result<()> {
let expr = format!(
"document.querySelector('{}').click()",
escape_selector(selector),
);
self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": expr })),
).await?;
Ok(())
}
pub(crate) async fn type_text_selector(&self, selector: &str, text: &str) -> Result<()> {
let focus_expr = format!("document.querySelector('{}').focus()", escape_selector(selector));
self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": focus_expr })),
).await?;
for ch in text.chars() {
self.send_command(
"Input.dispatchKeyEvent".to_string(),
Some(json!({
"type": "char",
"text": ch.to_string(),
})),
).await?;
}
Ok(())
}
pub async fn click(&self, selector: &str) -> Result<()> {
self.click_selector(selector).await
}
pub async fn type_text(&self, selector: &str, text: &str) -> Result<()> {
self.type_text_selector(selector, text).await
}
pub async fn content(&self) -> Result<String> {
let result = self.send_command(
"Runtime.evaluate".to_string(),
Some(json!({ "expression": "document.documentElement.outerHTML" })),
).await?;
result
.get("result")
.and_then(|v| v.get("value"))
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.ok_or_else(|| BrowserError::invalid_response("content()", "missing result.value string"))
}
pub async fn screenshot(&self) -> Result<Vec<u8>> {
let result = self.send_command(
"Page.captureScreenshot".to_string(),
None,
).await?;
let base64_data = result
.get("data")
.and_then(|v| v.as_str())
.ok_or_else(|| BrowserError::invalid_response("screenshot()", "missing data field"))?;
base64_decode(base64_data)
}
pub async fn intercept_requests<F>(&self, callback: F) -> Result<()>
where
F: Fn(&str, &str) -> bool + Send + 'static,
{
let _ = self.send_command("Network.enable".to_string(), None).await;
let _ = self.send_command(
"Network.setRequestInterception".to_string(),
Some(json!({ "patterns": [{ "urlPattern": "*" }] })),
).await;
let mut event_rx = self.cdp.subscribe_events();
let cdp = self.cdp.clone();
let session_id = self.session_id.clone();
tokio::spawn(async move {
while let Ok(msg) = event_rx.recv().await {
if msg.method.as_deref() != Some("Network.requestIntercepted") {
continue;
}
if msg.session_id.as_deref() != Some(&session_id) {
continue;
}
if let Some(params) = msg.params {
let url = params
.get("request")
.and_then(|r| r.get("url"))
.and_then(|u| u.as_str())
.unwrap_or("");
let resource_type = params
.get("request")
.and_then(|r| r.get("resourceType"))
.and_then(|r| r.as_str())
.unwrap_or("");
let request_id = params
.get("requestId")
.and_then(|r| r.as_str())
.unwrap_or("");
let should_abort = callback(url, resource_type);
let cdp_method = if should_abort {
"Network.abortRequest"
} else {
"Network.continueInterceptedRequest"
};
let _ = cdp
.send_command_with_session(
&session_id,
cdp_method.to_string(),
Some(json!({ "requestId": request_id })),
)
.await;
}
}
});
Ok(())
}
pub(crate) async fn send_command(&self, method: String, params: Option<Value>) -> Result<Value> {
self.cdp.send_command_with_session(&self.session_id, method, params).await
}
}
fn escape_selector(s: &str) -> String {
s.replace('\'', "\\'")
}
fn base64_decode(s: &str) -> Result<Vec<u8>> {
use base64::Engine;
let engine = base64::engine::general_purpose::STANDARD;
engine
.decode(s)
.map_err(|e| BrowserError::invalid_response("screenshot()", format!("base64 decode failed: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wait_until_default() {
let w: WaitUntil = Default::default();
assert!(matches!(w, WaitUntil::Load));
}
#[test]
fn test_escape_selector_plain() {
assert_eq!(escape_selector("button#id"), "button#id");
}
#[test]
fn test_escape_selector_quotes() {
assert_eq!(escape_selector("input[name='q']"), "input[name=\\'q\\']");
}
}