1use base64::Engine as _;
36use base64::engine::general_purpose::STANDARD as BASE64;
37use futures::StreamExt;
38use serde::Deserialize;
39use serde_json::{Map, Value, json};
40use std::collections::HashMap;
41use tokio::sync::oneshot;
42use tokio_util::sync::CancellationToken;
43use tracing::{trace, warn};
44use zendriver_transport::SessionHandle;
45
46use crate::builder::RequestPattern;
47use crate::error::InterceptionError;
48use crate::rule::Rule;
49use crate::types::{RequestInfo, RequestOverrides, ResourceType, ResponseInfo, ResponseOverrides};
50
51#[derive(Debug)]
58#[must_use = "interception stops when the handle is dropped — bind it to a variable to keep it alive"]
59pub struct InterceptHandle {
60 cancel: CancellationToken,
61 done: Option<oneshot::Receiver<()>>,
64}
65
66impl InterceptHandle {
67 pub(crate) fn new(cancel: CancellationToken, done: oneshot::Receiver<()>) -> Self {
71 Self {
72 cancel,
73 done: Some(done),
74 }
75 }
76
77 #[cfg(any(test, feature = "test-support"))]
87 #[doc(hidden)]
88 pub fn for_tests() -> Self {
89 let (_done_tx, done_rx) = oneshot::channel();
90 Self {
91 cancel: CancellationToken::new(),
92 done: Some(done_rx),
93 }
94 }
95
96 pub async fn stop(mut self) -> Result<(), InterceptionError> {
105 self.cancel.cancel();
106 match self.done.take() {
107 Some(rx) => rx.await.map_err(|_| InterceptionError::SubscriptionClosed),
108 None => Ok(()),
109 }
110 }
111}
112
113impl Drop for InterceptHandle {
114 fn drop(&mut self) {
115 self.cancel.cancel();
120 }
121}
122
123#[derive(Debug, Deserialize)]
132pub(crate) struct RequestPausedEvent {
133 #[serde(rename = "requestId")]
134 pub(crate) request_id: String,
135 pub(crate) request: RequestPayload,
136 #[serde(rename = "resourceType", default)]
137 pub(crate) resource_type: Option<String>,
138 #[serde(rename = "responseStatusCode", default)]
140 pub(crate) response_status_code: Option<u16>,
141 #[serde(rename = "responseStatusText", default)]
142 pub(crate) response_status_text: Option<String>,
143 #[serde(rename = "responseHeaders", default)]
144 pub(crate) response_headers: Option<Vec<HeaderPair>>,
145}
146
147#[derive(Debug, Deserialize)]
148pub(crate) struct RequestPayload {
149 pub(crate) url: String,
150 pub(crate) method: String,
151 #[serde(default)]
152 pub(crate) headers: HashMap<String, String>,
153 #[serde(rename = "postData", default)]
157 pub(crate) post_data: Option<String>,
158 #[serde(rename = "hasPostData", default)]
159 _has_post_data: Option<bool>,
160 #[serde(rename = "postDataEntries", default)]
164 pub(crate) post_data_entries: Option<Vec<PostDataEntry>>,
165}
166
167#[derive(Debug, Deserialize)]
168pub(crate) struct PostDataEntry {
169 #[serde(default)]
171 pub(crate) bytes: Option<String>,
172}
173
174#[derive(Debug, Deserialize)]
175pub(crate) struct HeaderPair {
176 pub(crate) name: String,
177 pub(crate) value: String,
178}
179
180pub(crate) async fn run_actor(
189 session: SessionHandle,
190 rules: Vec<Rule>,
191 patterns: Vec<RequestPattern>,
192 auth: Option<(String, String)>,
193 cancel: CancellationToken,
194 done: oneshot::Sender<()>,
195) {
196 let mut paused = session.subscribe::<Value>("Fetch.requestPaused");
202 let mut auth_required = session.subscribe::<Value>("Fetch.authRequired");
206
207 let enable_session = session.clone();
213 let enable_patterns: Vec<Value> = patterns.iter().map(serialize_pattern).collect();
214 let handle_auth_requests = auth.is_some();
215 tokio::spawn(async move {
216 if let Err(e) = enable_session
217 .call(
218 "Fetch.enable",
219 json!({
220 "patterns": enable_patterns,
221 "handleAuthRequests": handle_auth_requests,
222 }),
223 )
224 .await
225 {
226 warn!(error = %e, "interception: Fetch.enable failed; interception inactive");
227 }
228 });
229
230 loop {
232 tokio::select! {
233 () = cancel.cancelled() => {
234 trace!("interception: cancellation received, disabling Fetch and exiting");
235 break;
236 }
237 Some(ev_value) = paused.next() => {
238 let ev: RequestPausedEvent = match serde_json::from_value(ev_value) {
241 Ok(ev) => ev,
242 Err(e) => {
243 warn!(error = %e, "interception: skipping malformed Fetch.requestPaused event");
244 continue;
245 }
246 };
247 if let Err(e) = handle_paused(&session, &rules, ev).await {
248 warn!(error = %e, "interception: handler dispatch failed");
249 }
250 }
251 Some(ev_value) = auth_required.next() => {
252 let Some(request_id) = ev_value
258 .get("requestId")
259 .and_then(Value::as_str)
260 .map(str::to_owned)
261 else {
262 warn!("interception: Fetch.authRequired without requestId");
263 continue;
264 };
265 let response = match &auth {
266 Some((user, pass)) => json!({
267 "response": "ProvideCredentials",
268 "username": user,
269 "password": pass,
270 }),
271 None => json!({ "response": "Default" }),
272 };
273 if let Err(e) = session
274 .call(
275 "Fetch.continueWithAuth",
276 json!({
277 "requestId": request_id,
278 "authChallengeResponse": response,
279 }),
280 )
281 .await
282 {
283 warn!(error = %e, "interception: Fetch.continueWithAuth failed");
284 }
285 }
286 else => {
287 trace!("interception: event stream closed, exiting without Fetch.disable");
289 let _ = done.send(());
292 return;
293 }
294 }
295 }
296
297 if let Err(e) = session.call("Fetch.disable", json!({})).await {
301 warn!(error = %e, "interception: Fetch.disable failed during shutdown");
302 }
303 let _ = done.send(());
306}
307
308async fn handle_paused(
312 session: &SessionHandle,
313 rules: &[Rule],
314 ev: RequestPausedEvent,
315) -> Result<(), InterceptionError> {
316 let url = ev.request.url.clone();
317
318 let matched = rules.iter().find(|r| r.matches(&url));
322
323 match matched {
324 Some(Rule::Block { .. }) => fail_request(session, &ev.request_id, "BlockedByClient").await,
325 Some(Rule::Redirect { to, .. }) => continue_with_url(session, &ev.request_id, to).await,
326 Some(Rule::Respond {
327 status,
328 headers,
329 body,
330 ..
331 }) => fulfill_request(session, &ev.request_id, *status, headers, body).await,
332 Some(Rule::Modify { modify, .. }) => {
333 let info = build_request_info(&ev);
334 let overrides = modify(&info);
335 continue_with_overrides(session, &ev.request_id, overrides).await
336 }
337 Some(Rule::ModifyResponse { modify, .. }) => match build_response_info(&ev) {
338 Some(info) => {
339 let overrides = modify(&info);
340 continue_response_with_overrides(session, &ev.request_id, overrides).await
341 }
342 None => {
343 tracing::debug!(
348 request_id = %ev.request_id,
349 url = %url,
350 "interception: ModifyResponse matched at Request stage; no response yet, passing through"
351 );
352 continue_passthrough(session, &ev.request_id).await
353 }
354 },
355 None => continue_passthrough(session, &ev.request_id).await,
356 }
357}
358
359pub(crate) fn serialize_pattern(p: &RequestPattern) -> Value {
362 let mut obj = Map::new();
363 if let Some(url) = &p.url_pattern {
364 obj.insert("urlPattern".into(), Value::String(url.clone()));
365 }
366 if let Some(rt) = p.resource_type {
367 obj.insert("resourceType".into(), Value::String(rt.as_cdp_str().into()));
368 }
369 if let Some(stage) = p.request_stage {
370 obj.insert(
371 "requestStage".into(),
372 Value::String(stage.as_cdp_str().into()),
373 );
374 }
375 Value::Object(obj)
376}
377
378pub(crate) fn build_request_info(ev: &RequestPausedEvent) -> RequestInfo {
390 RequestInfo {
391 url: ev.request.url.clone(),
392 method: ev.request.method.clone(),
393 headers: ev
394 .request
395 .headers
396 .iter()
397 .map(|(k, v)| (k.clone(), v.clone()))
398 .collect(),
399 post_data: decode_post_data(&ev.request),
400 resource_type: parse_resource_type(ev.resource_type.as_deref()),
401 }
402}
403
404fn decode_post_data(req: &RequestPayload) -> Option<Vec<u8>> {
405 use base64::Engine as _;
406 use base64::engine::general_purpose::STANDARD as BASE64;
407
408 if let Some(entries) = req.post_data_entries.as_ref() {
409 let mut buf = Vec::new();
410 for entry in entries {
411 let Some(b64) = entry.bytes.as_deref() else {
412 continue;
413 };
414 match BASE64.decode(b64) {
415 Ok(bytes) => buf.extend_from_slice(&bytes),
416 Err(e) => {
417 tracing::warn!(error = %e, "interception: bad base64 in postDataEntries; skipping entry");
418 }
419 }
420 }
421 return Some(buf);
422 }
423 req.post_data.as_deref().map(|s| s.as_bytes().to_vec())
424}
425
426pub(crate) fn build_response_info(ev: &RequestPausedEvent) -> Option<ResponseInfo> {
433 let status = ev.response_status_code?;
434 let status_text = ev.response_status_text.clone().unwrap_or_default();
435 let headers: Vec<(String, String)> = ev
436 .response_headers
437 .as_ref()
438 .map(|hs| {
439 hs.iter()
440 .map(|h| (h.name.clone(), h.value.clone()))
441 .collect()
442 })
443 .unwrap_or_default();
444 Some(ResponseInfo {
445 status,
446 status_text,
447 headers,
448 })
449}
450
451pub(crate) fn headers_to_cdp(headers: &[(String, String)]) -> Vec<Value> {
455 headers
456 .iter()
457 .map(|(name, value)| json!({ "name": name, "value": value }))
458 .collect()
459}
460
461fn parse_resource_type(s: Option<&str>) -> ResourceType {
467 match s.unwrap_or("Other") {
468 "Document" => ResourceType::Document,
469 "Stylesheet" => ResourceType::Stylesheet,
470 "Image" => ResourceType::Image,
471 "Media" => ResourceType::Media,
472 "Font" => ResourceType::Font,
473 "Script" => ResourceType::Script,
474 "TextTrack" => ResourceType::TextTrack,
475 "XHR" => ResourceType::XHR,
476 "Fetch" => ResourceType::Fetch,
477 "EventSource" => ResourceType::EventSource,
478 "WebSocket" => ResourceType::WebSocket,
479 "Manifest" => ResourceType::Manifest,
480 "SignedExchange" => ResourceType::SignedExchange,
481 "Ping" => ResourceType::Ping,
482 "CSPViolationReport" => ResourceType::CSPViolationReport,
483 "Preflight" => ResourceType::Preflight,
484 _ => ResourceType::Other,
485 }
486}
487
488async fn fail_request(
491 session: &SessionHandle,
492 request_id: &str,
493 error_reason: &str,
494) -> Result<(), InterceptionError> {
495 session
496 .call(
497 "Fetch.failRequest",
498 json!({
499 "requestId": request_id,
500 "errorReason": error_reason,
501 }),
502 )
503 .await?;
504 Ok(())
505}
506
507async fn continue_passthrough(
508 session: &SessionHandle,
509 request_id: &str,
510) -> Result<(), InterceptionError> {
511 session
512 .call("Fetch.continueRequest", json!({ "requestId": request_id }))
513 .await?;
514 Ok(())
515}
516
517async fn continue_with_url(
518 session: &SessionHandle,
519 request_id: &str,
520 url: &str,
521) -> Result<(), InterceptionError> {
522 session
523 .call(
524 "Fetch.continueRequest",
525 json!({
526 "requestId": request_id,
527 "url": url,
528 }),
529 )
530 .await?;
531 Ok(())
532}
533
534async fn continue_with_overrides(
535 session: &SessionHandle,
536 request_id: &str,
537 overrides: RequestOverrides,
538) -> Result<(), InterceptionError> {
539 let mut params = Map::new();
540 params.insert("requestId".into(), Value::String(request_id.into()));
541 if let Some(url) = overrides.url {
542 params.insert("url".into(), Value::String(url));
543 }
544 if let Some(method) = overrides.method {
545 params.insert("method".into(), Value::String(method));
546 }
547 if let Some(headers) = overrides.headers {
548 params.insert("headers".into(), Value::Array(headers_to_cdp(&headers)));
549 }
550 if let Some(post_data) = overrides.post_data {
551 params.insert("postData".into(), Value::String(BASE64.encode(&post_data)));
552 }
553 session
554 .call("Fetch.continueRequest", Value::Object(params))
555 .await?;
556 Ok(())
557}
558
559async fn fulfill_request(
560 session: &SessionHandle,
561 request_id: &str,
562 status: u16,
563 headers: &[(String, String)],
564 body: &[u8],
565) -> Result<(), InterceptionError> {
566 let response_headers = headers_to_cdp(headers);
567 session
568 .call(
569 "Fetch.fulfillRequest",
570 json!({
571 "requestId": request_id,
572 "responseCode": status,
573 "responseHeaders": response_headers,
574 "body": BASE64.encode(body),
575 }),
576 )
577 .await?;
578 Ok(())
579}
580
581async fn continue_response_with_overrides(
588 session: &SessionHandle,
589 request_id: &str,
590 overrides: ResponseOverrides,
591) -> Result<(), InterceptionError> {
592 let mut params = Map::new();
593 params.insert("requestId".into(), Value::String(request_id.into()));
594 if let Some(status) = overrides.status {
595 params.insert("responseCode".into(), Value::from(status));
596 }
597 if let Some(phrase) = overrides.phrase {
598 params.insert("responsePhrase".into(), Value::String(phrase));
599 }
600 if let Some(headers) = overrides.headers {
601 params.insert(
602 "responseHeaders".into(),
603 Value::Array(headers_to_cdp(&headers)),
604 );
605 }
606 session
607 .call("Fetch.continueResponse", Value::Object(params))
608 .await?;
609 Ok(())
610}
611
612#[cfg(test)]
613#[allow(clippy::panic, clippy::unwrap_used)]
614mod tests {
615 use super::*;
616 use crate::url_pattern::UrlPattern;
617 use std::time::Duration;
618 use zendriver_transport::testing::MockConnection;
619
620 #[tokio::test]
628 async fn block_rule_dispatches_fail_request_with_blocked_by_client() {
629 let (mut mock, conn) = MockConnection::pair();
630 let sess = SessionHandle::new(conn.clone(), "S1");
631
632 let rules = vec![Rule::Block {
633 pattern: UrlPattern::new("*/blocked/*").unwrap(),
634 }];
635 let patterns = vec![RequestPattern {
636 url_pattern: Some("*".into()),
637 ..RequestPattern::default()
638 }];
639 let cancel = CancellationToken::new();
640 let (done_tx, done_rx) = oneshot::channel();
641 let actor_cancel = cancel.clone();
642 let actor = tokio::spawn(async move {
643 run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
644 });
645
646 let enable_id =
652 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
653 .await
654 .expect("actor did not send Fetch.enable within 2s");
655 let enable_params = mock.last_sent()["params"].clone();
656 assert_eq!(enable_params["handleAuthRequests"], false);
657 assert_eq!(enable_params["patterns"][0]["urlPattern"], "*");
658 mock.reply(enable_id, json!({})).await;
661
662 mock.emit_event_for_session(
665 "Fetch.requestPaused",
666 json!({
667 "requestId": "REQ-1",
668 "request": {
669 "url": "https://example.test/blocked/banner.png",
670 "method": "GET",
671 "headers": {},
672 },
673 "resourceType": "Image",
674 }),
675 "S1",
676 )
677 .await;
678
679 let fail_id =
681 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
682 .await
683 .expect("actor did not send Fetch.failRequest within 2s");
684 let fail_params = mock.last_sent()["params"].clone();
685 assert_eq!(fail_params["requestId"], "REQ-1");
686 assert_eq!(fail_params["errorReason"], "BlockedByClient");
687 mock.reply(fail_id, json!({})).await;
688
689 cancel.cancel();
692 let disable_id =
693 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
694 .await
695 .expect("actor did not send Fetch.disable on cancel");
696 mock.reply(disable_id, json!({})).await;
697
698 tokio::time::timeout(Duration::from_secs(2), done_rx)
699 .await
700 .expect("actor did not signal exit within 2s")
701 .expect("oneshot sender dropped without sending");
702 actor.await.unwrap();
703 conn.shutdown();
704 }
705
706 #[tokio::test]
707 async fn actor_handles_auth_required_with_credentials() {
708 let (mut mock, conn) = MockConnection::pair();
715 let sess = SessionHandle::new(conn.clone(), "S1");
716 let cancel = CancellationToken::new();
717 let (done_tx, done_rx) = oneshot::channel();
718 let actor_cancel = cancel.clone();
719 let auth = Some(("user1".to_string(), "pass1".to_string()));
720 let actor = tokio::spawn(async move {
721 run_actor(
722 sess,
723 Vec::new(),
724 vec![RequestPattern {
725 url_pattern: Some("*".into()),
726 ..RequestPattern::default()
727 }],
728 auth,
729 actor_cancel,
730 done_tx,
731 )
732 .await;
733 });
734
735 let enable_id =
736 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
737 .await
738 .expect("actor did not send Fetch.enable within 2s");
739 assert_eq!(
740 mock.last_sent()["params"]["handleAuthRequests"],
741 true,
742 "auth-enabled actor must flip handleAuthRequests"
743 );
744 mock.reply(enable_id, json!({})).await;
745
746 mock.emit_event_for_session(
747 "Fetch.authRequired",
748 json!({
749 "requestId": "AUTH-REQ-1",
750 "request": { "url": "https://example.test/", "method": "GET" },
751 "frameId": "F1",
752 "resourceType": "Document",
753 "authChallenge": {
754 "source": "Proxy",
755 "origin": "http://proxy.test",
756 "scheme": "basic",
757 "realm": "",
758 },
759 }),
760 "S1",
761 )
762 .await;
763
764 let auth_id = tokio::time::timeout(
765 Duration::from_secs(2),
766 mock.expect_cmd("Fetch.continueWithAuth"),
767 )
768 .await
769 .expect("actor did not send Fetch.continueWithAuth within 2s");
770 let params = mock.last_sent()["params"].clone();
771 assert_eq!(params["requestId"], "AUTH-REQ-1");
772 assert_eq!(
773 params["authChallengeResponse"]["response"],
774 "ProvideCredentials"
775 );
776 assert_eq!(params["authChallengeResponse"]["username"], "user1");
777 assert_eq!(params["authChallengeResponse"]["password"], "pass1");
778 mock.reply(auth_id, json!({})).await;
779
780 cancel.cancel();
781 let disable_id =
782 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
783 .await
784 .expect("actor did not send Fetch.disable on cancel");
785 mock.reply(disable_id, json!({})).await;
786 tokio::time::timeout(Duration::from_secs(2), done_rx)
787 .await
788 .expect("actor did not signal exit")
789 .expect("oneshot sender dropped");
790 actor.await.unwrap();
791 conn.shutdown();
792 }
793
794 #[tokio::test]
795 async fn actor_without_auth_responds_default_to_auth_required() {
796 let (mut mock, conn) = MockConnection::pair();
802 let sess = SessionHandle::new(conn.clone(), "S2");
803 let cancel = CancellationToken::new();
804 let (done_tx, done_rx) = oneshot::channel();
805 let actor_cancel = cancel.clone();
806 let actor = tokio::spawn(async move {
807 run_actor(
808 sess,
809 Vec::new(),
810 vec![RequestPattern {
811 url_pattern: Some("*".into()),
812 ..RequestPattern::default()
813 }],
814 None,
815 actor_cancel,
816 done_tx,
817 )
818 .await;
819 });
820
821 let enable_id =
822 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
823 .await
824 .expect("actor did not send Fetch.enable");
825 assert_eq!(mock.last_sent()["params"]["handleAuthRequests"], false);
826 mock.reply(enable_id, json!({})).await;
827
828 mock.emit_event_for_session(
829 "Fetch.authRequired",
830 json!({ "requestId": "AUTH-REQ-2" }),
831 "S2",
832 )
833 .await;
834
835 let auth_id = tokio::time::timeout(
836 Duration::from_secs(2),
837 mock.expect_cmd("Fetch.continueWithAuth"),
838 )
839 .await
840 .expect("actor did not respond to stray authRequired");
841 assert_eq!(
842 mock.last_sent()["params"]["authChallengeResponse"]["response"],
843 "Default"
844 );
845 mock.reply(auth_id, json!({})).await;
846
847 cancel.cancel();
848 let disable_id =
849 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
850 .await
851 .expect("actor did not send Fetch.disable");
852 mock.reply(disable_id, json!({})).await;
853 tokio::time::timeout(Duration::from_secs(2), done_rx)
854 .await
855 .expect("actor did not exit")
856 .expect("oneshot dropped");
857 actor.await.unwrap();
858 conn.shutdown();
859 }
860}