use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::broadcast;
use viewpoint_cdp::CdpConnection;
use viewpoint_cdp::protocol::network::{
LoadingFailedEvent, LoadingFinishedEvent, RequestWillBeSentEvent, ResponseReceivedEvent,
};
use super::request::Request;
use super::response::Response;
use super::types::{ResourceType, UrlMatcher};
use crate::error::NetworkError;
#[derive(Debug, Clone)]
pub struct RequestEvent {
pub request: Request,
}
#[derive(Debug, Clone)]
pub struct ResponseEvent {
pub response: Response,
}
#[derive(Debug, Clone)]
pub struct RequestFinishedEvent {
pub request: Request,
}
#[derive(Debug, Clone)]
pub struct RequestFailedEvent {
pub request: Request,
pub error: String,
}
#[derive(Debug, Clone)]
pub enum NetworkEvent {
Request(RequestEvent),
Response(ResponseEvent),
RequestFinished(RequestFinishedEvent),
RequestFailed(RequestFailedEvent),
}
#[derive(Debug)]
pub struct NetworkEventListener {
connection: Arc<CdpConnection>,
session_id: String,
event_tx: broadcast::Sender<NetworkEvent>,
}
impl NetworkEventListener {
pub fn new(connection: Arc<CdpConnection>, session_id: String) -> Self {
let (event_tx, _) = broadcast::channel(256);
Self {
connection,
session_id,
event_tx,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<NetworkEvent> {
self.event_tx.subscribe()
}
pub fn start(&self) {
let mut cdp_events = self.connection.subscribe_events();
let session_id = self.session_id.clone();
let event_tx = self.event_tx.clone();
let connection = self.connection.clone();
tokio::spawn(async move {
let mut pending_requests: HashMap<String, Request> = HashMap::new();
while let Ok(event) = cdp_events.recv().await {
if event.session_id.as_deref() != Some(&session_id) {
continue;
}
match event.method.as_str() {
"Network.requestWillBeSent" => {
if let Some(params) = &event.params {
if let Ok(req_event) =
serde_json::from_value::<RequestWillBeSentEvent>(params.clone())
{
let previous_request = if req_event.redirect_response.is_some() {
pending_requests.remove(&req_event.request_id)
} else {
None
};
let request =
parse_request_will_be_sent(&req_event, previous_request);
pending_requests
.insert(req_event.request_id.clone(), request.clone());
let _ =
event_tx.send(NetworkEvent::Request(RequestEvent { request }));
}
}
}
"Network.responseReceived" => {
if let Some(params) = &event.params {
if let Ok(resp_event) =
serde_json::from_value::<ResponseReceivedEvent>(params.clone())
{
if let Some(request) =
pending_requests.get(&resp_event.request_id).cloned()
{
let response = Response::new(
resp_event.response,
request,
connection.clone(),
session_id.clone(),
resp_event.request_id.clone(),
);
let _ = event_tx
.send(NetworkEvent::Response(ResponseEvent { response }));
}
}
}
}
"Network.loadingFinished" => {
if let Some(params) = &event.params {
if let Ok(finished_event) =
serde_json::from_value::<LoadingFinishedEvent>(params.clone())
{
if let Some(request) =
pending_requests.remove(&finished_event.request_id)
{
let _ = event_tx.send(NetworkEvent::RequestFinished(
RequestFinishedEvent { request },
));
}
}
}
}
"Network.loadingFailed" => {
if let Some(params) = &event.params {
if let Ok(failed_event) =
serde_json::from_value::<LoadingFailedEvent>(params.clone())
{
if let Some(request) =
pending_requests.remove(&failed_event.request_id)
{
let _ = event_tx.send(NetworkEvent::RequestFailed(
RequestFailedEvent {
request,
error: failed_event.error_text,
},
));
}
}
}
}
_ => {}
}
}
});
}
}
fn parse_request_will_be_sent(
event: &RequestWillBeSentEvent,
previous_request: Option<Request>,
) -> Request {
let resource_type = event
.resource_type
.as_ref()
.map_or(ResourceType::Other, |t| parse_resource_type(t));
Request {
url: event.request.url.clone(),
method: event.request.method.clone(),
headers: event.request.headers.clone(),
post_data: event.request.post_data.clone(),
resource_type,
frame_id: event.frame_id.clone().unwrap_or_default(),
is_navigation: event.initiator.initiator_type == "navigation",
connection: None,
session_id: None,
request_id: Some(event.request_id.clone()),
redirected_from: previous_request.map(Box::new),
redirected_to: None,
timing: None,
failure_text: None,
}
}
fn parse_resource_type(s: &str) -> ResourceType {
match s.to_lowercase().as_str() {
"document" => ResourceType::Document,
"stylesheet" => ResourceType::Stylesheet,
"image" => ResourceType::Image,
"media" => ResourceType::Media,
"font" => ResourceType::Font,
"script" => ResourceType::Script,
"texttrack" => ResourceType::TextTrack,
"xhr" => ResourceType::Xhr,
"fetch" => ResourceType::Fetch,
"eventsource" => ResourceType::EventSource,
"websocket" => ResourceType::WebSocket,
"manifest" => ResourceType::Manifest,
"ping" => ResourceType::Ping,
"other" | _ => ResourceType::Other,
}
}
#[derive(Debug)]
pub struct WaitForRequestBuilder<'a, M> {
connection: &'a Arc<CdpConnection>,
session_id: &'a str,
pattern: M,
timeout: Duration,
}
impl<'a, M: UrlMatcher + Clone + 'static> WaitForRequestBuilder<'a, M> {
pub fn new(connection: &'a Arc<CdpConnection>, session_id: &'a str, pattern: M) -> Self {
Self {
connection,
session_id,
pattern,
timeout: Duration::from_secs(30),
}
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn wait(self) -> Result<Request, NetworkError> {
let mut events = self.connection.subscribe_events();
let session_id = self.session_id.to_string();
let pattern = self.pattern;
let timeout = self.timeout;
tokio::time::timeout(timeout, async move {
while let Ok(event) = events.recv().await {
if event.session_id.as_deref() != Some(&session_id) {
continue;
}
if event.method == "Network.requestWillBeSent" {
if let Some(params) = &event.params {
if let Ok(req_event) =
serde_json::from_value::<RequestWillBeSentEvent>(params.clone())
{
if pattern.matches(&req_event.request.url) {
return Ok(parse_request_will_be_sent(&req_event, None));
}
}
}
}
}
Err(NetworkError::Aborted)
})
.await
.map_err(|_| NetworkError::Timeout(timeout))?
}
}
#[derive(Debug)]
pub struct WaitForResponseBuilder<'a, M> {
connection: &'a Arc<CdpConnection>,
session_id: &'a str,
pattern: M,
timeout: Duration,
}
impl<'a, M: UrlMatcher + Clone + 'static> WaitForResponseBuilder<'a, M> {
pub fn new(connection: &'a Arc<CdpConnection>, session_id: &'a str, pattern: M) -> Self {
Self {
connection,
session_id,
pattern,
timeout: Duration::from_secs(30),
}
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub async fn wait(self) -> Result<Response, NetworkError> {
let mut events = self.connection.subscribe_events();
let session_id = self.session_id.to_string();
let pattern = self.pattern;
let timeout = self.timeout;
let connection = self.connection.clone();
tokio::time::timeout(timeout, async move {
let mut pending_requests: HashMap<String, Request> = HashMap::new();
while let Ok(event) = events.recv().await {
if event.session_id.as_deref() != Some(&session_id) {
continue;
}
match event.method.as_str() {
"Network.requestWillBeSent" => {
if let Some(params) = &event.params {
if let Ok(req_event) =
serde_json::from_value::<RequestWillBeSentEvent>(params.clone())
{
let request = parse_request_will_be_sent(&req_event, None);
pending_requests.insert(req_event.request_id.clone(), request);
}
}
}
"Network.responseReceived" => {
if let Some(params) = &event.params {
if let Ok(resp_event) =
serde_json::from_value::<ResponseReceivedEvent>(params.clone())
{
if pattern.matches(&resp_event.response.url) {
let request = pending_requests
.get(&resp_event.request_id)
.cloned()
.unwrap_or_else(|| Request {
url: resp_event.response.url.clone(),
method: "GET".to_string(),
headers: HashMap::new(),
post_data: None,
resource_type: ResourceType::Other,
frame_id: resp_event
.frame_id
.clone()
.unwrap_or_default(),
is_navigation: false,
connection: None,
session_id: None,
request_id: Some(resp_event.request_id.clone()),
redirected_from: None,
redirected_to: None,
timing: None,
failure_text: None,
});
return Ok(Response::new(
resp_event.response,
request,
connection.clone(),
session_id.clone(),
resp_event.request_id.clone(),
));
}
}
}
}
_ => {}
}
}
Err(NetworkError::Aborted)
})
.await
.map_err(|_| NetworkError::Timeout(timeout))?
}
}
#[cfg(test)]
mod tests;