firefox_webdriver/browser/tab/
network.rs

1//! Network interception and blocking methods.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use base64::Engine;
7use base64::engine::general_purpose::STANDARD as Base64Standard;
8use serde_json::Value;
9use tracing::debug;
10
11use crate::browser::network::{
12    BodyAction, HeadersAction, InterceptedRequest, InterceptedRequestBody,
13    InterceptedRequestHeaders, InterceptedResponse, InterceptedResponseBody, RequestAction,
14    RequestBody,
15};
16use crate::error::{Error, Result};
17use crate::identifiers::InterceptId;
18use crate::protocol::{Command, Event, EventReply, NetworkCommand, Response};
19
20use super::Tab;
21
22// ============================================================================
23// Tab - Network
24// ============================================================================
25
26impl Tab {
27    /// Sets URL patterns to block.
28    ///
29    /// Patterns support wildcards (`*`).
30    ///
31    /// # Example
32    ///
33    /// ```ignore
34    /// tab.set_block_rules(&["*ads*", "*tracking*"]).await?;
35    /// ```
36    pub async fn set_block_rules(&self, patterns: &[&str]) -> Result<()> {
37        debug!(tab_id = %self.inner.tab_id, pattern_count = patterns.len(), "Setting block rules");
38
39        let command = Command::Network(NetworkCommand::SetBlockRules {
40            patterns: patterns.iter().map(|s| (*s).to_string()).collect(),
41        });
42
43        self.send_command(command).await?;
44        Ok(())
45    }
46
47    /// Clears all URL block rules.
48    pub async fn clear_block_rules(&self) -> Result<()> {
49        debug!(tab_id = %self.inner.tab_id, "Clearing block rules");
50        let command = Command::Network(NetworkCommand::ClearBlockRules);
51        self.send_command(command).await?;
52        Ok(())
53    }
54
55    /// Intercepts network requests with a callback.
56    ///
57    /// # Returns
58    ///
59    /// An `InterceptId` that can be used to stop this intercept.
60    ///
61    /// # Example
62    ///
63    /// ```ignore
64    /// use firefox_webdriver::RequestAction;
65    ///
66    /// let id = tab.intercept_request(|req| {
67    ///     if req.url.contains("ads") {
68    ///         RequestAction::block()
69    ///     } else {
70    ///         RequestAction::allow()
71    ///     }
72    /// }).await?;
73    /// ```
74    pub async fn intercept_request<F>(&self, callback: F) -> Result<InterceptId>
75    where
76        F: Fn(InterceptedRequest) -> RequestAction + Send + Sync + 'static,
77    {
78        debug!(tab_id = %self.inner.tab_id, "Enabling request interception");
79
80        let window = self.get_window()?;
81        let callback = Arc::new(callback);
82
83        window.inner.pool.set_event_handler(
84            window.inner.session_id,
85            Box::new(move |event: Event| {
86                if event.method.as_str() != "network.beforeRequestSent" {
87                    return None;
88                }
89
90                let request = parse_intercepted_request(&event);
91                let action = callback(request);
92                let result = request_action_to_json(&action);
93
94                Some(EventReply::new(
95                    event.id,
96                    "network.beforeRequestSent",
97                    result,
98                ))
99            }),
100        );
101
102        let command = Command::Network(NetworkCommand::AddIntercept {
103            intercept_requests: true,
104            intercept_request_headers: false,
105            intercept_request_body: false,
106            intercept_responses: false,
107            intercept_response_body: false,
108        });
109
110        let response = self.send_command(command).await?;
111        extract_intercept_id(&response)
112    }
113
114    /// Intercepts request headers with a callback.
115    pub async fn intercept_request_headers<F>(&self, callback: F) -> Result<InterceptId>
116    where
117        F: Fn(InterceptedRequestHeaders) -> HeadersAction + Send + Sync + 'static,
118    {
119        debug!(tab_id = %self.inner.tab_id, "Enabling request headers interception");
120
121        let window = self.get_window()?;
122        let callback = Arc::new(callback);
123
124        window.inner.pool.set_event_handler(
125            window.inner.session_id,
126            Box::new(move |event: Event| {
127                if event.method.as_str() != "network.requestHeaders" {
128                    return None;
129                }
130
131                let headers_data = parse_intercepted_request_headers(&event);
132                let action = callback(headers_data);
133                let result = headers_action_to_json(&action);
134
135                Some(EventReply::new(event.id, "network.requestHeaders", result))
136            }),
137        );
138
139        let command = Command::Network(NetworkCommand::AddIntercept {
140            intercept_requests: false,
141            intercept_request_headers: true,
142            intercept_request_body: false,
143            intercept_responses: false,
144            intercept_response_body: false,
145        });
146
147        let response = self.send_command(command).await?;
148        extract_intercept_id(&response)
149    }
150
151    /// Intercepts request body for logging (read-only).
152    pub async fn intercept_request_body<F>(&self, callback: F) -> Result<InterceptId>
153    where
154        F: Fn(InterceptedRequestBody) + Send + Sync + 'static,
155    {
156        debug!(tab_id = %self.inner.tab_id, "Enabling request body interception");
157
158        let window = self.get_window()?;
159        let callback = Arc::new(callback);
160
161        window.inner.pool.set_event_handler(
162            window.inner.session_id,
163            Box::new(move |event: Event| {
164                if event.method.as_str() != "network.requestBody" {
165                    return None;
166                }
167
168                let body_data = parse_intercepted_request_body(&event);
169                callback(body_data);
170
171                Some(EventReply::new(
172                    event.id,
173                    "network.requestBody",
174                    serde_json::json!({ "action": "allow" }),
175                ))
176            }),
177        );
178
179        let command = Command::Network(NetworkCommand::AddIntercept {
180            intercept_requests: false,
181            intercept_request_headers: false,
182            intercept_request_body: true,
183            intercept_responses: false,
184            intercept_response_body: false,
185        });
186
187        let response = self.send_command(command).await?;
188        extract_intercept_id(&response)
189    }
190
191    /// Intercepts response headers with a callback.
192    pub async fn intercept_response<F>(&self, callback: F) -> Result<InterceptId>
193    where
194        F: Fn(InterceptedResponse) -> HeadersAction + Send + Sync + 'static,
195    {
196        debug!(tab_id = %self.inner.tab_id, "Enabling response interception");
197
198        let window = self.get_window()?;
199        let callback = Arc::new(callback);
200
201        window.inner.pool.set_event_handler(
202            window.inner.session_id,
203            Box::new(move |event: Event| {
204                if event.method.as_str() != "network.responseHeaders" {
205                    return None;
206                }
207
208                let resp = parse_intercepted_response(&event);
209                let action = callback(resp);
210                let result = headers_action_to_json(&action);
211
212                Some(EventReply::new(event.id, "network.responseHeaders", result))
213            }),
214        );
215
216        let command = Command::Network(NetworkCommand::AddIntercept {
217            intercept_requests: false,
218            intercept_request_headers: false,
219            intercept_request_body: false,
220            intercept_responses: true,
221            intercept_response_body: false,
222        });
223
224        let response = self.send_command(command).await?;
225        extract_intercept_id(&response)
226    }
227
228    /// Intercepts response body with a callback.
229    pub async fn intercept_response_body<F>(&self, callback: F) -> Result<InterceptId>
230    where
231        F: Fn(InterceptedResponseBody) -> BodyAction + Send + Sync + 'static,
232    {
233        debug!(tab_id = %self.inner.tab_id, "Enabling response body interception");
234
235        let window = self.get_window()?;
236        let callback = Arc::new(callback);
237
238        window.inner.pool.set_event_handler(
239            window.inner.session_id,
240            Box::new(move |event: Event| {
241                if event.method.as_str() != "network.responseBody" {
242                    return None;
243                }
244
245                let body_data = parse_intercepted_response_body(&event);
246                let action = callback(body_data);
247                let result = body_action_to_json(&action);
248
249                Some(EventReply::new(event.id, "network.responseBody", result))
250            }),
251        );
252
253        let command = Command::Network(NetworkCommand::AddIntercept {
254            intercept_requests: false,
255            intercept_request_headers: false,
256            intercept_request_body: false,
257            intercept_responses: false,
258            intercept_response_body: true,
259        });
260
261        let response = self.send_command(command).await?;
262        extract_intercept_id(&response)
263    }
264
265    /// Stops network interception.
266    ///
267    /// # Arguments
268    ///
269    /// * `intercept_id` - The intercept ID returned from intercept methods
270    pub async fn stop_intercept(&self, intercept_id: &InterceptId) -> Result<()> {
271        debug!(tab_id = %self.inner.tab_id, %intercept_id, "Stopping interception");
272
273        let window = self.get_window()?;
274        window
275            .inner
276            .pool
277            .clear_event_handler(window.inner.session_id);
278
279        let command = Command::Network(NetworkCommand::RemoveIntercept {
280            intercept_id: intercept_id.clone(),
281        });
282
283        self.send_command(command).await?;
284        Ok(())
285    }
286}
287
288// ============================================================================
289// Helper Functions
290// ============================================================================
291
292/// Extracts intercept ID from response.
293fn extract_intercept_id(response: &Response) -> Result<InterceptId> {
294    let id = response
295        .result
296        .as_ref()
297        .and_then(|v| v.get("interceptId"))
298        .and_then(|v| v.as_str())
299        .ok_or_else(|| Error::protocol("No interceptId in response"))?;
300
301    Ok(InterceptId::new(id))
302}
303
304/// Parses intercepted request from event.
305fn parse_intercepted_request(event: &Event) -> InterceptedRequest {
306    InterceptedRequest {
307        request_id: event
308            .params
309            .get("requestId")
310            .and_then(|v| v.as_str())
311            .unwrap_or("")
312            .to_string(),
313        url: event
314            .params
315            .get("url")
316            .and_then(|v| v.as_str())
317            .unwrap_or("")
318            .to_string(),
319        method: event
320            .params
321            .get("method")
322            .and_then(|v| v.as_str())
323            .unwrap_or("GET")
324            .to_string(),
325        resource_type: event
326            .params
327            .get("resourceType")
328            .and_then(|v| v.as_str())
329            .unwrap_or("other")
330            .to_string(),
331        tab_id: event
332            .params
333            .get("tabId")
334            .and_then(|v| v.as_u64())
335            .unwrap_or(0) as u32,
336        frame_id: event
337            .params
338            .get("frameId")
339            .and_then(|v| v.as_u64())
340            .unwrap_or(0),
341        body: None,
342    }
343}
344
345/// Parses intercepted request headers from event.
346fn parse_intercepted_request_headers(event: &Event) -> InterceptedRequestHeaders {
347    InterceptedRequestHeaders {
348        request_id: event
349            .params
350            .get("requestId")
351            .and_then(|v| v.as_str())
352            .unwrap_or("")
353            .to_string(),
354        url: event
355            .params
356            .get("url")
357            .and_then(|v| v.as_str())
358            .unwrap_or("")
359            .to_string(),
360        method: event
361            .params
362            .get("method")
363            .and_then(|v| v.as_str())
364            .unwrap_or("GET")
365            .to_string(),
366        headers: event
367            .params
368            .get("headers")
369            .and_then(|v| v.as_object())
370            .map(|obj| {
371                obj.iter()
372                    .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
373                    .collect()
374            })
375            .unwrap_or_default(),
376        tab_id: event
377            .params
378            .get("tabId")
379            .and_then(|v| v.as_u64())
380            .unwrap_or(0) as u32,
381        frame_id: event
382            .params
383            .get("frameId")
384            .and_then(|v| v.as_u64())
385            .unwrap_or(0),
386    }
387}
388
389/// Parses intercepted request body from event.
390fn parse_intercepted_request_body(event: &Event) -> InterceptedRequestBody {
391    InterceptedRequestBody {
392        request_id: event
393            .params
394            .get("requestId")
395            .and_then(|v| v.as_str())
396            .unwrap_or("")
397            .to_string(),
398        url: event
399            .params
400            .get("url")
401            .and_then(|v| v.as_str())
402            .unwrap_or("")
403            .to_string(),
404        method: event
405            .params
406            .get("method")
407            .and_then(|v| v.as_str())
408            .unwrap_or("GET")
409            .to_string(),
410        resource_type: event
411            .params
412            .get("resourceType")
413            .and_then(|v| v.as_str())
414            .unwrap_or("other")
415            .to_string(),
416        tab_id: event
417            .params
418            .get("tabId")
419            .and_then(|v| v.as_u64())
420            .unwrap_or(0) as u32,
421        frame_id: event
422            .params
423            .get("frameId")
424            .and_then(|v| v.as_u64())
425            .unwrap_or(0),
426        body: event.params.as_object().and_then(parse_request_body),
427    }
428}
429
430/// Parses intercepted response from event.
431fn parse_intercepted_response(event: &Event) -> InterceptedResponse {
432    InterceptedResponse {
433        request_id: event
434            .params
435            .get("requestId")
436            .and_then(|v| v.as_str())
437            .unwrap_or("")
438            .to_string(),
439        url: event
440            .params
441            .get("url")
442            .and_then(|v| v.as_str())
443            .unwrap_or("")
444            .to_string(),
445        status: event
446            .params
447            .get("status")
448            .and_then(|v| v.as_u64())
449            .unwrap_or(0) as u16,
450        status_text: event
451            .params
452            .get("statusText")
453            .and_then(|v| v.as_str())
454            .unwrap_or("")
455            .to_string(),
456        headers: event
457            .params
458            .get("headers")
459            .and_then(|v| v.as_object())
460            .map(|obj| {
461                obj.iter()
462                    .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
463                    .collect()
464            })
465            .unwrap_or_default(),
466        tab_id: event
467            .params
468            .get("tabId")
469            .and_then(|v| v.as_u64())
470            .unwrap_or(0) as u32,
471        frame_id: event
472            .params
473            .get("frameId")
474            .and_then(|v| v.as_u64())
475            .unwrap_or(0),
476    }
477}
478
479/// Parses intercepted response body from event.
480fn parse_intercepted_response_body(event: &Event) -> InterceptedResponseBody {
481    InterceptedResponseBody {
482        request_id: event
483            .params
484            .get("requestId")
485            .and_then(|v| v.as_str())
486            .unwrap_or("")
487            .to_string(),
488        url: event
489            .params
490            .get("url")
491            .and_then(|v| v.as_str())
492            .unwrap_or("")
493            .to_string(),
494        tab_id: event
495            .params
496            .get("tabId")
497            .and_then(|v| v.as_u64())
498            .unwrap_or(0) as u32,
499        frame_id: event
500            .params
501            .get("frameId")
502            .and_then(|v| v.as_u64())
503            .unwrap_or(0),
504        body: event
505            .params
506            .get("body")
507            .and_then(|v| v.as_str())
508            .unwrap_or("")
509            .to_string(),
510        content_length: event
511            .params
512            .get("contentLength")
513            .and_then(|v| v.as_u64())
514            .unwrap_or(0) as usize,
515    }
516}
517
518/// Parses request body from event params.
519fn parse_request_body(params: &serde_json::Map<String, Value>) -> Option<RequestBody> {
520    let body = params.get("body")?;
521    let body_obj = body.as_object()?;
522
523    if let Some(error) = body_obj.get("error").and_then(|v| v.as_str()) {
524        return Some(RequestBody::Error(error.to_string()));
525    }
526
527    if let Some(form_data) = body_obj.get("data").and_then(|v| v.as_object())
528        && body_obj.get("type").and_then(|v| v.as_str()) == Some("formData")
529    {
530        let mut map = HashMap::new();
531        for (key, value) in form_data {
532            if let Some(arr) = value.as_array() {
533                let values: Vec<String> = arr
534                    .iter()
535                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
536                    .collect();
537                map.insert(key.clone(), values);
538            }
539        }
540        return Some(RequestBody::FormData(map));
541    }
542
543    if let Some(raw_data) = body_obj.get("data").and_then(|v| v.as_array())
544        && body_obj.get("type").and_then(|v| v.as_str()) == Some("raw")
545    {
546        let mut bytes = Vec::new();
547        for item in raw_data {
548            if let Some(obj) = item.as_object()
549                && let Some(b64) = obj.get("data").and_then(|v| v.as_str())
550                && let Ok(decoded) = Base64Standard.decode(b64)
551            {
552                bytes.extend(decoded);
553            }
554        }
555        if !bytes.is_empty() {
556            return Some(RequestBody::Raw(bytes));
557        }
558    }
559
560    None
561}
562
563/// Converts request action to JSON.
564fn request_action_to_json(action: &RequestAction) -> Value {
565    match action {
566        RequestAction::Allow => serde_json::json!({ "action": "allow" }),
567        RequestAction::Block => serde_json::json!({ "action": "block" }),
568        RequestAction::Redirect(url) => serde_json::json!({ "action": "redirect", "url": url }),
569    }
570}
571
572/// Converts headers action to JSON.
573fn headers_action_to_json(action: &HeadersAction) -> Value {
574    match action {
575        HeadersAction::Allow => serde_json::json!({ "action": "allow" }),
576        HeadersAction::ModifyHeaders(h) => {
577            serde_json::json!({ "action": "modifyHeaders", "headers": h })
578        }
579    }
580}
581
582/// Converts body action to JSON.
583fn body_action_to_json(action: &BodyAction) -> Value {
584    match action {
585        BodyAction::Allow => serde_json::json!({ "action": "allow" }),
586        BodyAction::ModifyBody(b) => serde_json::json!({ "action": "modifyBody", "body": b }),
587    }
588}