#![cfg(target_os = "windows")]
#![allow(unsafe_code)]
use windows::core::BSTR;
use windows::Win32::Foundation::VARIANT_TRUE;
use windows::Win32::NetworkManagement::WindowsFirewall::{
INetFwPolicy2, INetFwRules, NetFwPolicy2, NetFwRule, NET_FW_ACTION_ALLOW,
NET_FW_IP_PROTOCOL_TCP, NET_FW_IP_PROTOCOL_UDP, NET_FW_PROFILE2_DOMAIN,
NET_FW_PROFILE2_PRIVATE, NET_FW_RULE_DIR_IN,
};
use windows::Win32::System::Com::{
CoCreateInstance, CoInitializeEx, CoUninitialize, CLSCTX_INPROC_SERVER,
COINIT_APARTMENTTHREADED, COINIT_DISABLE_OLE1DDE,
};
use super::{FirewallError, API_RULE_NAME, MANAGED_RULE_NAMES, OVERLAY_RULE_NAME, RAFT_RULE_NAME};
struct ComGuard;
impl ComGuard {
fn new() -> Result<Self, FirewallError> {
let hr = unsafe { CoInitializeEx(None, COINIT_APARTMENTTHREADED | COINIT_DISABLE_OLE1DDE) };
if hr.is_err() {
return Err(FirewallError::ComInit(format!("{hr:?}")));
}
Ok(Self)
}
}
impl Drop for ComGuard {
fn drop(&mut self) {
unsafe { CoUninitialize() };
}
}
struct FirewallPolicy {
rules: INetFwRules,
_com: ComGuard,
}
impl FirewallPolicy {
fn open() -> Result<Self, FirewallError> {
let com = ComGuard::new()?;
let policy: INetFwPolicy2 = unsafe {
CoCreateInstance::<Option<&windows::core::IUnknown>, INetFwPolicy2>(
&NetFwPolicy2,
None,
CLSCTX_INPROC_SERVER,
)
}
.map_err(|e| FirewallError::PolicyUnavailable(format!("{e}")))?;
let rules = unsafe { policy.Rules() }
.map_err(|e| FirewallError::Com(format!("INetFwPolicy2::Rules: {e}")))?;
Ok(Self { rules, _com: com })
}
fn rule_exists(&self, name: &str) -> bool {
let key = BSTR::from(name);
unsafe { self.rules.Item(&key) }.is_ok()
}
fn add_rule(&self, spec: &RuleSpec<'_>) -> Result<(), FirewallError> {
let rule = unsafe {
CoCreateInstance::<Option<&windows::core::IUnknown>, _>(
&NetFwRule,
None,
CLSCTX_INPROC_SERVER,
)
}
.map_err(|e| FirewallError::AddRule {
name: spec.name.to_string(),
reason: format!("CoCreateInstance(NetFwRule): {e}"),
})?;
let name_bstr = BSTR::from(spec.name);
let desc_bstr = BSTR::from(spec.description);
let ports_bstr = BSTR::from(spec.port.to_string().as_str());
let group_bstr = BSTR::from("ZLayer");
let configure_and_add = || -> windows::core::Result<()> {
use windows::Win32::NetworkManagement::WindowsFirewall::INetFwRule;
let rule: &INetFwRule = &rule;
unsafe {
rule.SetName(&name_bstr)?;
rule.SetDescription(&desc_bstr)?;
rule.SetProtocol(spec.protocol)?;
rule.SetLocalPorts(&ports_bstr)?;
rule.SetDirection(NET_FW_RULE_DIR_IN)?;
rule.SetAction(NET_FW_ACTION_ALLOW)?;
rule.SetEnabled(VARIANT_TRUE)?;
let profile_mask = NET_FW_PROFILE2_DOMAIN.0 | NET_FW_PROFILE2_PRIVATE.0;
rule.SetProfiles(profile_mask)?;
rule.SetGrouping(&group_bstr)?;
self.rules.Add(rule)?;
}
Ok(())
};
configure_and_add().map_err(|e| FirewallError::AddRule {
name: spec.name.to_string(),
reason: format!("{e}"),
})
}
fn remove_rule(&self, name: &str) -> Result<(), FirewallError> {
if !self.rule_exists(name) {
return Ok(());
}
let key = BSTR::from(name);
unsafe { self.rules.Remove(&key) }.map_err(|e| FirewallError::RemoveRule {
name: name.to_string(),
reason: format!("{e}"),
})
}
}
struct RuleSpec<'a> {
name: &'a str,
description: &'a str,
port: u16,
protocol: i32,
}
pub(super) fn ensure_overlay_rules(
wg_port: u16,
api_port: u16,
raft_port: u16,
) -> Result<(), FirewallError> {
let policy = FirewallPolicy::open()?;
let specs = [
RuleSpec {
name: OVERLAY_RULE_NAME,
description: "ZLayer encrypted overlay (WireGuard/boringtun) inbound UDP",
port: wg_port,
protocol: NET_FW_IP_PROTOCOL_UDP.0,
},
RuleSpec {
name: API_RULE_NAME,
description: "ZLayer daemon HTTP/gRPC API inbound TCP",
port: api_port,
protocol: NET_FW_IP_PROTOCOL_TCP.0,
},
RuleSpec {
name: RAFT_RULE_NAME,
description: "ZLayer Raft scheduler inbound TCP",
port: raft_port,
protocol: NET_FW_IP_PROTOCOL_TCP.0,
},
];
for spec in specs {
if policy.rule_exists(spec.name) {
tracing::debug!(rule = spec.name, "firewall rule already present; skipping");
continue;
}
let port = spec.port;
let name = spec.name;
policy.add_rule(&spec)?;
tracing::info!(rule = name, port = port, "installed firewall rule");
}
Ok(())
}
pub(super) fn remove_overlay_rules() -> Result<(), FirewallError> {
let policy = FirewallPolicy::open()?;
for name in MANAGED_RULE_NAMES {
policy.remove_rule(name)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "requires administrator privileges + Windows Defender Firewall service"]
fn ensure_then_remove_roundtrip() {
ensure_overlay_rules(51820, 13669, 13670).expect("ensure failed");
let policy = FirewallPolicy::open().expect("open policy");
assert!(policy.rule_exists(OVERLAY_RULE_NAME));
assert!(policy.rule_exists(API_RULE_NAME));
assert!(policy.rule_exists(RAFT_RULE_NAME));
drop(policy);
remove_overlay_rules().expect("remove failed");
let policy = FirewallPolicy::open().expect("reopen policy");
assert!(!policy.rule_exists(OVERLAY_RULE_NAME));
assert!(!policy.rule_exists(API_RULE_NAME));
assert!(!policy.rule_exists(RAFT_RULE_NAME));
}
#[test]
#[ignore = "requires administrator privileges + Windows Defender Firewall service"]
fn ensure_is_idempotent() {
let _ = remove_overlay_rules();
ensure_overlay_rules(51820, 13669, 13670).expect("first ensure");
ensure_overlay_rules(51820, 13669, 13670).expect("second ensure (idempotent)");
let policy = FirewallPolicy::open().expect("open policy");
assert!(policy.rule_exists(OVERLAY_RULE_NAME));
assert!(policy.rule_exists(API_RULE_NAME));
assert!(policy.rule_exists(RAFT_RULE_NAME));
drop(policy);
remove_overlay_rules().expect("cleanup");
}
}