1use base64::Engine as _;
36use base64::engine::general_purpose::STANDARD as BASE64;
37use futures::StreamExt;
38use serde::Deserialize;
39use serde_json::{Map, Value, json};
40use std::collections::HashMap;
41use tokio::sync::oneshot;
42use tokio_util::sync::CancellationToken;
43use tracing::{trace, warn};
44use zendriver_transport::SessionHandle;
45
46use crate::builder::RequestPattern;
47use crate::error::InterceptionError;
48use crate::rule::Rule;
49use crate::types::{RequestInfo, RequestOverrides, ResourceType, ResponseInfo};
50
51#[derive(Debug)]
58#[must_use = "interception stops when the handle is dropped — bind it to a variable to keep it alive"]
59pub struct InterceptHandle {
60 cancel: CancellationToken,
61 done: Option<oneshot::Receiver<()>>,
64}
65
66impl InterceptHandle {
67 pub(crate) fn new(cancel: CancellationToken, done: oneshot::Receiver<()>) -> Self {
71 Self {
72 cancel,
73 done: Some(done),
74 }
75 }
76
77 #[cfg(any(test, feature = "test-support"))]
87 #[doc(hidden)]
88 pub fn for_tests() -> Self {
89 let (_done_tx, done_rx) = oneshot::channel();
90 Self {
91 cancel: CancellationToken::new(),
92 done: Some(done_rx),
93 }
94 }
95
96 pub async fn stop(mut self) -> Result<(), InterceptionError> {
105 self.cancel.cancel();
106 match self.done.take() {
107 Some(rx) => rx.await.map_err(|_| InterceptionError::SubscriptionClosed),
108 None => Ok(()),
109 }
110 }
111}
112
113impl Drop for InterceptHandle {
114 fn drop(&mut self) {
115 self.cancel.cancel();
120 }
121}
122
123#[derive(Debug, Deserialize)]
132pub(crate) struct RequestPausedEvent {
133 #[serde(rename = "requestId")]
134 pub(crate) request_id: String,
135 pub(crate) request: RequestPayload,
136 #[serde(rename = "resourceType", default)]
137 pub(crate) resource_type: Option<String>,
138 #[serde(rename = "responseStatusCode", default)]
140 pub(crate) response_status_code: Option<u16>,
141 #[serde(rename = "responseStatusText", default)]
142 pub(crate) response_status_text: Option<String>,
143 #[serde(rename = "responseHeaders", default)]
144 pub(crate) response_headers: Option<Vec<HeaderPair>>,
145}
146
147#[derive(Debug, Deserialize)]
148pub(crate) struct RequestPayload {
149 pub(crate) url: String,
150 pub(crate) method: String,
151 #[serde(default)]
152 pub(crate) headers: HashMap<String, String>,
153 #[serde(rename = "postData", default)]
157 pub(crate) post_data: Option<String>,
158 #[serde(rename = "hasPostData", default)]
159 _has_post_data: Option<bool>,
160 #[serde(rename = "postDataEntries", default)]
164 pub(crate) post_data_entries: Option<Vec<PostDataEntry>>,
165}
166
167#[derive(Debug, Deserialize)]
168pub(crate) struct PostDataEntry {
169 #[serde(default)]
171 pub(crate) bytes: Option<String>,
172}
173
174#[derive(Debug, Deserialize)]
175pub(crate) struct HeaderPair {
176 pub(crate) name: String,
177 pub(crate) value: String,
178}
179
180pub(crate) async fn run_actor(
189 session: SessionHandle,
190 rules: Vec<Rule>,
191 patterns: Vec<RequestPattern>,
192 auth: Option<(String, String)>,
193 cancel: CancellationToken,
194 done: oneshot::Sender<()>,
195) {
196 let mut paused = session.subscribe::<Value>("Fetch.requestPaused");
202 let mut auth_required = session.subscribe::<Value>("Fetch.authRequired");
206
207 let enable_session = session.clone();
213 let enable_patterns: Vec<Value> = patterns.iter().map(serialize_pattern).collect();
214 let handle_auth_requests = auth.is_some();
215 tokio::spawn(async move {
216 if let Err(e) = enable_session
217 .call(
218 "Fetch.enable",
219 json!({
220 "patterns": enable_patterns,
221 "handleAuthRequests": handle_auth_requests,
222 }),
223 )
224 .await
225 {
226 warn!(error = %e, "interception: Fetch.enable failed; interception inactive");
227 }
228 });
229
230 loop {
232 tokio::select! {
233 () = cancel.cancelled() => {
234 trace!("interception: cancellation received, disabling Fetch and exiting");
235 break;
236 }
237 Some(ev_value) = paused.next() => {
238 let ev: RequestPausedEvent = match serde_json::from_value(ev_value) {
241 Ok(ev) => ev,
242 Err(e) => {
243 warn!(error = %e, "interception: skipping malformed Fetch.requestPaused event");
244 continue;
245 }
246 };
247 if let Err(e) = handle_paused(&session, &rules, ev).await {
248 warn!(error = %e, "interception: handler dispatch failed");
249 }
250 }
251 Some(ev_value) = auth_required.next() => {
252 let Some(request_id) = ev_value
258 .get("requestId")
259 .and_then(Value::as_str)
260 .map(str::to_owned)
261 else {
262 warn!("interception: Fetch.authRequired without requestId");
263 continue;
264 };
265 let response = match &auth {
266 Some((user, pass)) => json!({
267 "response": "ProvideCredentials",
268 "username": user,
269 "password": pass,
270 }),
271 None => json!({ "response": "Default" }),
272 };
273 if let Err(e) = session
274 .call(
275 "Fetch.continueWithAuth",
276 json!({
277 "requestId": request_id,
278 "authChallengeResponse": response,
279 }),
280 )
281 .await
282 {
283 warn!(error = %e, "interception: Fetch.continueWithAuth failed");
284 }
285 }
286 else => {
287 trace!("interception: event stream closed, exiting without Fetch.disable");
289 let _ = done.send(());
292 return;
293 }
294 }
295 }
296
297 if let Err(e) = session.call("Fetch.disable", json!({})).await {
301 warn!(error = %e, "interception: Fetch.disable failed during shutdown");
302 }
303 let _ = done.send(());
306}
307
308async fn handle_paused(
312 session: &SessionHandle,
313 rules: &[Rule],
314 ev: RequestPausedEvent,
315) -> Result<(), InterceptionError> {
316 let url = ev.request.url.clone();
317
318 let matched = rules.iter().find(|r| r.matches(&url));
322
323 match matched {
324 Some(Rule::Block { .. }) => fail_request(session, &ev.request_id, "BlockedByClient").await,
325 Some(Rule::Redirect { to, .. }) => continue_with_url(session, &ev.request_id, to).await,
326 Some(Rule::Respond {
327 status,
328 headers,
329 body,
330 ..
331 }) => fulfill_request(session, &ev.request_id, *status, headers, body).await,
332 Some(Rule::Modify { modify, .. }) => {
333 let info = build_request_info(&ev);
334 let overrides = modify(&info);
335 continue_with_overrides(session, &ev.request_id, overrides).await
336 }
337 None => continue_passthrough(session, &ev.request_id).await,
338 }
339}
340
341pub(crate) fn serialize_pattern(p: &RequestPattern) -> Value {
344 let mut obj = Map::new();
345 if let Some(url) = &p.url_pattern {
346 obj.insert("urlPattern".into(), Value::String(url.clone()));
347 }
348 if let Some(rt) = p.resource_type {
349 obj.insert("resourceType".into(), Value::String(rt.as_cdp_str().into()));
350 }
351 if let Some(stage) = p.request_stage {
352 obj.insert(
353 "requestStage".into(),
354 Value::String(stage.as_cdp_str().into()),
355 );
356 }
357 Value::Object(obj)
358}
359
360pub(crate) fn build_request_info(ev: &RequestPausedEvent) -> RequestInfo {
372 RequestInfo {
373 url: ev.request.url.clone(),
374 method: ev.request.method.clone(),
375 headers: ev
376 .request
377 .headers
378 .iter()
379 .map(|(k, v)| (k.clone(), v.clone()))
380 .collect(),
381 post_data: decode_post_data(&ev.request),
382 resource_type: parse_resource_type(ev.resource_type.as_deref()),
383 }
384}
385
386fn decode_post_data(req: &RequestPayload) -> Option<Vec<u8>> {
387 use base64::Engine as _;
388 use base64::engine::general_purpose::STANDARD as BASE64;
389
390 if let Some(entries) = req.post_data_entries.as_ref() {
391 let mut buf = Vec::new();
392 for entry in entries {
393 let Some(b64) = entry.bytes.as_deref() else {
394 continue;
395 };
396 match BASE64.decode(b64) {
397 Ok(bytes) => buf.extend_from_slice(&bytes),
398 Err(e) => {
399 tracing::warn!(error = %e, "interception: bad base64 in postDataEntries; skipping entry");
400 }
401 }
402 }
403 return Some(buf);
404 }
405 req.post_data.as_deref().map(|s| s.as_bytes().to_vec())
406}
407
408pub(crate) fn build_response_info(ev: &RequestPausedEvent) -> Option<ResponseInfo> {
415 let status = ev.response_status_code?;
416 let status_text = ev.response_status_text.clone().unwrap_or_default();
417 let headers: Vec<(String, String)> = ev
418 .response_headers
419 .as_ref()
420 .map(|hs| {
421 hs.iter()
422 .map(|h| (h.name.clone(), h.value.clone()))
423 .collect()
424 })
425 .unwrap_or_default();
426 Some(ResponseInfo {
427 status,
428 status_text,
429 headers,
430 })
431}
432
433pub(crate) fn headers_to_cdp(headers: &[(String, String)]) -> Vec<Value> {
437 headers
438 .iter()
439 .map(|(name, value)| json!({ "name": name, "value": value }))
440 .collect()
441}
442
443fn parse_resource_type(s: Option<&str>) -> ResourceType {
449 match s.unwrap_or("Other") {
450 "Document" => ResourceType::Document,
451 "Stylesheet" => ResourceType::Stylesheet,
452 "Image" => ResourceType::Image,
453 "Media" => ResourceType::Media,
454 "Font" => ResourceType::Font,
455 "Script" => ResourceType::Script,
456 "TextTrack" => ResourceType::TextTrack,
457 "XHR" => ResourceType::XHR,
458 "Fetch" => ResourceType::Fetch,
459 "EventSource" => ResourceType::EventSource,
460 "WebSocket" => ResourceType::WebSocket,
461 "Manifest" => ResourceType::Manifest,
462 "SignedExchange" => ResourceType::SignedExchange,
463 "Ping" => ResourceType::Ping,
464 "CSPViolationReport" => ResourceType::CSPViolationReport,
465 "Preflight" => ResourceType::Preflight,
466 _ => ResourceType::Other,
467 }
468}
469
470async fn fail_request(
473 session: &SessionHandle,
474 request_id: &str,
475 error_reason: &str,
476) -> Result<(), InterceptionError> {
477 session
478 .call(
479 "Fetch.failRequest",
480 json!({
481 "requestId": request_id,
482 "errorReason": error_reason,
483 }),
484 )
485 .await?;
486 Ok(())
487}
488
489async fn continue_passthrough(
490 session: &SessionHandle,
491 request_id: &str,
492) -> Result<(), InterceptionError> {
493 session
494 .call("Fetch.continueRequest", json!({ "requestId": request_id }))
495 .await?;
496 Ok(())
497}
498
499async fn continue_with_url(
500 session: &SessionHandle,
501 request_id: &str,
502 url: &str,
503) -> Result<(), InterceptionError> {
504 session
505 .call(
506 "Fetch.continueRequest",
507 json!({
508 "requestId": request_id,
509 "url": url,
510 }),
511 )
512 .await?;
513 Ok(())
514}
515
516async fn continue_with_overrides(
517 session: &SessionHandle,
518 request_id: &str,
519 overrides: RequestOverrides,
520) -> Result<(), InterceptionError> {
521 let mut params = Map::new();
522 params.insert("requestId".into(), Value::String(request_id.into()));
523 if let Some(url) = overrides.url {
524 params.insert("url".into(), Value::String(url));
525 }
526 if let Some(method) = overrides.method {
527 params.insert("method".into(), Value::String(method));
528 }
529 if let Some(headers) = overrides.headers {
530 params.insert("headers".into(), Value::Array(headers_to_cdp(&headers)));
531 }
532 if let Some(post_data) = overrides.post_data {
533 params.insert("postData".into(), Value::String(BASE64.encode(&post_data)));
534 }
535 session
536 .call("Fetch.continueRequest", Value::Object(params))
537 .await?;
538 Ok(())
539}
540
541async fn fulfill_request(
542 session: &SessionHandle,
543 request_id: &str,
544 status: u16,
545 headers: &[(String, String)],
546 body: &[u8],
547) -> Result<(), InterceptionError> {
548 let response_headers = headers_to_cdp(headers);
549 session
550 .call(
551 "Fetch.fulfillRequest",
552 json!({
553 "requestId": request_id,
554 "responseCode": status,
555 "responseHeaders": response_headers,
556 "body": BASE64.encode(body),
557 }),
558 )
559 .await?;
560 Ok(())
561}
562
563#[cfg(test)]
564#[allow(clippy::panic, clippy::unwrap_used)]
565mod tests {
566 use super::*;
567 use crate::url_pattern::UrlPattern;
568 use std::time::Duration;
569 use zendriver_transport::testing::MockConnection;
570
571 #[tokio::test]
579 async fn block_rule_dispatches_fail_request_with_blocked_by_client() {
580 let (mut mock, conn) = MockConnection::pair();
581 let sess = SessionHandle::new(conn.clone(), "S1");
582
583 let rules = vec![Rule::Block {
584 pattern: UrlPattern::new("*/blocked/*").unwrap(),
585 }];
586 let patterns = vec![RequestPattern {
587 url_pattern: Some("*".into()),
588 ..RequestPattern::default()
589 }];
590 let cancel = CancellationToken::new();
591 let (done_tx, done_rx) = oneshot::channel();
592 let actor_cancel = cancel.clone();
593 let actor = tokio::spawn(async move {
594 run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
595 });
596
597 let enable_id =
603 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
604 .await
605 .expect("actor did not send Fetch.enable within 2s");
606 let enable_params = mock.last_sent()["params"].clone();
607 assert_eq!(enable_params["handleAuthRequests"], false);
608 assert_eq!(enable_params["patterns"][0]["urlPattern"], "*");
609 mock.reply(enable_id, json!({})).await;
612
613 mock.emit_event_for_session(
616 "Fetch.requestPaused",
617 json!({
618 "requestId": "REQ-1",
619 "request": {
620 "url": "https://example.test/blocked/banner.png",
621 "method": "GET",
622 "headers": {},
623 },
624 "resourceType": "Image",
625 }),
626 "S1",
627 )
628 .await;
629
630 let fail_id =
632 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
633 .await
634 .expect("actor did not send Fetch.failRequest within 2s");
635 let fail_params = mock.last_sent()["params"].clone();
636 assert_eq!(fail_params["requestId"], "REQ-1");
637 assert_eq!(fail_params["errorReason"], "BlockedByClient");
638 mock.reply(fail_id, json!({})).await;
639
640 cancel.cancel();
643 let disable_id =
644 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
645 .await
646 .expect("actor did not send Fetch.disable on cancel");
647 mock.reply(disable_id, json!({})).await;
648
649 tokio::time::timeout(Duration::from_secs(2), done_rx)
650 .await
651 .expect("actor did not signal exit within 2s")
652 .expect("oneshot sender dropped without sending");
653 actor.await.unwrap();
654 conn.shutdown();
655 }
656
657 #[tokio::test]
658 async fn actor_handles_auth_required_with_credentials() {
659 let (mut mock, conn) = MockConnection::pair();
666 let sess = SessionHandle::new(conn.clone(), "S1");
667 let cancel = CancellationToken::new();
668 let (done_tx, done_rx) = oneshot::channel();
669 let actor_cancel = cancel.clone();
670 let auth = Some(("user1".to_string(), "pass1".to_string()));
671 let actor = tokio::spawn(async move {
672 run_actor(
673 sess,
674 Vec::new(),
675 vec![RequestPattern {
676 url_pattern: Some("*".into()),
677 ..RequestPattern::default()
678 }],
679 auth,
680 actor_cancel,
681 done_tx,
682 )
683 .await;
684 });
685
686 let enable_id =
687 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
688 .await
689 .expect("actor did not send Fetch.enable within 2s");
690 assert_eq!(
691 mock.last_sent()["params"]["handleAuthRequests"],
692 true,
693 "auth-enabled actor must flip handleAuthRequests"
694 );
695 mock.reply(enable_id, json!({})).await;
696
697 mock.emit_event_for_session(
698 "Fetch.authRequired",
699 json!({
700 "requestId": "AUTH-REQ-1",
701 "request": { "url": "https://example.test/", "method": "GET" },
702 "frameId": "F1",
703 "resourceType": "Document",
704 "authChallenge": {
705 "source": "Proxy",
706 "origin": "http://proxy.test",
707 "scheme": "basic",
708 "realm": "",
709 },
710 }),
711 "S1",
712 )
713 .await;
714
715 let auth_id = tokio::time::timeout(
716 Duration::from_secs(2),
717 mock.expect_cmd("Fetch.continueWithAuth"),
718 )
719 .await
720 .expect("actor did not send Fetch.continueWithAuth within 2s");
721 let params = mock.last_sent()["params"].clone();
722 assert_eq!(params["requestId"], "AUTH-REQ-1");
723 assert_eq!(
724 params["authChallengeResponse"]["response"],
725 "ProvideCredentials"
726 );
727 assert_eq!(params["authChallengeResponse"]["username"], "user1");
728 assert_eq!(params["authChallengeResponse"]["password"], "pass1");
729 mock.reply(auth_id, json!({})).await;
730
731 cancel.cancel();
732 let disable_id =
733 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
734 .await
735 .expect("actor did not send Fetch.disable on cancel");
736 mock.reply(disable_id, json!({})).await;
737 tokio::time::timeout(Duration::from_secs(2), done_rx)
738 .await
739 .expect("actor did not signal exit")
740 .expect("oneshot sender dropped");
741 actor.await.unwrap();
742 conn.shutdown();
743 }
744
745 #[tokio::test]
746 async fn actor_without_auth_responds_default_to_auth_required() {
747 let (mut mock, conn) = MockConnection::pair();
753 let sess = SessionHandle::new(conn.clone(), "S2");
754 let cancel = CancellationToken::new();
755 let (done_tx, done_rx) = oneshot::channel();
756 let actor_cancel = cancel.clone();
757 let actor = tokio::spawn(async move {
758 run_actor(
759 sess,
760 Vec::new(),
761 vec![RequestPattern {
762 url_pattern: Some("*".into()),
763 ..RequestPattern::default()
764 }],
765 None,
766 actor_cancel,
767 done_tx,
768 )
769 .await;
770 });
771
772 let enable_id =
773 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
774 .await
775 .expect("actor did not send Fetch.enable");
776 assert_eq!(mock.last_sent()["params"]["handleAuthRequests"], false);
777 mock.reply(enable_id, json!({})).await;
778
779 mock.emit_event_for_session(
780 "Fetch.authRequired",
781 json!({ "requestId": "AUTH-REQ-2" }),
782 "S2",
783 )
784 .await;
785
786 let auth_id = tokio::time::timeout(
787 Duration::from_secs(2),
788 mock.expect_cmd("Fetch.continueWithAuth"),
789 )
790 .await
791 .expect("actor did not respond to stray authRequired");
792 assert_eq!(
793 mock.last_sent()["params"]["authChallengeResponse"]["response"],
794 "Default"
795 );
796 mock.reply(auth_id, json!({})).await;
797
798 cancel.cancel();
799 let disable_id =
800 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
801 .await
802 .expect("actor did not send Fetch.disable");
803 mock.reply(disable_id, json!({})).await;
804 tokio::time::timeout(Duration::from_secs(2), done_rx)
805 .await
806 .expect("actor did not exit")
807 .expect("oneshot dropped");
808 actor.await.unwrap();
809 conn.shutdown();
810 }
811}