cdp_core/network/
network_intercept.rs

1use crate::DomainType;
2use crate::error::Result;
3use crate::page::Page;
4use async_trait::async_trait;
5use cdp_protocol::fetch::{
6    self as fetch_cdp, ContinueRequest, ContinueResponse, FailRequest, FulfillRequest,
7    RequestPattern, RequestStage,
8};
9use cdp_protocol::network;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12use std::str::FromStr;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16/// HTTP methods supported by the interceptor helpers.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18pub enum HttpMethod {
19    GET,
20    POST,
21    PUT,
22    DELETE,
23    PATCH,
24    HEAD,
25    OPTIONS,
26    CONNECT,
27    TRACE,
28}
29
30impl HttpMethod {
31    pub fn as_str(&self) -> &'static str {
32        match self {
33            HttpMethod::GET => "GET",
34            HttpMethod::POST => "POST",
35            HttpMethod::PUT => "PUT",
36            HttpMethod::DELETE => "DELETE",
37            HttpMethod::PATCH => "PATCH",
38            HttpMethod::HEAD => "HEAD",
39            HttpMethod::OPTIONS => "OPTIONS",
40            HttpMethod::CONNECT => "CONNECT",
41            HttpMethod::TRACE => "TRACE",
42        }
43    }
44}
45
46impl FromStr for HttpMethod {
47    type Err = ();
48
49    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
50        match s.to_uppercase().as_str() {
51            "GET" => Ok(HttpMethod::GET),
52            "POST" => Ok(HttpMethod::POST),
53            "PUT" => Ok(HttpMethod::PUT),
54            "DELETE" => Ok(HttpMethod::DELETE),
55            "PATCH" => Ok(HttpMethod::PATCH),
56            "HEAD" => Ok(HttpMethod::HEAD),
57            "OPTIONS" => Ok(HttpMethod::OPTIONS),
58            "CONNECT" => Ok(HttpMethod::CONNECT),
59            "TRACE" => Ok(HttpMethod::TRACE),
60            _ => Ok(HttpMethod::GET), // Default to GET for unknown methods
61        }
62    }
63}
64
65/// Metadata captured for an intercepted request.
66#[derive(Debug, Clone)]
67pub struct InterceptedRequest {
68    /// The CDP-assigned request identifier.
69    pub request_id: String,
70    /// Request URL.
71    pub url: String,
72    /// HTTP method in use.
73    pub method: HttpMethod,
74    /// Request headers.
75    pub headers: HashMap<String, String>,
76    /// POST body if present.
77    pub post_data: Option<String>,
78    /// Resource type reported by CDP.
79    pub resource_type: Option<String>,
80}
81
82/// Metadata captured for an intercepted response.
83#[derive(Debug, Clone)]
84pub struct InterceptedResponse {
85    /// The CDP-assigned request identifier.
86    pub request_id: String,
87    /// HTTP status code.
88    pub status_code: i64,
89    /// Status text.
90    pub status_text: String,
91    /// Response headers.
92    pub headers: HashMap<String, String>,
93    /// Response body captured for the request.
94    ///
95    /// - When available the payload contains either raw text or a base64 encoded blob.
96    /// - Requests with no body (304, 204, and similar) leave this as `None`.
97    /// - Binary responses are stored using base64 encoding.
98    pub base_64_encoded: bool,
99    pub body: Option<String>,
100}
101
102/// Options used when mutating a request before it continues.
103#[derive(Debug, Clone, Default)]
104pub struct RequestModification {
105    /// Updated URL.
106    pub url: Option<String>,
107    /// Updated HTTP method.
108    pub method: Option<HttpMethod>,
109    /// Updated headers.
110    pub headers: Option<HashMap<String, String>>,
111    /// Updated POST body.
112    pub post_data: Option<String>,
113}
114
115/// Payload used to mock an entire response without issuing a network request.
116#[derive(Debug, Clone)]
117pub struct ResponseMock {
118    /// HTTP status code to report.
119    pub status_code: i64,
120    /// Headers returned to the client.
121    pub headers: HashMap<String, String>,
122    /// Body returned to the client (converted to base64 automatically).
123    pub body: String,
124}
125
126impl Default for ResponseMock {
127    fn default() -> Self {
128        Self {
129            status_code: 200,
130            headers: HashMap::new(),
131            body: String::new(),
132        }
133    }
134}
135
136/// Trait describing the interception primitives exposed by `Page`.
137#[async_trait]
138pub trait NetworkInterceptor {
139    /// Enables request interception with the provided URL patterns.
140    async fn enable_request_interception(self: &Arc<Self>, patterns: Vec<String>) -> Result<()>;
141
142    /// Disables request interception.
143    async fn disable_request_interception(self: &Arc<Self>) -> Result<()>;
144
145    /// Continues a request without modification.
146    async fn continue_request(self: &Arc<Self>, request_id: &str) -> Result<()>;
147
148    /// Continues a request after applying the provided modifications.
149    async fn continue_request_with_modification(
150        self: &Arc<Self>,
151        request_id: &str,
152        modification: RequestModification,
153    ) -> Result<()>;
154
155    /// Aborts the request with the CDP error reason provided.
156    async fn fail_request(self: &Arc<Self>, request_id: &str, error_reason: &str) -> Result<()>;
157
158    /// Fulfills the request with a mocked response payload.
159    async fn fulfill_request(
160        self: &Arc<Self>,
161        request_id: &str,
162        response: ResponseMock,
163    ) -> Result<()>;
164
165    /// Continues the response without modification.
166    async fn continue_response(self: &Arc<Self>, request_id: &str) -> Result<()>;
167
168    /// Continues the response after applying modifications.
169    async fn continue_response_with_modification(
170        self: &Arc<Self>,
171        request_id: &str,
172        response: ResponseMock,
173    ) -> Result<()>;
174}
175
176#[async_trait]
177impl NetworkInterceptor for Page {
178    async fn enable_request_interception(self: &Arc<Self>, patterns: Vec<String>) -> Result<()> {
179        // Convert raw URL patterns into the CDP request pattern format.
180        let request_patterns = patterns
181            .into_iter()
182            .map(|url_pattern| RequestPattern {
183                url_pattern: Some(url_pattern),
184                resource_type: None,
185                request_stage: Some(RequestStage::Request),
186            })
187            .collect();
188
189        // Enable the Fetch domain with our pattern set.
190        self.domain_manager
191            .enable_fetch_domain_with_patterns(Some(request_patterns))
192            .await?;
193
194        Ok(())
195    }
196
197    async fn disable_request_interception(self: &Arc<Self>) -> Result<()> {
198        // Disable the Fetch domain for the current session.
199        self.domain_manager.disable_fetch_domain().await?;
200        Ok(())
201    }
202
203    async fn continue_request(self: &Arc<Self>, request_id: &str) -> Result<()> {
204        let cont = ContinueRequest {
205            request_id: request_id.to_string(),
206            url: None,
207            method: None,
208            post_data: None,
209            headers: None,
210            intercept_response: None,
211        };
212
213        self.session
214            .send_command::<_, fetch_cdp::ContinueRequestReturnObject>(cont, None)
215            .await?;
216
217        Ok(())
218    }
219
220    async fn continue_request_with_modification(
221        self: &Arc<Self>,
222        request_id: &str,
223        modification: RequestModification,
224    ) -> Result<()> {
225        let headers = modification.headers.map(|h| {
226            h.into_iter()
227                .map(|(k, v)| fetch_cdp::HeaderEntry { name: k, value: v })
228                .collect()
229        });
230
231        let post_data = modification.post_data.map(|s| s.into_bytes());
232
233        let cont = ContinueRequest {
234            request_id: request_id.to_string(),
235            url: modification.url,
236            method: modification.method.map(|m| m.as_str().to_string()),
237            post_data,
238            headers,
239            intercept_response: None,
240        };
241
242        self.session
243            .send_command::<_, fetch_cdp::ContinueRequestReturnObject>(cont, None)
244            .await?;
245
246        Ok(())
247    }
248
249    async fn fail_request(self: &Arc<Self>, request_id: &str, error_reason: &str) -> Result<()> {
250        // Convert the supplied string into the CDP `ErrorReason` enum.
251        let error = match error_reason.to_uppercase().as_str() {
252            "FAILED" => network::ErrorReason::Failed,
253            "ABORTED" => network::ErrorReason::Aborted,
254            "TIMEDOUT" => network::ErrorReason::TimedOut,
255            "ACCESSDENIED" => network::ErrorReason::AccessDenied,
256            "CONNECTIONCLOSED" => network::ErrorReason::ConnectionClosed,
257            "CONNECTIONRESET" => network::ErrorReason::ConnectionReset,
258            "CONNECTIONREFUSED" => network::ErrorReason::ConnectionRefused,
259            "CONNECTIONABORTED" => network::ErrorReason::ConnectionAborted,
260            "CONNECTIONFAILED" => network::ErrorReason::ConnectionFailed,
261            "NAMENOTRESOLVED" => network::ErrorReason::NameNotResolved,
262            "INTERNETDISCONNECTED" => network::ErrorReason::InternetDisconnected,
263            "ADDRESSUNREACHABLE" => network::ErrorReason::AddressUnreachable,
264            "BLOCKEDBYCLIENT" => network::ErrorReason::BlockedByClient,
265            "BLOCKEDBYRESPONSE" => network::ErrorReason::BlockedByResponse,
266            _ => network::ErrorReason::Failed,
267        };
268
269        let fail = FailRequest {
270            request_id: request_id.to_string(),
271            error_reason: error,
272        };
273
274        self.session
275            .send_command::<_, fetch_cdp::FailRequestReturnObject>(fail, None)
276            .await?;
277
278        Ok(())
279    }
280
281    async fn fulfill_request(
282        self: &Arc<Self>,
283        request_id: &str,
284        response: ResponseMock,
285    ) -> Result<()> {
286        // Convert the response body into bytes for CDP.
287        let body_bytes = response.body.into_bytes();
288
289        let headers = response
290            .headers
291            .into_iter()
292            .map(|(k, v)| fetch_cdp::HeaderEntry { name: k, value: v })
293            .collect();
294
295        let fulfill = FulfillRequest {
296            request_id: request_id.to_string(),
297            response_code: response.status_code as u32,
298            response_headers: Some(headers),
299            binary_response_headers: None,
300            body: Some(body_bytes),
301            response_phrase: None,
302        };
303
304        self.session
305            .send_command::<_, fetch_cdp::FulfillRequestReturnObject>(fulfill, None)
306            .await?;
307
308        Ok(())
309    }
310
311    async fn continue_response(self: &Arc<Self>, request_id: &str) -> Result<()> {
312        let cont = ContinueResponse {
313            request_id: request_id.to_string(),
314            response_code: None,
315            response_phrase: None,
316            response_headers: None,
317            binary_response_headers: None,
318        };
319
320        self.session
321            .send_command::<_, fetch_cdp::ContinueResponseReturnObject>(cont, None)
322            .await?;
323
324        Ok(())
325    }
326
327    async fn continue_response_with_modification(
328        self: &Arc<Self>,
329        request_id: &str,
330        response: ResponseMock,
331    ) -> Result<()> {
332        let headers = response
333            .headers
334            .into_iter()
335            .map(|(k, v)| fetch_cdp::HeaderEntry { name: k, value: v })
336            .collect();
337
338        // Convert the response body into bytes for CDP.
339        let body_bytes = response.body.into_bytes();
340
341        let cont = ContinueResponse {
342            request_id: request_id.to_string(),
343            response_code: Some(response.status_code as u32),
344            response_phrase: None,
345            response_headers: Some(headers),
346            binary_response_headers: Some(body_bytes),
347        };
348
349        self.session
350            .send_command::<_, fetch_cdp::ContinueResponseReturnObject>(cont, None)
351            .await?;
352
353        Ok(())
354    }
355}
356
357/// Convenience helpers that wrap the interceptor trait.
358#[async_trait]
359pub trait RequestInterceptorExt {
360    /// Intercepts every request.
361    async fn intercept_all_requests(self: &Arc<Self>) -> Result<()>;
362
363    /// Intercepts requests that match the provided pattern.
364    async fn intercept_requests_matching(self: &Arc<Self>, pattern: &str) -> Result<()>;
365
366    /// Blocks common image formats.
367    async fn block_images(self: &Arc<Self>) -> Result<()>;
368
369    /// Blocks stylesheet resources.
370    async fn block_stylesheets(self: &Arc<Self>) -> Result<()>;
371}
372
373#[async_trait]
374impl RequestInterceptorExt for Page {
375    async fn intercept_all_requests(self: &Arc<Self>) -> Result<()> {
376        self.enable_request_interception(vec!["*".to_string()])
377            .await
378    }
379
380    async fn intercept_requests_matching(self: &Arc<Self>, pattern: &str) -> Result<()> {
381        self.enable_request_interception(vec![pattern.to_string()])
382            .await
383    }
384
385    async fn block_images(self: &Arc<Self>) -> Result<()> {
386        self.enable_request_interception(vec![
387            "*.png".to_string(),
388            "*.jpg".to_string(),
389            "*.jpeg".to_string(),
390            "*.gif".to_string(),
391            "*.webp".to_string(),
392        ])
393        .await
394    }
395
396    async fn block_stylesheets(self: &Arc<Self>) -> Result<()> {
397        self.enable_request_interception(vec!["*.css".to_string()])
398            .await
399    }
400}
401
402// ========= Network monitoring =========
403
404/// Network event envelope emitted by the monitor.
405#[derive(Clone, Debug)]
406pub enum NetworkEvent {
407    /// Request about to be sent.
408    RequestWillBeSent {
409        request_id: String,
410        url: String,
411        method: String,
412        headers: serde_json::Value,
413    },
414    /// Request finished loading.
415    LoadingFinished { request_id: String },
416    /// Request failed to load.
417    LoadingFailed {
418        request_id: String,
419        error_text: String,
420    },
421    /// Response received.
422    ResponseReceived {
423        request_id: String,
424        status: i64,
425        headers: serde_json::Value,
426    },
427    /// Response served from cache.
428    RequestServedFromCache { request_id: String },
429}
430
431/// Network event callback signature.
432pub type NetworkEventCallback = Arc<dyn Fn(NetworkEvent) + Send + Sync>;
433
434/// Tracks network activity and propagates events to user callbacks.
435pub struct NetworkMonitor {
436    /// Registered callbacks.
437    pub callbacks: Arc<Mutex<Vec<NetworkEventCallback>>>,
438    /// Number of inflight requests.
439    inflight_count: Arc<std::sync::atomic::AtomicUsize>,
440    /// Known request identifiers to prevent duplicate accounting.
441    active_requests: Arc<Mutex<HashSet<String>>>,
442}
443
444impl NetworkMonitor {
445    fn new() -> Self {
446        Self {
447            inflight_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
448            callbacks: Arc::new(Mutex::new(Vec::new())),
449            active_requests: Arc::new(Mutex::new(HashSet::new())),
450        }
451    }
452
453    /// Returns the number of active requests.
454    pub fn get_inflight_count(&self) -> usize {
455        self.inflight_count
456            .load(std::sync::atomic::Ordering::SeqCst)
457    }
458
459    /// Marks a request as started, incrementing the counter only once.
460    pub async fn request_started(&self, request_id: &str) {
461        let mut active = self.active_requests.lock().await;
462        if active.insert(request_id.to_string()) {
463            self.inflight_count
464                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
465        } else {
466            tracing::trace!("request_started called for tracked request {request_id}");
467        }
468    }
469
470    /// Marks a request as finished and decrements the counter if tracked.
471    pub async fn request_finished(&self, request_id: &str) {
472        let mut active = self.active_requests.lock().await;
473        if active.remove(request_id) {
474            if self
475                .inflight_count
476                .fetch_update(
477                    std::sync::atomic::Ordering::SeqCst,
478                    std::sync::atomic::Ordering::SeqCst,
479                    |current| current.checked_sub(1),
480                )
481                .is_err()
482            {
483                // Reset to zero if a mismatch occurs to avoid underflow.
484                self.inflight_count
485                    .store(0, std::sync::atomic::Ordering::SeqCst);
486                tracing::warn!(
487                    "request_finished detected underflow for {request_id}, resetting inflight count"
488                );
489            }
490        } else {
491            tracing::trace!("request_finished called for unknown request {request_id}");
492        }
493    }
494
495    /// Resets the inflight counter.
496    pub async fn reset_inflight(&self) {
497        self.inflight_count
498            .store(0, std::sync::atomic::Ordering::SeqCst);
499        self.active_requests.lock().await.clear();
500    }
501
502    /// Registers a network event callback.
503    pub async fn add_callback(&self, callback: NetworkEventCallback) {
504        self.callbacks.lock().await.push(callback);
505    }
506
507    /// Emits an event to all registered callbacks.
508    pub async fn trigger_event(&self, event: NetworkEvent) {
509        let callbacks = self.callbacks.lock().await;
510        for callback in callbacks.iter() {
511            callback(event.clone());
512        }
513    }
514}
515
516impl Default for NetworkMonitor {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522// ========= Response monitoring =========
523
524/// Callback used to decide whether a URL should be inspected.
525pub type ResponseFilterCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
526
527/// Callback invoked with the captured response metadata.
528pub type ResponseHandlerCallback = Arc<dyn Fn(&InterceptedResponse) + Send + Sync>;
529
530/// Manages response filters and handlers.
531pub struct ResponseMonitorManager {
532    /// Registered monitor pairs.
533    monitors: Mutex<Vec<(ResponseFilterCallback, ResponseHandlerCallback)>>,
534    /// Tracks whether monitoring is enabled.
535    enabled: std::sync::atomic::AtomicBool,
536    /// Pending responses waiting for body (requestId -> Response)
537    pending_responses: Mutex<HashMap<String, InterceptedResponse>>,
538}
539
540impl ResponseMonitorManager {
541    fn new() -> Self {
542        Self {
543            monitors: Mutex::new(Vec::new()),
544            enabled: std::sync::atomic::AtomicBool::new(false),
545            pending_responses: Mutex::new(HashMap::new()),
546        }
547    }
548
549    /// Returns whether response monitoring is active.
550    pub fn is_enabled(&self) -> bool {
551        self.enabled.load(std::sync::atomic::Ordering::SeqCst)
552    }
553
554    /// Adds a monitor pair and enables monitoring.
555    pub async fn add_monitor(
556        &self,
557        filter: ResponseFilterCallback,
558        handler: ResponseHandlerCallback,
559    ) {
560        let mut monitors = self.monitors.lock().await;
561        monitors.push((filter, handler));
562        // Automatically enable monitoring when a handler exists.
563        self.enabled
564            .store(true, std::sync::atomic::Ordering::SeqCst);
565    }
566
567    /// Clears all monitors and disables monitoring.
568    pub async fn clear_monitors(&self) {
569        let mut monitors = self.monitors.lock().await;
570        monitors.clear();
571        // Disable monitoring once the list is empty.
572        self.enabled
573            .store(false, std::sync::atomic::Ordering::SeqCst);
574    }
575
576    /// Dispatches the response to all registered handlers.
577    pub async fn handle_response(&self, response: &InterceptedResponse) {
578        // Skip work unless monitoring is active.
579        if !self.is_enabled() {
580            return;
581        }
582        let monitors = self.monitors.lock().await;
583        for (_, handler) in monitors.iter() {
584            handler(response);
585        }
586    }
587
588    pub async fn filter_url(&self, url: &str) -> bool {
589        if !self.is_enabled() {
590            return false;
591        }
592
593        let monitors = self.monitors.lock().await;
594        monitors.iter().any(|(filter, _)| filter(url))
595    }
596
597    pub async fn store_pending_response(&self, response: InterceptedResponse) {
598        self.pending_responses
599            .lock()
600            .await
601            .insert(response.request_id.clone(), response);
602    }
603
604    pub async fn retrieve_pending_response(&self, request_id: &str) -> Option<InterceptedResponse> {
605        self.pending_responses.lock().await.remove(request_id)
606    }
607}
608
609impl Default for ResponseMonitorManager {
610    fn default() -> Self {
611        Self::new()
612    }
613}
614
615/// Response monitoring convenience methods
616impl Page {
617    /// Registers a non-blocking handler for all responses that pass the filter.
618    ///
619    /// # Parameters
620    /// * `filter` - Returns `true` when a response should be forwarded to the handler.
621    /// * `handler` - Receives the captured response metadata.
622    ///
623    /// # Examples
624    /// ```no_run
625    /// # use cdp_core::Page;
626    /// # use std::sync::Arc;
627    /// # async fn example(page: Arc<Page>) -> anyhow::Result<()> {
628    /// page.monitor_responses(
629    ///     |url| url.contains("/api/"),
630    ///     |response| {
631    ///         println!("API Response: {} - {}", response.status_code, response.status_text);
632    ///         if let Some(body) = &response.body {
633    ///             println!("Body: {}", body);
634    ///         }
635    ///     },
636    /// ).await?;
637    /// # Ok(())
638    /// # }
639    /// ```
640    pub async fn monitor_responses<F, H>(self: &Arc<Self>, filter: F, handler: H) -> Result<()>
641    where
642        F: Fn(&str) -> bool + Send + Sync + 'static,
643        H: Fn(&InterceptedResponse) + Send + Sync + 'static,
644    {
645        // Ensure the Network domain is active before monitoring.
646        if !self.domain_manager.is_enabled(DomainType::Network).await {
647            self.domain_manager.enable_network_domain().await?;
648        }
649
650        // Register the monitor pair.
651        self.response_monitor_manager
652            .add_monitor(Arc::new(filter), Arc::new(handler))
653            .await;
654
655        Ok(())
656    }
657
658    /// Registers a handler for responses whose URLs contain the given pattern.
659    ///
660    /// # Parameters
661    /// * `url_pattern` - Substring that must be present in the URL.
662    /// * `handler` - Invoked with the captured response metadata.
663    ///
664    /// # Examples
665    /// ```no_run
666    /// # use cdp_core::Page;
667    /// # use std::sync::Arc;
668    /// # async fn example(page: Arc<Page>) -> anyhow::Result<()> {
669    /// page.monitor_responses_matching(
670    ///     "data.json",
671    ///     |response| {
672    ///         println!("Data Response: {}", response.status_code);
673    ///     },
674    /// ).await?;
675    /// # Ok(())
676    /// # }
677    /// ```
678    pub async fn monitor_responses_matching<H>(
679        self: &Arc<Self>,
680        url_pattern: &str,
681        handler: H,
682    ) -> Result<()>
683    where
684        H: Fn(&InterceptedResponse) + Send + Sync + 'static,
685    {
686        let pattern = url_pattern.to_string();
687        self.monitor_responses(move |url| url.contains(&pattern), handler)
688            .await
689    }
690
691    /// Removes all registered response monitors.
692    ///
693    /// # Examples
694    /// ```no_run
695    /// # use cdp_core::Page;
696    /// # use std::sync::Arc;
697    /// # async fn example(page: Arc<Page>) -> anyhow::Result<()> {
698    /// page.stop_response_monitoring().await?;
699    /// # Ok(())
700    /// # }
701    /// ```
702    pub async fn stop_response_monitoring(self: &Arc<Self>) -> Result<()> {
703        // Remove all handlers; keep the Network domain enabled because other features rely on it.
704        self.response_monitor_manager.clear_monitors().await;
705        Ok(())
706    }
707}