use anyhow::{Result, bail};
pub const DEFAULT_DEV_ALLOWLIST: &[&str] = &[
"github.com",
"*.githubusercontent.com",
"api.github.com",
"codeload.github.com",
"*.npmjs.org",
"registry.npmjs.org",
"*.pypi.org",
"pypi.org",
"files.pythonhosted.org",
"crates.io",
"static.crates.io",
"index.crates.io",
"*.docker.io",
"registry-1.docker.io",
"auth.docker.io",
];
#[derive(Debug, Clone, PartialEq, Eq)]
enum Pattern {
Exact(String),
Wildcard(String),
}
#[derive(Debug, Clone, Default)]
pub struct Filter {
patterns: Vec<Pattern>,
}
impl Filter {
pub fn new<I, S>(patterns: I) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let mut out = Vec::new();
for raw in patterns {
out.push(parse_pattern(raw.as_ref())?);
}
Ok(Self { patterns: out })
}
pub fn allows(&self, host: &str) -> bool {
let host = strip_port(host).to_ascii_lowercase();
for p in &self.patterns {
match p {
Pattern::Exact(want) => {
if host == *want {
return true;
}
}
Pattern::Wildcard(suffix) => {
if let Some(prefix) = host.strip_suffix(suffix) {
if !prefix.is_empty() && !prefix.contains('.') {
return true;
}
}
}
}
}
false
}
pub fn len(&self) -> usize {
self.patterns.len()
}
pub fn is_empty(&self) -> bool {
self.patterns.is_empty()
}
}
fn strip_port(host: &str) -> &str {
match host.rsplit_once(':') {
Some((h, _)) => h,
None => host,
}
}
fn parse_pattern(raw: &str) -> Result<Pattern> {
let raw = raw.trim();
if raw.is_empty() {
bail!("pattern must not be empty");
}
let lower = raw.to_ascii_lowercase();
if let Some(rest) = lower.strip_prefix("*.") {
if rest.is_empty() {
bail!("pattern {raw:?}: wildcard must be followed by at least one label");
}
if rest.contains('*') {
bail!("pattern {raw:?}: wildcard may only appear as the leftmost label");
}
Ok(Pattern::Wildcard(format!(".{rest}")))
} else {
if lower.contains('*') {
bail!("pattern {raw:?}: wildcard only allowed as leftmost label, e.g. *.example.com");
}
Ok(Pattern::Exact(lower))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_pattern_accepts_exact() {
assert_eq!(
parse_pattern("github.com").unwrap(),
Pattern::Exact("github.com".into())
);
}
#[test]
fn parse_pattern_lowercases_exact() {
assert_eq!(
parse_pattern("GitHub.COM").unwrap(),
Pattern::Exact("github.com".into())
);
}
#[test]
fn parse_pattern_accepts_wildcard() {
assert_eq!(
parse_pattern("*.npmjs.org").unwrap(),
Pattern::Wildcard(".npmjs.org".into())
);
}
#[test]
fn parse_pattern_lowercases_wildcard() {
assert_eq!(
parse_pattern("*.NPMJS.org").unwrap(),
Pattern::Wildcard(".npmjs.org".into())
);
}
#[test]
fn parse_pattern_trims_whitespace() {
assert_eq!(
parse_pattern(" github.com ").unwrap(),
Pattern::Exact("github.com".into())
);
}
#[test]
fn parse_pattern_rejects_empty() {
let err = parse_pattern("").expect_err("must reject empty");
assert!(err.to_string().contains("must not be empty"));
}
#[test]
fn parse_pattern_rejects_whitespace_only() {
let err = parse_pattern(" ").expect_err("must reject whitespace-only");
assert!(err.to_string().contains("must not be empty"));
}
#[test]
fn parse_pattern_rejects_bare_wildcard() {
let err = parse_pattern("*.").expect_err("must reject bare *.");
assert!(err.to_string().contains("at least one label"));
}
#[test]
fn parse_pattern_rejects_internal_wildcard() {
let err = parse_pattern("foo.*.com").expect_err("must reject internal *");
assert!(err.to_string().contains("leftmost label"));
}
#[test]
fn parse_pattern_rejects_trailing_wildcard() {
let err = parse_pattern("foo.*").expect_err("must reject trailing *");
assert!(err.to_string().contains("leftmost label"));
}
#[test]
fn parse_pattern_rejects_double_wildcard() {
let err = parse_pattern("*.*.com").expect_err("must reject *.*");
assert!(err.to_string().contains("leftmost label"));
}
#[test]
fn parse_pattern_rejects_prefix_glob() {
let err = parse_pattern("*foo.com").expect_err("must reject *foo");
assert!(err.to_string().contains("leftmost label"));
}
#[test]
fn allows_exact_match() {
let f = Filter::new(["github.com"]).unwrap();
assert!(f.allows("github.com"));
}
#[test]
fn allows_exact_match_case_insensitive() {
let f = Filter::new(["github.com"]).unwrap();
assert!(f.allows("GitHub.COM"));
}
#[test]
fn allows_exact_strips_port() {
let f = Filter::new(["github.com"]).unwrap();
assert!(f.allows("github.com:443"));
}
#[test]
fn allows_exact_does_not_match_subdomain() {
let f = Filter::new(["github.com"]).unwrap();
assert!(!f.allows("api.github.com"));
}
#[test]
fn allows_exact_does_not_match_suffix() {
let f = Filter::new(["github.com"]).unwrap();
assert!(!f.allows("evilgithub.com"));
}
#[test]
fn allows_wildcard_matches_one_subdomain() {
let f = Filter::new(["*.npmjs.org"]).unwrap();
assert!(f.allows("registry.npmjs.org"));
assert!(f.allows("www.npmjs.org"));
}
#[test]
fn allows_wildcard_does_not_match_apex() {
let f = Filter::new(["*.npmjs.org"]).unwrap();
assert!(!f.allows("npmjs.org"));
}
#[test]
fn allows_wildcard_does_not_match_two_labels() {
let f = Filter::new(["*.npmjs.org"]).unwrap();
assert!(!f.allows("a.b.npmjs.org"));
}
#[test]
fn allows_wildcard_does_not_match_suffix_attack() {
let f = Filter::new(["*.npmjs.org"]).unwrap();
assert!(!f.allows("evil.npmjs.org.attacker.com"));
}
#[test]
fn allows_wildcard_strips_port() {
let f = Filter::new(["*.npmjs.org"]).unwrap();
assert!(f.allows("registry.npmjs.org:443"));
}
#[test]
fn allows_wildcard_case_insensitive() {
let f = Filter::new(["*.npmjs.org"]).unwrap();
assert!(f.allows("REGISTRY.NPMJS.ORG"));
}
#[test]
fn allows_apex_and_wildcard_together() {
let f = Filter::new(["pypi.org", "*.pypi.org"]).unwrap();
assert!(f.allows("pypi.org"));
assert!(f.allows("files.pypi.org"));
}
#[test]
fn allows_returns_false_when_empty() {
let f = Filter::default();
assert!(!f.allows("github.com"));
assert!(!f.allows("anything.example.com"));
}
#[test]
fn allows_short_circuits_on_first_match() {
let f = Filter::new(["github.com", "*.npmjs.org"]).unwrap();
assert!(f.allows("github.com"));
assert!(f.allows("registry.npmjs.org"));
}
#[test]
fn default_dev_allowlist_parses_cleanly() {
let f = Filter::new(DEFAULT_DEV_ALLOWLIST).expect("default allowlist must parse");
assert_eq!(f.len(), DEFAULT_DEV_ALLOWLIST.len());
}
#[test]
fn default_dev_allowlist_covers_common_cases() {
let f = Filter::new(DEFAULT_DEV_ALLOWLIST).unwrap();
assert!(f.allows("github.com"));
assert!(f.allows("api.github.com"));
assert!(f.allows("registry.npmjs.org"));
assert!(f.allows("pypi.org"));
assert!(f.allows("files.pythonhosted.org"));
assert!(f.allows("crates.io"));
assert!(f.allows("registry-1.docker.io"));
assert!(!f.allows("evil.example.com"));
assert!(!f.allows("nation-state.adversary.io"));
}
#[test]
fn len_and_is_empty() {
let f = Filter::default();
assert_eq!(f.len(), 0);
assert!(f.is_empty());
let f = Filter::new(["a.com", "b.com"]).unwrap();
assert_eq!(f.len(), 2);
assert!(!f.is_empty());
}
#[test]
fn strip_port_handles_no_port() {
assert_eq!(strip_port("github.com"), "github.com");
}
#[test]
fn strip_port_handles_port() {
assert_eq!(strip_port("github.com:443"), "github.com");
}
}