use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::{Mutex, broadcast};
use tokio::time::{Instant, sleep, timeout};
use tracing::{debug, instrument, trace, warn};
use viewpoint_cdp::CdpEvent;
use viewpoint_cdp::protocol::network::{
LoadingFailedEvent, LoadingFinishedEvent, RequestWillBeSentEvent, ResponseReceivedEvent,
};
use super::DocumentLoadState;
use crate::error::WaitError;
#[derive(Debug, Clone, Default)]
pub struct NavigationResponseData {
pub status: Option<u16>,
pub headers: HashMap<String, String>,
pub url: Option<String>,
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const NETWORK_IDLE_THRESHOLD: Duration = Duration::from_millis(500);
#[derive(Debug)]
pub struct LoadStateWaiter {
current_state: Arc<Mutex<DocumentLoadState>>,
event_rx: broadcast::Receiver<CdpEvent>,
session_id: String,
frame_id: String,
pending_requests: Arc<AtomicUsize>,
pending_request_ids: Arc<Mutex<HashSet<String>>>,
response_data: Arc<Mutex<NavigationResponseData>>,
main_request_id: Arc<Mutex<Option<String>>>,
}
impl LoadStateWaiter {
pub fn new(
event_rx: broadcast::Receiver<CdpEvent>,
session_id: String,
frame_id: String,
) -> Self {
debug!(session_id = %session_id, frame_id = %frame_id, "Created LoadStateWaiter");
Self {
current_state: Arc::new(Mutex::new(DocumentLoadState::Commit)),
event_rx,
session_id,
frame_id,
pending_requests: Arc::new(AtomicUsize::new(0)),
pending_request_ids: Arc::new(Mutex::new(HashSet::new())),
response_data: Arc::new(Mutex::new(NavigationResponseData::default())),
main_request_id: Arc::new(Mutex::new(None)),
}
}
pub async fn wait_for_load_state(
&mut self,
target_state: DocumentLoadState,
) -> Result<(), WaitError> {
self.wait_for_load_state_with_timeout(target_state, DEFAULT_TIMEOUT)
.await
}
#[instrument(level = "debug", skip(self), fields(target_state = ?target_state, timeout_ms = timeout_duration.as_millis()))]
pub async fn wait_for_load_state_with_timeout(
&mut self,
target_state: DocumentLoadState,
timeout_duration: Duration,
) -> Result<(), WaitError> {
{
let current = *self.current_state.lock().await;
if target_state.is_reached(current) {
debug!(current = ?current, "Target state already reached");
return Ok(());
}
trace!(current = ?current, "Starting wait for target state");
}
let result = timeout(timeout_duration, self.wait_for_state_impl(target_state)).await;
match result {
Ok(Ok(())) => {
debug!("Wait completed successfully");
Ok(())
}
Ok(Err(e)) => {
warn!(error = ?e, "Wait failed with error");
Err(e)
}
Err(_) => {
warn!(timeout_ms = timeout_duration.as_millis(), "Wait timed out");
Err(WaitError::Timeout(timeout_duration))
}
}
}
async fn wait_for_state_impl(
&mut self,
target_state: DocumentLoadState,
) -> Result<(), WaitError> {
let mut last_network_activity = Instant::now();
loop {
{
let current = *self.current_state.lock().await;
if target_state.is_reached(current) {
if target_state == DocumentLoadState::NetworkIdle {
let pending = self.pending_requests.load(Ordering::Relaxed);
if pending == 0 && last_network_activity.elapsed() >= NETWORK_IDLE_THRESHOLD
{
return Ok(());
}
} else {
return Ok(());
}
}
}
let event = match self.event_rx.recv().await {
Ok(event) => event,
Err(broadcast::error::RecvError::Closed) => {
return Err(WaitError::PageClosed);
}
Err(broadcast::error::RecvError::Lagged(_)) => {
continue;
}
};
if event.session_id.as_deref() != Some(&self.session_id) {
continue;
}
match event.method.as_str() {
"Page.domContentEventFired" => {
let mut current = self.current_state.lock().await;
if *current < DocumentLoadState::DomContentLoaded {
debug!(previous = ?*current, "State transition: DomContentLoaded");
*current = DocumentLoadState::DomContentLoaded;
}
}
"Page.loadEventFired" => {
let mut current = self.current_state.lock().await;
if *current < DocumentLoadState::Load {
debug!(previous = ?*current, "State transition: Load");
*current = DocumentLoadState::Load;
}
}
"Network.requestWillBeSent" => {
if let Some(params) = event.params {
if let Ok(req) = serde_json::from_value::<RequestWillBeSentEvent>(params) {
if req.frame_id.as_deref() == Some(&self.frame_id) {
let mut ids = self.pending_request_ids.lock().await;
if ids.insert(req.request_id.clone()) {
let count =
self.pending_requests.fetch_add(1, Ordering::Relaxed) + 1;
trace!(request_id = %req.request_id, pending_count = count, "Network request started");
last_network_activity = Instant::now();
if req.resource_type.as_deref() == Some("Document") {
let mut main_req = self.main_request_id.lock().await;
if main_req.is_none() {
*main_req = Some(req.request_id.clone());
trace!(request_id = %req.request_id, "Tracking main document request");
}
}
}
}
}
}
}
"Network.responseReceived" => {
if let Some(params) = event.params {
if let Ok(resp) = serde_json::from_value::<ResponseReceivedEvent>(params) {
let main_req = self.main_request_id.lock().await;
if main_req.as_deref() == Some(&resp.request_id) {
let mut response_data = self.response_data.lock().await;
response_data.status = Some(resp.response.status as u16);
response_data.url = Some(resp.response.url.clone());
response_data.headers = resp.response.headers.clone();
trace!(
status = response_data.status,
url = ?response_data.url,
header_count = response_data.headers.len(),
"Captured main document response"
);
}
}
}
}
"Network.loadingFinished" => {
if let Some(params) = event.params {
if let Ok(finished) = serde_json::from_value::<LoadingFinishedEvent>(params)
{
let mut ids = self.pending_request_ids.lock().await;
if ids.remove(&finished.request_id) {
let count =
self.pending_requests.fetch_sub(1, Ordering::Relaxed) - 1;
trace!(request_id = %finished.request_id, pending_count = count, "Network request finished");
last_network_activity = Instant::now();
}
}
}
}
"Network.loadingFailed" => {
if let Some(params) = event.params {
if let Ok(failed) = serde_json::from_value::<LoadingFailedEvent>(params) {
let mut ids = self.pending_request_ids.lock().await;
if ids.remove(&failed.request_id) {
let count =
self.pending_requests.fetch_sub(1, Ordering::Relaxed) - 1;
trace!(request_id = %failed.request_id, pending_count = count, "Network request failed");
last_network_activity = Instant::now();
}
}
}
}
_ => {}
}
if target_state == DocumentLoadState::NetworkIdle {
let pending = self.pending_requests.load(Ordering::Relaxed);
let current = *self.current_state.lock().await;
if pending == 0 && current >= DocumentLoadState::Load {
sleep(NETWORK_IDLE_THRESHOLD).await;
let pending_after = self.pending_requests.load(Ordering::Relaxed);
if pending_after == 0 {
return Ok(());
}
}
}
}
}
pub async fn set_commit_received(&self) {
let mut current = self.current_state.lock().await;
if *current < DocumentLoadState::Commit {
debug!("State transition: Commit");
*current = DocumentLoadState::Commit;
}
}
pub async fn current_state(&self) -> DocumentLoadState {
*self.current_state.lock().await
}
pub async fn response_data(&self) -> NavigationResponseData {
self.response_data.lock().await.clone()
}
}