1use crate::DomainType;
2use crate::error::Result;
3use crate::page::Page;
4use async_trait::async_trait;
5use cdp_protocol::fetch::{
6 self as fetch_cdp, ContinueRequest, ContinueResponse, FailRequest, FulfillRequest,
7 RequestPattern, RequestStage,
8};
9use cdp_protocol::network;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::str::FromStr;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18pub enum HttpMethod {
19 GET,
20 POST,
21 PUT,
22 DELETE,
23 PATCH,
24 HEAD,
25 OPTIONS,
26 CONNECT,
27 TRACE,
28}
29
30impl HttpMethod {
31 pub fn as_str(&self) -> &'static str {
32 match self {
33 HttpMethod::GET => "GET",
34 HttpMethod::POST => "POST",
35 HttpMethod::PUT => "PUT",
36 HttpMethod::DELETE => "DELETE",
37 HttpMethod::PATCH => "PATCH",
38 HttpMethod::HEAD => "HEAD",
39 HttpMethod::OPTIONS => "OPTIONS",
40 HttpMethod::CONNECT => "CONNECT",
41 HttpMethod::TRACE => "TRACE",
42 }
43 }
44}
45
46impl FromStr for HttpMethod {
47 type Err = ();
48
49 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
50 match s.to_uppercase().as_str() {
51 "GET" => Ok(HttpMethod::GET),
52 "POST" => Ok(HttpMethod::POST),
53 "PUT" => Ok(HttpMethod::PUT),
54 "DELETE" => Ok(HttpMethod::DELETE),
55 "PATCH" => Ok(HttpMethod::PATCH),
56 "HEAD" => Ok(HttpMethod::HEAD),
57 "OPTIONS" => Ok(HttpMethod::OPTIONS),
58 "CONNECT" => Ok(HttpMethod::CONNECT),
59 "TRACE" => Ok(HttpMethod::TRACE),
60 _ => Ok(HttpMethod::GET), }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct InterceptedRequest {
68 pub request_id: String,
70 pub url: String,
72 pub method: HttpMethod,
74 pub headers: HashMap<String, String>,
76 pub post_data: Option<String>,
78 pub resource_type: Option<String>,
80}
81
82#[derive(Debug, Clone)]
84pub struct InterceptedResponse {
85 pub request_id: String,
87 pub status_code: i64,
89 pub status_text: String,
91 pub headers: HashMap<String, String>,
93 pub base_64_encoded: bool,
99 pub body: Option<String>,
100}
101
102#[derive(Debug, Clone, Default)]
104pub struct RequestModification {
105 pub url: Option<String>,
107 pub method: Option<HttpMethod>,
109 pub headers: Option<HashMap<String, String>>,
111 pub post_data: Option<String>,
113}
114
115#[derive(Debug, Clone)]
117pub struct ResponseMock {
118 pub status_code: i64,
120 pub headers: HashMap<String, String>,
122 pub body: String,
124}
125
126impl Default for ResponseMock {
127 fn default() -> Self {
128 Self {
129 status_code: 200,
130 headers: HashMap::new(),
131 body: String::new(),
132 }
133 }
134}
135
136#[async_trait]
138pub trait NetworkInterceptor {
139 async fn enable_request_interception(self: &Arc<Self>, patterns: Vec<String>) -> Result<()>;
141
142 async fn disable_request_interception(self: &Arc<Self>) -> Result<()>;
144
145 async fn continue_request(self: &Arc<Self>, request_id: &str) -> Result<()>;
147
148 async fn continue_request_with_modification(
150 self: &Arc<Self>,
151 request_id: &str,
152 modification: RequestModification,
153 ) -> Result<()>;
154
155 async fn fail_request(self: &Arc<Self>, request_id: &str, error_reason: &str) -> Result<()>;
157
158 async fn fulfill_request(
160 self: &Arc<Self>,
161 request_id: &str,
162 response: ResponseMock,
163 ) -> Result<()>;
164
165 async fn continue_response(self: &Arc<Self>, request_id: &str) -> Result<()>;
167
168 async fn continue_response_with_modification(
170 self: &Arc<Self>,
171 request_id: &str,
172 response: ResponseMock,
173 ) -> Result<()>;
174}
175
176#[async_trait]
177impl NetworkInterceptor for Page {
178 async fn enable_request_interception(self: &Arc<Self>, patterns: Vec<String>) -> Result<()> {
179 let request_patterns = patterns
181 .into_iter()
182 .map(|url_pattern| RequestPattern {
183 url_pattern: Some(url_pattern),
184 resource_type: None,
185 request_stage: Some(RequestStage::Request),
186 })
187 .collect();
188
189 self.domain_manager
191 .enable_fetch_domain_with_patterns(Some(request_patterns))
192 .await?;
193
194 Ok(())
195 }
196
197 async fn disable_request_interception(self: &Arc<Self>) -> Result<()> {
198 self.domain_manager.disable_fetch_domain().await?;
200 Ok(())
201 }
202
203 async fn continue_request(self: &Arc<Self>, request_id: &str) -> Result<()> {
204 let cont = ContinueRequest {
205 request_id: request_id.to_string(),
206 url: None,
207 method: None,
208 post_data: None,
209 headers: None,
210 intercept_response: None,
211 };
212
213 self.session
214 .send_command::<_, fetch_cdp::ContinueRequestReturnObject>(cont, None)
215 .await?;
216
217 Ok(())
218 }
219
220 async fn continue_request_with_modification(
221 self: &Arc<Self>,
222 request_id: &str,
223 modification: RequestModification,
224 ) -> Result<()> {
225 let headers = modification.headers.map(|h| {
226 h.into_iter()
227 .map(|(k, v)| fetch_cdp::HeaderEntry { name: k, value: v })
228 .collect()
229 });
230
231 let post_data = modification.post_data.map(|s| s.into_bytes());
232
233 let cont = ContinueRequest {
234 request_id: request_id.to_string(),
235 url: modification.url,
236 method: modification.method.map(|m| m.as_str().to_string()),
237 post_data,
238 headers,
239 intercept_response: None,
240 };
241
242 self.session
243 .send_command::<_, fetch_cdp::ContinueRequestReturnObject>(cont, None)
244 .await?;
245
246 Ok(())
247 }
248
249 async fn fail_request(self: &Arc<Self>, request_id: &str, error_reason: &str) -> Result<()> {
250 let error = match error_reason.to_uppercase().as_str() {
252 "FAILED" => network::ErrorReason::Failed,
253 "ABORTED" => network::ErrorReason::Aborted,
254 "TIMEDOUT" => network::ErrorReason::TimedOut,
255 "ACCESSDENIED" => network::ErrorReason::AccessDenied,
256 "CONNECTIONCLOSED" => network::ErrorReason::ConnectionClosed,
257 "CONNECTIONRESET" => network::ErrorReason::ConnectionReset,
258 "CONNECTIONREFUSED" => network::ErrorReason::ConnectionRefused,
259 "CONNECTIONABORTED" => network::ErrorReason::ConnectionAborted,
260 "CONNECTIONFAILED" => network::ErrorReason::ConnectionFailed,
261 "NAMENOTRESOLVED" => network::ErrorReason::NameNotResolved,
262 "INTERNETDISCONNECTED" => network::ErrorReason::InternetDisconnected,
263 "ADDRESSUNREACHABLE" => network::ErrorReason::AddressUnreachable,
264 "BLOCKEDBYCLIENT" => network::ErrorReason::BlockedByClient,
265 "BLOCKEDBYRESPONSE" => network::ErrorReason::BlockedByResponse,
266 _ => network::ErrorReason::Failed,
267 };
268
269 let fail = FailRequest {
270 request_id: request_id.to_string(),
271 error_reason: error,
272 };
273
274 self.session
275 .send_command::<_, fetch_cdp::FailRequestReturnObject>(fail, None)
276 .await?;
277
278 Ok(())
279 }
280
281 async fn fulfill_request(
282 self: &Arc<Self>,
283 request_id: &str,
284 response: ResponseMock,
285 ) -> Result<()> {
286 let body_bytes = response.body.into_bytes();
288
289 let headers = response
290 .headers
291 .into_iter()
292 .map(|(k, v)| fetch_cdp::HeaderEntry { name: k, value: v })
293 .collect();
294
295 let fulfill = FulfillRequest {
296 request_id: request_id.to_string(),
297 response_code: response.status_code as u32,
298 response_headers: Some(headers),
299 binary_response_headers: None,
300 body: Some(body_bytes),
301 response_phrase: None,
302 };
303
304 self.session
305 .send_command::<_, fetch_cdp::FulfillRequestReturnObject>(fulfill, None)
306 .await?;
307
308 Ok(())
309 }
310
311 async fn continue_response(self: &Arc<Self>, request_id: &str) -> Result<()> {
312 let cont = ContinueResponse {
313 request_id: request_id.to_string(),
314 response_code: None,
315 response_phrase: None,
316 response_headers: None,
317 binary_response_headers: None,
318 };
319
320 self.session
321 .send_command::<_, fetch_cdp::ContinueResponseReturnObject>(cont, None)
322 .await?;
323
324 Ok(())
325 }
326
327 async fn continue_response_with_modification(
328 self: &Arc<Self>,
329 request_id: &str,
330 response: ResponseMock,
331 ) -> Result<()> {
332 let headers = response
333 .headers
334 .into_iter()
335 .map(|(k, v)| fetch_cdp::HeaderEntry { name: k, value: v })
336 .collect();
337
338 let body_bytes = response.body.into_bytes();
340
341 let cont = ContinueResponse {
342 request_id: request_id.to_string(),
343 response_code: Some(response.status_code as u32),
344 response_phrase: None,
345 response_headers: Some(headers),
346 binary_response_headers: Some(body_bytes),
347 };
348
349 self.session
350 .send_command::<_, fetch_cdp::ContinueResponseReturnObject>(cont, None)
351 .await?;
352
353 Ok(())
354 }
355}
356
357#[async_trait]
359pub trait RequestInterceptorExt {
360 async fn intercept_all_requests(self: &Arc<Self>) -> Result<()>;
362
363 async fn intercept_requests_matching(self: &Arc<Self>, pattern: &str) -> Result<()>;
365
366 async fn block_images(self: &Arc<Self>) -> Result<()>;
368
369 async fn block_stylesheets(self: &Arc<Self>) -> Result<()>;
371}
372
373#[async_trait]
374impl RequestInterceptorExt for Page {
375 async fn intercept_all_requests(self: &Arc<Self>) -> Result<()> {
376 self.enable_request_interception(vec!["*".to_string()])
377 .await
378 }
379
380 async fn intercept_requests_matching(self: &Arc<Self>, pattern: &str) -> Result<()> {
381 self.enable_request_interception(vec![pattern.to_string()])
382 .await
383 }
384
385 async fn block_images(self: &Arc<Self>) -> Result<()> {
386 self.enable_request_interception(vec![
387 "*.png".to_string(),
388 "*.jpg".to_string(),
389 "*.jpeg".to_string(),
390 "*.gif".to_string(),
391 "*.webp".to_string(),
392 ])
393 .await
394 }
395
396 async fn block_stylesheets(self: &Arc<Self>) -> Result<()> {
397 self.enable_request_interception(vec!["*.css".to_string()])
398 .await
399 }
400}
401
402#[derive(Clone, Debug)]
406pub enum NetworkEvent {
407 RequestWillBeSent {
409 request_id: String,
410 url: String,
411 method: String,
412 headers: serde_json::Value,
413 },
414 LoadingFinished { request_id: String },
416 LoadingFailed {
418 request_id: String,
419 error_text: String,
420 },
421 ResponseReceived {
423 request_id: String,
424 status: i64,
425 headers: serde_json::Value,
426 },
427 RequestServedFromCache { request_id: String },
429}
430
431pub type NetworkEventCallback = Arc<dyn Fn(NetworkEvent) + Send + Sync>;
433
434pub struct NetworkMonitor {
436 pub callbacks: Arc<Mutex<Vec<NetworkEventCallback>>>,
438 inflight_count: Arc<std::sync::atomic::AtomicUsize>,
440 active_requests: Arc<Mutex<HashSet<String>>>,
442}
443
444impl NetworkMonitor {
445 fn new() -> Self {
446 Self {
447 inflight_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
448 callbacks: Arc::new(Mutex::new(Vec::new())),
449 active_requests: Arc::new(Mutex::new(HashSet::new())),
450 }
451 }
452
453 pub fn get_inflight_count(&self) -> usize {
455 self.inflight_count
456 .load(std::sync::atomic::Ordering::SeqCst)
457 }
458
459 pub async fn request_started(&self, request_id: &str) {
461 let mut active = self.active_requests.lock().await;
462 if active.insert(request_id.to_string()) {
463 self.inflight_count
464 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
465 } else {
466 tracing::trace!("request_started called for tracked request {request_id}");
467 }
468 }
469
470 pub async fn request_finished(&self, request_id: &str) {
472 let mut active = self.active_requests.lock().await;
473 if active.remove(request_id) {
474 if self
475 .inflight_count
476 .fetch_update(
477 std::sync::atomic::Ordering::SeqCst,
478 std::sync::atomic::Ordering::SeqCst,
479 |current| current.checked_sub(1),
480 )
481 .is_err()
482 {
483 self.inflight_count
485 .store(0, std::sync::atomic::Ordering::SeqCst);
486 tracing::warn!(
487 "request_finished detected underflow for {request_id}, resetting inflight count"
488 );
489 }
490 } else {
491 tracing::trace!("request_finished called for unknown request {request_id}");
492 }
493 }
494
495 pub async fn reset_inflight(&self) {
497 self.inflight_count
498 .store(0, std::sync::atomic::Ordering::SeqCst);
499 self.active_requests.lock().await.clear();
500 }
501
502 pub async fn add_callback(&self, callback: NetworkEventCallback) {
504 self.callbacks.lock().await.push(callback);
505 }
506
507 pub async fn trigger_event(&self, event: NetworkEvent) {
509 let callbacks = self.callbacks.lock().await;
510 for callback in callbacks.iter() {
511 callback(event.clone());
512 }
513 }
514}
515
516impl Default for NetworkMonitor {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522pub type ResponseFilterCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
526
527pub type ResponseHandlerCallback = Arc<dyn Fn(&InterceptedResponse) + Send + Sync>;
529
530pub struct ResponseMonitorManager {
532 monitors: Mutex<Vec<(ResponseFilterCallback, ResponseHandlerCallback)>>,
534 enabled: std::sync::atomic::AtomicBool,
536 pending_responses: Mutex<HashMap<String, InterceptedResponse>>,
538}
539
540impl ResponseMonitorManager {
541 fn new() -> Self {
542 Self {
543 monitors: Mutex::new(Vec::new()),
544 enabled: std::sync::atomic::AtomicBool::new(false),
545 pending_responses: Mutex::new(HashMap::new()),
546 }
547 }
548
549 pub fn is_enabled(&self) -> bool {
551 self.enabled.load(std::sync::atomic::Ordering::SeqCst)
552 }
553
554 pub async fn add_monitor(
556 &self,
557 filter: ResponseFilterCallback,
558 handler: ResponseHandlerCallback,
559 ) {
560 let mut monitors = self.monitors.lock().await;
561 monitors.push((filter, handler));
562 self.enabled
564 .store(true, std::sync::atomic::Ordering::SeqCst);
565 }
566
567 pub async fn clear_monitors(&self) {
569 let mut monitors = self.monitors.lock().await;
570 monitors.clear();
571 self.enabled
573 .store(false, std::sync::atomic::Ordering::SeqCst);
574 }
575
576 pub async fn handle_response(&self, response: &InterceptedResponse) {
578 if !self.is_enabled() {
580 return;
581 }
582 let monitors = self.monitors.lock().await;
583 for (_, handler) in monitors.iter() {
584 handler(response);
585 }
586 }
587
588 pub async fn filter_url(&self, url: &str) -> bool {
589 if !self.is_enabled() {
590 return false;
591 }
592
593 let monitors = self.monitors.lock().await;
594 monitors.iter().any(|(filter, _)| filter(url))
595 }
596
597 pub async fn store_pending_response(&self, response: InterceptedResponse) {
598 self.pending_responses
599 .lock()
600 .await
601 .insert(response.request_id.clone(), response);
602 }
603
604 pub async fn retrieve_pending_response(&self, request_id: &str) -> Option<InterceptedResponse> {
605 self.pending_responses.lock().await.remove(request_id)
606 }
607}
608
609impl Default for ResponseMonitorManager {
610 fn default() -> Self {
611 Self::new()
612 }
613}
614
615impl Page {
617 pub async fn monitor_responses<F, H>(self: &Arc<Self>, filter: F, handler: H) -> Result<()>
641 where
642 F: Fn(&str) -> bool + Send + Sync + 'static,
643 H: Fn(&InterceptedResponse) + Send + Sync + 'static,
644 {
645 if !self.domain_manager.is_enabled(DomainType::Network).await {
647 self.domain_manager.enable_network_domain().await?;
648 }
649
650 self.response_monitor_manager
652 .add_monitor(Arc::new(filter), Arc::new(handler))
653 .await;
654
655 Ok(())
656 }
657
658 pub async fn monitor_responses_matching<H>(
679 self: &Arc<Self>,
680 url_pattern: &str,
681 handler: H,
682 ) -> Result<()>
683 where
684 H: Fn(&InterceptedResponse) + Send + Sync + 'static,
685 {
686 let pattern = url_pattern.to_string();
687 self.monitor_responses(move |url| url.contains(&pattern), handler)
688 .await
689 }
690
691 pub async fn stop_response_monitoring(self: &Arc<Self>) -> Result<()> {
703 self.response_monitor_manager.clear_monitors().await;
705 Ok(())
706 }
707}