use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64;
use futures::StreamExt;
use serde::Deserialize;
use serde_json::{Map, Value, json};
use std::collections::HashMap;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::{trace, warn};
use zendriver_transport::SessionHandle;
use crate::builder::RequestPattern;
use crate::error::InterceptionError;
use crate::rule::Rule;
use crate::types::{RequestInfo, RequestOverrides, ResourceType, ResponseInfo, ResponseOverrides};
#[derive(Debug)]
#[must_use = "interception stops when the handle is dropped — bind it to a variable to keep it alive"]
pub struct InterceptHandle {
cancel: CancellationToken,
done: Option<oneshot::Receiver<()>>,
}
impl InterceptHandle {
pub(crate) fn new(cancel: CancellationToken, done: oneshot::Receiver<()>) -> Self {
Self {
cancel,
done: Some(done),
}
}
#[cfg(any(test, feature = "test-support"))]
#[doc(hidden)]
pub fn for_tests() -> Self {
let (_done_tx, done_rx) = oneshot::channel();
Self {
cancel: CancellationToken::new(),
done: Some(done_rx),
}
}
pub async fn stop(mut self) -> Result<(), InterceptionError> {
self.cancel.cancel();
match self.done.take() {
Some(rx) => rx.await.map_err(|_| InterceptionError::SubscriptionClosed),
None => Ok(()),
}
}
}
impl Drop for InterceptHandle {
fn drop(&mut self) {
self.cancel.cancel();
}
}
#[derive(Debug, Deserialize)]
pub(crate) struct RequestPausedEvent {
#[serde(rename = "requestId")]
pub(crate) request_id: String,
pub(crate) request: RequestPayload,
#[serde(rename = "resourceType", default)]
pub(crate) resource_type: Option<String>,
#[serde(rename = "responseStatusCode", default)]
pub(crate) response_status_code: Option<u16>,
#[serde(rename = "responseStatusText", default)]
pub(crate) response_status_text: Option<String>,
#[serde(rename = "responseHeaders", default)]
pub(crate) response_headers: Option<Vec<HeaderPair>>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct RequestPayload {
pub(crate) url: String,
pub(crate) method: String,
#[serde(default)]
pub(crate) headers: HashMap<String, String>,
#[serde(rename = "postData", default)]
pub(crate) post_data: Option<String>,
#[serde(rename = "hasPostData", default)]
_has_post_data: Option<bool>,
#[serde(rename = "postDataEntries", default)]
pub(crate) post_data_entries: Option<Vec<PostDataEntry>>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct PostDataEntry {
#[serde(default)]
pub(crate) bytes: Option<String>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct HeaderPair {
pub(crate) name: String,
pub(crate) value: String,
}
pub(crate) async fn run_actor(
session: SessionHandle,
rules: Vec<Rule>,
patterns: Vec<RequestPattern>,
auth: Option<(String, String)>,
cancel: CancellationToken,
done: oneshot::Sender<()>,
) {
let mut paused = session.subscribe::<Value>("Fetch.requestPaused");
let mut auth_required = session.subscribe::<Value>("Fetch.authRequired");
let enable_session = session.clone();
let enable_patterns: Vec<Value> = patterns.iter().map(serialize_pattern).collect();
let handle_auth_requests = auth.is_some();
tokio::spawn(async move {
if let Err(e) = enable_session
.call(
"Fetch.enable",
json!({
"patterns": enable_patterns,
"handleAuthRequests": handle_auth_requests,
}),
)
.await
{
warn!(error = %e, "interception: Fetch.enable failed; interception inactive");
}
});
loop {
tokio::select! {
() = cancel.cancelled() => {
trace!("interception: cancellation received, disabling Fetch and exiting");
break;
}
Some(ev_value) = paused.next() => {
let ev: RequestPausedEvent = match serde_json::from_value(ev_value) {
Ok(ev) => ev,
Err(e) => {
warn!(error = %e, "interception: skipping malformed Fetch.requestPaused event");
continue;
}
};
if let Err(e) = handle_paused(&session, &rules, ev).await {
warn!(error = %e, "interception: handler dispatch failed");
}
}
Some(ev_value) = auth_required.next() => {
let Some(request_id) = ev_value
.get("requestId")
.and_then(Value::as_str)
.map(str::to_owned)
else {
warn!("interception: Fetch.authRequired without requestId");
continue;
};
let response = match &auth {
Some((user, pass)) => json!({
"response": "ProvideCredentials",
"username": user,
"password": pass,
}),
None => json!({ "response": "Default" }),
};
if let Err(e) = session
.call(
"Fetch.continueWithAuth",
json!({
"requestId": request_id,
"authChallengeResponse": response,
}),
)
.await
{
warn!(error = %e, "interception: Fetch.continueWithAuth failed");
}
}
else => {
trace!("interception: event stream closed, exiting without Fetch.disable");
let _ = done.send(());
return;
}
}
}
if let Err(e) = session.call("Fetch.disable", json!({})).await {
warn!(error = %e, "interception: Fetch.disable failed during shutdown");
}
let _ = done.send(());
}
async fn handle_paused(
session: &SessionHandle,
rules: &[Rule],
ev: RequestPausedEvent,
) -> Result<(), InterceptionError> {
let url = ev.request.url.clone();
let matched = rules.iter().find(|r| r.matches(&url));
match matched {
Some(Rule::Block { .. }) | Some(Rule::BlockHosts { .. }) => {
fail_request(session, &ev.request_id, "BlockedByClient").await
}
Some(Rule::Redirect { to, .. }) => continue_with_url(session, &ev.request_id, to).await,
Some(Rule::Respond {
status,
headers,
body,
..
}) => fulfill_request(session, &ev.request_id, *status, headers, body).await,
Some(Rule::Modify { modify, .. }) => {
let info = build_request_info(&ev);
let overrides = modify(&info);
continue_with_overrides(session, &ev.request_id, overrides).await
}
Some(Rule::ModifyResponse { modify, .. }) => match build_response_info(&ev) {
Some(info) => {
let overrides = modify(&info);
continue_response_with_overrides(session, &ev.request_id, overrides).await
}
None => {
tracing::debug!(
request_id = %ev.request_id,
url = %url,
"interception: ModifyResponse matched at Request stage; no response yet, passing through"
);
continue_passthrough(session, &ev.request_id).await
}
},
None => continue_passthrough(session, &ev.request_id).await,
}
}
pub(crate) fn serialize_pattern(p: &RequestPattern) -> Value {
let mut obj = Map::new();
if let Some(url) = &p.url_pattern {
obj.insert("urlPattern".into(), Value::String(url.clone()));
}
if let Some(rt) = p.resource_type {
obj.insert("resourceType".into(), Value::String(rt.as_cdp_str().into()));
}
if let Some(stage) = p.request_stage {
obj.insert(
"requestStage".into(),
Value::String(stage.as_cdp_str().into()),
);
}
Value::Object(obj)
}
pub(crate) fn build_request_info(ev: &RequestPausedEvent) -> RequestInfo {
RequestInfo {
url: ev.request.url.clone(),
method: ev.request.method.clone(),
headers: ev
.request
.headers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
post_data: decode_post_data(&ev.request),
resource_type: parse_resource_type(ev.resource_type.as_deref()),
}
}
fn decode_post_data(req: &RequestPayload) -> Option<Vec<u8>> {
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64;
if let Some(entries) = req.post_data_entries.as_ref() {
let mut buf = Vec::new();
for entry in entries {
let Some(b64) = entry.bytes.as_deref() else {
continue;
};
match BASE64.decode(b64) {
Ok(bytes) => buf.extend_from_slice(&bytes),
Err(e) => {
tracing::warn!(error = %e, "interception: bad base64 in postDataEntries; skipping entry");
}
}
}
return Some(buf);
}
req.post_data.as_deref().map(|s| s.as_bytes().to_vec())
}
pub(crate) fn build_response_info(ev: &RequestPausedEvent) -> Option<ResponseInfo> {
let status = ev.response_status_code?;
let status_text = ev.response_status_text.clone().unwrap_or_default();
let headers: Vec<(String, String)> = ev
.response_headers
.as_ref()
.map(|hs| {
hs.iter()
.map(|h| (h.name.clone(), h.value.clone()))
.collect()
})
.unwrap_or_default();
Some(ResponseInfo {
status,
status_text,
headers,
})
}
pub(crate) fn headers_to_cdp(headers: &[(String, String)]) -> Vec<Value> {
headers
.iter()
.map(|(name, value)| json!({ "name": name, "value": value }))
.collect()
}
fn parse_resource_type(s: Option<&str>) -> ResourceType {
match s.unwrap_or("Other") {
"Document" => ResourceType::Document,
"Stylesheet" => ResourceType::Stylesheet,
"Image" => ResourceType::Image,
"Media" => ResourceType::Media,
"Font" => ResourceType::Font,
"Script" => ResourceType::Script,
"TextTrack" => ResourceType::TextTrack,
"XHR" => ResourceType::XHR,
"Fetch" => ResourceType::Fetch,
"EventSource" => ResourceType::EventSource,
"WebSocket" => ResourceType::WebSocket,
"Manifest" => ResourceType::Manifest,
"SignedExchange" => ResourceType::SignedExchange,
"Ping" => ResourceType::Ping,
"CSPViolationReport" => ResourceType::CSPViolationReport,
"Preflight" => ResourceType::Preflight,
_ => ResourceType::Other,
}
}
async fn fail_request(
session: &SessionHandle,
request_id: &str,
error_reason: &str,
) -> Result<(), InterceptionError> {
session
.call(
"Fetch.failRequest",
json!({
"requestId": request_id,
"errorReason": error_reason,
}),
)
.await?;
Ok(())
}
async fn continue_passthrough(
session: &SessionHandle,
request_id: &str,
) -> Result<(), InterceptionError> {
session
.call("Fetch.continueRequest", json!({ "requestId": request_id }))
.await?;
Ok(())
}
async fn continue_with_url(
session: &SessionHandle,
request_id: &str,
url: &str,
) -> Result<(), InterceptionError> {
session
.call(
"Fetch.continueRequest",
json!({
"requestId": request_id,
"url": url,
}),
)
.await?;
Ok(())
}
async fn continue_with_overrides(
session: &SessionHandle,
request_id: &str,
overrides: RequestOverrides,
) -> Result<(), InterceptionError> {
let mut params = Map::new();
params.insert("requestId".into(), Value::String(request_id.into()));
if let Some(url) = overrides.url {
params.insert("url".into(), Value::String(url));
}
if let Some(method) = overrides.method {
params.insert("method".into(), Value::String(method));
}
if let Some(headers) = overrides.headers {
params.insert("headers".into(), Value::Array(headers_to_cdp(&headers)));
}
if let Some(post_data) = overrides.post_data {
params.insert("postData".into(), Value::String(BASE64.encode(&post_data)));
}
session
.call("Fetch.continueRequest", Value::Object(params))
.await?;
Ok(())
}
async fn fulfill_request(
session: &SessionHandle,
request_id: &str,
status: u16,
headers: &[(String, String)],
body: &[u8],
) -> Result<(), InterceptionError> {
let response_headers = headers_to_cdp(headers);
session
.call(
"Fetch.fulfillRequest",
json!({
"requestId": request_id,
"responseCode": status,
"responseHeaders": response_headers,
"body": BASE64.encode(body),
}),
)
.await?;
Ok(())
}
async fn continue_response_with_overrides(
session: &SessionHandle,
request_id: &str,
overrides: ResponseOverrides,
) -> Result<(), InterceptionError> {
let mut params = Map::new();
params.insert("requestId".into(), Value::String(request_id.into()));
if let Some(status) = overrides.status {
params.insert("responseCode".into(), Value::from(status));
}
if let Some(phrase) = overrides.phrase {
params.insert("responsePhrase".into(), Value::String(phrase));
}
if let Some(headers) = overrides.headers {
params.insert(
"responseHeaders".into(),
Value::Array(headers_to_cdp(&headers)),
);
}
session
.call("Fetch.continueResponse", Value::Object(params))
.await?;
Ok(())
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
use crate::url_pattern::UrlPattern;
use std::time::Duration;
use zendriver_transport::testing::MockConnection;
#[tokio::test]
async fn block_rule_dispatches_fail_request_with_blocked_by_client() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let rules = vec![Rule::Block {
pattern: UrlPattern::new("*/blocked/*").unwrap(),
}];
let patterns = vec![RequestPattern {
url_pattern: Some("*".into()),
..RequestPattern::default()
}];
let cancel = CancellationToken::new();
let (done_tx, done_rx) = oneshot::channel();
let actor_cancel = cancel.clone();
let actor = tokio::spawn(async move {
run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
});
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("actor did not send Fetch.enable within 2s");
let enable_params = mock.last_sent()["params"].clone();
assert_eq!(enable_params["handleAuthRequests"], false);
assert_eq!(enable_params["patterns"][0]["urlPattern"], "*");
mock.reply(enable_id, json!({})).await;
mock.emit_event_for_session(
"Fetch.requestPaused",
json!({
"requestId": "REQ-1",
"request": {
"url": "https://example.test/blocked/banner.png",
"method": "GET",
"headers": {},
},
"resourceType": "Image",
}),
"S1",
)
.await;
let fail_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
.await
.expect("actor did not send Fetch.failRequest within 2s");
let fail_params = mock.last_sent()["params"].clone();
assert_eq!(fail_params["requestId"], "REQ-1");
assert_eq!(fail_params["errorReason"], "BlockedByClient");
mock.reply(fail_id, json!({})).await;
cancel.cancel();
let disable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
.await
.expect("actor did not send Fetch.disable on cancel");
mock.reply(disable_id, json!({})).await;
tokio::time::timeout(Duration::from_secs(2), done_rx)
.await
.expect("actor did not signal exit within 2s")
.expect("oneshot sender dropped without sending");
actor.await.unwrap();
conn.shutdown();
}
#[tokio::test]
async fn block_hosts_rule_dispatches_fail_request() {
use crate::host_matcher::HostMatcher;
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let rules = vec![Rule::BlockHosts {
matcher: std::sync::Arc::new(HostMatcher::new(["evil.com".to_string()])),
}];
let patterns = vec![RequestPattern {
url_pattern: Some("*".into()),
..RequestPattern::default()
}];
let cancel = CancellationToken::new();
let (done_tx, done_rx) = oneshot::channel();
let actor_cancel = cancel.clone();
let actor = tokio::spawn(async move {
run_actor(sess, rules, patterns, None, actor_cancel, done_tx).await;
});
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("actor did not send Fetch.enable within 2s");
mock.reply(enable_id, json!({})).await;
mock.emit_event_for_session(
"Fetch.requestPaused",
json!({
"requestId": "REQ-1",
"request": {
"url": "https://cdn.evil.com/fp.js",
"method": "GET",
"headers": {},
},
"resourceType": "Script",
}),
"S1",
)
.await;
let fail_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
.await
.expect("actor did not send Fetch.failRequest within 2s");
let fail_params = mock.last_sent()["params"].clone();
assert_eq!(fail_params["requestId"], "REQ-1");
assert_eq!(fail_params["errorReason"], "BlockedByClient");
mock.reply(fail_id, json!({})).await;
cancel.cancel();
let disable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
.await
.expect("actor did not send Fetch.disable on cancel");
mock.reply(disable_id, json!({})).await;
tokio::time::timeout(Duration::from_secs(2), done_rx)
.await
.expect("actor did not signal exit within 2s")
.expect("oneshot sender dropped without sending");
actor.await.unwrap();
conn.shutdown();
}
#[tokio::test]
async fn actor_handles_auth_required_with_credentials() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let cancel = CancellationToken::new();
let (done_tx, done_rx) = oneshot::channel();
let actor_cancel = cancel.clone();
let auth = Some(("user1".to_string(), "pass1".to_string()));
let actor = tokio::spawn(async move {
run_actor(
sess,
Vec::new(),
vec![RequestPattern {
url_pattern: Some("*".into()),
..RequestPattern::default()
}],
auth,
actor_cancel,
done_tx,
)
.await;
});
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("actor did not send Fetch.enable within 2s");
assert_eq!(
mock.last_sent()["params"]["handleAuthRequests"],
true,
"auth-enabled actor must flip handleAuthRequests"
);
mock.reply(enable_id, json!({})).await;
mock.emit_event_for_session(
"Fetch.authRequired",
json!({
"requestId": "AUTH-REQ-1",
"request": { "url": "https://example.test/", "method": "GET" },
"frameId": "F1",
"resourceType": "Document",
"authChallenge": {
"source": "Proxy",
"origin": "http://proxy.test",
"scheme": "basic",
"realm": "",
},
}),
"S1",
)
.await;
let auth_id = tokio::time::timeout(
Duration::from_secs(2),
mock.expect_cmd("Fetch.continueWithAuth"),
)
.await
.expect("actor did not send Fetch.continueWithAuth within 2s");
let params = mock.last_sent()["params"].clone();
assert_eq!(params["requestId"], "AUTH-REQ-1");
assert_eq!(
params["authChallengeResponse"]["response"],
"ProvideCredentials"
);
assert_eq!(params["authChallengeResponse"]["username"], "user1");
assert_eq!(params["authChallengeResponse"]["password"], "pass1");
mock.reply(auth_id, json!({})).await;
cancel.cancel();
let disable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
.await
.expect("actor did not send Fetch.disable on cancel");
mock.reply(disable_id, json!({})).await;
tokio::time::timeout(Duration::from_secs(2), done_rx)
.await
.expect("actor did not signal exit")
.expect("oneshot sender dropped");
actor.await.unwrap();
conn.shutdown();
}
#[tokio::test]
async fn actor_without_auth_responds_default_to_auth_required() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S2");
let cancel = CancellationToken::new();
let (done_tx, done_rx) = oneshot::channel();
let actor_cancel = cancel.clone();
let actor = tokio::spawn(async move {
run_actor(
sess,
Vec::new(),
vec![RequestPattern {
url_pattern: Some("*".into()),
..RequestPattern::default()
}],
None,
actor_cancel,
done_tx,
)
.await;
});
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("actor did not send Fetch.enable");
assert_eq!(mock.last_sent()["params"]["handleAuthRequests"], false);
mock.reply(enable_id, json!({})).await;
mock.emit_event_for_session(
"Fetch.authRequired",
json!({ "requestId": "AUTH-REQ-2" }),
"S2",
)
.await;
let auth_id = tokio::time::timeout(
Duration::from_secs(2),
mock.expect_cmd("Fetch.continueWithAuth"),
)
.await
.expect("actor did not respond to stray authRequired");
assert_eq!(
mock.last_sent()["params"]["authChallengeResponse"]["response"],
"Default"
);
mock.reply(auth_id, json!({})).await;
cancel.cancel();
let disable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
.await
.expect("actor did not send Fetch.disable");
mock.reply(disable_id, json!({})).await;
tokio::time::timeout(Duration::from_secs(2), done_rx)
.await
.expect("actor did not exit")
.expect("oneshot dropped");
actor.await.unwrap();
conn.shutdown();
}
}