1use std::collections::HashSet;
2use std::time::Duration;
3
4#[derive(Clone, Debug)]
9pub struct NetworkPolicy {
10 pub enabled: bool,
11 pub allowed_url_prefixes: Vec<String>,
12 pub allowed_methods: HashSet<String>,
13 pub max_redirects: usize,
14 pub max_response_size: usize,
15 pub timeout: Duration,
16}
17
18impl Default for NetworkPolicy {
19 fn default() -> Self {
20 Self {
21 enabled: false,
22 allowed_url_prefixes: Vec::new(),
23 allowed_methods: HashSet::from(["GET".to_string(), "POST".to_string()]),
24 max_redirects: 5,
25 max_response_size: 10 * 1024 * 1024, timeout: Duration::from_secs(30),
27 }
28 }
29}
30
31impl NetworkPolicy {
32 pub fn validate_url(&self, url: &str) -> Result<(), String> {
41 let parsed = url::Url::parse(url).map_err(|e| format!("invalid URL '{url}': {e}"))?;
42 let normalized = parsed.as_str();
43
44 for prefix in &self.allowed_url_prefixes {
45 let norm_prefix = url::Url::parse(prefix)
46 .map(|u| u.to_string())
47 .unwrap_or_else(|_| prefix.clone());
48 if normalized.starts_with(&norm_prefix) {
49 return Ok(());
50 }
51 }
52
53 Err(format!("URL not allowed by network policy: {normalized}"))
54 }
55
56 pub fn validate_method(&self, method: &str) -> Result<(), String> {
58 let upper = method.to_uppercase();
59 if self.allowed_methods.contains(&upper) {
60 Ok(())
61 } else {
62 Err(format!(
63 "HTTP method not allowed by network policy: {upper}"
64 ))
65 }
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 #[test]
74 fn default_is_disabled() {
75 let policy = NetworkPolicy::default();
76 assert!(!policy.enabled);
77 }
78
79 #[test]
80 fn default_allows_get_and_post() {
81 let policy = NetworkPolicy::default();
82 assert!(policy.allowed_methods.contains("GET"));
83 assert!(policy.allowed_methods.contains("POST"));
84 assert!(!policy.allowed_methods.contains("DELETE"));
85 }
86
87 #[test]
88 fn validate_url_matches_prefix() {
89 let policy = NetworkPolicy {
90 allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
91 ..Default::default()
92 };
93 assert!(
94 policy
95 .validate_url("https://api.example.com/v1/data")
96 .is_ok()
97 );
98 assert!(
99 policy
100 .validate_url("https://api.example.com/users?id=1")
101 .is_ok()
102 );
103 }
104
105 #[test]
106 fn validate_url_rejects_different_domain() {
107 let policy = NetworkPolicy {
108 allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
109 ..Default::default()
110 };
111 assert!(
112 policy
113 .validate_url("https://api.example.com.evil.org/")
114 .is_err()
115 );
116 }
117
118 #[test]
119 fn validate_url_rejects_different_scheme() {
120 let policy = NetworkPolicy {
121 allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
122 ..Default::default()
123 };
124 assert!(policy.validate_url("http://api.example.com/").is_err());
125 }
126
127 #[test]
128 fn validate_url_rejects_subdomain_without_trailing_slash() {
129 let policy = NetworkPolicy {
130 allowed_url_prefixes: vec!["https://api.example.com".to_string()],
131 ..Default::default()
132 };
133 assert!(
135 policy
136 .validate_url("https://api.example.com.evil.com/")
137 .is_err()
138 );
139 assert!(
141 policy
142 .validate_url("https://api.example.com/v1/data")
143 .is_ok()
144 );
145 }
146
147 #[test]
148 fn validate_url_rejects_userinfo_attack() {
149 let policy = NetworkPolicy {
150 allowed_url_prefixes: vec!["https://api.example.com/".to_string()],
151 ..Default::default()
152 };
153 assert!(
155 policy
156 .validate_url("https://api.example.com@evil.com/")
157 .is_err()
158 );
159 }
160
161 #[test]
162 fn validate_url_no_prefixes_rejects_all() {
163 let policy = NetworkPolicy::default();
164 assert!(policy.validate_url("https://example.com/").is_err());
165 }
166
167 #[test]
168 fn validate_url_invalid_url() {
169 let policy = NetworkPolicy::default();
170 assert!(policy.validate_url("not a url").is_err());
171 }
172
173 #[test]
174 fn validate_method_allowed() {
175 let policy = NetworkPolicy::default();
176 assert!(policy.validate_method("GET").is_ok());
177 assert!(policy.validate_method("get").is_ok());
178 assert!(policy.validate_method("POST").is_ok());
179 }
180
181 #[test]
182 fn validate_method_rejected() {
183 let policy = NetworkPolicy::default();
184 assert!(policy.validate_method("DELETE").is_err());
185 assert!(policy.validate_method("PUT").is_err());
186 }
187}