agent_sdk/web/
security.rs1use anyhow::{Context, Result, bail};
4use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
5use url::Url;
6
7const DEFAULT_BLOCKED_HOSTS: &[&str] = &[
9 "localhost",
10 "127.0.0.1",
11 "0.0.0.0",
12 "::1",
13 "[::1]",
14 "169.254.169.254", "metadata.google.internal", "metadata.goog", ];
18
19#[derive(Clone, Debug)]
37pub struct UrlValidator {
38 allowed_domains: Option<Vec<String>>,
40 blocked_hosts: Vec<String>,
42 allow_private_ips: bool,
44 max_redirects: usize,
46 require_https: bool,
48}
49
50impl Default for UrlValidator {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl UrlValidator {
57 #[must_use]
59 pub fn new() -> Self {
60 Self {
61 allowed_domains: None,
62 blocked_hosts: DEFAULT_BLOCKED_HOSTS
63 .iter()
64 .map(|&s| s.to_string())
65 .collect(),
66 allow_private_ips: false,
67 max_redirects: 3,
68 require_https: true,
69 }
70 }
71
72 #[must_use]
74 pub fn with_allowed_domains(mut self, domains: Vec<String>) -> Self {
75 self.allowed_domains = Some(domains);
76 self
77 }
78
79 #[must_use]
81 pub fn with_blocked_hosts(mut self, hosts: Vec<String>) -> Self {
82 self.blocked_hosts.extend(hosts);
83 self
84 }
85
86 #[must_use]
88 pub const fn with_allow_private_ips(mut self, allow: bool) -> Self {
89 self.allow_private_ips = allow;
90 self
91 }
92
93 #[must_use]
95 pub const fn with_max_redirects(mut self, max: usize) -> Self {
96 self.max_redirects = max;
97 self
98 }
99
100 #[must_use]
102 pub const fn with_allow_http(mut self) -> Self {
103 self.require_https = false;
104 self
105 }
106
107 #[must_use]
109 pub const fn max_redirects(&self) -> usize {
110 self.max_redirects
111 }
112
113 pub fn validate(&self, url_str: &str) -> Result<Url> {
125 let url = Url::parse(url_str).context("Invalid URL format")?;
126
127 match url.scheme() {
129 "https" => {}
130 "http" => {
131 if self.require_https {
132 bail!("HTTPS required, but HTTP URL provided");
133 }
134 }
135 scheme => bail!("Unsupported URL scheme: {scheme}"),
136 }
137
138 let host = url.host_str().context("URL must have a host")?;
140
141 if self.blocked_hosts.iter().any(|blocked| {
143 host.eq_ignore_ascii_case(blocked) || host.ends_with(&format!(".{blocked}"))
144 }) {
145 bail!("Access to host '{host}' is blocked");
146 }
147
148 if let Some(ref allowed) = self.allowed_domains {
150 let is_allowed = allowed.iter().any(|domain| {
151 host.eq_ignore_ascii_case(domain) || host.ends_with(&format!(".{domain}"))
152 });
153 if !is_allowed {
154 bail!("Host '{host}' is not in the allowed domains list");
155 }
156 }
157
158 self.validate_resolved_ip(host)?;
160
161 Ok(url)
162 }
163
164 fn validate_resolved_ip(&self, host: &str) -> Result<()> {
166 let addrs: Vec<_> = format!("{host}:80")
168 .to_socket_addrs()
169 .map(Iterator::collect)
170 .unwrap_or_default();
171
172 for addr in addrs {
173 let ip = addr.ip();
174 if !self.allow_private_ips && is_private_ip(&ip) {
175 bail!("Access to private IP address {ip} is blocked");
176 }
177 if is_loopback(&ip) {
178 bail!("Access to loopback address {ip} is blocked");
179 }
180 if is_link_local(&ip) {
181 bail!("Access to link-local address {ip} is blocked");
182 }
183 }
184
185 Ok(())
186 }
187}
188
189fn is_private_ip(ip: &IpAddr) -> bool {
191 match ip {
192 IpAddr::V4(ipv4) => is_private_ipv4(*ipv4),
193 IpAddr::V6(ipv6) => is_private_ipv6(ipv6),
194 }
195}
196
197fn is_private_ipv4(ip: Ipv4Addr) -> bool {
199 let octets = ip.octets();
200
201 if octets[0] == 10 {
203 return true;
204 }
205
206 if octets[0] == 172 && (16..=31).contains(&octets[1]) {
208 return true;
209 }
210
211 if octets[0] == 192 && octets[1] == 168 {
213 return true;
214 }
215
216 if octets[0] == 100 && (64..=127).contains(&octets[1]) {
218 return true;
219 }
220
221 false
222}
223
224const fn is_private_ipv6(ip: &Ipv6Addr) -> bool {
226 let segments = ip.segments();
228 (segments[0] & 0xfe00) == 0xfc00
229}
230
231const fn is_loopback(ip: &IpAddr) -> bool {
233 match ip {
234 IpAddr::V4(ipv4) => ipv4.is_loopback(),
235 IpAddr::V6(ipv6) => ipv6.is_loopback(),
236 }
237}
238
239const fn is_link_local(ip: &IpAddr) -> bool {
241 match ip {
242 IpAddr::V4(ipv4) => ipv4.is_link_local(),
243 IpAddr::V6(ipv6) => {
244 let segments = ipv6.segments();
246 (segments[0] & 0xffc0) == 0xfe80
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn test_valid_https_url() {
257 let validator = UrlValidator::new();
258 assert!(validator.validate("https://example.com").is_ok());
259 assert!(validator.validate("https://example.com/path").is_ok());
260 assert!(validator.validate("https://sub.example.com").is_ok());
261 }
262
263 #[test]
264 fn test_http_blocked_by_default() {
265 let validator = UrlValidator::new();
266 let result = validator.validate("http://example.com");
267 assert!(result.is_err());
268 assert!(result.unwrap_err().to_string().contains("HTTPS required"));
269 }
270
271 #[test]
272 fn test_http_allowed_with_flag() {
273 let validator = UrlValidator::new().with_allow_http();
274 assert!(validator.validate("http://example.com").is_ok());
275 }
276
277 #[test]
278 fn test_localhost_blocked() {
279 let validator = UrlValidator::new().with_allow_http();
280 assert!(validator.validate("http://localhost").is_err());
281 assert!(validator.validate("http://127.0.0.1").is_err());
282 assert!(validator.validate("http://[::1]").is_err());
283 }
284
285 #[test]
286 fn test_metadata_endpoints_blocked() {
287 let validator = UrlValidator::new().with_allow_http();
288 assert!(validator.validate("http://169.254.169.254").is_err());
289 assert!(
290 validator
291 .validate("http://metadata.google.internal")
292 .is_err()
293 );
294 }
295
296 #[test]
297 fn test_invalid_url() {
298 let validator = UrlValidator::new();
299 assert!(validator.validate("not-a-url").is_err());
300 assert!(validator.validate("").is_err());
301 assert!(validator.validate("ftp://example.com").is_err());
302 }
303
304 #[test]
305 fn test_allowed_domains() {
306 let validator = UrlValidator::new().with_allowed_domains(vec!["example.com".to_string()]);
307
308 assert!(validator.validate("https://example.com").is_ok());
309 assert!(validator.validate("https://sub.example.com").is_ok());
310
311 let result = validator.validate("https://other.com");
312 assert!(result.is_err());
313 assert!(
314 result
315 .unwrap_err()
316 .to_string()
317 .contains("not in the allowed domains")
318 );
319 }
320
321 #[test]
322 fn test_blocked_hosts() {
323 let validator = UrlValidator::new().with_blocked_hosts(vec!["blocked.com".to_string()]);
324
325 let result = validator.validate("https://blocked.com");
326 assert!(result.is_err());
327 assert!(result.unwrap_err().to_string().contains("blocked"));
328 }
329
330 #[test]
331 fn test_is_private_ipv4() {
332 assert!(is_private_ipv4(Ipv4Addr::new(10, 0, 0, 1)));
334 assert!(is_private_ipv4(Ipv4Addr::new(10, 255, 255, 255)));
335 assert!(is_private_ipv4(Ipv4Addr::new(172, 16, 0, 1)));
336 assert!(is_private_ipv4(Ipv4Addr::new(172, 31, 255, 255)));
337 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 0, 1)));
338 assert!(is_private_ipv4(Ipv4Addr::new(192, 168, 255, 255)));
339
340 assert!(!is_private_ipv4(Ipv4Addr::new(8, 8, 8, 8)));
342 assert!(!is_private_ipv4(Ipv4Addr::new(1, 1, 1, 1)));
343 assert!(!is_private_ipv4(Ipv4Addr::new(172, 15, 0, 1)));
344 assert!(!is_private_ipv4(Ipv4Addr::new(172, 32, 0, 1)));
345 }
346
347 #[test]
348 fn test_max_redirects() {
349 let validator = UrlValidator::new().with_max_redirects(5);
350 assert_eq!(validator.max_redirects(), 5);
351 }
352
353 #[test]
354 fn test_default_validator() {
355 let validator = UrlValidator::default();
356 assert!(!validator.allow_private_ips);
357 assert!(validator.require_https);
358 assert_eq!(validator.max_redirects, 3);
359 }
360}