use std::collections::HashMap;
use futures::stream::{FuturesUnordered, StreamExt};
use tracing::{debug, instrument, trace, warn};
use viewpoint_cdp::protocol::dom::{BackendNodeId, DescribeNodeParams, DescribeNodeResult};
use viewpoint_js::js;
use super::locator::aria_js::aria_snapshot_with_refs_js;
use super::locator::AriaSnapshot;
use super::ref_resolution::format_ref;
use super::Page;
use crate::error::PageError;
pub const DEFAULT_MAX_CONCURRENCY: usize = 50;
#[derive(Debug, Clone)]
pub struct SnapshotOptions {
max_concurrency: usize,
include_refs: bool,
}
impl Default for SnapshotOptions {
fn default() -> Self {
Self {
max_concurrency: DEFAULT_MAX_CONCURRENCY,
include_refs: true,
}
}
}
impl SnapshotOptions {
#[must_use]
pub fn max_concurrency(mut self, max: usize) -> Self {
self.max_concurrency = max;
self
}
#[must_use]
pub fn include_refs(mut self, include: bool) -> Self {
self.include_refs = include;
self
}
pub fn get_max_concurrency(&self) -> usize {
self.max_concurrency
}
pub fn get_include_refs(&self) -> bool {
self.include_refs
}
}
impl Page {
#[instrument(level = "debug", skip(self), fields(target_id = %self.target_id))]
pub async fn aria_snapshot_with_frames(&self) -> Result<AriaSnapshot, PageError> {
self.aria_snapshot_with_frames_and_options(SnapshotOptions::default())
.await
}
#[instrument(level = "debug", skip(self, options), fields(target_id = %self.target_id))]
pub async fn aria_snapshot_with_frames_and_options(
&self,
options: SnapshotOptions,
) -> Result<AriaSnapshot, PageError> {
if self.closed {
return Err(PageError::Closed);
}
let main_frame = self.main_frame().await?;
let mut root_snapshot = main_frame
.aria_snapshot_with_options(options.clone())
.await?;
let frames = self.frames().await?;
let child_frames: Vec<_> = frames.iter().filter(|f| !f.is_main()).collect();
if child_frames.is_empty() {
return Ok(root_snapshot);
}
debug!(
frame_count = child_frames.len(),
"Capturing child frame snapshots in parallel"
);
let frame_futures: FuturesUnordered<_> = child_frames
.iter()
.map(|frame| {
let frame_id = frame.id().to_string();
let frame_url = frame.url().clone();
let frame_name = frame.name().clone();
let opts = options.clone();
async move {
match frame.aria_snapshot_with_options(opts).await {
Ok(snapshot) => Some((frame_id, frame_url, frame_name, snapshot)),
Err(e) => {
warn!(
error = %e,
frame_id = %frame_id,
frame_url = %frame_url,
"Failed to capture frame snapshot, skipping"
);
None
}
}
}
})
.collect();
let results: Vec<_> = frame_futures.collect().await;
let mut frame_snapshots: HashMap<String, AriaSnapshot> = HashMap::new();
for result in results.into_iter().flatten() {
let (frame_id, frame_url, frame_name, snapshot) = result;
if !frame_url.is_empty() && frame_url != "about:blank" {
frame_snapshots.insert(frame_url, snapshot.clone());
}
if !frame_name.is_empty() {
frame_snapshots.insert(frame_name, snapshot.clone());
}
frame_snapshots.insert(frame_id, snapshot);
}
stitch_frame_content(&mut root_snapshot, &frame_snapshots, 0);
Ok(root_snapshot)
}
#[instrument(level = "debug", skip(self), fields(target_id = %self.target_id))]
pub async fn aria_snapshot(&self) -> Result<AriaSnapshot, PageError> {
self.aria_snapshot_with_options(SnapshotOptions::default())
.await
}
#[instrument(level = "debug", skip(self, options), fields(target_id = %self.target_id))]
pub async fn aria_snapshot_with_options(
&self,
options: SnapshotOptions,
) -> Result<AriaSnapshot, PageError> {
if self.closed {
return Err(PageError::Closed);
}
self.capture_snapshot_with_refs(options).await
}
#[instrument(level = "debug", skip(self, options), fields(target_id = %self.target_id))]
async fn capture_snapshot_with_refs(
&self,
options: SnapshotOptions,
) -> Result<AriaSnapshot, PageError> {
let snapshot_fn = aria_snapshot_with_refs_js();
let js_code = js! {
(function() {
const getSnapshotWithRefs = @{snapshot_fn};
return getSnapshotWithRefs(document.body);
})()
};
let result: viewpoint_cdp::protocol::runtime::EvaluateResult = self
.connection()
.send_command(
"Runtime.evaluate",
Some(viewpoint_cdp::protocol::runtime::EvaluateParams {
expression: js_code,
object_group: Some("viewpoint-snapshot".to_string()),
include_command_line_api: None,
silent: Some(true),
context_id: None,
return_by_value: Some(false), await_promise: Some(false),
}),
Some(self.session_id()),
)
.await?;
if let Some(exception) = result.exception_details {
return Err(PageError::EvaluationFailed(exception.text));
}
let result_object_id = result.result.object_id.ok_or_else(|| {
PageError::EvaluationFailed("No object ID from snapshot evaluation".to_string())
})?;
let snapshot_value = self
.get_property_value(&result_object_id, "snapshot")
.await?;
let mut snapshot: AriaSnapshot = serde_json::from_value(snapshot_value).map_err(|e| {
PageError::EvaluationFailed(format!("Failed to parse aria snapshot: {e}"))
})?;
if options.include_refs {
let elements_result = self
.get_property_object(&result_object_id, "elements")
.await?;
if let Some(elements_object_id) = elements_result {
let element_object_ids = self
.get_all_array_elements(&elements_object_id)
.await?;
let element_count = element_object_ids.len();
debug!(
element_count = element_count,
max_concurrency = options.max_concurrency,
"Resolving element refs in parallel"
);
let ref_map = self
.resolve_node_ids_parallel(element_object_ids, options.max_concurrency)
.await;
debug!(
resolved_count = ref_map.len(),
total_count = element_count,
"Completed parallel ref resolution"
);
apply_refs_to_snapshot(&mut snapshot, &ref_map);
let _ = self.release_object(&elements_object_id).await;
}
}
let _ = self.release_object(&result_object_id).await;
Ok(snapshot)
}
async fn get_all_array_elements(
&self,
array_object_id: &str,
) -> Result<Vec<(usize, String)>, PageError> {
#[derive(Debug, serde::Deserialize)]
struct PropertyDescriptor {
name: String,
value: Option<viewpoint_cdp::protocol::runtime::RemoteObject>,
}
#[derive(Debug, serde::Deserialize)]
struct GetPropertiesResult {
result: Vec<PropertyDescriptor>,
}
let result: GetPropertiesResult = self
.connection()
.send_command(
"Runtime.getProperties",
Some(serde_json::json!({
"objectId": array_object_id,
"ownProperties": true,
"generatePreview": false
})),
Some(self.session_id()),
)
.await?;
let mut elements: Vec<(usize, String)> = Vec::new();
for prop in result.result {
if let Ok(index) = prop.name.parse::<usize>() {
if let Some(value) = prop.value {
if let Some(object_id) = value.object_id {
elements.push((index, object_id));
}
}
}
}
elements.sort_by_key(|(index, _)| *index);
trace!(element_count = elements.len(), "Batch-fetched array elements");
Ok(elements)
}
async fn resolve_node_ids_parallel(
&self,
element_object_ids: Vec<(usize, String)>,
max_concurrency: usize,
) -> HashMap<usize, BackendNodeId> {
let mut ref_map = HashMap::new();
for chunk in element_object_ids.chunks(max_concurrency) {
let futures: FuturesUnordered<_> = chunk
.iter()
.map(|(index, object_id)| {
let index = *index;
let object_id = object_id.clone();
async move {
match self.describe_node(&object_id).await {
Ok(backend_node_id) => {
trace!(
index = index,
backend_node_id = backend_node_id,
"Resolved element ref"
);
Some((index, backend_node_id))
}
Err(e) => {
debug!(index = index, error = %e, "Failed to get backendNodeId for element");
None
}
}
}
})
.collect();
let results: Vec<_> = futures.collect().await;
for result in results.into_iter().flatten() {
ref_map.insert(result.0, result.1);
}
}
ref_map
}
async fn get_property_value(
&self,
object_id: &str,
property: &str,
) -> Result<serde_json::Value, PageError> {
#[derive(Debug, serde::Deserialize)]
struct CallResult {
result: viewpoint_cdp::protocol::runtime::RemoteObject,
}
let result: CallResult = self
.connection()
.send_command(
"Runtime.callFunctionOn",
Some(serde_json::json!({
"objectId": object_id,
"functionDeclaration": format!("function() {{ return this.{}; }}", property),
"returnByValue": true
})),
Some(self.session_id()),
)
.await?;
Ok(result.result.value.unwrap_or(serde_json::Value::Null))
}
async fn get_property_object(
&self,
object_id: &str,
property: &str,
) -> Result<Option<String>, PageError> {
#[derive(Debug, serde::Deserialize)]
struct CallResult {
result: viewpoint_cdp::protocol::runtime::RemoteObject,
}
let result: CallResult = self
.connection()
.send_command(
"Runtime.callFunctionOn",
Some(serde_json::json!({
"objectId": object_id,
"functionDeclaration": format!("function() {{ return this.{}; }}", property),
"returnByValue": false
})),
Some(self.session_id()),
)
.await?;
Ok(result.result.object_id)
}
async fn describe_node(&self, object_id: &str) -> Result<BackendNodeId, PageError> {
let result: DescribeNodeResult = self
.connection()
.send_command(
"DOM.describeNode",
Some(DescribeNodeParams {
node_id: None,
backend_node_id: None,
object_id: Some(object_id.to_string()),
depth: Some(0),
pierce: None,
}),
Some(self.session_id()),
)
.await?;
Ok(result.node.backend_node_id)
}
async fn release_object(&self, object_id: &str) -> Result<(), PageError> {
let _: serde_json::Value = self
.connection()
.send_command(
"Runtime.releaseObject",
Some(serde_json::json!({
"objectId": object_id
})),
Some(self.session_id()),
)
.await?;
Ok(())
}
}
pub(crate) fn apply_refs_to_snapshot(snapshot: &mut AriaSnapshot, ref_map: &HashMap<usize, BackendNodeId>) {
if let Some(index) = snapshot.element_index {
if let Some(&backend_node_id) = ref_map.get(&index) {
snapshot.node_ref = Some(format_ref(backend_node_id));
}
snapshot.element_index = None;
}
for child in &mut snapshot.children {
apply_refs_to_snapshot(child, ref_map);
}
}
fn stitch_frame_content(
snapshot: &mut AriaSnapshot,
frame_snapshots: &HashMap<String, AriaSnapshot>,
depth: usize,
) {
const MAX_DEPTH: usize = 10;
if depth > MAX_DEPTH {
warn!(
depth = depth,
"Max frame nesting depth exceeded, stopping recursion"
);
return;
}
if snapshot.is_frame == Some(true) {
let frame_snapshot = snapshot
.frame_url
.as_ref()
.and_then(|url| frame_snapshots.get(url))
.or_else(|| {
snapshot
.frame_name
.as_ref()
.and_then(|name| frame_snapshots.get(name))
});
if let Some(frame_content) = frame_snapshot {
debug!(
frame_url = ?snapshot.frame_url,
frame_name = ?snapshot.frame_name,
depth = depth,
"Stitching frame content into snapshot"
);
snapshot.is_frame = Some(false);
snapshot.children = vec![frame_content.clone()];
} else {
debug!(
frame_url = ?snapshot.frame_url,
frame_name = ?snapshot.frame_name,
"No matching frame snapshot found for iframe boundary"
);
}
}
for child in &mut snapshot.children {
stitch_frame_content(child, frame_snapshots, depth + 1);
}
}