1use serde::{Deserialize, Serialize};
2use std::fmt;
3use std::str::FromStr;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub struct HostPort {
8 pub host: String,
9 pub port: u16,
10}
11
12impl HostPort {
13 pub fn new(host: impl Into<String>, port: u16) -> Self {
14 Self {
15 host: host.into(),
16 port,
17 }
18 }
19}
20
21impl fmt::Display for HostPort {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 write!(f, "{}:{}", self.host, self.port)
24 }
25}
26
27impl FromStr for HostPort {
28 type Err = anyhow::Error;
29
30 fn from_str(s: &str) -> Result<Self, Self::Err> {
31 let (host, port) = s
32 .rsplit_once(':')
33 .ok_or_else(|| anyhow::anyhow!("expected host:port, got {:?}", s))?;
34 if host.is_empty() {
35 anyhow::bail!("host cannot be empty in {:?}", s);
36 }
37 let port: u16 = port
38 .parse()
39 .map_err(|_| anyhow::anyhow!("invalid port in {:?}", s))?;
40 Ok(Self::new(host, port))
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "lowercase")]
47pub enum NetworkPreset {
48 Unrestricted,
50 None,
52 Registries,
54 Dev,
56}
57
58impl NetworkPreset {
59 pub fn rules(&self) -> Vec<HostPort> {
61 match self {
62 Self::Unrestricted => vec![], Self::None => vec![], Self::Registries => registry_rules(),
65 Self::Dev => {
66 let mut rules = registry_rules();
67 rules.extend(dev_extra_rules());
68 rules
69 }
70 }
71 }
72
73 pub fn is_unrestricted(&self) -> bool {
75 matches!(self, Self::Unrestricted)
76 }
77
78 pub fn is_deny_all(&self) -> bool {
80 matches!(self, Self::None)
81 }
82}
83
84impl FromStr for NetworkPreset {
85 type Err = anyhow::Error;
86
87 fn from_str(s: &str) -> Result<Self, Self::Err> {
88 match s {
89 "unrestricted" => Ok(Self::Unrestricted),
90 "none" => Ok(Self::None),
91 "registries" => Ok(Self::Registries),
92 "dev" => Ok(Self::Dev),
93 _ => anyhow::bail!(
94 "unknown network preset {:?} (expected: unrestricted, none, registries, dev)",
95 s
96 ),
97 }
98 }
99}
100
101impl fmt::Display for NetworkPreset {
102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103 match self {
104 Self::Unrestricted => write!(f, "unrestricted"),
105 Self::None => write!(f, "none"),
106 Self::Registries => write!(f, "registries"),
107 Self::Dev => write!(f, "dev"),
108 }
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
114#[serde(tag = "type", rename_all = "lowercase")]
115pub enum NetworkPolicy {
116 Preset { preset: NetworkPreset },
118 AllowList { rules: Vec<HostPort> },
120}
121
122impl NetworkPolicy {
123 pub fn unrestricted() -> Self {
124 Self::Preset {
125 preset: NetworkPreset::Unrestricted,
126 }
127 }
128
129 pub fn deny_all() -> Self {
130 Self::Preset {
131 preset: NetworkPreset::None,
132 }
133 }
134
135 pub fn preset(preset: NetworkPreset) -> Self {
136 Self::Preset { preset }
137 }
138
139 pub fn allow_list(rules: Vec<HostPort>) -> Self {
140 Self::AllowList { rules }
141 }
142
143 pub fn is_unrestricted(&self) -> bool {
145 matches!(
146 self,
147 Self::Preset {
148 preset: NetworkPreset::Unrestricted
149 }
150 )
151 }
152
153 pub fn resolve_rules(&self) -> Option<Vec<HostPort>> {
156 match self {
157 Self::Preset { preset } if preset.is_unrestricted() => None,
158 Self::Preset { preset } => Some(preset.rules()),
159 Self::AllowList { rules } => Some(rules.clone()),
160 }
161 }
162
163 pub fn iptables_script(&self, bridge_dev: &str, guest_ip: &str) -> Option<String> {
169 let rules = self.resolve_rules()?;
170
171 let mut script = String::new();
172 script.push_str(&format!(
173 "# Network policy: drop all outbound from {} except allowed hosts\n",
174 guest_ip
175 ));
176
177 script.push_str(&format!(
179 "sudo iptables -I FORWARD -i {br} -s {ip} -j DROP\n",
180 br = bridge_dev,
181 ip = guest_ip,
182 ));
183
184 script.push_str(&format!(
186 "sudo iptables -I FORWARD -i {br} -s {ip} -m state --state ESTABLISHED,RELATED -j ACCEPT\n",
187 br = bridge_dev,
188 ip = guest_ip,
189 ));
190
191 script.push_str(&format!(
193 "sudo iptables -I FORWARD -i {br} -s {ip} -p udp --dport 53 -j ACCEPT\n",
194 br = bridge_dev,
195 ip = guest_ip,
196 ));
197 script.push_str(&format!(
198 "sudo iptables -I FORWARD -i {br} -s {ip} -p tcp --dport 53 -j ACCEPT\n",
199 br = bridge_dev,
200 ip = guest_ip,
201 ));
202
203 for rule in &rules {
205 script.push_str(&format!(
206 "sudo iptables -I FORWARD -i {br} -s {ip} -d {host} -p tcp --dport {port} -j ACCEPT\n",
207 br = bridge_dev,
208 ip = guest_ip,
209 host = rule.host,
210 port = rule.port,
211 ));
212 }
213
214 Some(script)
215 }
216
217 pub fn iptables_cleanup_script(&self, bridge_dev: &str, guest_ip: &str) -> Option<String> {
220 if self.is_unrestricted() {
221 return None;
222 }
223
224 Some(format!(
225 "# Clean up network policy rules for {ip}\n\
226 while sudo iptables -D FORWARD -i {br} -s {ip} -j DROP 2>/dev/null; do :; done\n\
227 while sudo iptables -D FORWARD -i {br} -s {ip} -m state --state ESTABLISHED,RELATED -j ACCEPT 2>/dev/null; do :; done\n\
228 while sudo iptables -D FORWARD -i {br} -s {ip} -p udp --dport 53 -j ACCEPT 2>/dev/null; do :; done\n\
229 while sudo iptables -D FORWARD -i {br} -s {ip} -p tcp --dport 53 -j ACCEPT 2>/dev/null; do :; done\n",
230 br = bridge_dev,
231 ip = guest_ip,
232 ))
233 }
234}
235
236impl Default for NetworkPolicy {
237 fn default() -> Self {
238 Self::unrestricted()
239 }
240}
241
242fn registry_rules() -> Vec<HostPort> {
243 vec![
244 HostPort::new("registry.npmjs.org", 443),
245 HostPort::new("crates.io", 443),
246 HostPort::new("static.crates.io", 443),
247 HostPort::new("index.crates.io", 443),
248 HostPort::new("pypi.org", 443),
249 HostPort::new("files.pythonhosted.org", 443),
250 ]
251}
252
253fn dev_extra_rules() -> Vec<HostPort> {
254 vec![
255 HostPort::new("github.com", 443),
256 HostPort::new("github.com", 22),
257 HostPort::new("api.github.com", 443),
258 HostPort::new("api.openai.com", 443),
259 HostPort::new("api.anthropic.com", 443),
260 ]
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn host_port_parse() {
269 let hp: HostPort = "github.com:443".parse().unwrap();
270 assert_eq!(hp.host, "github.com");
271 assert_eq!(hp.port, 443);
272 }
273
274 #[test]
275 fn host_port_parse_missing_port() {
276 assert!("github.com".parse::<HostPort>().is_err());
277 }
278
279 #[test]
280 fn host_port_parse_empty_host() {
281 assert!(":443".parse::<HostPort>().is_err());
282 }
283
284 #[test]
285 fn host_port_parse_invalid_port() {
286 assert!("github.com:abc".parse::<HostPort>().is_err());
287 }
288
289 #[test]
290 fn host_port_display() {
291 let hp = HostPort::new("github.com", 443);
292 assert_eq!(hp.to_string(), "github.com:443");
293 }
294
295 #[test]
296 fn host_port_serde_roundtrip() {
297 let hp = HostPort::new("api.openai.com", 443);
298 let json = serde_json::to_string(&hp).unwrap();
299 let parsed: HostPort = serde_json::from_str(&json).unwrap();
300 assert_eq!(parsed, hp);
301 }
302
303 #[test]
304 fn preset_parse() {
305 assert_eq!("dev".parse::<NetworkPreset>().unwrap(), NetworkPreset::Dev);
306 assert_eq!(
307 "none".parse::<NetworkPreset>().unwrap(),
308 NetworkPreset::None
309 );
310 assert_eq!(
311 "registries".parse::<NetworkPreset>().unwrap(),
312 NetworkPreset::Registries
313 );
314 assert_eq!(
315 "unrestricted".parse::<NetworkPreset>().unwrap(),
316 NetworkPreset::Unrestricted
317 );
318 }
319
320 #[test]
321 fn preset_parse_invalid() {
322 assert!("foo".parse::<NetworkPreset>().is_err());
323 }
324
325 #[test]
326 fn preset_display_roundtrip() {
327 for preset in [
328 NetworkPreset::Unrestricted,
329 NetworkPreset::None,
330 NetworkPreset::Registries,
331 NetworkPreset::Dev,
332 ] {
333 let s = preset.to_string();
334 let parsed: NetworkPreset = s.parse().unwrap();
335 assert_eq!(parsed, preset);
336 }
337 }
338
339 #[test]
340 fn preset_rules_dev_includes_registries() {
341 let dev_rules = NetworkPreset::Dev.rules();
342 let reg_rules = NetworkPreset::Registries.rules();
343 for reg in ®_rules {
344 assert!(
345 dev_rules.contains(reg),
346 "dev preset should include registry rule {}",
347 reg
348 );
349 }
350 }
351
352 #[test]
353 fn preset_rules_dev_has_github_and_ai() {
354 let rules = NetworkPreset::Dev.rules();
355 let hosts: Vec<&str> = rules.iter().map(|r| r.host.as_str()).collect();
356 assert!(hosts.contains(&"github.com"));
357 assert!(hosts.contains(&"api.openai.com"));
358 assert!(hosts.contains(&"api.anthropic.com"));
359 }
360
361 #[test]
362 fn preset_rules_none_is_empty() {
363 assert!(NetworkPreset::None.rules().is_empty());
364 }
365
366 #[test]
367 fn preset_rules_unrestricted_is_empty() {
368 assert!(NetworkPreset::Unrestricted.rules().is_empty());
369 }
370
371 #[test]
372 fn policy_default_is_unrestricted() {
373 assert!(NetworkPolicy::default().is_unrestricted());
374 }
375
376 #[test]
377 fn policy_unrestricted_no_rules() {
378 assert!(NetworkPolicy::unrestricted().resolve_rules().is_none());
379 }
380
381 #[test]
382 fn policy_deny_all_empty_rules() {
383 let rules = NetworkPolicy::deny_all().resolve_rules().unwrap();
384 assert!(rules.is_empty());
385 }
386
387 #[test]
388 fn policy_preset_dev_resolves() {
389 let policy = NetworkPolicy::preset(NetworkPreset::Dev);
390 let rules = policy.resolve_rules().unwrap();
391 assert!(!rules.is_empty());
392 assert!(rules.iter().any(|r| r.host == "github.com"));
393 }
394
395 #[test]
396 fn policy_allow_list_resolves() {
397 let policy = NetworkPolicy::allow_list(vec![
398 HostPort::new("example.com", 443),
399 HostPort::new("example.com", 80),
400 ]);
401 let rules = policy.resolve_rules().unwrap();
402 assert_eq!(rules.len(), 2);
403 }
404
405 #[test]
406 fn policy_serde_roundtrip_preset() {
407 let policy = NetworkPolicy::preset(NetworkPreset::Dev);
408 let json = serde_json::to_string(&policy).unwrap();
409 let parsed: NetworkPolicy = serde_json::from_str(&json).unwrap();
410 assert_eq!(parsed, policy);
411 }
412
413 #[test]
414 fn policy_serde_roundtrip_allow_list() {
415 let policy = NetworkPolicy::allow_list(vec![HostPort::new("example.com", 443)]);
416 let json = serde_json::to_string(&policy).unwrap();
417 let parsed: NetworkPolicy = serde_json::from_str(&json).unwrap();
418 assert_eq!(parsed, policy);
419 }
420
421 #[test]
422 fn iptables_script_unrestricted_is_none() {
423 let policy = NetworkPolicy::unrestricted();
424 assert!(policy.iptables_script("br-mvm", "172.16.0.2").is_none());
425 }
426
427 #[test]
428 fn iptables_script_deny_all_has_drop_no_host_rules() {
429 let policy = NetworkPolicy::deny_all();
430 let script = policy.iptables_script("br-mvm", "172.16.0.2").unwrap();
431 assert!(script.contains("-j DROP"));
432 assert!(script.contains("--dport 53")); let accept_lines: Vec<&str> = script
435 .lines()
436 .filter(|l| {
437 l.contains("-j ACCEPT") && !l.contains("--dport 53") && !l.contains("ESTABLISHED")
438 })
439 .collect();
440 assert!(
441 accept_lines.is_empty(),
442 "deny-all should have no host ACCEPT rules"
443 );
444 }
445
446 #[test]
447 fn iptables_script_allow_list_has_host_rules() {
448 let policy = NetworkPolicy::allow_list(vec![
449 HostPort::new("github.com", 443),
450 HostPort::new("api.openai.com", 443),
451 ]);
452 let script = policy.iptables_script("br-mvm", "172.16.0.3").unwrap();
453 assert!(script.contains("-d github.com"));
454 assert!(script.contains("-d api.openai.com"));
455 assert!(script.contains("--dport 443"));
456 assert!(script.contains("-s 172.16.0.3"));
457 assert!(script.contains("-i br-mvm"));
458 }
459
460 #[test]
461 fn iptables_cleanup_unrestricted_is_none() {
462 let policy = NetworkPolicy::unrestricted();
463 assert!(
464 policy
465 .iptables_cleanup_script("br-mvm", "172.16.0.2")
466 .is_none()
467 );
468 }
469
470 #[test]
471 fn iptables_cleanup_deny_all_has_commands() {
472 let policy = NetworkPolicy::deny_all();
473 let script = policy
474 .iptables_cleanup_script("br-mvm", "172.16.0.2")
475 .unwrap();
476 assert!(script.contains("iptables -D FORWARD"));
477 }
478}