#![cfg(target_os = "windows")]
#![allow(unsafe_code)]
use windows::core::BSTR;
use windows::Win32::Foundation::VARIANT_TRUE;
use windows::Win32::NetworkManagement::WindowsFirewall::{
INetFwPolicy2, INetFwRules, NetFwPolicy2, NetFwRule, NET_FW_ACTION_ALLOW,
NET_FW_IP_PROTOCOL_TCP, NET_FW_IP_PROTOCOL_UDP, NET_FW_PROFILE2_DOMAIN,
NET_FW_PROFILE2_PRIVATE, NET_FW_RULE_DIR_IN,
};
use windows::Win32::System::Com::{
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_INPROC_SERVER,
COINIT_APARTMENTTHREADED, COINIT_DISABLE_OLE1DDE,
};
use super::{FirewallError, API_RULE_NAME, MANAGED_RULE_NAMES, OVERLAY_RULE_NAME, RAFT_RULE_NAME};
const DNS_UDP_RULE_NAME: &str = "ZLayer DNS (UDP)";
const DNS_TCP_RULE_NAME: &str = "ZLayer DNS (TCP)";
const DNS_PORT: u16 = 53;
struct ComGuard;
impl ComGuard {
fn new() -> Result<Self, FirewallError> {
let hr = unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED | COINIT_DISABLE_OLE1DDE) };
if hr.is_err() {
return Err(FirewallError::ComInit(format!("{hr:?}")));
}
Ok(Self)
}
}
impl Drop for ComGuard {
fn drop(&mut self) {
unsafe { CoUninitialize() };
}
}
struct FirewallPolicy {
rules: INetFwRules,
_com: ComGuard,
}
impl FirewallPolicy {
fn open() -> Result<Self, FirewallError> {
let com = ComGuard::new()?;
let policy: INetFwPolicy2 = unsafe {
CoCreateInstance::<Option<&windows::core::IUnknown>, INetFwPolicy2>(
&NetFwPolicy2,
None,
CLSCTX_INPROC_SERVER,
)
}
.map_err(|e| FirewallError::PolicyUnavailable(format!("{e}")))?;
let rules = unsafe { policy.Rules() }
.map_err(|e| FirewallError::Com(format!("INetFwPolicy2::Rules: {e}")))?;
Ok(Self { rules, _com: com })
}
fn rule_exists(&self, name: &str) -> bool {
let key = BSTR::from(name);
unsafe { self.rules.Item(&key) }.is_ok()
}
fn add_rule(&self, spec: &RuleSpec<'_>) -> Result<(), FirewallError> {
let rule = unsafe {
CoCreateInstance::<Option<&windows::core::IUnknown>, _>(
&NetFwRule,
None,
CLSCTX_INPROC_SERVER,
)
}
.map_err(|e| FirewallError::AddRule {
name: spec.name.to_string(),
reason: format!("CoCreateInstance(NetFwRule): {e}"),
})?;
let name_bstr = BSTR::from(spec.name);
let desc_bstr = BSTR::from(spec.description);
let ports_bstr = BSTR::from(spec.port.to_string().as_str());
let group_bstr = BSTR::from("ZLayer");
let configure_and_add = || -> windows::core::Result<()> {
use windows::Win32::NetworkManagement::WindowsFirewall::INetFwRule;
let rule: &INetFwRule = &rule;
unsafe {
rule.SetName(&name_bstr)?;
rule.SetDescription(&desc_bstr)?;
rule.SetProtocol(spec.protocol)?;
rule.SetLocalPorts(&ports_bstr)?;
rule.SetDirection(NET_FW_RULE_DIR_IN)?;
rule.SetAction(NET_FW_ACTION_ALLOW)?;
rule.SetEnabled(VARIANT_TRUE)?;
let profile_mask = NET_FW_PROFILE2_DOMAIN.0 | NET_FW_PROFILE2_PRIVATE.0;
rule.SetProfiles(profile_mask)?;
rule.SetGrouping(&group_bstr)?;
self.rules.Add(rule)?;
}
Ok(())
};
configure_and_add().map_err(|e| FirewallError::AddRule {
name: spec.name.to_string(),
reason: format!("{e}"),
})
}
fn remove_rule(&self, name: &str) -> Result<(), FirewallError> {
if !self.rule_exists(name) {
return Ok(());
}
let key = BSTR::from(name);
unsafe { self.rules.Remove(&key) }.map_err(|e| FirewallError::RemoveRule {
name: name.to_string(),
reason: format!("{e}"),
})
}
}
struct RuleSpec<'a> {
name: &'a str,
description: &'a str,
port: u16,
protocol: i32,
}
pub(super) fn ensure_overlay_rules(
wg_port: u16,
api_port: u16,
raft_port: u16,
) -> Result<(), FirewallError> {
let policy = FirewallPolicy::open()?;
let specs = [
RuleSpec {
name: OVERLAY_RULE_NAME,
description: "ZLayer encrypted overlay (WireGuard/boringtun) inbound UDP",
port: wg_port,
protocol: NET_FW_IP_PROTOCOL_UDP.0,
},
RuleSpec {
name: API_RULE_NAME,
description: "ZLayer daemon HTTP/gRPC API inbound TCP",
port: api_port,
protocol: NET_FW_IP_PROTOCOL_TCP.0,
},
RuleSpec {
name: RAFT_RULE_NAME,
description: "ZLayer Raft scheduler inbound TCP",
port: raft_port,
protocol: NET_FW_IP_PROTOCOL_TCP.0,
},
RuleSpec {
name: DNS_UDP_RULE_NAME,
description: "ZLayer overlay DNS responder inbound UDP",
port: DNS_PORT,
protocol: NET_FW_IP_PROTOCOL_UDP.0,
},
RuleSpec {
name: DNS_TCP_RULE_NAME,
description: "ZLayer overlay DNS responder inbound TCP",
port: DNS_PORT,
protocol: NET_FW_IP_PROTOCOL_TCP.0,
},
];
for spec in specs {
if policy.rule_exists(spec.name) {
tracing::debug!(rule = spec.name, "firewall rule already present; skipping");
continue;
}
let port = spec.port;
let name = spec.name;
policy.add_rule(&spec)?;
tracing::info!(rule = name, port = port, "installed firewall rule");
}
Ok(())
}
pub(super) fn remove_overlay_rules() -> Result<(), FirewallError> {
let policy = FirewallPolicy::open()?;
for name in MANAGED_RULE_NAMES {
policy.remove_rule(name)?;
}
for name in [DNS_UDP_RULE_NAME, DNS_TCP_RULE_NAME] {
policy.remove_rule(name)?;
}
Ok(())
}
fn published_port_rule_name(port: u16, udp: bool) -> String {
format!(
"ZLayer Published {}/{}",
port,
if udp { "UDP" } else { "TCP" }
)
}
pub(super) fn ensure_published_port(port: u16, udp: bool) -> Result<(), FirewallError> {
let policy = FirewallPolicy::open()?;
let name = published_port_rule_name(port, udp);
if policy.rule_exists(&name) {
tracing::debug!(rule = %name, "published-port firewall rule already present; skipping");
return Ok(());
}
let protocol = if udp {
NET_FW_IP_PROTOCOL_UDP.0
} else {
NET_FW_IP_PROTOCOL_TCP.0
};
let description = if udp {
"ZLayer dynamically-published host port inbound UDP"
} else {
"ZLayer dynamically-published host port inbound TCP"
};
let spec = RuleSpec {
name: &name,
description,
port,
protocol,
};
policy.add_rule(&spec)?;
tracing::info!(rule = %name, port = port, "installed published-port firewall rule");
Ok(())
}
pub(super) fn remove_published_port(port: u16, udp: bool) {
let name = published_port_rule_name(port, udp);
match FirewallPolicy::open() {
Ok(policy) => {
if let Err(e) = policy.remove_rule(&name) {
tracing::warn!(rule = %name, error = %e, "failed to remove published-port firewall rule");
}
}
Err(e) => {
tracing::warn!(rule = %name, error = %e, "failed to open firewall policy to remove published-port rule");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "requires administrator privileges + Windows Defender Firewall service"]
fn ensure_then_remove_roundtrip() {
ensure_overlay_rules(51820, 13669, 13670).expect("ensure failed");
let policy = FirewallPolicy::open().expect("open policy");
assert!(policy.rule_exists(OVERLAY_RULE_NAME));
assert!(policy.rule_exists(API_RULE_NAME));
assert!(policy.rule_exists(RAFT_RULE_NAME));
assert!(policy.rule_exists(DNS_UDP_RULE_NAME));
assert!(policy.rule_exists(DNS_TCP_RULE_NAME));
drop(policy);
remove_overlay_rules().expect("remove failed");
let policy = FirewallPolicy::open().expect("reopen policy");
assert!(!policy.rule_exists(OVERLAY_RULE_NAME));
assert!(!policy.rule_exists(API_RULE_NAME));
assert!(!policy.rule_exists(RAFT_RULE_NAME));
assert!(!policy.rule_exists(DNS_UDP_RULE_NAME));
assert!(!policy.rule_exists(DNS_TCP_RULE_NAME));
}
#[test]
#[ignore = "requires administrator privileges + Windows Defender Firewall service"]
fn published_port_roundtrip() {
ensure_published_port(18080, false).expect("ensure tcp");
ensure_published_port(18080, false).expect("ensure tcp idempotent");
ensure_published_port(18081, true).expect("ensure udp");
let policy = FirewallPolicy::open().expect("open policy");
assert!(policy.rule_exists(&published_port_rule_name(18080, false)));
assert!(policy.rule_exists(&published_port_rule_name(18081, true)));
assert!(!policy.rule_exists(&published_port_rule_name(18080, true)));
drop(policy);
remove_published_port(18080, false);
remove_published_port(18081, true);
remove_published_port(18080, false);
let policy = FirewallPolicy::open().expect("reopen policy");
assert!(!policy.rule_exists(&published_port_rule_name(18080, false)));
assert!(!policy.rule_exists(&published_port_rule_name(18081, true)));
}
#[test]
#[ignore = "requires administrator privileges + Windows Defender Firewall service"]
fn ensure_is_idempotent() {
let _ = remove_overlay_rules();
ensure_overlay_rules(51820, 13669, 13670).expect("first ensure");
ensure_overlay_rules(51820, 13669, 13670).expect("second ensure (idempotent)");
let policy = FirewallPolicy::open().expect("open policy");
assert!(policy.rule_exists(OVERLAY_RULE_NAME));
assert!(policy.rule_exists(API_RULE_NAME));
assert!(policy.rule_exists(RAFT_RULE_NAME));
drop(policy);
remove_overlay_rules().expect("cleanup");
}
}