Skip to main content

firefox_webdriver/browser/tab/
network.rs

1//! Network interception and blocking methods.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use base64::Engine;
8use base64::engine::general_purpose::STANDARD as Base64Standard;
9use serde_json::Value;
10use tracing::debug;
11
12use crate::browser::network::{
13    BodyAction, HeadersAction, InterceptedRequest, InterceptedRequestBody,
14    InterceptedRequestHeaders, InterceptedResponse, InterceptedResponseBody, RequestAction,
15    RequestBody, ResponseBodyData,
16};
17use crate::error::{Error, Result};
18use crate::identifiers::InterceptId;
19use crate::protocol::{Command, Event, EventReply, NetworkCommand, Response};
20
21use super::Tab;
22
23// ============================================================================
24// Constants
25// ============================================================================
26
27const EVENT_BEFORE_REQUEST: &str = "network.beforeRequestSent";
28const EVENT_REQUEST_HEADERS: &str = "network.requestHeaders";
29const EVENT_REQUEST_BODY: &str = "network.requestBody";
30const EVENT_RESPONSE_HEADERS: &str = "network.responseHeaders";
31const EVENT_RESPONSE_BODY: &str = "network.responseBody";
32
33/// Global counter for generating unique handler keys per intercept registration.
34static HANDLER_KEY_COUNTER: AtomicU64 = AtomicU64::new(0);
35
36/// Generates a unique handler key for an intercept registration.
37fn next_handler_key(prefix: &str) -> String {
38    let id = HANDLER_KEY_COUNTER.fetch_add(1, Ordering::Relaxed);
39    format!("{prefix}_{id}")
40}
41
42// ============================================================================
43// InterceptBuilder
44// ============================================================================
45
46/// Builder for configuring network intercept filters.
47///
48/// # Example
49///
50/// ```ignore
51/// let id = tab.intercept()
52///     .url_patterns(vec!["*api*".into()])
53///     .resource_types(vec!["xhr".into(), "fetch".into()])
54///     .requests(|req| {
55///         RequestAction::allow()
56///     }).await?;
57/// ```
58pub struct InterceptBuilder<'a> {
59    tab: &'a Tab,
60    url_patterns: Option<Vec<String>>,
61    resource_types: Option<Vec<String>>,
62}
63
64impl<'a> InterceptBuilder<'a> {
65    /// Sets URL patterns to filter intercepted requests/responses.
66    ///
67    /// Patterns support glob-style wildcards (`*`).
68    #[inline]
69    #[must_use]
70    pub fn url_patterns(mut self, patterns: Vec<String>) -> Self {
71        self.url_patterns = Some(patterns);
72        self
73    }
74
75    /// Sets resource types to filter intercepted requests/responses.
76    ///
77    /// Valid types: "document", "script", "stylesheet", "image", "media",
78    /// "font", "xhr", "fetch", "websocket", "other".
79    #[inline]
80    #[must_use]
81    pub fn resource_types(mut self, types: Vec<String>) -> Self {
82        self.resource_types = Some(types);
83        self
84    }
85
86    /// Intercepts requests with the configured filters.
87    ///
88    /// # Returns
89    ///
90    /// An `InterceptId` that can be used to stop this intercept.
91    pub async fn requests<F>(self, callback: F) -> Result<InterceptId>
92    where
93        F: Fn(InterceptedRequest) -> RequestAction + Send + Sync + 'static,
94    {
95        debug!(tab_id = %self.tab.inner.tab_id, "Enabling filtered request interception");
96
97        let window = self.tab.get_window()?;
98        let callback = Arc::new(callback);
99        let handler_key = next_handler_key("intercept_request");
100
101        window.inner.pool.add_event_handler(
102            window.inner.session_id,
103            handler_key.clone(),
104            Box::new(move |event: Event| {
105                if event.method.as_str() != EVENT_BEFORE_REQUEST {
106                    return None;
107                }
108
109                let request = parse_intercepted_request(&event);
110                let action = callback(request);
111                let result = request_action_to_json(&action);
112
113                Some(EventReply::new(
114                    event.id,
115                    EVENT_BEFORE_REQUEST,
116                    result,
117                ))
118            }),
119        );
120
121        let command = Command::Network(NetworkCommand::AddIntercept {
122            intercept_requests: true,
123            intercept_request_headers: false,
124            intercept_request_body: false,
125            intercept_responses: false,
126            intercept_response_body: false,
127            url_patterns: self.url_patterns,
128            resource_types: self.resource_types,
129        });
130
131        let response = self.tab.send_command(command).await?;
132        let intercept_id = extract_intercept_id(&response)?;
133        window.inner.intercept_handlers.lock().insert(intercept_id.clone(), handler_key);
134        Ok(intercept_id)
135    }
136
137    /// Intercepts request headers with the configured filters.
138    pub async fn request_headers<F>(self, callback: F) -> Result<InterceptId>
139    where
140        F: Fn(InterceptedRequestHeaders) -> HeadersAction + Send + Sync + 'static,
141    {
142        debug!(tab_id = %self.tab.inner.tab_id, "Enabling filtered request headers interception");
143
144        let window = self.tab.get_window()?;
145        let callback = Arc::new(callback);
146        let handler_key = next_handler_key("intercept_request_headers");
147
148        window.inner.pool.add_event_handler(
149            window.inner.session_id,
150            handler_key.clone(),
151            Box::new(move |event: Event| {
152                if event.method.as_str() != EVENT_REQUEST_HEADERS {
153                    return None;
154                }
155
156                let headers_data = parse_intercepted_request_headers(&event);
157                let action = callback(headers_data);
158                let result = headers_action_to_json(&action);
159
160                Some(EventReply::new(event.id, EVENT_REQUEST_HEADERS, result))
161            }),
162        );
163
164        let command = Command::Network(NetworkCommand::AddIntercept {
165            intercept_requests: false,
166            intercept_request_headers: true,
167            intercept_request_body: false,
168            intercept_responses: false,
169            intercept_response_body: false,
170            url_patterns: self.url_patterns,
171            resource_types: self.resource_types,
172        });
173
174        let response = self.tab.send_command(command).await?;
175        let intercept_id = extract_intercept_id(&response)?;
176        window.inner.intercept_handlers.lock().insert(intercept_id.clone(), handler_key);
177        Ok(intercept_id)
178    }
179
180    /// Intercepts request body with the configured filters (read-only).
181    pub async fn request_body<F>(self, callback: F) -> Result<InterceptId>
182    where
183        F: Fn(InterceptedRequestBody) + Send + Sync + 'static,
184    {
185        debug!(tab_id = %self.tab.inner.tab_id, "Enabling filtered request body interception");
186
187        let window = self.tab.get_window()?;
188        let callback = Arc::new(callback);
189        let handler_key = next_handler_key("intercept_request_body");
190
191        window.inner.pool.add_event_handler(
192            window.inner.session_id,
193            handler_key.clone(),
194            Box::new(move |event: Event| {
195                if event.method.as_str() != EVENT_REQUEST_BODY {
196                    return None;
197                }
198
199                let body_data = parse_intercepted_request_body(&event);
200                callback(body_data);
201
202                Some(EventReply::new(
203                    event.id,
204                    EVENT_REQUEST_BODY,
205                    serde_json::json!({ "action": "allow" }),
206                ))
207            }),
208        );
209
210        let command = Command::Network(NetworkCommand::AddIntercept {
211            intercept_requests: false,
212            intercept_request_headers: false,
213            intercept_request_body: true,
214            intercept_responses: false,
215            intercept_response_body: false,
216            url_patterns: self.url_patterns,
217            resource_types: self.resource_types,
218        });
219
220        let response = self.tab.send_command(command).await?;
221        let intercept_id = extract_intercept_id(&response)?;
222        window.inner.intercept_handlers.lock().insert(intercept_id.clone(), handler_key);
223        Ok(intercept_id)
224    }
225
226    /// Intercepts response headers with the configured filters.
227    pub async fn responses<F>(self, callback: F) -> Result<InterceptId>
228    where
229        F: Fn(InterceptedResponse) -> HeadersAction + Send + Sync + 'static,
230    {
231        debug!(tab_id = %self.tab.inner.tab_id, "Enabling filtered response interception");
232
233        let window = self.tab.get_window()?;
234        let callback = Arc::new(callback);
235        let handler_key = next_handler_key("intercept_response");
236
237        window.inner.pool.add_event_handler(
238            window.inner.session_id,
239            handler_key.clone(),
240            Box::new(move |event: Event| {
241                if event.method.as_str() != EVENT_RESPONSE_HEADERS {
242                    return None;
243                }
244
245                let resp = parse_intercepted_response(&event);
246                let action = callback(resp);
247                let result = headers_action_to_json(&action);
248
249                Some(EventReply::new(event.id, EVENT_RESPONSE_HEADERS, result))
250            }),
251        );
252
253        let command = Command::Network(NetworkCommand::AddIntercept {
254            intercept_requests: false,
255            intercept_request_headers: false,
256            intercept_request_body: false,
257            intercept_responses: true,
258            intercept_response_body: false,
259            url_patterns: self.url_patterns,
260            resource_types: self.resource_types,
261        });
262
263        let response = self.tab.send_command(command).await?;
264        let intercept_id = extract_intercept_id(&response)?;
265        window.inner.intercept_handlers.lock().insert(intercept_id.clone(), handler_key);
266        Ok(intercept_id)
267    }
268
269    /// Intercepts response body with the configured filters.
270    pub async fn response_body<F>(self, callback: F) -> Result<InterceptId>
271    where
272        F: Fn(InterceptedResponseBody) -> BodyAction + Send + Sync + 'static,
273    {
274        debug!(tab_id = %self.tab.inner.tab_id, "Enabling filtered response body interception");
275
276        let window = self.tab.get_window()?;
277        let callback = Arc::new(callback);
278        let handler_key = next_handler_key("intercept_response_body");
279
280        window.inner.pool.add_event_handler(
281            window.inner.session_id,
282            handler_key.clone(),
283            Box::new(move |event: Event| {
284                if event.method.as_str() != EVENT_RESPONSE_BODY {
285                    return None;
286                }
287
288                let body_data = parse_intercepted_response_body(&event);
289                let action = callback(body_data);
290                let result = body_action_to_json(&action);
291
292                Some(EventReply::new(event.id, EVENT_RESPONSE_BODY, result))
293            }),
294        );
295
296        let command = Command::Network(NetworkCommand::AddIntercept {
297            intercept_requests: false,
298            intercept_request_headers: false,
299            intercept_request_body: false,
300            intercept_responses: false,
301            intercept_response_body: true,
302            url_patterns: self.url_patterns,
303            resource_types: self.resource_types,
304        });
305
306        let response = self.tab.send_command(command).await?;
307        let intercept_id = extract_intercept_id(&response)?;
308        window.inner.intercept_handlers.lock().insert(intercept_id.clone(), handler_key);
309        Ok(intercept_id)
310    }
311}
312
313// ============================================================================
314// Tab - Network
315// ============================================================================
316
317impl Tab {
318    /// Creates an intercept builder for filtered network interception.
319    ///
320    /// # Example
321    ///
322    /// ```ignore
323    /// let id = tab.intercept()
324    ///     .url_patterns(vec!["*api*".into()])
325    ///     .resource_types(vec!["xhr".into(), "fetch".into()])
326    ///     .requests(|req| RequestAction::allow())
327    ///     .await?;
328    /// ```
329    #[must_use]
330    pub fn intercept(&self) -> InterceptBuilder<'_> {
331        InterceptBuilder {
332            tab: self,
333            url_patterns: None,
334            resource_types: None,
335        }
336    }
337
338    /// Sets URL patterns to block.
339    ///
340    /// Patterns support wildcards (`*`).
341    ///
342    /// # Example
343    ///
344    /// ```ignore
345    /// tab.set_block_rules(&["*ads*", "*tracking*"]).await?;
346    /// ```
347    pub async fn set_block_rules(&self, patterns: &[&str]) -> Result<()> {
348        debug!(tab_id = %self.inner.tab_id, pattern_count = patterns.len(), "Setting block rules");
349
350        let command = Command::Network(NetworkCommand::SetBlockRules {
351            patterns: patterns.iter().map(|s| (*s).to_string()).collect(),
352        });
353
354        self.send_command(command).await?;
355        Ok(())
356    }
357
358    /// Clears all URL block rules.
359    pub async fn clear_block_rules(&self) -> Result<()> {
360        debug!(tab_id = %self.inner.tab_id, "Clearing block rules");
361        let command = Command::Network(NetworkCommand::ClearBlockRules);
362        self.send_command(command).await?;
363        Ok(())
364    }
365
366    /// Intercepts network requests with a callback.
367    ///
368    /// # Returns
369    ///
370    /// An `InterceptId` that can be used to stop this intercept.
371    ///
372    /// # Example
373    ///
374    /// ```ignore
375    /// use firefox_webdriver::RequestAction;
376    ///
377    /// let id = tab.intercept_request(|req| {
378    ///     if req.url.contains("ads") {
379    ///         RequestAction::block()
380    ///     } else {
381    ///         RequestAction::allow()
382    ///     }
383    /// }).await?;
384    /// ```
385    pub async fn intercept_request<F>(&self, callback: F) -> Result<InterceptId>
386    where
387        F: Fn(InterceptedRequest) -> RequestAction + Send + Sync + 'static,
388    {
389        self.intercept().requests(callback).await
390    }
391
392    /// Intercepts request headers with a callback.
393    pub async fn intercept_request_headers<F>(&self, callback: F) -> Result<InterceptId>
394    where
395        F: Fn(InterceptedRequestHeaders) -> HeadersAction + Send + Sync + 'static,
396    {
397        self.intercept().request_headers(callback).await
398    }
399
400    /// Intercepts request body for logging (read-only).
401    pub async fn intercept_request_body<F>(&self, callback: F) -> Result<InterceptId>
402    where
403        F: Fn(InterceptedRequestBody) + Send + Sync + 'static,
404    {
405        self.intercept().request_body(callback).await
406    }
407
408    /// Intercepts response headers with a callback.
409    pub async fn intercept_response<F>(&self, callback: F) -> Result<InterceptId>
410    where
411        F: Fn(InterceptedResponse) -> HeadersAction + Send + Sync + 'static,
412    {
413        self.intercept().responses(callback).await
414    }
415
416    /// Intercepts response body with a callback.
417    pub async fn intercept_response_body<F>(&self, callback: F) -> Result<InterceptId>
418    where
419        F: Fn(InterceptedResponseBody) -> BodyAction + Send + Sync + 'static,
420    {
421        self.intercept().response_body(callback).await
422    }
423
424    /// Stops network interception.
425    ///
426    /// Only removes the specific handler associated with this intercept,
427    /// leaving other intercept handlers intact.
428    ///
429    /// # Arguments
430    ///
431    /// * `intercept_id` - The intercept ID returned from intercept methods
432    pub async fn stop_intercept(&self, intercept_id: &InterceptId) -> Result<()> {
433        debug!(tab_id = %self.inner.tab_id, %intercept_id, "Stopping interception");
434
435        let window = self.get_window()?;
436
437        if let Some(handler_key) = window.inner.intercept_handlers.lock().remove(intercept_id) {
438            window
439                .inner
440                .pool
441                .remove_event_handler(window.inner.session_id, &handler_key);
442        }
443
444        let command = Command::Network(NetworkCommand::RemoveIntercept {
445            intercept_id: intercept_id.clone(),
446        });
447
448        self.send_command(command).await?;
449        Ok(())
450    }
451}
452
453// ============================================================================
454// JSON Field Extraction Helpers
455// ============================================================================
456
457/// Extracts a string field from event params.
458fn get_str(params: &serde_json::Value, key: &str) -> String {
459    params.get(key).and_then(|v| v.as_str()).unwrap_or("").to_string()
460}
461
462/// Extracts a string field with a default value.
463fn get_str_or(params: &serde_json::Value, key: &str, default: &str) -> String {
464    params.get(key).and_then(|v| v.as_str()).unwrap_or(default).to_string()
465}
466
467/// Extracts a u32 field from event params.
468fn get_u32(params: &serde_json::Value, key: &str) -> u32 {
469    params.get(key).and_then(|v| v.as_u64()).unwrap_or(0) as u32
470}
471
472/// Extracts a u64 field from event params.
473fn get_u64(params: &serde_json::Value, key: &str) -> u64 {
474    params.get(key).and_then(|v| v.as_u64()).unwrap_or(0)
475}
476
477/// Extracts a u16 field from event params.
478fn get_u16(params: &serde_json::Value, key: &str) -> u16 {
479    params.get(key).and_then(|v| v.as_u64()).unwrap_or(0) as u16
480}
481
482/// Parses headers from event params.
483fn parse_headers(params: &serde_json::Value) -> HashMap<String, String> {
484    params.get("headers")
485        .and_then(|v| v.as_object())
486        .map(|obj| {
487            obj.iter()
488                .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
489                .collect()
490        })
491        .unwrap_or_default()
492}
493
494// ============================================================================
495// Helper Functions
496// ============================================================================
497
498/// Extracts intercept ID from response.
499fn extract_intercept_id(response: &Response) -> Result<InterceptId> {
500    let id = response
501        .result
502        .as_ref()
503        .and_then(|v| v.get("interceptId"))
504        .and_then(|v| v.as_str())
505        .ok_or_else(|| Error::protocol("No interceptId in response"))?;
506
507    Ok(InterceptId::new(id))
508}
509
510/// Parses intercepted request from event.
511fn parse_intercepted_request(event: &Event) -> InterceptedRequest {
512    InterceptedRequest {
513        request_id: get_str(&event.params, "requestId"),
514        url: get_str(&event.params, "url"),
515        method: get_str_or(&event.params, "method", "GET"),
516        headers: parse_headers(&event.params),
517        resource_type: get_str_or(&event.params, "resourceType", "other"),
518        tab_id: get_u32(&event.params, "tabId"),
519        frame_id: get_u64(&event.params, "frameId"),
520        body: None,
521    }
522}
523
524/// Parses intercepted request headers from event.
525fn parse_intercepted_request_headers(event: &Event) -> InterceptedRequestHeaders {
526    InterceptedRequestHeaders {
527        request_id: get_str(&event.params, "requestId"),
528        url: get_str(&event.params, "url"),
529        method: get_str_or(&event.params, "method", "GET"),
530        headers: parse_headers(&event.params),
531        tab_id: get_u32(&event.params, "tabId"),
532        frame_id: get_u64(&event.params, "frameId"),
533    }
534}
535
536/// Parses intercepted request body from event.
537fn parse_intercepted_request_body(event: &Event) -> InterceptedRequestBody {
538    InterceptedRequestBody {
539        request_id: get_str(&event.params, "requestId"),
540        url: get_str(&event.params, "url"),
541        method: get_str_or(&event.params, "method", "GET"),
542        resource_type: get_str_or(&event.params, "resourceType", "other"),
543        tab_id: get_u32(&event.params, "tabId"),
544        frame_id: get_u64(&event.params, "frameId"),
545        body: event.params.as_object().and_then(parse_request_body),
546    }
547}
548
549/// Parses intercepted response from event.
550fn parse_intercepted_response(event: &Event) -> InterceptedResponse {
551    InterceptedResponse {
552        request_id: get_str(&event.params, "requestId"),
553        url: get_str(&event.params, "url"),
554        status: get_u16(&event.params, "status"),
555        status_text: get_str(&event.params, "statusText"),
556        headers: parse_headers(&event.params),
557        tab_id: get_u32(&event.params, "tabId"),
558        frame_id: get_u64(&event.params, "frameId"),
559    }
560}
561
562/// Parses intercepted response body from event.
563fn parse_intercepted_response_body(event: &Event) -> InterceptedResponseBody {
564    let body = if let Some(b64) = event.params.get("bodyBase64").and_then(|v| v.as_str()) {
565        match Base64Standard.decode(b64) {
566            Ok(bytes) => ResponseBodyData::Binary(bytes),
567            Err(_) => ResponseBodyData::Text(String::new()),
568        }
569    } else {
570        let text = get_str(&event.params, "body");
571        ResponseBodyData::Text(text)
572    };
573
574    InterceptedResponseBody {
575        request_id: get_str(&event.params, "requestId"),
576        url: get_str(&event.params, "url"),
577        status: event
578            .params
579            .get("status")
580            .and_then(|v| v.as_u64())
581            .unwrap_or(200) as u16,
582        content_type: get_str_or(&event.params, "contentType", "application/octet-stream"),
583        body,
584        tab_id: get_u32(&event.params, "tabId"),
585        frame_id: get_u64(&event.params, "frameId"),
586        content_length: get_u64(&event.params, "contentLength") as usize,
587    }
588}
589
590/// Parses request body from event params.
591fn parse_request_body(params: &serde_json::Map<String, Value>) -> Option<RequestBody> {
592    let body = params.get("body")?;
593    let body_obj = body.as_object()?;
594
595    if let Some(error) = body_obj.get("error").and_then(|v| v.as_str()) {
596        return Some(RequestBody::Error(error.to_string()));
597    }
598
599    if let Some(form_data) = body_obj.get("data").and_then(|v| v.as_object())
600        && body_obj.get("type").and_then(|v| v.as_str()) == Some("formData")
601    {
602        let mut map = HashMap::new();
603        for (key, value) in form_data {
604            if let Some(arr) = value.as_array() {
605                let values: Vec<String> = arr
606                    .iter()
607                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
608                    .collect();
609                map.insert(key.clone(), values);
610            }
611        }
612        return Some(RequestBody::FormData(map));
613    }
614
615    if let Some(raw_data) = body_obj.get("data").and_then(|v| v.as_array())
616        && body_obj.get("type").and_then(|v| v.as_str()) == Some("raw")
617    {
618        let mut bytes = Vec::new();
619        for item in raw_data {
620            if let Some(obj) = item.as_object()
621                && let Some(b64) = obj.get("data").and_then(|v| v.as_str())
622                && let Ok(decoded) = Base64Standard.decode(b64)
623            {
624                bytes.extend(decoded);
625            }
626        }
627        if !bytes.is_empty() {
628            return Some(RequestBody::Raw(bytes));
629        }
630    }
631
632    None
633}
634
635/// Converts request action to JSON.
636fn request_action_to_json(action: &RequestAction) -> Value {
637    match action {
638        RequestAction::Allow => serde_json::json!({ "action": "allow" }),
639        RequestAction::Block => serde_json::json!({ "action": "block" }),
640        RequestAction::Redirect(url) => serde_json::json!({ "action": "redirect", "url": url }),
641    }
642}
643
644/// Converts headers action to JSON.
645fn headers_action_to_json(action: &HeadersAction) -> Value {
646    match action {
647        HeadersAction::Allow => serde_json::json!({ "action": "allow" }),
648        HeadersAction::Block => serde_json::json!({ "action": "block" }),
649        HeadersAction::Modify {
650            headers,
651            status_code,
652        } => {
653            let mut json = serde_json::json!({ "action": "modifyHeaders", "headers": headers });
654            if let Some(code) = status_code {
655                json["statusCode"] = serde_json::json!(code);
656            }
657            json
658        }
659    }
660}
661
662/// Converts body action to JSON.
663fn body_action_to_json(action: &BodyAction) -> Value {
664    match action {
665        BodyAction::Allow => serde_json::json!({ "action": "allow" }),
666        BodyAction::ModifyBody(b) => serde_json::json!({ "action": "modifyBody", "body": b }),
667    }
668}