use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use crate::cdp::transport::CdpMessage;
use crate::cdp::types::{
NetworkLoadingFailedEvent, NetworkLoadingFinishedEvent, NetworkRequestWillBeSentEvent,
NetworkResponseReceivedEvent,
};
use crate::page::CapturedRequest;
#[derive(Debug, Clone)]
pub enum NetworkEvent {
RequestStarted(CapturedRequest),
ResponseReceived {
request_id: String,
status: i32,
status_text: String,
headers: HashMap<String, String>,
mime_type: Option<String>,
},
RequestCompleted {
request_id: String,
encoded_data_length: i64,
},
RequestFailed {
request_id: String,
error_text: String,
canceled: bool,
},
}
const MAX_INFLIGHT_REQUESTS: usize = 10_000;
pub struct NetworkWatcher {
requests: Arc<Mutex<HashMap<String, CapturedRequest>>>,
event_tx: mpsc::Sender<NetworkEvent>,
event_rx: Mutex<mpsc::Receiver<NetworkEvent>>,
}
impl NetworkWatcher {
pub fn new() -> Self {
let (event_tx, event_rx) = mpsc::channel(256);
Self {
requests: Arc::new(Mutex::new(HashMap::new())),
event_tx,
event_rx: Mutex::new(event_rx),
}
}
pub async fn process_event(&self, event: &CdpMessage) -> bool {
if let CdpMessage::Event { method, params, .. } = event {
match method.as_str() {
"Network.requestWillBeSent" => {
if let Ok(e) =
serde_json::from_value::<NetworkRequestWillBeSentEvent>(params.clone())
{
self.on_request_will_be_sent(e).await;
return true;
}
}
"Network.responseReceived" => {
if let Ok(e) =
serde_json::from_value::<NetworkResponseReceivedEvent>(params.clone())
{
self.on_response_received(e).await;
return true;
}
}
"Network.loadingFinished" => {
if let Ok(e) =
serde_json::from_value::<NetworkLoadingFinishedEvent>(params.clone())
{
self.on_loading_finished(e).await;
return true;
}
}
"Network.loadingFailed" => {
if let Ok(e) =
serde_json::from_value::<NetworkLoadingFailedEvent>(params.clone())
{
self.on_loading_failed(e).await;
return true;
}
}
_ => {}
}
}
false
}
async fn on_request_will_be_sent(&self, event: NetworkRequestWillBeSentEvent) {
let request = CapturedRequest {
request_id: event.request_id.clone(),
url: event.request.url.clone(),
method: event.request.method.clone(),
headers: event.request.headers.clone(),
post_data: event.request.post_data.clone(),
resource_type: event.r#type.clone(),
status: None,
status_text: None,
response_headers: None,
mime_type: None,
timestamp: event.timestamp,
complete: false,
};
{
let mut requests = self.requests.lock().await;
if requests.len() >= MAX_INFLIGHT_REQUESTS {
if let Some(oldest_id) = requests
.iter()
.min_by(|a, b| {
a.1.timestamp
.partial_cmp(&b.1.timestamp)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(id, _)| id.clone())
{
requests.remove(&oldest_id);
tracing::debug!(
"Evicted oldest in-flight request (cap={})",
MAX_INFLIGHT_REQUESTS
);
}
}
requests.insert(event.request_id.clone(), request.clone());
}
let _ = self
.event_tx
.send(NetworkEvent::RequestStarted(request))
.await;
}
async fn on_response_received(&self, event: NetworkResponseReceivedEvent) {
{
let mut requests = self.requests.lock().await;
if let Some(req) = requests.get_mut(&event.request_id) {
req.status = Some(event.response.status);
req.status_text = Some(event.response.status_text.clone());
req.response_headers = Some(event.response.headers.clone());
req.mime_type = event.response.mime_type.clone();
}
}
let _ = self
.event_tx
.send(NetworkEvent::ResponseReceived {
request_id: event.request_id,
status: event.response.status,
status_text: event.response.status_text,
headers: event.response.headers,
mime_type: event.response.mime_type,
})
.await;
}
async fn on_loading_finished(&self, event: NetworkLoadingFinishedEvent) {
{
let mut requests = self.requests.lock().await;
requests.remove(&event.request_id);
}
let _ = self
.event_tx
.send(NetworkEvent::RequestCompleted {
request_id: event.request_id,
encoded_data_length: event.encoded_data_length,
})
.await;
}
async fn on_loading_failed(&self, event: NetworkLoadingFailedEvent) {
{
let mut requests = self.requests.lock().await;
requests.remove(&event.request_id);
}
let _ = self
.event_tx
.send(NetworkEvent::RequestFailed {
request_id: event.request_id,
error_text: event.error_text,
canceled: event.canceled.unwrap_or(false),
})
.await;
}
pub async fn recv(&self) -> Option<NetworkEvent> {
let mut rx = self.event_rx.lock().await;
rx.recv().await
}
pub async fn try_recv(&self) -> Option<NetworkEvent> {
let mut rx = self.event_rx.lock().await;
rx.try_recv().ok()
}
pub async fn get_request(&self, request_id: &str) -> Option<CapturedRequest> {
let requests = self.requests.lock().await;
requests.get(request_id).cloned()
}
pub async fn get_all_requests(&self) -> Vec<CapturedRequest> {
let requests = self.requests.lock().await;
requests.values().cloned().collect()
}
pub async fn clear(&self) {
let mut requests = self.requests.lock().await;
requests.clear();
}
}
impl Default for NetworkWatcher {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_network_watcher_creation() {
let watcher = NetworkWatcher::new();
assert!(watcher.get_all_requests().await.is_empty());
}
}