Skip to main content

agent_fetch/
policy.rs

1use serde::{Deserialize, Serialize};
2
3/// Pattern for matching domains — either exact or wildcard (e.g. `*.example.com`).
4#[derive(Debug, Clone, Serialize, Deserialize)]
5pub struct DomainPattern(pub String);
6
7impl DomainPattern {
8    pub fn matches(&self, domain: &str) -> bool {
9        let pattern = self.0.to_lowercase();
10        let domain = domain.to_lowercase();
11
12        if let Some(suffix) = pattern.strip_prefix("*.") {
13            domain.ends_with(&format!(".{suffix}"))
14        } else {
15            domain == pattern
16        }
17    }
18}
19
20/// Controls every aspect of what the safe HTTP client is allowed to do.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct FetchPolicy {
23    /// If `Some`, only these domains may be fetched. If `None`, all public domains are allowed.
24    pub allowed_domains: Option<Vec<DomainPattern>>,
25    /// Domains that are always rejected (checked before `allowed_domains`).
26    pub blocked_domains: Vec<DomainPattern>,
27    /// Block requests that resolve to private/internal IPs (default: true).
28    pub deny_private_ips: bool,
29    /// Allowed HTTP methods (default: common methods).
30    pub allowed_methods: Vec<String>,
31    /// Allowed URL schemes (default: ["https", "http"]).
32    pub allowed_schemes: Vec<String>,
33    /// Max request body size in bytes (default: 10 MB).
34    pub max_request_body_bytes: usize,
35    /// Max response body size in bytes (default: 50 MB).
36    pub max_response_body_bytes: usize,
37    /// TCP connect timeout in milliseconds (default: 10 000).
38    pub connect_timeout_ms: u64,
39    /// Overall request timeout in milliseconds (default: 30 000).
40    pub request_timeout_ms: u64,
41    /// Maximum number of redirects to follow (default: 10).
42    pub max_redirects: u8,
43    /// Maximum number of concurrent in-flight requests (default: 50).
44    pub max_concurrent_requests: usize,
45    /// Maximum requests per minute globally (default: 500).
46    pub max_requests_per_minute: u32,
47}
48
49impl Default for FetchPolicy {
50    fn default() -> Self {
51        Self {
52            allowed_domains: None,
53            blocked_domains: Vec::new(),
54            deny_private_ips: true,
55            allowed_methods: vec![
56                "GET".into(),
57                "POST".into(),
58                "PUT".into(),
59                "PATCH".into(),
60                "DELETE".into(),
61                "HEAD".into(),
62                "OPTIONS".into(),
63            ],
64            allowed_schemes: vec!["https".into(), "http".into()],
65            max_request_body_bytes: 10 * 1024 * 1024,
66            max_response_body_bytes: 50 * 1024 * 1024,
67            connect_timeout_ms: 10_000,
68            request_timeout_ms: 30_000,
69            max_redirects: 10,
70            max_concurrent_requests: 50,
71            max_requests_per_minute: 500,
72        }
73    }
74}
75
76impl FetchPolicy {
77    /// Check domain against blocked list, then allowed list.
78    pub fn check_domain(&self, domain: &str) -> Result<(), crate::error::FetchError> {
79        for pat in &self.blocked_domains {
80            if pat.matches(domain) {
81                return Err(crate::error::FetchError::DomainBlocked(domain.to_string()));
82            }
83        }
84        if let Some(ref allowed) = self.allowed_domains {
85            if !allowed.iter().any(|pat| pat.matches(domain)) {
86                return Err(crate::error::FetchError::DomainNotAllowed(
87                    domain.to_string(),
88                ));
89            }
90        }
91        Ok(())
92    }
93
94    pub fn check_scheme(&self, scheme: &str) -> Result<(), crate::error::FetchError> {
95        if !self
96            .allowed_schemes
97            .iter()
98            .any(|s| s.eq_ignore_ascii_case(scheme))
99        {
100            return Err(crate::error::FetchError::SchemeNotAllowed(
101                scheme.to_string(),
102            ));
103        }
104        Ok(())
105    }
106
107    pub fn check_method(&self, method: &str) -> Result<(), crate::error::FetchError> {
108        if !self
109            .allowed_methods
110            .iter()
111            .any(|m| m.eq_ignore_ascii_case(method))
112        {
113            return Err(crate::error::FetchError::MethodNotAllowed(
114                method.to_string(),
115            ));
116        }
117        Ok(())
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn exact_domain_match() {
127        let pat = DomainPattern("api.example.com".into());
128        assert!(pat.matches("api.example.com"));
129        assert!(pat.matches("API.EXAMPLE.COM"));
130        assert!(!pat.matches("other.example.com"));
131        assert!(!pat.matches("example.com"));
132    }
133
134    #[test]
135    fn wildcard_domain_match() {
136        let pat = DomainPattern("*.example.com".into());
137        assert!(pat.matches("api.example.com"));
138        assert!(pat.matches("deep.sub.example.com"));
139        assert!(!pat.matches("example.com")); // base domain does NOT match wildcard
140        assert!(!pat.matches("example.org"));
141        assert!(!pat.matches("notexample.com"));
142    }
143
144    #[test]
145    fn blocked_takes_precedence() {
146        let policy = FetchPolicy {
147            allowed_domains: Some(vec![DomainPattern("*.example.com".into())]),
148            blocked_domains: vec![DomainPattern("evil.example.com".into())],
149            ..Default::default()
150        };
151
152        assert!(policy.check_domain("api.example.com").is_ok());
153        assert!(policy.check_domain("evil.example.com").is_err());
154    }
155
156    #[test]
157    fn allowlist_rejects_unlisted() {
158        let policy = FetchPolicy {
159            allowed_domains: Some(vec![DomainPattern("api.example.com".into())]),
160            ..Default::default()
161        };
162
163        assert!(policy.check_domain("api.example.com").is_ok());
164        assert!(policy.check_domain("other.example.com").is_err());
165    }
166
167    #[test]
168    fn no_allowlist_allows_all() {
169        let policy = FetchPolicy::default();
170        assert!(policy.check_domain("anything.example.com").is_ok());
171    }
172
173    #[test]
174    fn scheme_validation() {
175        let policy = FetchPolicy::default();
176        assert!(policy.check_scheme("https").is_ok());
177        assert!(policy.check_scheme("http").is_ok());
178        assert!(policy.check_scheme("ftp").is_err());
179    }
180
181    #[test]
182    fn method_validation() {
183        let policy = FetchPolicy::default();
184        assert!(policy.check_method("GET").is_ok());
185        assert!(policy.check_method("get").is_ok());
186        assert!(policy.check_method("TRACE").is_err());
187    }
188}