use std::sync::Arc;
use futures::stream::{Stream, StreamExt};
use serde_json::{Value, json};
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use zendriver_transport::SessionHandle;
use crate::actor::{
InterceptHandle, RequestPausedEvent, build_request_info, build_response_info, run_actor,
serialize_pattern,
};
use crate::error::InterceptionError;
use crate::host_matcher::HostMatcher;
use crate::paused::PausedRequest;
use crate::rule::Rule;
use crate::types::{
RequestInfo, RequestOverrides, RequestStage, ResourceType, ResponseInfo, ResponseOverrides,
};
use crate::url_pattern::UrlPattern;
#[derive(Debug, Clone, Default)]
pub struct RequestPattern {
pub url_pattern: Option<String>,
pub resource_type: Option<ResourceType>,
pub request_stage: Option<RequestStage>,
}
#[derive(Debug)]
pub struct InterceptBuilder<'tab> {
tab: &'tab SessionHandle,
patterns: Vec<RequestPattern>,
rules: Vec<Rule>,
auth: Option<(String, String)>,
}
impl<'tab> InterceptBuilder<'tab> {
#[must_use]
pub fn new(tab: &'tab SessionHandle) -> Self {
Self {
tab,
patterns: Vec::new(),
rules: Vec::new(),
auth: None,
}
}
#[must_use]
pub fn handle_auth(mut self, user: impl Into<String>, pass: impl Into<String>) -> Self {
self.auth = Some((user.into(), pass.into()));
self
}
#[must_use]
pub fn pattern(mut self, pattern: impl Into<String>) -> Self {
self.patterns.push(RequestPattern {
url_pattern: Some(pattern.into()),
..RequestPattern::default()
});
self
}
#[must_use]
pub fn at_request(mut self) -> Self {
self.ensure_pattern().request_stage = Some(RequestStage::Request);
self
}
#[must_use]
pub fn at_response(mut self) -> Self {
self.ensure_pattern().request_stage = Some(RequestStage::Response);
self
}
#[must_use]
pub fn resource(mut self, kind: ResourceType) -> Self {
self.ensure_pattern().resource_type = Some(kind);
self
}
pub fn block(mut self, pattern: impl Into<String>) -> Result<Self, InterceptionError> {
self.rules.push(Rule::Block {
pattern: UrlPattern::new(pattern)?,
});
Ok(self)
}
#[must_use]
pub fn block_hosts(mut self, matcher: Arc<HostMatcher>) -> Self {
self.rules.push(Rule::BlockHosts { matcher });
self
}
pub fn redirect(
mut self,
from: impl Into<String>,
to: impl Into<String>,
) -> Result<Self, InterceptionError> {
self.rules.push(Rule::Redirect {
from: UrlPattern::new(from)?,
to: to.into(),
});
Ok(self)
}
pub fn respond(
mut self,
pattern: impl Into<String>,
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
) -> Result<Self, InterceptionError> {
self.rules.push(Rule::Respond {
pattern: UrlPattern::new(pattern)?,
status,
headers,
body,
});
Ok(self)
}
pub fn modify_request<F>(
mut self,
pattern: impl Into<String>,
modify: F,
) -> Result<Self, InterceptionError>
where
F: Fn(&RequestInfo) -> RequestOverrides + Send + Sync + 'static,
{
self.rules.push(Rule::Modify {
pattern: UrlPattern::new(pattern)?,
modify: Arc::new(modify),
});
Ok(self)
}
pub fn modify_response<F>(
mut self,
pattern: impl Into<String>,
modify: F,
) -> Result<Self, InterceptionError>
where
F: Fn(&ResponseInfo) -> ResponseOverrides + Send + Sync + 'static,
{
self.rules.push(Rule::ModifyResponse {
pattern: UrlPattern::new(pattern)?,
modify: Arc::new(modify),
});
Ok(self)
}
#[must_use = "interception stops when the handle is dropped — bind the returned InterceptHandle to keep it alive"]
pub fn start(mut self) -> InterceptHandle {
if self.patterns.is_empty() {
self.patterns.push(RequestPattern {
url_pattern: Some("*".into()),
..RequestPattern::default()
});
}
let cancel = CancellationToken::new();
let (done_tx, done_rx) = oneshot::channel();
let actor_session = self.tab.clone();
let actor_cancel = cancel.clone();
let actor_rules = self.rules;
let actor_patterns = self.patterns;
let actor_auth = self.auth;
tokio::spawn(async move {
run_actor(
actor_session,
actor_rules,
actor_patterns,
actor_auth,
actor_cancel,
done_tx,
)
.await;
});
InterceptHandle::new(cancel, done_rx)
}
#[must_use = "the returned stream is the only handle on the subscription"]
pub fn subscribe(mut self) -> impl Stream<Item = PausedRequest> + Send + use<> {
if self.patterns.is_empty() {
self.patterns.push(RequestPattern {
url_pattern: Some("*".into()),
..RequestPattern::default()
});
}
let raw = self.tab.subscribe::<Value>("Fetch.requestPaused");
let session = self.tab.clone();
let enable_session = session.clone();
let enable_patterns: Vec<Value> = self.patterns.iter().map(serialize_pattern).collect();
tokio::spawn(async move {
if let Err(e) = enable_session
.call(
"Fetch.enable",
json!({
"patterns": enable_patterns,
"handleAuthRequests": false,
}),
)
.await
{
warn!(error = %e, "interception: Fetch.enable failed; subscribe() stream will be empty");
}
});
raw.filter_map(move |ev_value| {
let session = session.clone();
async move {
let ev: RequestPausedEvent = match serde_json::from_value(ev_value) {
Ok(ev) => ev,
Err(e) => {
warn!(error = %e, "interception: skipping malformed Fetch.requestPaused event");
return None;
}
};
let info = build_request_info(&ev);
let response = build_response_info(&ev);
Some(PausedRequest::new(ev.request_id, info, response, session))
}
})
}
fn ensure_pattern(&mut self) -> &mut RequestPattern {
if self.patterns.is_empty() {
self.patterns.push(RequestPattern::default());
}
self.patterns
.last_mut()
.expect("ensure_pattern pushed if empty")
}
#[cfg(test)]
pub(crate) fn rules_count(&self) -> usize {
self.rules.len()
}
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
use std::time::Duration;
use zendriver_transport::testing::MockConnection;
#[tokio::test]
async fn three_rules_register_and_count() {
let (_mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let builder = InterceptBuilder::new(&sess)
.block("*/ads/*")
.unwrap()
.redirect("*/old/*", "https://example.com/new/")
.unwrap()
.respond(
"*/api/health",
200,
vec![("content-type".into(), "application/json".into())],
br#"{"ok":true}"#.to_vec(),
)
.unwrap();
assert_eq!(builder.rules_count(), 3);
conn.shutdown();
}
#[tokio::test]
async fn start_spawns_actor_with_rules() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let handle = InterceptBuilder::new(&sess)
.block("*/blocked/*")
.unwrap()
.pattern("*")
.start();
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("actor did not send Fetch.enable within 2s");
let enable_params = mock.last_sent()["params"].clone();
assert_eq!(enable_params["handleAuthRequests"], false);
assert_eq!(enable_params["patterns"][0]["urlPattern"], "*");
mock.reply(enable_id, json!({})).await;
mock.emit_event_for_session(
"Fetch.requestPaused",
json!({
"requestId": "REQ-1",
"request": {
"url": "https://example.test/blocked/banner.png",
"method": "GET",
"headers": {},
},
"resourceType": "Image",
}),
"S1",
)
.await;
let fail_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.failRequest"))
.await
.expect("actor did not send Fetch.failRequest within 2s");
let fail_params = mock.last_sent()["params"].clone();
assert_eq!(fail_params["requestId"], "REQ-1");
assert_eq!(fail_params["errorReason"], "BlockedByClient");
mock.reply(fail_id, json!({})).await;
let stop_fut = tokio::spawn(handle.stop());
let disable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.disable"))
.await
.expect("actor did not send Fetch.disable on stop()");
mock.reply(disable_id, json!({})).await;
stop_fut
.await
.expect("stop() task panicked")
.expect("stop() returned Err");
conn.shutdown();
}
#[tokio::test]
async fn start_defaults_to_match_all_pattern_when_none_registered() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let handle = InterceptBuilder::new(&sess)
.block("*/blocked/*")
.unwrap()
.start();
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("actor did not send Fetch.enable within 2s");
let patterns = mock.last_sent()["params"]["patterns"].clone();
let arr = patterns.as_array().expect("patterns must be a JSON array");
assert_eq!(arr.len(), 1);
assert_eq!(arr[0]["urlPattern"], "*");
mock.reply(enable_id, json!({})).await;
drop(handle);
conn.shutdown();
}
#[tokio::test]
async fn subscribe_yields_paused_request_per_event() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let mut stream = Box::pin(InterceptBuilder::new(&sess).subscribe());
let enable_id =
tokio::time::timeout(Duration::from_secs(2), mock.expect_cmd("Fetch.enable"))
.await
.expect("subscribe() did not send Fetch.enable within 2s");
mock.reply(enable_id, json!({})).await;
mock.emit_event_for_session(
"Fetch.requestPaused",
json!({
"requestId": "REQ-1",
"request": {
"url": "https://example.test/widget.json",
"method": "GET",
"headers": {"accept": "application/json"},
},
"resourceType": "XHR",
}),
"S1",
)
.await;
let paused = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("subscribe() stream did not yield within 2s")
.expect("subscribe() stream closed before yielding");
assert_eq!(paused.request_id, "REQ-1");
assert_eq!(paused.request.url, "https://example.test/widget.json");
assert_eq!(paused.request.method, "GET");
assert_eq!(
paused
.request
.headers
.iter()
.find(|(k, _)| k == "accept")
.map(|(_, v)| v.as_str()),
Some("application/json"),
);
assert!(
paused.response.is_none(),
"request-stage event has no response"
);
drop(stream);
conn.shutdown();
}
}