use rustc_hash::FxHashMap;
#[derive(Debug, Clone)]
pub enum RouteAction {
Continue(ContinueOverrides),
Fulfill(FulfillResponse),
Abort(String),
}
#[derive(Debug, Clone, Default)]
pub struct ContinueOverrides {
pub url: Option<String>,
pub method: Option<String>,
pub headers: Option<Vec<(String, String)>>,
pub post_data: Option<Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct FulfillResponse {
pub status: i32,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
pub content_type: Option<String>,
}
impl Default for FulfillResponse {
fn default() -> Self {
Self {
status: 200,
headers: vec![],
body: vec![],
content_type: None,
}
}
}
#[derive(Debug, Clone)]
pub struct InterceptedRequest {
pub request_id: String,
pub url: String,
pub method: String,
pub headers: FxHashMap<String, String>,
pub post_data: Option<String>,
pub resource_type: String,
}
pub struct Route {
request: InterceptedRequest,
action_tx: Option<tokio::sync::oneshot::Sender<RouteAction>>,
}
impl Route {
#[must_use]
pub fn new(request: InterceptedRequest, action_tx: tokio::sync::oneshot::Sender<RouteAction>) -> Self {
Self {
request,
action_tx: Some(action_tx),
}
}
#[must_use]
pub fn request(&self) -> &InterceptedRequest {
&self.request
}
#[must_use]
pub fn network_request(&self) -> crate::network::Request {
crate::network::Request::new(crate::network::RequestInit {
id: self.request.request_id.clone(),
url: self.request.url.clone(),
method: self.request.method.clone(),
resource_type: self.request.resource_type.clone(),
is_navigation_request: false,
post_data: self.request.post_data.clone().map(String::into_bytes),
headers: self.request.headers.clone(),
frame_id: None,
redirected_from: None,
timing: None,
raw_headers_fn: None,
})
}
pub fn fulfill(mut self, response: FulfillResponse) {
if let Some(tx) = self.action_tx.take() {
let _ = tx.send(RouteAction::Fulfill(response));
}
}
pub fn continue_route(mut self, overrides: ContinueOverrides) {
if let Some(tx) = self.action_tx.take() {
let _ = tx.send(RouteAction::Continue(overrides));
}
}
pub fn fallback(mut self, overrides: ContinueOverrides) {
if let Some(tx) = self.action_tx.take() {
let _ = tx.send(RouteAction::Continue(overrides));
}
}
pub fn abort(mut self, reason: &str) {
if let Some(tx) = self.action_tx.take() {
let _ = tx.send(RouteAction::Abort(reason.to_string()));
}
}
}
impl Drop for Route {
fn drop(&mut self) {
if let Some(tx) = self.action_tx.take() {
let _ = tx.send(RouteAction::Continue(ContinueOverrides::default()));
}
}
}
pub type RouteHandler = std::sync::Arc<dyn Fn(Route) + Send + Sync>;
pub struct RegisteredRoute {
pub matcher: crate::url_matcher::UrlMatcher,
pub handler: RouteHandler,
pub remaining: Option<std::sync::Arc<std::sync::atomic::AtomicU32>>,
}
impl RegisteredRoute {
#[must_use]
pub fn new(matcher: crate::url_matcher::UrlMatcher, handler: RouteHandler, times: Option<u32>) -> Self {
Self {
matcher,
handler,
remaining: times.map(|t| std::sync::Arc::new(std::sync::atomic::AtomicU32::new(t))),
}
}
#[must_use]
pub fn live(&self) -> bool {
self
.remaining
.as_ref()
.is_none_or(|c| c.load(std::sync::atomic::Ordering::Acquire) > 0)
}
}
#[must_use]
pub fn take_matching_handler(routes: &mut Vec<RegisteredRoute>, url: &str) -> Option<RouteHandler> {
let idx = routes.iter().position(|r| r.live() && r.matcher.matches(url))?;
let handler = std::sync::Arc::clone(&routes[idx].handler);
let exhausted = routes[idx].remaining.as_ref().is_some_and(|c| {
c.fetch_update(
std::sync::atomic::Ordering::AcqRel,
std::sync::atomic::Ordering::Acquire,
|n| n.checked_sub(1),
)
.map_or(true, |prev| prev <= 1)
});
if exhausted {
routes.remove(idx);
}
Some(handler)
}
#[must_use]
pub fn status_text(code: i32) -> &'static str {
match code {
201 => "Created",
204 => "No Content",
301 => "Moved Permanently",
302 => "Found",
304 => "Not Modified",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
405 => "Method Not Allowed",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
_ => "OK",
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_request() -> InterceptedRequest {
let mut headers = FxHashMap::default();
headers.insert("x-from".to_string(), "test".to_string());
InterceptedRequest {
request_id: "req-1".to_string(),
url: "https://example.com/api".to_string(),
method: "POST".to_string(),
headers,
post_data: Some("hello".to_string()),
resource_type: "Fetch".to_string(),
}
}
#[tokio::test]
async fn fallback_sends_continue_with_overrides() {
let (tx, rx) = tokio::sync::oneshot::channel();
let route = Route::new(sample_request(), tx);
route.fallback(ContinueOverrides {
method: Some("PUT".to_string()),
..Default::default()
});
match rx.await.expect("route action") {
RouteAction::Continue(o) => assert_eq!(o.method.as_deref(), Some("PUT")),
other => panic!("expected Continue, got {other:?}"),
}
}
#[tokio::test]
async fn fallback_without_overrides_sends_unmodified_continue() {
let (tx, rx) = tokio::sync::oneshot::channel();
let route = Route::new(sample_request(), tx);
route.fallback(ContinueOverrides::default());
match rx.await.expect("route action") {
RouteAction::Continue(o) => {
assert!(o.url.is_none() && o.method.is_none() && o.headers.is_none() && o.post_data.is_none());
},
other => panic!("expected Continue, got {other:?}"),
}
}
#[test]
fn network_request_carries_interception_fields() {
let (tx, _rx) = tokio::sync::oneshot::channel();
let route = Route::new(sample_request(), tx);
let req = route.network_request();
assert_eq!(req.url(), "https://example.com/api");
assert_eq!(req.method(), "POST");
assert_eq!(req.resource_type(), "Fetch");
assert_eq!(req.post_data().as_deref(), Some("hello"));
assert!(!req.is_navigation_request());
}
}