use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainPattern(pub String);
impl DomainPattern {
pub fn matches(&self, domain: &str) -> bool {
let pattern = self.0.to_lowercase();
let domain = domain.to_lowercase();
if let Some(suffix) = pattern.strip_prefix("*.") {
domain.ends_with(&format!(".{suffix}"))
} else {
domain == pattern
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FetchPolicy {
pub allowed_domains: Option<Vec<DomainPattern>>,
pub blocked_domains: Vec<DomainPattern>,
pub deny_private_ips: bool,
pub allowed_methods: Vec<String>,
pub allowed_schemes: Vec<String>,
pub max_request_body_bytes: usize,
pub max_response_body_bytes: usize,
pub connect_timeout_ms: u64,
pub request_timeout_ms: u64,
pub max_redirects: u8,
pub max_concurrent_requests: usize,
pub max_requests_per_minute: u32,
}
impl Default for FetchPolicy {
fn default() -> Self {
Self {
allowed_domains: None,
blocked_domains: Vec::new(),
deny_private_ips: true,
allowed_methods: vec![
"GET".into(),
"POST".into(),
"PUT".into(),
"PATCH".into(),
"DELETE".into(),
"HEAD".into(),
"OPTIONS".into(),
],
allowed_schemes: vec!["https".into(), "http".into()],
max_request_body_bytes: 10 * 1024 * 1024,
max_response_body_bytes: 50 * 1024 * 1024,
connect_timeout_ms: 10_000,
request_timeout_ms: 30_000,
max_redirects: 10,
max_concurrent_requests: 50,
max_requests_per_minute: 500,
}
}
}
impl FetchPolicy {
pub fn check_domain(&self, domain: &str) -> Result<(), crate::error::FetchError> {
for pat in &self.blocked_domains {
if pat.matches(domain) {
return Err(crate::error::FetchError::DomainBlocked(domain.to_string()));
}
}
if let Some(ref allowed) = self.allowed_domains {
if !allowed.iter().any(|pat| pat.matches(domain)) {
return Err(crate::error::FetchError::DomainNotAllowed(
domain.to_string(),
));
}
}
Ok(())
}
pub fn check_scheme(&self, scheme: &str) -> Result<(), crate::error::FetchError> {
if !self
.allowed_schemes
.iter()
.any(|s| s.eq_ignore_ascii_case(scheme))
{
return Err(crate::error::FetchError::SchemeNotAllowed(
scheme.to_string(),
));
}
Ok(())
}
pub fn check_method(&self, method: &str) -> Result<(), crate::error::FetchError> {
if !self
.allowed_methods
.iter()
.any(|m| m.eq_ignore_ascii_case(method))
{
return Err(crate::error::FetchError::MethodNotAllowed(
method.to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exact_domain_match() {
let pat = DomainPattern("api.example.com".into());
assert!(pat.matches("api.example.com"));
assert!(pat.matches("API.EXAMPLE.COM"));
assert!(!pat.matches("other.example.com"));
assert!(!pat.matches("example.com"));
}
#[test]
fn wildcard_domain_match() {
let pat = DomainPattern("*.example.com".into());
assert!(pat.matches("api.example.com"));
assert!(pat.matches("deep.sub.example.com"));
assert!(!pat.matches("example.com")); assert!(!pat.matches("example.org"));
assert!(!pat.matches("notexample.com"));
}
#[test]
fn blocked_takes_precedence() {
let policy = FetchPolicy {
allowed_domains: Some(vec![DomainPattern("*.example.com".into())]),
blocked_domains: vec![DomainPattern("evil.example.com".into())],
..Default::default()
};
assert!(policy.check_domain("api.example.com").is_ok());
assert!(policy.check_domain("evil.example.com").is_err());
}
#[test]
fn allowlist_rejects_unlisted() {
let policy = FetchPolicy {
allowed_domains: Some(vec![DomainPattern("api.example.com".into())]),
..Default::default()
};
assert!(policy.check_domain("api.example.com").is_ok());
assert!(policy.check_domain("other.example.com").is_err());
}
#[test]
fn no_allowlist_allows_all() {
let policy = FetchPolicy::default();
assert!(policy.check_domain("anything.example.com").is_ok());
}
#[test]
fn scheme_validation() {
let policy = FetchPolicy::default();
assert!(policy.check_scheme("https").is_ok());
assert!(policy.check_scheme("http").is_ok());
assert!(policy.check_scheme("ftp").is_err());
}
#[test]
fn method_validation() {
let policy = FetchPolicy::default();
assert!(policy.check_method("GET").is_ok());
assert!(policy.check_method("get").is_ok());
assert!(policy.check_method("TRACE").is_err());
}
}