use scopeguard::guard;
use std::mem::ManuallyDrop;
use tracing::warn;
use windows::Win32::NetworkManagement::WindowsFirewall::{INetFwRule, INetFwRules};
use windows::Win32::System::Ole::IEnumVARIANT;
use windows::Win32::System::Variant::VARIANT;
use windows::core::{BSTR, Interface};
use crate::errors::WindowsFirewallError;
use crate::firewall_rule::{Direction, FirewallRule};
use crate::utils::with_policy;
pub fn count_rules() -> Result<i32, WindowsFirewallError> {
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rules_count = unsafe { fw_rules.Count()? };
Ok(rules_count)
})
}
pub fn list_rules() -> Result<Vec<FirewallRule>, WindowsFirewallError> {
let mut rules_list = Vec::new();
with_policy(|fw_policy| {
let fw_rules: INetFwRules = unsafe { fw_policy.Rules() }?;
let rules_count = unsafe { fw_rules.Count() }?;
let enumerator = unsafe { fw_rules._NewEnum() }?.cast::<IEnumVARIANT>()?;
let mut variants: [VARIANT; 1] = Default::default();
let mut pceltfetch: u32 = 0;
for _ in 0..rules_count {
let fetched = unsafe { enumerator.Next(&mut variants, &mut pceltfetch) };
let (true, Some(variant)) = (fetched.is_ok(), variants.first()) else {
warn!("Error while fetching rules");
continue;
};
let dispatch = unsafe { variant.Anonymous.Anonymous.Anonymous.pdispVal.clone() };
let _dispatch_cleanup = guard(dispatch.clone(), |mut d| {
unsafe { ManuallyDrop::drop(&mut d) };
});
let Some(dispatch) = dispatch.as_ref() else {
warn!("Variant does not contain a dispatch pointer");
continue;
};
let fw_rule = dispatch.cast::<INetFwRule>()?;
match fw_rule.try_into() {
Ok(rule) => rules_list.push(rule),
Err(e) => {
let fw_rule = dispatch.cast::<INetFwRule>()?;
warn!(
"Failed to convert {:?} rule into FirewallRule struct: {:?}",
unsafe { fw_rule.Name().unwrap_or_else(|_| BSTR::from("<unknown>")) },
e
);
}
}
}
Ok(rules_list)
})
}
pub fn list_incoming_rules() -> Result<Vec<FirewallRule>, WindowsFirewallError> {
let all_rules = list_rules()?;
let incoming_rules: Vec<FirewallRule> = all_rules
.into_iter()
.filter(|rule| *rule.direction() == Direction::In)
.collect();
Ok(incoming_rules)
}
pub fn list_outgoing_rules() -> Result<Vec<FirewallRule>, WindowsFirewallError> {
let all_rules = list_rules()?;
let outgoing_rules: Vec<FirewallRule> = all_rules
.into_iter()
.filter(|rule| *rule.direction() == Direction::Out)
.collect();
Ok(outgoing_rules)
}