use std::sync::Arc;
use serde_json::{Value, json};
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::mpsc;
use tokio::task::AbortHandle;
use crate::Result;
use crate::browser::interceptor::ResumeOptions;
use crate::browser::listener::ListenFilter;
use crate::cdp::core::CdpCore;
use crate::protocol::Connection;
use crate::util::base64_encode;
#[derive(Default)]
pub(crate) struct InterceptShared {
pub(crate) running: bool,
pub(crate) abort: Option<AbortHandle>,
pub(crate) rx: Option<mpsc::UnboundedReceiver<CdpInterceptedRequest>>,
}
pub struct CdpIntercept {
core: Arc<CdpCore>,
}
impl CdpIntercept {
pub(crate) fn new(core: Arc<CdpCore>) -> Self {
Self { core }
}
pub async fn start(&self, keywords: &[&str]) -> Result<()> {
self.start_with(keywords, false).await
}
pub async fn start_xhr(&self, keywords: &[&str]) -> Result<()> {
self.start_with(keywords, true).await
}
async fn start_with(&self, keywords: &[&str], xhr_only: bool) -> Result<()> {
let filter = ListenFilter {
url_keywords: keywords.iter().map(|s| s.to_string()).collect(),
xhr_only,
};
self.stop().await?;
self.core
.send(
"Fetch.enable",
json!({ "patterns": [{ "urlPattern": "*" }] }),
)
.await?;
let (tx, rx) = mpsc::unbounded_channel();
let task = tokio::spawn(intercept_pump(
self.core.conn.clone(),
self.core.session_id.clone(),
filter,
tx,
));
let mut g = self.core.intercept.lock().await;
g.running = true;
g.abort = Some(task.abort_handle());
g.rx = Some(rx);
Ok(())
}
pub async fn is_intercepting(&self) -> bool {
self.core.intercept.lock().await.running
}
pub async fn next(
&self,
timeout: Option<std::time::Duration>,
) -> Result<Option<CdpInterceptedRequest>> {
let d = timeout.unwrap_or_else(|| self.core.timeout());
let mut rx = match self.core.intercept.lock().await.rx.take() {
Some(rx) => rx,
None => return Ok(None),
};
let got = tokio::time::timeout(d, rx.recv()).await.ok().flatten();
self.core.intercept.lock().await.rx = Some(rx);
Ok(got)
}
pub async fn stop(&self) -> Result<()> {
let abort = {
let mut g = self.core.intercept.lock().await;
g.running = false;
g.rx = None;
g.abort.take()
};
if let Some(a) = abort {
a.abort();
let _ = self.core.send("Fetch.disable", json!({})).await;
}
Ok(())
}
}
pub struct CdpInterceptedRequest {
pub url: String,
pub method: String,
pub resource_type: String,
pub headers: Vec<(String, String)>,
pub post_data: Option<String>,
request_id: String,
conn: Connection,
session_id: String,
}
impl CdpInterceptedRequest {
async fn fetch(&self, method: &str, params: Value) -> Result<()> {
self.conn
.send(method, params, Some(&self.session_id))
.await?;
Ok(())
}
pub async fn resume(self) -> Result<()> {
self.fetch(
"Fetch.continueRequest",
json!({ "requestId": self.request_id }),
)
.await
}
pub async fn resume_with(self, opts: ResumeOptions) -> Result<()> {
let mut p = serde_json::Map::new();
p.insert("requestId".into(), json!(self.request_id));
if let Some(u) = opts.url {
p.insert("url".into(), json!(u));
}
if let Some(m) = opts.method {
p.insert("method".into(), json!(m));
}
if let Some(h) = opts.headers {
p.insert("headers".into(), json!(to_header_array(&h)));
}
if let Some(d) = opts.post_data {
p.insert("postData".into(), json!(base64_encode(d.as_bytes())));
}
self.fetch("Fetch.continueRequest", Value::Object(p)).await
}
pub async fn fulfill(
self,
status: u16,
headers: Vec<(String, String)>,
body: &str,
) -> Result<()> {
let p = json!({
"requestId": self.request_id,
"responseCode": status,
"responseHeaders": to_header_array(&headers),
"body": base64_encode(body.as_bytes()),
});
self.fetch("Fetch.fulfillRequest", p).await
}
pub async fn abort(self, error_code: &str) -> Result<()> {
self.fetch(
"Fetch.failRequest",
json!({ "requestId": self.request_id, "errorReason": fail_reason(error_code) }),
)
.await
}
pub fn request_id(&self) -> &str {
&self.request_id
}
}
async fn intercept_pump(
conn: Connection,
session_id: String,
filter: ListenFilter,
tx: mpsc::UnboundedSender<CdpInterceptedRequest>,
) {
let auto_continue = |request_id: &str| {
let conn = conn.clone();
let sid = session_id.clone();
let id = request_id.to_string();
async move {
let _ = conn
.send(
"Fetch.continueRequest",
json!({ "requestId": id }),
Some(&sid),
)
.await;
}
};
let mut events = conn.subscribe();
loop {
let ev = match events.recv().await {
Ok(ev) => ev,
Err(RecvError::Lagged(_)) => continue,
Err(RecvError::Closed) => break,
};
if ev.session_id.as_deref() != Some(session_id.as_str()) {
continue;
}
if ev.method != "Fetch.requestPaused" {
continue;
}
let Some(request_id) = ev.params["requestId"].as_str() else {
continue;
};
let req = &ev.params["request"];
let url = req["url"].as_str().unwrap_or_default().to_string();
let resource_type = ev.params["resourceType"]
.as_str()
.unwrap_or_default()
.to_string();
if !filter.matches(&url, &resource_type) {
auto_continue(request_id).await;
continue;
}
let intercepted = CdpInterceptedRequest {
url,
method: req["method"].as_str().unwrap_or_default().to_string(),
resource_type,
headers: header_map_to_pairs(&req["headers"]),
post_data: req["postData"].as_str().map(str::to_string),
request_id: request_id.to_string(),
conn: conn.clone(),
session_id: session_id.clone(),
};
if tx.send(intercepted).is_err() {
auto_continue(request_id).await;
}
}
}
fn to_header_array(headers: &[(String, String)]) -> Vec<Value> {
headers
.iter()
.map(|(n, v)| json!({ "name": n, "value": v }))
.collect()
}
fn header_map_to_pairs(v: &Value) -> Vec<(String, String)> {
v.as_object()
.map(|o| {
o.iter()
.map(|(k, val)| (k.clone(), val.as_str().unwrap_or_default().to_string()))
.collect()
})
.unwrap_or_default()
}
fn fail_reason(code: &str) -> &'static str {
match code.to_ascii_lowercase().as_str() {
"aborted" => "Aborted",
"timedout" | "timeout" => "TimedOut",
"accessdenied" => "AccessDenied",
"connectionclosed" => "ConnectionClosed",
"connectionreset" => "ConnectionReset",
"connectionrefused" => "ConnectionRefused",
"namenotresolved" => "NameNotResolved",
"internetdisconnected" => "InternetDisconnected",
"addressunreachable" => "AddressUnreachable",
"blockedbyclient" => "BlockedByClient",
"blockedbyresponse" => "BlockedByResponse",
_ => "Failed",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fail_reason_maps_known_and_unknown() {
assert_eq!(fail_reason("aborted"), "Aborted");
assert_eq!(fail_reason("BlockedByClient"), "BlockedByClient");
assert_eq!(fail_reason("namenotresolved"), "NameNotResolved");
assert_eq!(fail_reason("whatever"), "Failed");
}
#[test]
fn header_array_shape() {
let h = vec![("Content-Type".to_string(), "text/html".to_string())];
let arr = to_header_array(&h);
assert_eq!(arr[0]["name"], "Content-Type");
assert_eq!(arr[0]["value"], "text/html");
}
}