1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct DomainPattern(pub String);
6
7impl DomainPattern {
8 pub fn matches(&self, domain: &str) -> bool {
9 let pattern = self.0.to_lowercase();
10 let domain = domain.to_lowercase();
11
12 if let Some(suffix) = pattern.strip_prefix("*.") {
13 domain.ends_with(&format!(".{suffix}"))
14 } else {
15 domain == pattern
16 }
17 }
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FetchPolicy {
23 pub allowed_domains: Option<Vec<DomainPattern>>,
25 pub blocked_domains: Vec<DomainPattern>,
27 pub deny_private_ips: bool,
29 pub allowed_methods: Vec<String>,
31 pub allowed_schemes: Vec<String>,
33 pub max_request_body_bytes: usize,
35 pub max_response_body_bytes: usize,
37 pub connect_timeout_ms: u64,
39 pub request_timeout_ms: u64,
41 pub max_redirects: u8,
43 pub max_concurrent_requests: usize,
45 pub max_requests_per_minute: u32,
47}
48
49impl Default for FetchPolicy {
50 fn default() -> Self {
51 Self {
52 allowed_domains: None,
53 blocked_domains: Vec::new(),
54 deny_private_ips: true,
55 allowed_methods: vec![
56 "GET".into(),
57 "POST".into(),
58 "PUT".into(),
59 "PATCH".into(),
60 "DELETE".into(),
61 "HEAD".into(),
62 "OPTIONS".into(),
63 ],
64 allowed_schemes: vec!["https".into(), "http".into()],
65 max_request_body_bytes: 10 * 1024 * 1024,
66 max_response_body_bytes: 50 * 1024 * 1024,
67 connect_timeout_ms: 10_000,
68 request_timeout_ms: 30_000,
69 max_redirects: 10,
70 max_concurrent_requests: 50,
71 max_requests_per_minute: 500,
72 }
73 }
74}
75
76impl FetchPolicy {
77 pub fn check_domain(&self, domain: &str) -> Result<(), crate::error::FetchError> {
79 for pat in &self.blocked_domains {
80 if pat.matches(domain) {
81 return Err(crate::error::FetchError::DomainBlocked(domain.to_string()));
82 }
83 }
84 if let Some(ref allowed) = self.allowed_domains {
85 if !allowed.iter().any(|pat| pat.matches(domain)) {
86 return Err(crate::error::FetchError::DomainNotAllowed(
87 domain.to_string(),
88 ));
89 }
90 }
91 Ok(())
92 }
93
94 pub fn check_scheme(&self, scheme: &str) -> Result<(), crate::error::FetchError> {
95 if !self
96 .allowed_schemes
97 .iter()
98 .any(|s| s.eq_ignore_ascii_case(scheme))
99 {
100 return Err(crate::error::FetchError::SchemeNotAllowed(
101 scheme.to_string(),
102 ));
103 }
104 Ok(())
105 }
106
107 pub fn check_method(&self, method: &str) -> Result<(), crate::error::FetchError> {
108 if !self
109 .allowed_methods
110 .iter()
111 .any(|m| m.eq_ignore_ascii_case(method))
112 {
113 return Err(crate::error::FetchError::MethodNotAllowed(
114 method.to_string(),
115 ));
116 }
117 Ok(())
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn exact_domain_match() {
127 let pat = DomainPattern("api.example.com".into());
128 assert!(pat.matches("api.example.com"));
129 assert!(pat.matches("API.EXAMPLE.COM"));
130 assert!(!pat.matches("other.example.com"));
131 assert!(!pat.matches("example.com"));
132 }
133
134 #[test]
135 fn wildcard_domain_match() {
136 let pat = DomainPattern("*.example.com".into());
137 assert!(pat.matches("api.example.com"));
138 assert!(pat.matches("deep.sub.example.com"));
139 assert!(!pat.matches("example.com")); assert!(!pat.matches("example.org"));
141 assert!(!pat.matches("notexample.com"));
142 }
143
144 #[test]
145 fn blocked_takes_precedence() {
146 let policy = FetchPolicy {
147 allowed_domains: Some(vec![DomainPattern("*.example.com".into())]),
148 blocked_domains: vec![DomainPattern("evil.example.com".into())],
149 ..Default::default()
150 };
151
152 assert!(policy.check_domain("api.example.com").is_ok());
153 assert!(policy.check_domain("evil.example.com").is_err());
154 }
155
156 #[test]
157 fn allowlist_rejects_unlisted() {
158 let policy = FetchPolicy {
159 allowed_domains: Some(vec![DomainPattern("api.example.com".into())]),
160 ..Default::default()
161 };
162
163 assert!(policy.check_domain("api.example.com").is_ok());
164 assert!(policy.check_domain("other.example.com").is_err());
165 }
166
167 #[test]
168 fn no_allowlist_allows_all() {
169 let policy = FetchPolicy::default();
170 assert!(policy.check_domain("anything.example.com").is_ok());
171 }
172
173 #[test]
174 fn scheme_validation() {
175 let policy = FetchPolicy::default();
176 assert!(policy.check_scheme("https").is_ok());
177 assert!(policy.check_scheme("http").is_ok());
178 assert!(policy.check_scheme("ftp").is_err());
179 }
180
181 #[test]
182 fn method_validation() {
183 let policy = FetchPolicy::default();
184 assert!(policy.check_method("GET").is_ok());
185 assert!(policy.check_method("get").is_ok());
186 assert!(policy.check_method("TRACE").is_err());
187 }
188}