use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64;
use serde_json::{Map, Value, json};
use zendriver_transport::SessionHandle;
use crate::error::InterceptionError;
use crate::types::{AbortReason, RequestInfo, RequestOverrides, ResponseInfo};
#[derive(Debug)]
pub struct PausedRequest {
pub request_id: String,
pub request: RequestInfo,
pub response: Option<ResponseInfo>,
session: SessionHandle,
released: bool,
}
impl PausedRequest {
pub(crate) fn new(
request_id: impl Into<String>,
request: RequestInfo,
response: Option<ResponseInfo>,
session: SessionHandle,
) -> Self {
Self {
request_id: request_id.into(),
request,
response,
session,
released: false,
}
}
pub async fn continue_(mut self) -> Result<(), InterceptionError> {
self.released = true;
self.session
.call(
"Fetch.continueRequest",
json!({ "requestId": self.request_id }),
)
.await?;
Ok(())
}
pub async fn abort(mut self, reason: AbortReason) -> Result<(), InterceptionError> {
self.released = true;
self.session
.call(
"Fetch.failRequest",
json!({
"requestId": self.request_id,
"errorReason": reason.as_cdp_str(),
}),
)
.await?;
Ok(())
}
pub async fn respond(
mut self,
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
) -> Result<(), InterceptionError> {
self.released = true;
let response_headers = crate::actor::headers_to_cdp(&headers);
self.session
.call(
"Fetch.fulfillRequest",
json!({
"requestId": self.request_id,
"responseCode": status,
"responseHeaders": response_headers,
"body": BASE64.encode(&body),
}),
)
.await?;
Ok(())
}
pub async fn modify_and_continue(
mut self,
overrides: RequestOverrides,
) -> Result<(), InterceptionError> {
self.released = true;
let mut params = Map::new();
params.insert("requestId".into(), Value::String(self.request_id.clone()));
if let Some(url) = overrides.url {
params.insert("url".into(), Value::String(url));
}
if let Some(method) = overrides.method {
params.insert("method".into(), Value::String(method));
}
if let Some(headers) = overrides.headers {
params.insert(
"headers".into(),
Value::Array(crate::actor::headers_to_cdp(&headers)),
);
}
if let Some(post_data) = overrides.post_data {
params.insert("postData".into(), Value::String(BASE64.encode(&post_data)));
}
self.session
.call("Fetch.continueRequest", Value::Object(params))
.await?;
Ok(())
}
pub async fn continue_response(
mut self,
status: Option<u16>,
phrase: Option<String>,
headers: Option<Vec<(String, String)>>,
) -> Result<(), InterceptionError> {
if self.response.is_none() {
return Err(InterceptionError::WrongStage);
}
self.released = true;
let mut params = Map::new();
params.insert("requestId".into(), Value::String(self.request_id.clone()));
if let Some(status) = status {
params.insert("responseCode".into(), Value::from(status));
}
if let Some(phrase) = phrase {
params.insert("responsePhrase".into(), Value::String(phrase));
}
if let Some(headers) = headers {
params.insert(
"responseHeaders".into(),
Value::Array(crate::actor::headers_to_cdp(&headers)),
);
}
self.session
.call("Fetch.continueResponse", Value::Object(params))
.await?;
Ok(())
}
pub async fn body(&self) -> Result<Vec<u8>, InterceptionError> {
let res = self
.session
.call(
"Fetch.getResponseBody",
json!({ "requestId": self.request_id }),
)
.await?;
let body = res.get("body").and_then(Value::as_str).ok_or_else(|| {
InterceptionError::InvalidResponse(
"Fetch.getResponseBody returned no body field".into(),
)
})?;
let base64_encoded = res
.get("base64Encoded")
.and_then(Value::as_bool)
.unwrap_or(false);
if base64_encoded {
BASE64
.decode(body)
.map_err(|e| InterceptionError::InvalidResponse(format!("invalid base64: {e}")))
} else {
Ok(body.as_bytes().to_vec())
}
}
}
impl Drop for PausedRequest {
fn drop(&mut self) {
if self.released {
return;
}
let session = self.session.clone();
let request_id = std::mem::take(&mut self.request_id);
tokio::spawn(async move {
if let Err(e) = session
.call("Fetch.continueRequest", json!({ "requestId": request_id }))
.await
{
tracing::debug!(
error = %e,
request_id = %request_id,
"PausedRequest::drop: best-effort Fetch.continueRequest failed (session likely closed)"
);
}
});
}
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
use crate::types::ResourceType;
use zendriver_transport::SessionHandle;
use zendriver_transport::testing::MockConnection;
fn make_request_info() -> RequestInfo {
RequestInfo {
url: "https://example.test/widget".into(),
method: "GET".into(),
headers: Vec::new(),
post_data: None,
resource_type: ResourceType::XHR,
}
}
fn make_response_info() -> ResponseInfo {
ResponseInfo {
status: 200,
status_text: "OK".into(),
headers: Vec::new(),
}
}
#[tokio::test]
async fn continue_dispatches_fetch_continue_request() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let req = PausedRequest::new("REQ-1", make_request_info(), None, sess);
let fut = tokio::spawn(async move { req.continue_().await });
let id = mock.expect_cmd("Fetch.continueRequest").await;
let sent = mock.last_sent();
assert_eq!(sent["params"]["requestId"], "REQ-1");
let params_obj = sent["params"].as_object().unwrap();
assert_eq!(params_obj.len(), 1);
mock.reply(id, serde_json::json!({})).await;
fut.await.unwrap().unwrap();
conn.shutdown();
}
#[tokio::test]
async fn drop_without_terminal_action_fires_fallback_continue() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
{
let req = PausedRequest::new("REQ-DROP", make_request_info(), None, sess);
drop(req);
}
let id = mock.expect_cmd("Fetch.continueRequest").await;
let sent = mock.last_sent();
assert_eq!(sent["params"]["requestId"], "REQ-DROP");
mock.reply(id, serde_json::json!({})).await;
conn.shutdown();
}
#[tokio::test]
async fn continue_does_not_double_fire_on_drop() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let req = PausedRequest::new("REQ-ONCE", make_request_info(), None, sess);
let fut = tokio::spawn(async move { req.continue_().await });
let id = mock.expect_cmd("Fetch.continueRequest").await;
assert_eq!(mock.last_sent()["params"]["requestId"], "REQ-ONCE");
mock.reply(id, serde_json::json!({})).await;
fut.await.unwrap().unwrap();
tokio::task::yield_now().await;
assert!(
mock.try_recv_cmd().is_none(),
"Drop fired a second Fetch.continueRequest after continue_ already released"
);
conn.shutdown();
}
#[tokio::test]
async fn respond_dispatches_fetch_fulfill_with_base64_body() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let req = PausedRequest::new("REQ-2", make_request_info(), None, sess);
let body = b"hello world".to_vec();
let expected_b64 = BASE64.encode(&body);
let fut = tokio::spawn(async move {
req.respond(
200,
vec![("content-type".into(), "text/plain".into())],
body,
)
.await
});
let id = mock.expect_cmd("Fetch.fulfillRequest").await;
let sent = mock.last_sent();
assert_eq!(sent["params"]["requestId"], "REQ-2");
assert_eq!(sent["params"]["responseCode"], 200);
assert_eq!(sent["params"]["body"], expected_b64);
let headers = sent["params"]["responseHeaders"].as_array().unwrap();
assert_eq!(headers.len(), 1);
assert_eq!(headers[0]["name"], "content-type");
assert_eq!(headers[0]["value"], "text/plain");
mock.reply(id, serde_json::json!({})).await;
fut.await.unwrap().unwrap();
conn.shutdown();
}
#[tokio::test]
async fn continue_response_dispatches_fetch_continue_response() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let req = PausedRequest::new(
"REQ-CR",
make_request_info(),
Some(make_response_info()),
sess,
);
let fut = tokio::spawn(async move {
req.continue_response(Some(204), None, Some(vec![("x".into(), "y".into())]))
.await
});
let id = mock.expect_cmd("Fetch.continueResponse").await;
let sent = mock.last_sent();
assert_eq!(sent["params"]["requestId"], "REQ-CR");
assert_eq!(sent["params"]["responseCode"], 204);
assert!(
sent["params"]
.as_object()
.unwrap()
.get("responsePhrase")
.is_none()
);
let headers = sent["params"]["responseHeaders"].as_array().unwrap();
assert_eq!(headers.len(), 1);
assert_eq!(headers[0]["name"], "x");
assert_eq!(headers[0]["value"], "y");
mock.reply(id, serde_json::json!({})).await;
fut.await.unwrap().unwrap();
conn.shutdown();
}
#[tokio::test]
async fn continue_response_wrong_stage_errs() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let req = PausedRequest::new("REQ-WS", make_request_info(), None, sess);
let err = req
.continue_response(Some(200), None, None)
.await
.expect_err("continue_response at Request stage must error");
assert!(matches!(err, InterceptionError::WrongStage));
tokio::task::yield_now().await;
while let Some((method, _id)) = mock.try_recv_cmd() {
assert_ne!(
method, "Fetch.continueResponse",
"WrongStage path must not dispatch Fetch.continueResponse"
);
}
conn.shutdown();
}
#[tokio::test]
async fn continue_response_no_double_fire_on_drop() {
let (mut mock, conn) = MockConnection::pair();
let sess = SessionHandle::new(conn.clone(), "S1");
let req = PausedRequest::new(
"REQ-CR-ONCE",
make_request_info(),
Some(make_response_info()),
sess,
);
let fut = tokio::spawn(async move { req.continue_response(Some(200), None, None).await });
let id = mock.expect_cmd("Fetch.continueResponse").await;
assert_eq!(mock.last_sent()["params"]["requestId"], "REQ-CR-ONCE");
mock.reply(id, serde_json::json!({})).await;
fut.await.unwrap().unwrap();
tokio::task::yield_now().await;
assert!(
mock.try_recv_cmd().is_none(),
"Drop fired a fallback continueRequest after continue_response already released"
);
conn.shutdown();
}
}