1use base64::Engine as _;
36use base64::engine::general_purpose::STANDARD as BASE64;
37use futures::StreamExt;
38use serde::Deserialize;
39use serde_json::{Map, Value, json};
40use std::collections::HashMap;
41use tokio::sync::oneshot;
42use tokio_util::sync::CancellationToken;
43use tracing::{trace, warn};
44use zendriver_transport::SessionHandle;
45
46use crate::builder::RequestPattern;
47use crate::error::InterceptionError;
48use crate::rule::Rule;
49use crate::types::{RequestInfo, RequestOverrides, ResourceType, ResponseInfo, ResponseOverrides};
50
51#[derive(Debug)]
58#[must_use = "interception stops when the handle is dropped — bind it to a variable to keep it alive"]
59pub struct InterceptHandle {
60 cancel: CancellationToken,
61 done: Option<oneshot::Receiver<()>>,
64}
65
66impl InterceptHandle {
67 pub(crate) fn new(cancel: CancellationToken, done: oneshot::Receiver<()>) -> Self {
71 Self {
72 cancel,
73 done: Some(done),
74 }
75 }
76
77 #[cfg(any(test, feature = "test-support"))]
87 #[doc(hidden)]
88 pub fn for_tests() -> Self {
89 let (_done_tx, done_rx) = oneshot::channel();
90 Self {
91 cancel: CancellationToken::new(),
92 done: Some(done_rx),
93 }
94 }
95
96 pub async fn stop(mut self) -> Result<(), InterceptionError> {
105 self.cancel.cancel();
106 match self.done.take() {
107 Some(rx) => rx.await.map_err(|_| InterceptionError::SubscriptionClosed),
108 None => Ok(()),
109 }
110 }
111}
112
113impl Drop for InterceptHandle {
114 fn drop(&mut self) {
115 self.cancel.cancel();
120 }
121}
122
123#[derive(Debug, Deserialize)]
132pub(crate) struct RequestPausedEvent {
133 #[serde(rename = "requestId")]
134 pub(crate) request_id: String,
135 pub(crate) request: RequestPayload,
136 #[serde(rename = "resourceType", default)]
137 pub(crate) resource_type: Option<String>,
138 #[serde(rename = "responseStatusCode", default)]
140 pub(crate) response_status_code: Option<u16>,
141 #[serde(rename = "responseStatusText", default)]
142 pub(crate) response_status_text: Option<String>,
143 #[serde(rename = "responseHeaders", default)]
144 pub(crate) response_headers: Option<Vec<HeaderPair>>,
145}
146
147#[derive(Debug, Deserialize)]
148pub(crate) struct RequestPayload {
149 pub(crate) url: String,
150 pub(crate) method: String,
151 #[serde(default)]
152 pub(crate) headers: HashMap<String, String>,
153 #[serde(rename = "postData", default)]
157 pub(crate) post_data: Option<String>,
158 #[serde(rename = "hasPostData", default)]
159 _has_post_data: Option<bool>,
160 #[serde(rename = "postDataEntries", default)]
164 pub(crate) post_data_entries: Option<Vec<PostDataEntry>>,
165}
166
167#[derive(Debug, Deserialize)]
168pub(crate) struct PostDataEntry {
169 #[serde(default)]
171 pub(crate) bytes: Option<String>,
172}
173
174#[derive(Debug, Deserialize)]
175pub(crate) struct HeaderPair {
176 pub(crate) name: String,
177 pub(crate) value: String,
178}
179
180pub(crate) async fn run_actor(
189 session: SessionHandle,
190 rules: Vec<Rule>,
191 patterns: Vec<RequestPattern>,
192 auth: Option<(String, String)>,
193 cancel: CancellationToken,
194 done: oneshot::Sender<()>,
195) {
196 let mut paused = session.subscribe::<Value>("Fetch.requestPaused");
202 let mut auth_required = session.subscribe::<Value>("Fetch.authRequired");
206
207 let enable_session = session.clone();
213 let enable_patterns: Vec<Value> = patterns.iter().map(serialize_pattern).collect();
214 let handle_auth_requests = auth.is_some();
215 tokio::spawn(async move {
216 if let Err(e) = enable_session
217 .call(
218 "Fetch.enable",
219 json!({
220 "patterns": enable_patterns,
221 "handleAuthRequests": handle_auth_requests,
222 }),
223 )
224 .await
225 {
226 warn!(error = %e, "interception: Fetch.enable failed; interception inactive");
227 }
228 });
229
230 loop {
232 tokio::select! {
233 () = cancel.cancelled() => {
234 trace!("interception: cancellation received, disabling Fetch and exiting");
235 break;
236 }
237 Some(ev_value) = paused.next() => {
238 let ev: RequestPausedEvent = match serde_json::from_value(ev_value) {
241 Ok(ev) => ev,
242 Err(e) => {
243 warn!(error = %e, "interception: skipping malformed Fetch.requestPaused event");
244 continue;
245 }
246 };
247 if let Err(e) = handle_paused(&session, &rules, ev).await {
248 warn!(error = %e, "interception: handler dispatch failed");
249 }
250 }
251 Some(ev_value) = auth_required.next() => {
252 let Some(request_id) = ev_value
258 .get("requestId")
259 .and_then(Value::as_str)
260 .map(str::to_owned)
261 else {
262 warn!("interception: Fetch.authRequired without requestId");
263 continue;
264 };
265 let response = match &auth {
266 Some((user, pass)) => json!({
267 "response": "ProvideCredentials",
268 "username": user,
269 "password": pass,
270 }),
271 None => json!({ "response": "Default" }),
272 };
273 if let Err(e) = session
274 .call(
275 "Fetch.continueWithAuth",
276 json!({
277 "requestId": request_id,
278 "authChallengeResponse": response,
279 }),
280 )
281 .await
282 {
283 warn!(error = %e, "interception: Fetch.continueWithAuth failed");
284 }
285 }
286 else => {
287 trace!("interception: event stream closed, exiting without Fetch.disable");
289 let _ = done.send(());
292 return;
293 }
294 }
295 }
296
297 if let Err(e) = session.call("Fetch.disable", json!({})).await {
301 warn!(error = %e, "interception: Fetch.disable failed during shutdown");
302 }
303 let _ = done.send(());
306}
307
308async fn handle_paused(
312 session: &SessionHandle,
313 rules: &[Rule],
314 ev: RequestPausedEvent,
315) -> Result<(), InterceptionError> {
316 let url = ev.request.url.clone();
317
318 let matched = rules.iter().find(|r| r.matches(&url));
322
323 match matched {
324 Some(Rule::Block { .. }) | Some(Rule::BlockHosts { .. }) => {
325 fail_request(session, &ev.request_id, "BlockedByClient").await
326 }
327 Some(Rule::Redirect { to, .. }) => continue_with_url(session, &ev.request_id, to).await,
328 Some(Rule::Respond {
329 status,
330 headers,
331 body,
332 ..
333 }) => fulfill_request(session, &ev.request_id, *status, headers, body).await,
334 Some(Rule::Modify { modify, .. }) => {
335 let info = build_request_info(&ev);
336 let overrides = modify(&info);
337 continue_with_overrides(session, &ev.request_id, overrides).await
338 }
339 Some(Rule::ModifyResponse { modify, .. }) => match build_response_info(&ev) {
340 Some(info) => {
341 let overrides = modify(&info);
342 continue_response_with_overrides(session, &ev.request_id, overrides).await
343 }
344 None => {
345 tracing::debug!(
350 request_id = %ev.request_id,
351 url = %url,
352 "interception: ModifyResponse matched at Request stage; no response yet, passing through"
353 );
354 continue_passthrough(session, &ev.request_id).await
355 }
356 },
357 None => continue_passthrough(session, &ev.request_id).await,
358 }
359}
360
361pub(crate) fn serialize_pattern(p: &RequestPattern) -> Value {
364 let mut obj = Map::new();
365 if let Some(url) = &p.url_pattern {
366 obj.insert("urlPattern".into(), Value::String(url.clone()));
367 }
368 if let Some(rt) = p.resource_type {
369 obj.insert("resourceType".into(), Value::String(rt.as_cdp_str().into()));
370 }
371 if let Some(stage) = p.request_stage {
372 obj.insert(
373 "requestStage".into(),
374 Value::String(stage.as_cdp_str().into()),
375 );
376 }
377 Value::Object(obj)
378}
379
380pub(crate) fn build_request_info(ev: &RequestPausedEvent) -> RequestInfo {
392 RequestInfo {
393 url: ev.request.url.clone(),
394 method: ev.request.method.clone(),
395 headers: ev
396 .request
397 .headers
398 .iter()
399 .map(|(k, v)| (k.clone(), v.clone()))
400 .collect(),
401 post_data: decode_post_data(&ev.request),
402 resource_type: parse_resource_type(ev.resource_type.as_deref()),
403 }
404}
405
406fn decode_post_data(req: &RequestPayload) -> Option<Vec<u8>> {
407 use base64::Engine as _;
408 use base64::engine::general_purpose::STANDARD as BASE64;
409
410 if let Some(entries) = req.post_data_entries.as_ref() {
411 let mut buf = Vec::new();
412 for entry in entries {
413 let Some(b64) = entry.bytes.as_deref() else {
414 continue;
415 };
416 match BASE64.decode(b64) {
417 Ok(bytes) => buf.extend_from_slice(&bytes),
418 Err(e) => {
419 tracing::warn!(error = %e, "interception: bad base64 in postDataEntries; skipping entry");
420 }
421 }
422 }
423 return Some(buf);
424 }
425 req.post_data.as_deref().map(|s| s.as_bytes().to_vec())
426}
427
428pub(crate) fn build_response_info(ev: &RequestPausedEvent) -> Option<ResponseInfo> {
435 let status = ev.response_status_code?;
436 let status_text = ev.response_status_text.clone().unwrap_or_default();
437 let headers: Vec<(String, String)> = ev
438 .response_headers
439 .as_ref()
440 .map(|hs| {
441 hs.iter()
442 .map(|h| (h.name.clone(), h.value.clone()))
443 .collect()
444 })
445 .unwrap_or_default();
446 Some(ResponseInfo {
447 status,
448 status_text,
449 headers,
450 })
451}
452
453pub(crate) fn headers_to_cdp(headers: &[(String, String)]) -> Vec<Value> {
457 headers
458 .iter()
459 .map(|(name, value)| json!({ "name": name, "value": value }))
460 .collect()
461}
462
463fn parse_resource_type(s: Option<&str>) -> ResourceType {
469 match s.unwrap_or("Other") {
470 "Document" => ResourceType::Document,
471 "Stylesheet" => ResourceType::Stylesheet,
472 "Image" => ResourceType::Image,
473 "Media" => ResourceType::Media,
474 "Font" => ResourceType::Font,
475 "Script" => ResourceType::Script,
476 "TextTrack" => ResourceType::TextTrack,
477 "XHR" => ResourceType::XHR,
478 "Fetch" => ResourceType::Fetch,
479 "EventSource" => ResourceType::EventSource,
480 "WebSocket" => ResourceType::WebSocket,
481 "Manifest" => ResourceType::Manifest,
482 "SignedExchange" => ResourceType::SignedExchange,
483 "Ping" => ResourceType::Ping,
484 "CSPViolationReport" => ResourceType::CSPViolationReport,
485 "Preflight" => ResourceType::Preflight,
486 _ => ResourceType::Other,
487 }
488}
489
490async fn fail_request(
493 session: &SessionHandle,
494 request_id: &str,
495 error_reason: &str,
496) -> Result<(), InterceptionError> {
497 session
498 .call(
499 "Fetch.failRequest",
500 json!({
501 "requestId": request_id,
502 "errorReason": error_reason,
503 }),
504 )
505 .await?;
506 Ok(())
507}
508
509async fn continue_passthrough(
510 session: &SessionHandle,
511 request_id: &str,
512) -> Result<(), InterceptionError> {
513 session
514 .call("Fetch.continueRequest", json!({ "requestId": request_id }))
515 .await?;
516 Ok(())
517}
518
519async fn continue_with_url(
520 session: &SessionHandle,
521 request_id: &str,
522 url: &str,
523) -> Result<(), InterceptionError> {
524 session
525 .call(
526 "Fetch.continueRequest",
527 json!({
528 "requestId": request_id,
529 "url": url,
530 }),
531 )
532 .await?;
533 Ok(())
534}
535
536async fn continue_with_overrides(
537 session: &SessionHandle,
538 request_id: &str,
539 overrides: RequestOverrides,
540) -> Result<(), InterceptionError> {
541 let mut params = Map::new();
542 params.insert("requestId".into(), Value::String(request_id.into()));
543 if let Some(url) = overrides.url {
544 params.insert("url".into(), Value::String(url));
545 }
546 if let Some(method) = overrides.method {
547 params.insert("method".into(), Value::String(method));
548 }
549 if let Some(headers) = overrides.headers {
550 params.insert("headers".into(), Value::Array(headers_to_cdp(&headers)));
551 }
552 if let Some(post_data) = overrides.post_data {
553 params.insert("postData".into(), Value::String(BASE64.encode(&post_data)));
554 }
555 session
556 .call("Fetch.continueRequest", Value::Object(params))
557 .await?;
558 Ok(())
559}
560
561async fn fulfill_request(
562 session: &SessionHandle,
563 request_id: &str,
564 status: u16,
565 headers: &[(String, String)],
566 body: &[u8],
567) -> Result<(), InterceptionError> {
568 let response_headers = headers_to_cdp(headers);
569 session
570 .call(
571 "Fetch.fulfillRequest",
572 json!({
573 "requestId": request_id,
574 "responseCode": status,
575 "responseHeaders": response_headers,
576 "body": BASE64.encode(body),
577 }),
578 )
579 .await?;
580 Ok(())
581}
582
583async fn continue_response_with_overrides(
590 session: &SessionHandle,
591 request_id: &str,
592 overrides: ResponseOverrides,
593) -> Result<(), InterceptionError> {
594 let mut params = Map::new();
595 params.insert("requestId".into(), Value::String(request_id.into()));
596 if let Some(status) = overrides.status {
597 params.insert("responseCode".into(), Value::from(status));
598 }
599 if let Some(phrase) = overrides.phrase {
600 params.insert("responsePhrase".into(), Value::String(phrase));
601 }
602 if let Some(headers) = overrides.headers {
603 params.insert(
604 "responseHeaders".into(),
605 Value::Array(headers_to_cdp(&headers)),
606 );
607 }
608 session
609 .call("Fetch.continueResponse", Value::Object(params))
610 .await?;
611 Ok(())
612}
613
614#[cfg(test)]
615#[allow(clippy::panic, clippy::unwrap_used)]
616mod tests {
617 use super::*;
618 use crate::url_pattern::UrlPattern;
619 use std::time::Duration;
620 use zendriver_transport::testing::MockConnection;
621
622 #[tokio::test]
630 async fn block_rule_dispatches_fail_request_with_blocked_by_client() {
631 let (mut mock, conn) = MockConnection::pair();
632 let sess = SessionHandle::new(conn.clone(), "S1");
633
634 let rules = vec![Rule::Block {
635 pattern: UrlPattern::new("*/blocked/*").unwrap(),
636 }];
637 let patterns = vec![RequestPattern {
638 url_pattern: Some("*".into()),
639 ..RequestPattern::default()
640 }];
641 let cancel = CancellationToken::new();
642 let (done_tx, done_rx) = oneshot::channel();
643 let actor_cancel = cancel.clone();
644 let actor = tokio::spawn(async move {
645 run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
646 });
647
648 let enable_id =
654 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
655 .await
656 .expect("actor did not send Fetch.enable within 2s");
657 let enable_params = mock.last_sent()["params"].clone();
658 assert_eq!(enable_params["handleAuthRequests"], false);
659 assert_eq!(enable_params["patterns"][0]["urlPattern"], "*");
660 mock.reply(enable_id, json!({})).await;
663
664 mock.emit_event_for_session(
667 "Fetch.requestPaused",
668 json!({
669 "requestId": "REQ-1",
670 "request": {
671 "url": "https://example.test/blocked/banner.png",
672 "method": "GET",
673 "headers": {},
674 },
675 "resourceType": "Image",
676 }),
677 "S1",
678 )
679 .await;
680
681 let fail_id =
683 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
684 .await
685 .expect("actor did not send Fetch.failRequest within 2s");
686 let fail_params = mock.last_sent()["params"].clone();
687 assert_eq!(fail_params["requestId"], "REQ-1");
688 assert_eq!(fail_params["errorReason"], "BlockedByClient");
689 mock.reply(fail_id, json!({})).await;
690
691 cancel.cancel();
694 let disable_id =
695 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
696 .await
697 .expect("actor did not send Fetch.disable on cancel");
698 mock.reply(disable_id, json!({})).await;
699
700 tokio::time::timeout(Duration::from_secs(2), done_rx)
701 .await
702 .expect("actor did not signal exit within 2s")
703 .expect("oneshot sender dropped without sending");
704 actor.await.unwrap();
705 conn.shutdown();
706 }
707
708 #[tokio::test]
711 async fn block_hosts_rule_dispatches_fail_request() {
712 use crate::host_matcher::HostMatcher;
713 let (mut mock, conn) = MockConnection::pair();
714 let sess = SessionHandle::new(conn.clone(), "S1");
715
716 let rules = vec![Rule::BlockHosts {
717 matcher: std::sync::Arc::new(HostMatcher::new(["evil.com".to_string()])),
718 }];
719 let patterns = vec![RequestPattern {
720 url_pattern: Some("*".into()),
721 ..RequestPattern::default()
722 }];
723 let cancel = CancellationToken::new();
724 let (done_tx, done_rx) = oneshot::channel();
725 let actor_cancel = cancel.clone();
726 let actor = tokio::spawn(async move {
727 run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
728 });
729
730 let enable_id =
731 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
732 .await
733 .expect("actor did not send Fetch.enable within 2s");
734 mock.reply(enable_id, json!({})).await;
735
736 mock.emit_event_for_session(
738 "Fetch.requestPaused",
739 json!({
740 "requestId": "REQ-1",
741 "request": {
742 "url": "https://cdn.evil.com/fp.js",
743 "method": "GET",
744 "headers": {},
745 },
746 "resourceType": "Script",
747 }),
748 "S1",
749 )
750 .await;
751
752 let fail_id =
753 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
754 .await
755 .expect("actor did not send Fetch.failRequest within 2s");
756 let fail_params = mock.last_sent()["params"].clone();
757 assert_eq!(fail_params["requestId"], "REQ-1");
758 assert_eq!(fail_params["errorReason"], "BlockedByClient");
759 mock.reply(fail_id, json!({})).await;
760
761 cancel.cancel();
762 let disable_id =
763 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
764 .await
765 .expect("actor did not send Fetch.disable on cancel");
766 mock.reply(disable_id, json!({})).await;
767
768 tokio::time::timeout(Duration::from_secs(2), done_rx)
769 .await
770 .expect("actor did not signal exit within 2s")
771 .expect("oneshot sender dropped without sending");
772 actor.await.unwrap();
773 conn.shutdown();
774 }
775
776 #[tokio::test]
777 async fn actor_handles_auth_required_with_credentials() {
778 let (mut mock, conn) = MockConnection::pair();
785 let sess = SessionHandle::new(conn.clone(), "S1");
786 let cancel = CancellationToken::new();
787 let (done_tx, done_rx) = oneshot::channel();
788 let actor_cancel = cancel.clone();
789 let auth = Some(("user1".to_string(), "pass1".to_string()));
790 let actor = tokio::spawn(async move {
791 run_actor(
792 sess,
793 Vec::new(),
794 vec![RequestPattern {
795 url_pattern: Some("*".into()),
796 ..RequestPattern::default()
797 }],
798 auth,
799 actor_cancel,
800 done_tx,
801 )
802 .await;
803 });
804
805 let enable_id =
806 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
807 .await
808 .expect("actor did not send Fetch.enable within 2s");
809 assert_eq!(
810 mock.last_sent()["params"]["handleAuthRequests"],
811 true,
812 "auth-enabled actor must flip handleAuthRequests"
813 );
814 mock.reply(enable_id, json!({})).await;
815
816 mock.emit_event_for_session(
817 "Fetch.authRequired",
818 json!({
819 "requestId": "AUTH-REQ-1",
820 "request": { "url": "https://example.test/", "method": "GET" },
821 "frameId": "F1",
822 "resourceType": "Document",
823 "authChallenge": {
824 "source": "Proxy",
825 "origin": "http://proxy.test",
826 "scheme": "basic",
827 "realm": "",
828 },
829 }),
830 "S1",
831 )
832 .await;
833
834 let auth_id = tokio::time::timeout(
835 Duration::from_secs(2),
836 mock.expect_cmd("Fetch.continueWithAuth"),
837 )
838 .await
839 .expect("actor did not send Fetch.continueWithAuth within 2s");
840 let params = mock.last_sent()["params"].clone();
841 assert_eq!(params["requestId"], "AUTH-REQ-1");
842 assert_eq!(
843 params["authChallengeResponse"]["response"],
844 "ProvideCredentials"
845 );
846 assert_eq!(params["authChallengeResponse"]["username"], "user1");
847 assert_eq!(params["authChallengeResponse"]["password"], "pass1");
848 mock.reply(auth_id, json!({})).await;
849
850 cancel.cancel();
851 let disable_id =
852 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
853 .await
854 .expect("actor did not send Fetch.disable on cancel");
855 mock.reply(disable_id, json!({})).await;
856 tokio::time::timeout(Duration::from_secs(2), done_rx)
857 .await
858 .expect("actor did not signal exit")
859 .expect("oneshot sender dropped");
860 actor.await.unwrap();
861 conn.shutdown();
862 }
863
864 #[tokio::test]
865 async fn actor_without_auth_responds_default_to_auth_required() {
866 let (mut mock, conn) = MockConnection::pair();
872 let sess = SessionHandle::new(conn.clone(), "S2");
873 let cancel = CancellationToken::new();
874 let (done_tx, done_rx) = oneshot::channel();
875 let actor_cancel = cancel.clone();
876 let actor = tokio::spawn(async move {
877 run_actor(
878 sess,
879 Vec::new(),
880 vec![RequestPattern {
881 url_pattern: Some("*".into()),
882 ..RequestPattern::default()
883 }],
884 None,
885 actor_cancel,
886 done_tx,
887 )
888 .await;
889 });
890
891 let enable_id =
892 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
893 .await
894 .expect("actor did not send Fetch.enable");
895 assert_eq!(mock.last_sent()["params"]["handleAuthRequests"], false);
896 mock.reply(enable_id, json!({})).await;
897
898 mock.emit_event_for_session(
899 "Fetch.authRequired",
900 json!({ "requestId": "AUTH-REQ-2" }),
901 "S2",
902 )
903 .await;
904
905 let auth_id = tokio::time::timeout(
906 Duration::from_secs(2),
907 mock.expect_cmd("Fetch.continueWithAuth"),
908 )
909 .await
910 .expect("actor did not respond to stray authRequired");
911 assert_eq!(
912 mock.last_sent()["params"]["authChallengeResponse"]["response"],
913 "Default"
914 );
915 mock.reply(auth_id, json!({})).await;
916
917 cancel.cancel();
918 let disable_id =
919 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
920 .await
921 .expect("actor did not send Fetch.disable");
922 mock.reply(disable_id, json!({})).await;
923 tokio::time::timeout(Duration::from_secs(2), done_rx)
924 .await
925 .expect("actor did not exit")
926 .expect("oneshot dropped");
927 actor.await.unwrap();
928 conn.shutdown();
929 }
930}