1use base64::Engine as _;
36use base64::engine::general_purpose::STANDARD as BASE64;
37use futures::StreamExt;
38use serde::Deserialize;
39use serde_json::{Map, Value, json};
40use std::collections::HashMap;
41use tokio::sync::oneshot;
42use tokio_util::sync::CancellationToken;
43use tracing::{trace, warn};
44use zendriver_transport::SessionHandle;
45
46use crate::builder::RequestPattern;
47use crate::error::InterceptionError;
48use crate::rule::Rule;
49use crate::types::{RequestInfo, RequestOverrides, ResourceType, ResponseInfo};
50
51#[derive(Debug)]
58#[must_use = "interception stops when the handle is dropped — bind it to a variable to keep it alive"]
59pub struct InterceptHandle {
60 cancel: CancellationToken,
61 done: Option<oneshot::Receiver<()>>,
64}
65
66impl InterceptHandle {
67 pub(crate) fn new(cancel: CancellationToken, done: oneshot::Receiver<()>) -> Self {
71 Self {
72 cancel,
73 done: Some(done),
74 }
75 }
76
77 pub async fn stop(mut self) -> Result<(), InterceptionError> {
86 self.cancel.cancel();
87 match self.done.take() {
88 Some(rx) => rx.await.map_err(|_| InterceptionError::SubscriptionClosed),
89 None => Ok(()),
90 }
91 }
92}
93
94impl Drop for InterceptHandle {
95 fn drop(&mut self) {
96 self.cancel.cancel();
101 }
102}
103
104#[derive(Debug, Deserialize)]
113pub(crate) struct RequestPausedEvent {
114 #[serde(rename = "requestId")]
115 pub(crate) request_id: String,
116 pub(crate) request: RequestPayload,
117 #[serde(rename = "resourceType", default)]
118 pub(crate) resource_type: Option<String>,
119 #[serde(rename = "responseStatusCode", default)]
121 pub(crate) response_status_code: Option<u16>,
122 #[serde(rename = "responseStatusText", default)]
123 pub(crate) response_status_text: Option<String>,
124 #[serde(rename = "responseHeaders", default)]
125 pub(crate) response_headers: Option<Vec<HeaderPair>>,
126}
127
128#[derive(Debug, Deserialize)]
129pub(crate) struct RequestPayload {
130 pub(crate) url: String,
131 pub(crate) method: String,
132 #[serde(default)]
133 pub(crate) headers: HashMap<String, String>,
134 #[serde(rename = "postData", default)]
138 pub(crate) post_data: Option<String>,
139 #[serde(rename = "hasPostData", default)]
140 _has_post_data: Option<bool>,
141 #[serde(rename = "postDataEntries", default)]
145 pub(crate) post_data_entries: Option<Vec<PostDataEntry>>,
146}
147
148#[derive(Debug, Deserialize)]
149pub(crate) struct PostDataEntry {
150 #[serde(default)]
152 pub(crate) bytes: Option<String>,
153}
154
155#[derive(Debug, Deserialize)]
156pub(crate) struct HeaderPair {
157 pub(crate) name: String,
158 pub(crate) value: String,
159}
160
161pub(crate) async fn run_actor(
170 session: SessionHandle,
171 rules: Vec<Rule>,
172 patterns: Vec<RequestPattern>,
173 auth: Option<(String, String)>,
174 cancel: CancellationToken,
175 done: oneshot::Sender<()>,
176) {
177 let mut paused = session.subscribe::<Value>("Fetch.requestPaused");
183 let mut auth_required = session.subscribe::<Value>("Fetch.authRequired");
187
188 let enable_session = session.clone();
194 let enable_patterns: Vec<Value> = patterns.iter().map(serialize_pattern).collect();
195 let handle_auth_requests = auth.is_some();
196 tokio::spawn(async move {
197 if let Err(e) = enable_session
198 .call(
199 "Fetch.enable",
200 json!({
201 "patterns": enable_patterns,
202 "handleAuthRequests": handle_auth_requests,
203 }),
204 )
205 .await
206 {
207 warn!(error = %e, "interception: Fetch.enable failed; interception inactive");
208 }
209 });
210
211 loop {
213 tokio::select! {
214 () = cancel.cancelled() => {
215 trace!("interception: cancellation received, disabling Fetch and exiting");
216 break;
217 }
218 Some(ev_value) = paused.next() => {
219 let ev: RequestPausedEvent = match serde_json::from_value(ev_value) {
222 Ok(ev) => ev,
223 Err(e) => {
224 warn!(error = %e, "interception: skipping malformed Fetch.requestPaused event");
225 continue;
226 }
227 };
228 if let Err(e) = handle_paused(&session, &rules, ev).await {
229 warn!(error = %e, "interception: handler dispatch failed");
230 }
231 }
232 Some(ev_value) = auth_required.next() => {
233 let Some(request_id) = ev_value
239 .get("requestId")
240 .and_then(Value::as_str)
241 .map(str::to_owned)
242 else {
243 warn!("interception: Fetch.authRequired without requestId");
244 continue;
245 };
246 let response = match &auth {
247 Some((user, pass)) => json!({
248 "response": "ProvideCredentials",
249 "username": user,
250 "password": pass,
251 }),
252 None => json!({ "response": "Default" }),
253 };
254 if let Err(e) = session
255 .call(
256 "Fetch.continueWithAuth",
257 json!({
258 "requestId": request_id,
259 "authChallengeResponse": response,
260 }),
261 )
262 .await
263 {
264 warn!(error = %e, "interception: Fetch.continueWithAuth failed");
265 }
266 }
267 else => {
268 trace!("interception: event stream closed, exiting without Fetch.disable");
270 let _ = done.send(());
273 return;
274 }
275 }
276 }
277
278 if let Err(e) = session.call("Fetch.disable", json!({})).await {
282 warn!(error = %e, "interception: Fetch.disable failed during shutdown");
283 }
284 let _ = done.send(());
287}
288
289async fn handle_paused(
293 session: &SessionHandle,
294 rules: &[Rule],
295 ev: RequestPausedEvent,
296) -> Result<(), InterceptionError> {
297 let url = ev.request.url.clone();
298
299 let matched = rules.iter().find(|r| r.matches(&url));
303
304 match matched {
305 Some(Rule::Block { .. }) => fail_request(session, &ev.request_id, "BlockedByClient").await,
306 Some(Rule::Redirect { to, .. }) => continue_with_url(session, &ev.request_id, to).await,
307 Some(Rule::Respond {
308 status,
309 headers,
310 body,
311 ..
312 }) => fulfill_request(session, &ev.request_id, *status, headers, body).await,
313 Some(Rule::Modify { modify, .. }) => {
314 let info = build_request_info(&ev);
315 let overrides = modify(&info);
316 continue_with_overrides(session, &ev.request_id, overrides).await
317 }
318 None => continue_passthrough(session, &ev.request_id).await,
319 }
320}
321
322pub(crate) fn serialize_pattern(p: &RequestPattern) -> Value {
325 let mut obj = Map::new();
326 if let Some(url) = &p.url_pattern {
327 obj.insert("urlPattern".into(), Value::String(url.clone()));
328 }
329 if let Some(rt) = p.resource_type {
330 obj.insert("resourceType".into(), Value::String(rt.as_cdp_str().into()));
331 }
332 if let Some(stage) = p.request_stage {
333 obj.insert(
334 "requestStage".into(),
335 Value::String(stage.as_cdp_str().into()),
336 );
337 }
338 Value::Object(obj)
339}
340
341pub(crate) fn build_request_info(ev: &RequestPausedEvent) -> RequestInfo {
353 RequestInfo {
354 url: ev.request.url.clone(),
355 method: ev.request.method.clone(),
356 headers: ev
357 .request
358 .headers
359 .iter()
360 .map(|(k, v)| (k.clone(), v.clone()))
361 .collect(),
362 post_data: decode_post_data(&ev.request),
363 resource_type: parse_resource_type(ev.resource_type.as_deref()),
364 }
365}
366
367fn decode_post_data(req: &RequestPayload) -> Option<Vec<u8>> {
368 use base64::Engine as _;
369 use base64::engine::general_purpose::STANDARD as BASE64;
370
371 if let Some(entries) = req.post_data_entries.as_ref() {
372 let mut buf = Vec::new();
373 for entry in entries {
374 let Some(b64) = entry.bytes.as_deref() else {
375 continue;
376 };
377 match BASE64.decode(b64) {
378 Ok(bytes) => buf.extend_from_slice(&bytes),
379 Err(e) => {
380 tracing::warn!(error = %e, "interception: bad base64 in postDataEntries; skipping entry");
381 }
382 }
383 }
384 return Some(buf);
385 }
386 req.post_data.as_deref().map(|s| s.as_bytes().to_vec())
387}
388
389pub(crate) fn build_response_info(ev: &RequestPausedEvent) -> Option<ResponseInfo> {
396 let status = ev.response_status_code?;
397 let status_text = ev.response_status_text.clone().unwrap_or_default();
398 let headers: Vec<(String, String)> = ev
399 .response_headers
400 .as_ref()
401 .map(|hs| {
402 hs.iter()
403 .map(|h| (h.name.clone(), h.value.clone()))
404 .collect()
405 })
406 .unwrap_or_default();
407 Some(ResponseInfo {
408 status,
409 status_text,
410 headers,
411 })
412}
413
414pub(crate) fn headers_to_cdp(headers: &[(String, String)]) -> Vec<Value> {
418 headers
419 .iter()
420 .map(|(name, value)| json!({ "name": name, "value": value }))
421 .collect()
422}
423
424fn parse_resource_type(s: Option<&str>) -> ResourceType {
430 match s.unwrap_or("Other") {
431 "Document" => ResourceType::Document,
432 "Stylesheet" => ResourceType::Stylesheet,
433 "Image" => ResourceType::Image,
434 "Media" => ResourceType::Media,
435 "Font" => ResourceType::Font,
436 "Script" => ResourceType::Script,
437 "TextTrack" => ResourceType::TextTrack,
438 "XHR" => ResourceType::XHR,
439 "Fetch" => ResourceType::Fetch,
440 "EventSource" => ResourceType::EventSource,
441 "WebSocket" => ResourceType::WebSocket,
442 "Manifest" => ResourceType::Manifest,
443 "SignedExchange" => ResourceType::SignedExchange,
444 "Ping" => ResourceType::Ping,
445 "CSPViolationReport" => ResourceType::CSPViolationReport,
446 "Preflight" => ResourceType::Preflight,
447 _ => ResourceType::Other,
448 }
449}
450
451async fn fail_request(
454 session: &SessionHandle,
455 request_id: &str,
456 error_reason: &str,
457) -> Result<(), InterceptionError> {
458 session
459 .call(
460 "Fetch.failRequest",
461 json!({
462 "requestId": request_id,
463 "errorReason": error_reason,
464 }),
465 )
466 .await?;
467 Ok(())
468}
469
470async fn continue_passthrough(
471 session: &SessionHandle,
472 request_id: &str,
473) -> Result<(), InterceptionError> {
474 session
475 .call("Fetch.continueRequest", json!({ "requestId": request_id }))
476 .await?;
477 Ok(())
478}
479
480async fn continue_with_url(
481 session: &SessionHandle,
482 request_id: &str,
483 url: &str,
484) -> Result<(), InterceptionError> {
485 session
486 .call(
487 "Fetch.continueRequest",
488 json!({
489 "requestId": request_id,
490 "url": url,
491 }),
492 )
493 .await?;
494 Ok(())
495}
496
497async fn continue_with_overrides(
498 session: &SessionHandle,
499 request_id: &str,
500 overrides: RequestOverrides,
501) -> Result<(), InterceptionError> {
502 let mut params = Map::new();
503 params.insert("requestId".into(), Value::String(request_id.into()));
504 if let Some(url) = overrides.url {
505 params.insert("url".into(), Value::String(url));
506 }
507 if let Some(method) = overrides.method {
508 params.insert("method".into(), Value::String(method));
509 }
510 if let Some(headers) = overrides.headers {
511 params.insert("headers".into(), Value::Array(headers_to_cdp(&headers)));
512 }
513 if let Some(post_data) = overrides.post_data {
514 params.insert("postData".into(), Value::String(BASE64.encode(&post_data)));
515 }
516 session
517 .call("Fetch.continueRequest", Value::Object(params))
518 .await?;
519 Ok(())
520}
521
522async fn fulfill_request(
523 session: &SessionHandle,
524 request_id: &str,
525 status: u16,
526 headers: &[(String, String)],
527 body: &[u8],
528) -> Result<(), InterceptionError> {
529 let response_headers = headers_to_cdp(headers);
530 session
531 .call(
532 "Fetch.fulfillRequest",
533 json!({
534 "requestId": request_id,
535 "responseCode": status,
536 "responseHeaders": response_headers,
537 "body": BASE64.encode(body),
538 }),
539 )
540 .await?;
541 Ok(())
542}
543
544#[cfg(test)]
545#[allow(clippy::panic, clippy::unwrap_used)]
546mod tests {
547 use super::*;
548 use crate::url_pattern::UrlPattern;
549 use std::time::Duration;
550 use zendriver_transport::testing::MockConnection;
551
552 #[tokio::test]
560 async fn block_rule_dispatches_fail_request_with_blocked_by_client() {
561 let (mut mock, conn) = MockConnection::pair();
562 let sess = SessionHandle::new(conn.clone(), "S1");
563
564 let rules = vec![Rule::Block {
565 pattern: UrlPattern::new("*/blocked/*").unwrap(),
566 }];
567 let patterns = vec![RequestPattern {
568 url_pattern: Some("*".into()),
569 ..RequestPattern::default()
570 }];
571 let cancel = CancellationToken::new();
572 let (done_tx, done_rx) = oneshot::channel();
573 let actor_cancel = cancel.clone();
574 let actor = tokio::spawn(async move {
575 run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
576 });
577
578 let enable_id =
584 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
585 .await
586 .expect("actor did not send Fetch.enable within 2s");
587 let enable_params = mock.last_sent()["params"].clone();
588 assert_eq!(enable_params["handleAuthRequests"], false);
589 assert_eq!(enable_params["patterns"][0]["urlPattern"], "*");
590 mock.reply(enable_id, json!({})).await;
593
594 mock.emit_event_for_session(
597 "Fetch.requestPaused",
598 json!({
599 "requestId": "REQ-1",
600 "request": {
601 "url": "https://example.test/blocked/banner.png",
602 "method": "GET",
603 "headers": {},
604 },
605 "resourceType": "Image",
606 }),
607 "S1",
608 )
609 .await;
610
611 let fail_id =
613 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
614 .await
615 .expect("actor did not send Fetch.failRequest within 2s");
616 let fail_params = mock.last_sent()["params"].clone();
617 assert_eq!(fail_params["requestId"], "REQ-1");
618 assert_eq!(fail_params["errorReason"], "BlockedByClient");
619 mock.reply(fail_id, json!({})).await;
620
621 cancel.cancel();
624 let disable_id =
625 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
626 .await
627 .expect("actor did not send Fetch.disable on cancel");
628 mock.reply(disable_id, json!({})).await;
629
630 tokio::time::timeout(Duration::from_secs(2), done_rx)
631 .await
632 .expect("actor did not signal exit within 2s")
633 .expect("oneshot sender dropped without sending");
634 actor.await.unwrap();
635 conn.shutdown();
636 }
637
638 #[tokio::test]
639 async fn actor_handles_auth_required_with_credentials() {
640 let (mut mock, conn) = MockConnection::pair();
647 let sess = SessionHandle::new(conn.clone(), "S1");
648 let cancel = CancellationToken::new();
649 let (done_tx, done_rx) = oneshot::channel();
650 let actor_cancel = cancel.clone();
651 let auth = Some(("user1".to_string(), "pass1".to_string()));
652 let actor = tokio::spawn(async move {
653 run_actor(
654 sess,
655 Vec::new(),
656 vec![RequestPattern {
657 url_pattern: Some("*".into()),
658 ..RequestPattern::default()
659 }],
660 auth,
661 actor_cancel,
662 done_tx,
663 )
664 .await;
665 });
666
667 let enable_id =
668 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
669 .await
670 .expect("actor did not send Fetch.enable within 2s");
671 assert_eq!(
672 mock.last_sent()["params"]["handleAuthRequests"],
673 true,
674 "auth-enabled actor must flip handleAuthRequests"
675 );
676 mock.reply(enable_id, json!({})).await;
677
678 mock.emit_event_for_session(
679 "Fetch.authRequired",
680 json!({
681 "requestId": "AUTH-REQ-1",
682 "request": { "url": "https://example.test/", "method": "GET" },
683 "frameId": "F1",
684 "resourceType": "Document",
685 "authChallenge": {
686 "source": "Proxy",
687 "origin": "http://proxy.test",
688 "scheme": "basic",
689 "realm": "",
690 },
691 }),
692 "S1",
693 )
694 .await;
695
696 let auth_id = tokio::time::timeout(
697 Duration::from_secs(2),
698 mock.expect_cmd("Fetch.continueWithAuth"),
699 )
700 .await
701 .expect("actor did not send Fetch.continueWithAuth within 2s");
702 let params = mock.last_sent()["params"].clone();
703 assert_eq!(params["requestId"], "AUTH-REQ-1");
704 assert_eq!(
705 params["authChallengeResponse"]["response"],
706 "ProvideCredentials"
707 );
708 assert_eq!(params["authChallengeResponse"]["username"], "user1");
709 assert_eq!(params["authChallengeResponse"]["password"], "pass1");
710 mock.reply(auth_id, json!({})).await;
711
712 cancel.cancel();
713 let disable_id =
714 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
715 .await
716 .expect("actor did not send Fetch.disable on cancel");
717 mock.reply(disable_id, json!({})).await;
718 tokio::time::timeout(Duration::from_secs(2), done_rx)
719 .await
720 .expect("actor did not signal exit")
721 .expect("oneshot sender dropped");
722 actor.await.unwrap();
723 conn.shutdown();
724 }
725
726 #[tokio::test]
727 async fn actor_without_auth_responds_default_to_auth_required() {
728 let (mut mock, conn) = MockConnection::pair();
734 let sess = SessionHandle::new(conn.clone(), "S2");
735 let cancel = CancellationToken::new();
736 let (done_tx, done_rx) = oneshot::channel();
737 let actor_cancel = cancel.clone();
738 let actor = tokio::spawn(async move {
739 run_actor(
740 sess,
741 Vec::new(),
742 vec![RequestPattern {
743 url_pattern: Some("*".into()),
744 ..RequestPattern::default()
745 }],
746 None,
747 actor_cancel,
748 done_tx,
749 )
750 .await;
751 });
752
753 let enable_id =
754 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
755 .await
756 .expect("actor did not send Fetch.enable");
757 assert_eq!(mock.last_sent()["params"]["handleAuthRequests"], false);
758 mock.reply(enable_id, json!({})).await;
759
760 mock.emit_event_for_session(
761 "Fetch.authRequired",
762 json!({ "requestId": "AUTH-REQ-2" }),
763 "S2",
764 )
765 .await;
766
767 let auth_id = tokio::time::timeout(
768 Duration::from_secs(2),
769 mock.expect_cmd("Fetch.continueWithAuth"),
770 )
771 .await
772 .expect("actor did not respond to stray authRequired");
773 assert_eq!(
774 mock.last_sent()["params"]["authChallengeResponse"]["response"],
775 "Default"
776 );
777 mock.reply(auth_id, json!({})).await;
778
779 cancel.cancel();
780 let disable_id =
781 tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
782 .await
783 .expect("actor did not send Fetch.disable");
784 mock.reply(disable_id, json!({})).await;
785 tokio::time::timeout(Duration::from_secs(2), done_rx)
786 .await
787 .expect("actor did not exit")
788 .expect("oneshot dropped");
789 actor.await.unwrap();
790 conn.shutdown();
791 }
792}