use std::fmt;
use std::sync::Arc;
use crate::host_matcher::HostMatcher;
use crate::types::{RequestInfo, RequestOverrides, ResponseInfo, ResponseOverrides};
use crate::url_pattern::UrlPattern;
pub enum Rule {
Block {
pattern: UrlPattern,
},
Redirect {
from: UrlPattern,
to: String,
},
Respond {
pattern: UrlPattern,
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
},
Modify {
pattern: UrlPattern,
modify: Arc<dyn Fn(&RequestInfo) -> RequestOverrides + Send + Sync>,
},
ModifyResponse {
pattern: UrlPattern,
modify: Arc<dyn Fn(&ResponseInfo) -> ResponseOverrides + Send + Sync>,
},
BlockHosts {
matcher: Arc<HostMatcher>,
},
}
impl Rule {
pub fn matches(&self, url: &str) -> bool {
match self {
Self::Block { pattern }
| Self::Respond { pattern, .. }
| Self::Modify { pattern, .. }
| Self::ModifyResponse { pattern, .. } => pattern.matches(url),
Self::Redirect { from, .. } => from.matches(url),
Self::BlockHosts { matcher } => {
crate::host_matcher::host_of(url).is_some_and(|h| matcher.is_blocked(h))
}
}
}
}
impl fmt::Debug for Rule {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Block { pattern } => f.debug_struct("Block").field("pattern", pattern).finish(),
Self::Redirect { from, to } => f
.debug_struct("Redirect")
.field("from", from)
.field("to", to)
.finish(),
Self::Respond {
pattern,
status,
headers,
body,
} => f
.debug_struct("Respond")
.field("pattern", pattern)
.field("status", status)
.field("headers", headers)
.field("body_len", &body.len())
.finish(),
Self::Modify { pattern, .. } => f
.debug_struct("Modify")
.field("pattern", pattern)
.field("modify", &"<closure>")
.finish(),
Self::ModifyResponse { pattern, .. } => f
.debug_struct("ModifyResponse")
.field("pattern", pattern)
.field("modify", &"<closure>")
.finish(),
Self::BlockHosts { matcher } => f
.debug_struct("BlockHosts")
.field("hosts", &matcher.len())
.finish(),
}
}
}
#[cfg(test)]
#[allow(clippy::panic, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn block_matches_via_pattern() {
let rule = Rule::Block {
pattern: UrlPattern::new("*/ads/*").unwrap(),
};
assert!(rule.matches("https://example.com/ads/banner.png"));
assert!(!rule.matches("https://example.com/content/main.css"));
}
#[test]
fn redirect_matches_via_from_field() {
let rule = Rule::Redirect {
from: UrlPattern::new("*/old/*").unwrap(),
to: "https://example.com/new/replacement".into(),
};
assert!(rule.matches("https://example.com/old/page.html"));
assert!(!rule.matches("https://example.com/new/page.html"));
}
#[test]
fn block_hosts_matches_on_host_and_subdomain() {
let rule = Rule::BlockHosts {
matcher: Arc::new(HostMatcher::new(["evil.com".to_string()])),
};
assert!(rule.matches("https://evil.com/track.js"));
assert!(rule.matches("https://a.b.evil.com/x?y=1"));
assert!(!rule.matches("https://good.com/app.js"));
assert!(!rule.matches("https://notevil.com/app.js"));
let dbg = format!("{rule:?}");
assert!(dbg.contains("BlockHosts"), "got: {dbg}");
assert!(dbg.contains("hosts"), "got: {dbg}");
}
#[test]
fn rule_modify_response_matches_and_debug() {
let rule = Rule::ModifyResponse {
pattern: UrlPattern::new("*/api/*").unwrap(),
modify: Arc::new(|_resp: &ResponseInfo| ResponseOverrides {
status: Some(418),
..ResponseOverrides::default()
}),
};
assert!(rule.matches("https://example.com/api/users"));
assert!(!rule.matches("https://example.com/static/app.js"));
let dbg = format!("{rule:?}");
assert!(dbg.contains("ModifyResponse"), "got: {dbg}");
assert!(dbg.contains("*/api/*"), "got: {dbg}");
assert!(dbg.contains("<closure>"), "got: {dbg}");
}
}