Skip to main content

koi_common/
firewall.rs

1/// Firewall port metadata reported by capability modules.
2#[derive(Debug, Clone, PartialEq, Eq)]
3pub struct FirewallPort {
4    pub name: String,
5    pub protocol: FirewallProtocol,
6    pub port: u16,
7}
8
9impl FirewallPort {
10    pub fn new(name: impl Into<String>, protocol: FirewallProtocol, port: u16) -> Self {
11        Self {
12            name: name.into(),
13            protocol,
14            port,
15        }
16    }
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum FirewallProtocol {
21    Tcp,
22    Udp,
23}
24
25impl FirewallProtocol {
26    pub fn as_str(&self) -> &'static str {
27        match self {
28            FirewallProtocol::Tcp => "TCP",
29            FirewallProtocol::Udp => "UDP",
30        }
31    }
32}
33
34/// Format a firewall rule name: `"{prefix} {port.name} ({PROTO} {port})"`.
35pub fn firewall_rule_name(prefix: &str, port: &FirewallPort) -> String {
36    format!(
37        "{} {} ({} {})",
38        prefix,
39        port.name,
40        port.protocol.as_str(),
41        port.port
42    )
43}
44
45/// Best-effort ensure that Windows Firewall inbound-allow rules exist for
46/// every port in the list.  Rules are **port-based** (not program-scoped) so
47/// they work regardless of which exe path is running.
48///
49/// * Idempotent – deletes then recreates each rule.
50/// * Non-fatal  – logs warnings but never panics or returns errors.
51/// * No-op on non-Windows platforms.
52///
53/// Returns the number of rules successfully created.
54#[cfg(windows)]
55pub fn ensure_firewall_rules(prefix: &str, ports: &[FirewallPort]) -> usize {
56    use std::collections::HashSet;
57    use std::process::Command;
58
59    // Deduplicate by (protocol, port)
60    let mut seen = HashSet::new();
61    let unique: Vec<_> = ports
62        .iter()
63        .filter(|p| seen.insert((p.protocol, p.port)))
64        .collect();
65
66    let mut ok_count = 0usize;
67
68    for port in &unique {
69        let rule_name = firewall_rule_name(prefix, port);
70
71        // Delete first for idempotency (ignore errors – rule may not exist)
72        let _ = Command::new("netsh")
73            .args(["advfirewall", "firewall", "delete", "rule"])
74            .arg(format!("name={rule_name}"))
75            .stdout(std::process::Stdio::null())
76            .stderr(std::process::Stdio::null())
77            .status();
78
79        let result = Command::new("netsh")
80            .args(["advfirewall", "firewall", "add", "rule"])
81            .arg(format!("name={rule_name}"))
82            .args(["dir=in", "action=allow"])
83            .arg(format!("protocol={}", port.protocol.as_str()))
84            .arg(format!("localport={}", port.port))
85            .stdout(std::process::Stdio::null())
86            .stderr(std::process::Stdio::null())
87            .status();
88
89        match result {
90            Ok(status) if status.success() => {
91                tracing::info!(
92                    rule = %rule_name,
93                    "Firewall rule ensured"
94                );
95                ok_count += 1;
96            }
97            Ok(status) => {
98                tracing::warn!(
99                    rule = %rule_name,
100                    exit_code = ?status.code(),
101                    "Could not create firewall rule (not elevated?)"
102                );
103            }
104            Err(e) => {
105                tracing::warn!(
106                    rule = %rule_name,
107                    error = %e,
108                    "Failed to run netsh"
109                );
110            }
111        }
112    }
113
114    ok_count
115}
116
117/// No-op on non-Windows platforms – always returns 0.
118#[cfg(not(windows))]
119pub fn ensure_firewall_rules(_prefix: &str, _ports: &[FirewallPort]) -> usize {
120    0
121}