use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use ipconfig::get_adapters;
use ipnet::IpNet;
use scopeguard::guard;
use windows::Win32::NetworkManagement::WindowsFirewall::INetFwRule;
use windows::Win32::System::Com::{COINIT_APARTMENTTHREADED, CoInitializeEx, CoUninitialize};
use windows_firewall::{
Address, AddressKeyword, AddressRange, Direction, PortKeyword, PortRange, Protocol,
};
use windows_firewall::{FirewallRuleUpdate, Port};
use windows_firewall::{get_rule, list_rules};
use helpers::build::{
build_full_rule_for_protocol, build_icmp_full_rule, build_rule_for_interface,
build_tcp_full_rule,
};
use helpers::constants::RULE_NAME;
use crate::helpers::auto_remove_firewall_rule::AutoRemoveFirewallRule;
use crate::helpers::build::{build_base_rule, build_rule_for_address, build_rule_for_port};
use crate::helpers::utils::assert_firewall_rule_eq;
mod helpers;
#[test]
fn test_firewall_rules_conversion() {
let firewall_rules = list_rules().expect("Failed to retrieve firewall rules");
unsafe {
CoInitializeEx(None, COINIT_APARTMENTTHREADED).unwrap();
}
let _com_cleanup = guard((), |()| unsafe { CoUninitialize() });
let inetfw_rules = firewall_rules
.iter()
.map(|rule| INetFwRule::try_from(rule).expect("Failed to convert to INetFwRule"));
assert_eq!(
firewall_rules.len(),
inetfw_rules.len(),
"Conversion changed the number of rules!"
);
}
#[test]
fn test_add_rule_if_not_exists() {
let rule_name = format!("{RULE_NAME}_add_if_not_exists");
let rule = build_tcp_full_rule(&rule_name);
let auto_remove_rule_result = AutoRemoveFirewallRule::add_if_not_exists(&rule).unwrap();
assert!(auto_remove_rule_result.added_or_changed);
let auto_remove_rule_result = AutoRemoveFirewallRule::add_if_not_exists(&rule).unwrap();
assert!(!auto_remove_rule_result.added_or_changed);
}
#[test]
fn test_add_or_update() {
let rule_name = format!("{RULE_NAME}_add_or_update");
let rule = build_tcp_full_rule(&rule_name);
let auto_remove_rule_result = AutoRemoveFirewallRule::add_or_update(&rule)
.expect("Failed to add or update full parameter firewall rule");
assert!(
auto_remove_rule_result.added_or_changed,
"Rule should be added"
);
let updated_settings = build_icmp_full_rule(&rule_name);
let auto_remove_rule_result = AutoRemoveFirewallRule::add_or_update(&updated_settings)
.expect("Failed to add or update full parameter firewall rule");
assert!(
!auto_remove_rule_result.added_or_changed,
"Rule should be updated"
);
let updated_rule = get_rule(&rule_name).expect("Failed to get updated firewall rule");
assert_firewall_rule_eq(&updated_rule, &updated_settings);
}
#[test]
fn test_enable_rule() {
let rule_name = format!("{RULE_NAME}_enable_rule");
let mut rule = build_tcp_full_rule(&rule_name);
let _guard = AutoRemoveFirewallRule::add(&rule).unwrap();
rule.enable(false).unwrap();
assert!(!rule.enabled());
rule.enable(true).unwrap();
assert!(rule.enabled());
}
#[test]
fn test_all_protocol_transitions() {
let protocols = [
(Protocol::Tcp, "Tcp"),
(Protocol::Udp, "Udp"),
(Protocol::Icmpv4, "Icmpv4"),
(Protocol::Icmpv6, "Icmpv6"),
(Protocol::Igmp, "Igmp"),
(Protocol::Ipv4, "Ipv4"),
(Protocol::Ipv6, "Ipv6"),
(Protocol::Gre, "Gre"),
(Protocol::Esp, "Esp"),
(Protocol::Ah, "Ah"),
(Protocol::Sctp, "Sctp"),
(Protocol::Any, "Any"),
];
for (proto_from, label_from) in &protocols {
for (proto_to, label_to) in &protocols {
let rule_name = format!("{RULE_NAME}_transition_{label_from}_to_{label_to}");
let mut rule = build_full_rule_for_protocol(&rule_name, *proto_from);
let _guard = AutoRemoveFirewallRule::add(&rule);
if let Err(e) = &_guard {
panic!("Failed to add rule with protocol {:?}: {}", proto_from, e);
}
let fetched = get_rule(&rule_name).unwrap();
assert_firewall_rule_eq(&fetched, &rule);
let new_settings =
FirewallRuleUpdate::from(build_full_rule_for_protocol(&rule_name, *proto_to));
let rule_update = rule.update(&new_settings);
if let Err(e) = &rule_update {
panic!("Failed to update rule to protocol {:?}: {}", proto_to, e);
}
let fetched_updated = get_rule(&rule_name).unwrap();
assert_firewall_rule_eq(&fetched_updated, &rule);
}
}
}
#[test]
fn test_add_rule_per_network_interface() {
let adapters = get_adapters().expect("Failed to retrieve network interfaces");
for adapter in adapters {
let interface_name = adapter.friendly_name();
let rule_name = format!("{RULE_NAME}_add_{interface_name}");
let rule = build_rule_for_interface(&rule_name, interface_name);
let _guard = AutoRemoveFirewallRule::add(&rule);
if let Err(e) = &_guard {
panic!(
"Failed to add rule for interface '{}': {}",
interface_name, e
);
}
let fetched_rule = get_rule(&rule_name).expect("Failed to retrieve the rule");
assert_firewall_rule_eq(&fetched_rule, &rule);
}
}
#[test]
fn test_update_rule_per_network_interface() {
let adapters = get_adapters().expect("Failed to retrieve network interfaces");
for adapter in adapters {
let interface_name = adapter.friendly_name();
let rule_name = format!("{RULE_NAME}_update_{interface_name}");
let mut rule = build_base_rule(&rule_name);
let _guard = AutoRemoveFirewallRule::add(&rule);
if let Err(e) = &_guard {
panic!(
"Failed to add rule for interface '{}': {}",
interface_name, e
);
}
let updated_settings = FirewallRuleUpdate::builder()
.interfaces([interface_name])
.build();
println!("Updating rule for interface: {interface_name}");
let update_result = rule.update(&updated_settings.clone());
if let Err(e) = &update_result {
panic!(
"Failed to update rule for interface '{}': {}",
interface_name, e
);
}
let updated_rule = get_rule(&rule_name).expect("Failed to get updated firewall rule");
rule.set_interfaces(Some(HashSet::from([interface_name.to_string()])));
assert_firewall_rule_eq(&updated_rule, &rule);
}
}
#[test]
fn test_direction_and_edge_traversal_transitions() {
let states = [
(Direction::In, true, "In_EdgeTrue"),
(Direction::Out, false, "Out_EdgeFalse"),
];
for (dir_from, edge_from, label_from) in &states {
for (dir_to, edge_to, label_to) in &states {
let rule_name = format!("{RULE_NAME}_transition_{label_from}_to_{label_to}");
let mut rule = build_tcp_full_rule(&rule_name);
rule.set_direction(*dir_from);
rule.set_edge_traversal(Some(*edge_from));
let _guard = AutoRemoveFirewallRule::add(&rule);
if let Err(e) = &_guard {
panic!(
"Failed to add rule with direction {:?} and edge traversal {:?}: {}",
dir_from, edge_from, e
);
}
let fetched = get_rule(&rule_name).unwrap();
assert_firewall_rule_eq(&fetched, &rule);
let new_settings = FirewallRuleUpdate::builder()
.direction(*dir_to)
.edge_traversal(*edge_to)
.build();
let update_result = rule.update(&new_settings);
if let Err(e) = &update_result {
panic!(
"Failed to update rule to direction {:?} and edge traversal {:?}: {}",
dir_to, edge_to, e
);
}
let fetched_updated = get_rule(&rule_name).unwrap();
rule.set_direction(*dir_to);
rule.set_edge_traversal(Some(*edge_to));
assert_firewall_rule_eq(&fetched_updated, &rule);
}
}
}
#[test]
fn test_all_port_variants() {
let ports = [
Port::Any,
Port::Keyword(PortKeyword::Rpc),
Port::Keyword(PortKeyword::RpcEpmap),
Port::Keyword(PortKeyword::Teredo),
Port::Port(80),
Port::Port(443),
Port::Range(PortRange {
start: 1000,
end: 2000,
}),
];
for (i, port) in ports.iter().enumerate() {
let rule_name = format!("TEST_Port_RULE_{}", i);
let rule = build_rule_for_port(&rule_name, port);
let _guard = AutoRemoveFirewallRule::add(&rule);
if let Err(e) = &_guard {
panic!("Failed to add rule for port {:?}: {}", port, e);
}
let fetched = get_rule(&rule_name).expect("Failed to fetch the rule");
assert_firewall_rule_eq(&fetched, &rule);
}
}
#[test]
fn test_all_address_variants() {
let addresses = [
Address::Any,
Address::Keyword(AddressKeyword::DefaultGateway),
Address::Keyword(AddressKeyword::Dhcp),
Address::Keyword(AddressKeyword::Dns),
Address::Keyword(AddressKeyword::Wins),
Address::Keyword(AddressKeyword::LocalSubnet),
Address::Ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))),
Address::Cidr(IpNet::new(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 0)), 24).unwrap()),
Address::Range(
AddressRange::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)),
IpAddr::V4(Ipv4Addr::new(192, 168, 0, 255)),
)
.unwrap(),
),
];
for (i, address) in addresses.iter().enumerate() {
let rule_name = format!("TEST_Address_RULE_{}", i);
let rule = build_rule_for_address(&rule_name, address);
let _guard = AutoRemoveFirewallRule::add(&rule);
if let Err(e) = &_guard {
panic!("Failed to add rule for address {:?}: {}", address, e);
}
let fetched = get_rule(&rule_name).expect("Failed to fetch the rule");
assert_firewall_rule_eq(&fetched, &rule);
}
}