use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::peers::FastDashMap;
use anyhow::{Context, Result, bail};
use arc_swap::ArcSwap;
use iroh::EndpointId;
use ray_proto::SuggestedFirewall;
use ray_proto::ipc::FirewallRuleView;
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
pub use ray_proto::{Action, Direction, Protocol};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PeerFilter {
Any,
Identity(EndpointId),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PortRange {
pub start: u16,
pub end: u16,
}
impl PortRange {
pub fn contains(&self, port: u16) -> bool {
port >= self.start && port <= self.end
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum RuleOrigin {
#[default]
Local,
Network(String),
Ssh,
}
impl RuleOrigin {
pub fn is_local(&self) -> bool {
matches!(self, RuleOrigin::Local)
}
}
pub fn ssh_passthrough_rule() -> FirewallRule {
FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Ssh,
}
}
pub fn same_selector(a: &FirewallRule, b: &FirewallRule) -> bool {
a.direction == b.direction
&& a.protocol == b.protocol
&& a.port == b.port
&& a.peer == b.peer
&& a.network == b.network
}
pub fn dedup_by_selector(rules: Vec<FirewallRule>) -> Vec<FirewallRule> {
let mut deduped: Vec<FirewallRule> = Vec::with_capacity(rules.len());
for rule in rules.into_iter().rev() {
if !deduped.iter().any(|r| same_selector(r, &rule)) {
deduped.push(rule);
}
}
deduped.reverse();
deduped
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FirewallRule {
pub direction: Direction,
pub action: Action,
pub protocol: Protocol,
pub port: Option<PortRange>,
pub peer: PeerFilter,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub network: Option<String>,
#[serde(default, skip_serializing_if = "RuleOrigin::is_local")]
pub origin: RuleOrigin,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FirewallConfig {
#[serde(default = "default_inbound_action")]
pub default_inbound: Action,
#[serde(default = "default_outbound_action")]
pub default_outbound: Action,
#[serde(default)]
pub reject: bool,
pub rules: Vec<FirewallRule>,
}
fn default_inbound_action() -> Action {
Action::Deny
}
fn default_outbound_action() -> Action {
Action::Allow
}
pub fn default_icmp_rule() -> FirewallRule {
FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Icmp,
port: None,
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}
}
impl Default for FirewallConfig {
fn default() -> Self {
Self {
default_inbound: default_inbound_action(),
default_outbound: default_outbound_action(),
reject: false,
rules: vec![default_icmp_rule()],
}
}
}
const TCP_FLOW_TIMEOUT: Duration = Duration::from_secs(300);
const UDP_FLOW_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
struct Flow {
proto: u8,
local_ip: IpAddr,
local_port: u16,
peer_ip: IpAddr,
peer_port: u16,
icmp_id: u16,
}
#[derive(Clone)]
pub struct SharedFirewall {
inner: Arc<ArcSwap<FirewallConfig>>,
conntrack: Arc<FastDashMap<Flow, Instant>>,
}
impl SharedFirewall {
pub fn new(config: FirewallConfig) -> Self {
Self {
inner: Arc::new(ArcSwap::from_pointee(config)),
conntrack: Arc::new(FastDashMap::default()),
}
}
fn match_rule(
&self,
direction: Direction,
protocol: u8,
dst_port: u16,
peer: &EndpointId,
network: Option<&str>,
) -> Option<Action> {
let config = self.inner.load();
for rule in &config.rules {
if rule.direction != direction {
continue;
}
if let Some(ref rule_net) = rule.network
&& Some(rule_net.as_str()) != network
{
continue;
}
if !protocol_matches(rule.protocol, protocol) {
continue;
}
if let Some(ref range) = rule.port
&& !range.contains(dst_port)
{
continue;
}
match &rule.peer {
PeerFilter::Any => {}
PeerFilter::Identity(id) => {
if id != peer {
continue;
}
}
}
return Some(rule.action);
}
None
}
fn default_for(&self, direction: Direction) -> Action {
let config = self.inner.load();
match direction {
Direction::Out => config.default_outbound,
Direction::In => config.default_inbound,
}
}
#[allow(dead_code)]
pub fn evaluate(
&self,
direction: Direction,
protocol: u8,
dst_port: u16,
peer: &EndpointId,
) -> Action {
self.match_rule(direction, protocol, dst_port, peer, None)
.unwrap_or_else(|| self.default_for(direction))
}
pub fn reject_enabled(&self) -> bool {
self.inner.load().reject
}
pub fn evaluate_packet(
&self,
direction: Direction,
info: &PacketInfo,
peer: &EndpointId,
network: Option<&str>,
) -> Action {
let proto = info.protocol;
let (local_ip, local_port, peer_ip, peer_port) = match direction {
Direction::Out => (info.src_ip, info.src_port, info.dst_ip, info.dst_port),
Direction::In => (info.dst_ip, info.dst_port, info.src_ip, info.src_port),
};
let flow = Flow {
proto,
local_ip,
local_port,
peer_ip,
peer_port,
icmp_id: info.icmp_id,
};
if let Some(action) = self.match_rule(direction, proto, info.dst_port, peer, network) {
if direction == Direction::Out && action.is_allow() {
self.track_outbound(&flow, info);
}
return action;
}
match direction {
Direction::Out => {
let default = self.default_for(Direction::Out);
if default.is_allow() {
self.track_outbound(&flow, info);
}
default
}
Direction::In => {
let conntrack_eligible =
!is_icmp(proto) || is_icmp_echo_reply(proto, info.icmp_type);
if conntrack_eligible && self.flow_active(&flow) {
self.conntrack.insert(flow, Instant::now());
Action::Allow
} else {
self.default_for(Direction::In)
}
}
}
}
fn track_outbound(&self, flow: &Flow, info: &PacketInfo) {
if flow.proto == 6 {
let fin = info.tcp_flags & 0x01 != 0;
let rst = info.tcp_flags & 0x04 != 0;
if fin || rst {
self.conntrack.remove(flow);
return;
}
}
if is_icmp(flow.proto) && !is_icmp_echo_request(flow.proto, info.icmp_type) {
return;
}
self.conntrack.insert(*flow, Instant::now());
}
fn flow_active(&self, flow: &Flow) -> bool {
let timeout = if flow.proto == 6 {
TCP_FLOW_TIMEOUT
} else {
UDP_FLOW_TIMEOUT
};
if let Some(ts) = self.conntrack.get(flow)
&& ts.elapsed() < timeout
{
return true;
}
false
}
pub fn spawn_evictor(self, token: CancellationToken) {
tokio::spawn(async move {
loop {
tokio::select! {
_ = token.cancelled() => return,
_ = tokio::time::sleep(Duration::from_secs(60)) => {
let now = Instant::now();
self.conntrack.retain(|flow, ts| {
let timeout = if flow.proto == 6 { TCP_FLOW_TIMEOUT } else { UDP_FLOW_TIMEOUT };
now.duration_since(*ts) < timeout
});
}
}
}
});
}
pub fn update(&self, config: FirewallConfig) {
self.inner.store(Arc::new(config));
}
pub fn get_config(&self) -> Arc<FirewallConfig> {
self.inner.load_full()
}
pub fn replace_network_rules(&self, net: &str, new_rules: Vec<FirewallRule>) -> FirewallConfig {
let mut config = (*self.get_config()).clone();
config
.rules
.retain(|r| !matches!(&r.origin, RuleOrigin::Network(n) if n == net));
config.rules.extend(new_rules);
self.update(config.clone());
config
}
pub fn set_ssh_passthrough(&self, enabled: bool) -> FirewallConfig {
let mut config = (*self.get_config()).clone();
config
.rules
.retain(|r| !matches!(&r.origin, RuleOrigin::Ssh));
if enabled {
config.rules.insert(0, ssh_passthrough_rule());
}
self.update(config.clone());
config
}
}
fn protocol_matches(filter: Protocol, ip_proto: u8) -> bool {
match filter {
Protocol::Any => true,
Protocol::Tcp => ip_proto == 6,
Protocol::Udp => ip_proto == 17,
Protocol::Icmp => ip_proto == 1 || ip_proto == 58, }
}
#[derive(Clone, Copy)]
pub struct PacketInfo {
pub src_ip: IpAddr,
pub dst_ip: IpAddr,
pub protocol: u8,
pub src_port: u16,
pub dst_port: u16,
pub tcp_flags: u8,
pub icmp_type: u8,
pub icmp_id: u16,
}
fn is_icmp(proto: u8) -> bool {
proto == 1 || proto == 58
}
fn is_icmp_echo_request(proto: u8, icmp_type: u8) -> bool {
(proto == 1 && icmp_type == 8) || (proto == 58 && icmp_type == 128)
}
fn is_icmp_echo_reply(proto: u8, icmp_type: u8) -> bool {
(proto == 1 && icmp_type == 0) || (proto == 58 && icmp_type == 129)
}
pub fn parse_packet_info(packet: &[u8]) -> Option<PacketInfo> {
if packet.is_empty() {
return None;
}
match packet[0] >> 4 {
4 => parse_ipv4(packet),
6 => parse_ipv6(packet),
_ => None,
}
}
fn parse_ipv4(packet: &[u8]) -> Option<PacketInfo> {
if packet.len() < 20 {
return None;
}
let ihl = (packet[0] & 0x0F) as usize;
let header_len = ihl * 4;
if packet.len() < header_len {
return None;
}
let protocol = packet[9];
let src_ip = IpAddr::V4(Ipv4Addr::new(
packet[12], packet[13], packet[14], packet[15],
));
let dst_ip = IpAddr::V4(Ipv4Addr::new(
packet[16], packet[17], packet[18], packet[19],
));
let (src_port, dst_port) = extract_ports(protocol, packet, header_len);
let tcp_flags = extract_tcp_flags(protocol, packet, header_len);
let (icmp_type, icmp_id) = extract_icmp(protocol, packet, header_len);
Some(PacketInfo {
src_ip,
dst_ip,
protocol,
src_port,
dst_port,
tcp_flags,
icmp_type,
icmp_id,
})
}
fn parse_ipv6(packet: &[u8]) -> Option<PacketInfo> {
if packet.len() < 40 {
return None;
}
let protocol = packet[6]; let mut src_octets = [0u8; 16];
let mut dst_octets = [0u8; 16];
src_octets.copy_from_slice(&packet[8..24]);
dst_octets.copy_from_slice(&packet[24..40]);
let src_ip = IpAddr::V6(Ipv6Addr::from(src_octets));
let dst_ip = IpAddr::V6(Ipv6Addr::from(dst_octets));
let header_len = 40; let (src_port, dst_port) = extract_ports(protocol, packet, header_len);
let tcp_flags = extract_tcp_flags(protocol, packet, header_len);
let (icmp_type, icmp_id) = extract_icmp(protocol, packet, header_len);
Some(PacketInfo {
src_ip,
dst_ip,
protocol,
src_port,
dst_port,
tcp_flags,
icmp_type,
icmp_id,
})
}
fn extract_ports(protocol: u8, packet: &[u8], header_len: usize) -> (u16, u16) {
if (protocol == 6 || protocol == 17) && packet.len() >= header_len + 4 {
(
u16::from_be_bytes([packet[header_len], packet[header_len + 1]]),
u16::from_be_bytes([packet[header_len + 2], packet[header_len + 3]]),
)
} else {
(0, 0)
}
}
fn extract_tcp_flags(protocol: u8, packet: &[u8], header_len: usize) -> u8 {
if protocol == 6 && packet.len() >= header_len + 14 {
packet[header_len + 13]
} else {
0
}
}
fn extract_icmp(protocol: u8, packet: &[u8], header_len: usize) -> (u8, u16) {
if !is_icmp(protocol) || packet.len() < header_len + 1 {
return (0, 0);
}
let icmp_type = packet[header_len];
let id = if (is_icmp_echo_request(protocol, icmp_type)
|| is_icmp_echo_reply(protocol, icmp_type))
&& packet.len() >= header_len + 6
{
u16::from_be_bytes([packet[header_len + 4], packet[header_len + 5]])
} else {
0
};
(icmp_type, id)
}
pub fn firewall_path() -> Result<PathBuf> {
Ok(crate::config::config_dir()?.join("firewall.toml"))
}
pub fn load_firewall() -> Result<FirewallConfig> {
let path = firewall_path()?;
if !path.exists() {
return Ok(FirewallConfig::default());
}
let content =
std::fs::read_to_string(&path).with_context(|| format!("read {}", path.display()))?;
toml::from_str(&content).with_context(|| format!("parse {}", path.display()))
}
pub fn save_firewall(config: &FirewallConfig) -> Result<()> {
let path = firewall_path()?;
let content = toml::to_string_pretty(config).context("serialize firewall config")?;
crate::config::write_file(&path, content.as_bytes(), false)
}
pub fn parse_port_range(s: &str) -> Result<PortRange> {
if s.trim() == "*" {
return Ok(PortRange {
start: 0,
end: u16::MAX,
});
}
if let Some((start, end)) = s.split_once('-') {
let start: u16 = start.parse().context("invalid start port")?;
let end: u16 = end.parse().context("invalid end port")?;
if start > end {
bail!("start port ({start}) must be <= end port ({end})");
}
Ok(PortRange { start, end })
} else {
let port: u16 = s.parse().context("invalid port number")?;
Ok(PortRange {
start: port,
end: port,
})
}
}
pub fn parse_port_list(s: &str) -> Result<Vec<PortRange>> {
let ranges = s
.split(',')
.map(str::trim)
.filter(|i| !i.is_empty())
.map(parse_port_range)
.collect::<Result<Vec<_>>>()?;
if ranges.is_empty() {
bail!("no valid port given");
}
Ok(ranges)
}
pub fn parse_spec_token(tok: &str) -> Result<(Protocol, Option<PortRange>)> {
let tok = tok.trim();
match tok.split_once(':') {
Some((proto_str, port_str)) => {
let proto = proto_str.parse::<Protocol>().map_err(anyhow::Error::msg)?;
match proto {
Protocol::Icmp | Protocol::Any => Ok((proto, None)),
Protocol::Tcp | Protocol::Udp => {
let range = parse_port_range(port_str)?;
Ok((proto, Some(range)))
}
}
}
None => {
if tok.parse::<u16>().is_ok() {
bail!("missing protocol prefix for '{tok}'; use e.g. 'tcp:{tok}' or 'icmp'");
}
let proto = tok.parse::<Protocol>().map_err(anyhow::Error::msg)?;
match proto {
Protocol::Icmp | Protocol::Any => Ok((proto, None)),
Protocol::Tcp | Protocol::Udp => Ok((
proto,
Some(PortRange {
start: 0,
end: u16::MAX,
}),
)),
}
}
}
}
pub fn materialize_suggestions(
net: &str,
my_hostname: &str,
suggestions: &SuggestedFirewall,
resolve: &dyn Fn(&str) -> Option<EndpointId>,
) -> Vec<FirewallRule> {
let mut rules = Vec::new();
let mut keys = vec!["*"];
if my_hostname != "*" {
keys.push(my_hostname);
}
let applicable: Vec<_> = keys.iter().filter_map(|k| suggestions.get(*k)).collect();
if applicable.is_empty() {
return rules;
}
for host in &applicable {
for (action, list) in [(Action::Allow, &host.allows), (Action::Deny, &host.denies)] {
for (peer, ports) in list {
let filter = if peer == "*" {
PeerFilter::Any
} else {
match resolve(peer) {
Some(id) => PeerFilter::Identity(id),
None => continue,
}
};
for tok in ports.split(',').map(str::trim).filter(|s| !s.is_empty()) {
match parse_spec_token(tok) {
Ok((proto, port)) => rules.push(FirewallRule {
direction: Direction::In,
action,
protocol: proto,
port,
peer: filter.clone(),
network: Some(net.to_string()),
origin: RuleOrigin::Network(net.to_string()),
}),
Err(e) => tracing::warn!(
token = %tok, network = %net, error = %e,
"skipping invalid firewall spec token"
),
}
}
}
}
}
rules
}
pub fn rule_view(
rule: &FirewallRule,
short_id: &dyn Fn(&EndpointId) -> String,
) -> FirewallRuleView {
let peer = match &rule.peer {
PeerFilter::Any => "any".to_string(),
PeerFilter::Identity(id) => short_id(id),
};
let port = match &rule.port {
None => "*".to_string(),
Some(r) if r.start == r.end => r.start.to_string(),
Some(r) => format!("{}-{}", r.start, r.end),
};
let network = rule.network.clone().unwrap_or_else(|| "any".to_string());
let suggested_by = match &rule.origin {
RuleOrigin::Local => None,
RuleOrigin::Network(n) => Some(n.clone()),
RuleOrigin::Ssh => Some("ssh".to_string()),
};
FirewallRuleView {
direction: rule.direction,
action: rule.action,
protocol: rule.protocol,
port,
peer,
network,
suggested_by,
}
}
pub fn rule_views(
rules: &[FirewallRule],
short_id: &dyn Fn(&EndpointId) -> String,
) -> Vec<FirewallRuleView> {
rules.iter().map(|r| rule_view(r, short_id)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_id(seed: u8) -> EndpointId {
let mut key_bytes = [0u8; 32];
key_bytes[0] = seed;
iroh::SecretKey::from(key_bytes).public()
}
#[test]
fn parse_valid_ipv4_tcp() {
let mut pkt = vec![0u8; 40];
pkt[0] = 0x45; pkt[9] = 6; pkt[12..16].copy_from_slice(&[10, 0, 0, 1]);
pkt[16..20].copy_from_slice(&[10, 0, 0, 2]);
pkt[20] = 0x1F; pkt[21] = 0x90;
pkt[22] = 0x01; pkt[23] = 0xBB;
let info = parse_packet_info(&pkt).unwrap();
assert_eq!(info.src_ip, Ipv4Addr::new(10, 0, 0, 1));
assert_eq!(info.dst_ip, Ipv4Addr::new(10, 0, 0, 2));
assert_eq!(info.protocol, 6);
assert_eq!(info.src_port, 8080);
assert_eq!(info.dst_port, 443);
}
#[test]
fn parse_udp_packet() {
let mut pkt = vec![0u8; 28];
pkt[0] = 0x45;
pkt[9] = 17; pkt[20] = 0x00;
pkt[21] = 53; pkt[22] = 0x04;
pkt[23] = 0xD2;
let info = parse_packet_info(&pkt).unwrap();
assert_eq!(info.protocol, 17);
assert_eq!(info.src_port, 53);
assert_eq!(info.dst_port, 1234);
}
#[test]
fn parse_icmp_no_ports() {
let mut pkt = vec![0u8; 28];
pkt[0] = 0x45;
pkt[9] = 1;
let info = parse_packet_info(&pkt).unwrap();
assert_eq!(info.protocol, 1);
assert_eq!(info.src_port, 0);
assert_eq!(info.dst_port, 0);
}
#[test]
fn parse_too_short() {
assert!(parse_packet_info(&[0x45; 10]).is_none());
}
#[test]
fn parse_ipv6_basic() {
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60; pkt[6] = 17; pkt[24] = 0x02; let info = parse_packet_info(&pkt).unwrap();
assert!(info.dst_ip.is_ipv6());
assert_eq!(info.protocol, 17);
}
#[test]
fn parse_not_ip() {
let pkt = vec![0x30; 40]; assert!(parse_packet_info(&pkt).is_none());
}
#[test]
fn default_config_is_secure_inbound() {
let fw = SharedFirewall::new(FirewallConfig::default());
assert_eq!(fw.evaluate(Direction::In, 6, 22, &test_id(1)), Action::Deny);
assert_eq!(
fw.evaluate(Direction::In, 17, 53, &test_id(1)),
Action::Deny
);
assert_eq!(fw.evaluate(Direction::In, 1, 0, &test_id(1)), Action::Allow);
assert_eq!(
fw.evaluate(Direction::In, 58, 0, &test_id(1)),
Action::Allow
);
assert_eq!(
fw.evaluate(Direction::Out, 6, 443, &test_id(1)),
Action::Allow
);
assert_eq!(
fw.evaluate(Direction::Out, 17, 53, &test_id(1)),
Action::Allow
);
}
#[test]
fn default_config_seeds_one_removable_icmp_rule() {
let config = FirewallConfig::default();
assert_eq!(config.rules.len(), 1);
let r = &config.rules[0];
assert_eq!(r.direction, Direction::In);
assert_eq!(r.action, Action::Allow);
assert_eq!(r.protocol, Protocol::Icmp);
assert_eq!(r.peer, PeerFilter::Any);
assert_eq!(r.origin, RuleOrigin::Local); }
#[test]
fn removing_seeded_icmp_rule_denies_icmp() {
let mut config = FirewallConfig::default();
config.rules.clear();
let fw = SharedFirewall::new(config);
assert_eq!(fw.evaluate(Direction::In, 1, 0, &test_id(1)), Action::Deny);
assert_eq!(fw.evaluate(Direction::In, 58, 0, &test_id(1)), Action::Deny);
}
#[test]
fn same_selector_ignores_action_and_origin_but_not_match_fields() {
let allow_icmp = default_icmp_rule();
let deny_icmp = FirewallRule {
action: Action::Deny,
..default_icmp_rule()
};
assert!(same_selector(&allow_icmp, &deny_icmp));
let deny_icmp_peer = FirewallRule {
action: Action::Deny,
peer: PeerFilter::Identity(test_id(7)),
..default_icmp_rule()
};
assert!(!same_selector(&allow_icmp, &deny_icmp_peer));
}
#[test]
fn merge_same_selector_then_prepend_makes_latest_win() {
let mut rules = vec![default_icmp_rule()];
let deny_icmp = FirewallRule {
action: Action::Deny,
..default_icmp_rule()
};
rules.retain(|r| !same_selector(r, &deny_icmp));
rules.insert(0, deny_icmp);
assert_eq!(rules.len(), 1, "merged, not accumulated");
let fw = SharedFirewall::new(FirewallConfig {
rules,
..FirewallConfig::default()
});
assert_eq!(fw.evaluate(Direction::In, 1, 0, &test_id(1)), Action::Deny);
}
#[test]
fn dedup_by_selector_keeps_newest_and_collapses_duplicates() {
let installed_22 = FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: Some("homelab".into()),
origin: RuleOrigin::Network("homelab".into()),
};
let installed_80 = FirewallRule {
port: Some(PortRange {
start: 80,
end: 443,
}),
..installed_22.clone()
};
let reaccepted_22 = FirewallRule {
action: Action::Deny,
..installed_22.clone()
};
let out = dedup_by_selector(vec![
installed_80.clone(),
installed_22.clone(),
reaccepted_22.clone(),
]);
assert_eq!(out.len(), 2, "one rule per selector");
assert!(out.iter().any(|r| r == &installed_80));
assert!(out.iter().any(|r| r == &reaccepted_22));
assert!(!out.iter().any(|r| r == &installed_22));
}
#[test]
fn default_allow_override_permits_inbound_tcp() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Allow,
..FirewallConfig::default()
});
assert_eq!(
fw.evaluate(Direction::In, 6, 22, &test_id(1)),
Action::Allow
);
}
#[test]
fn explicit_deny_in_icmp_ordered_before_seed_wins() {
let deny_icmp = FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Icmp,
port: None,
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
};
let fw = SharedFirewall::new(FirewallConfig {
rules: vec![deny_icmp, default_icmp_rule()],
..FirewallConfig::default()
});
assert_eq!(fw.evaluate(Direction::In, 1, 0, &test_id(1)), Action::Deny);
assert_eq!(fw.evaluate(Direction::In, 58, 0, &test_id(1)), Action::Deny);
}
#[test]
fn default_config_denies_unsolicited_inbound_but_allows_return() {
let fw = SharedFirewall::new(FirewallConfig::default());
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let unsolicited = tcp_pkt(peer, 51000, me, 8080, SYN);
assert_eq!(
fw.evaluate_packet(Direction::In, &unsolicited, &peer_id, None),
Action::Deny
);
let out = tcp_pkt(me, 50000, peer, 443, SYN);
assert_eq!(
fw.evaluate_packet(Direction::Out, &out, &peer_id, None),
Action::Allow
);
let ret = tcp_pkt(peer, 443, me, 50000, SYN | ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &ret, &peer_id, None),
Action::Allow
);
let mut icmp = vec![0u8; 28];
icmp[0] = 0x45;
icmp[9] = 1; icmp[12..16].copy_from_slice(&peer.octets());
icmp[16..20].copy_from_slice(&me.octets());
let icmp = parse_packet_info(&icmp).unwrap();
assert_eq!(
fw.evaluate_packet(Direction::In, &icmp, &peer_id, None),
Action::Allow
);
}
#[test]
fn evaluate_default_deny() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![],
});
assert_eq!(fw.evaluate(Direction::In, 6, 22, &test_id(1)), Action::Deny);
}
#[test]
fn evaluate_deny_specific_port() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Allow,
default_outbound: Action::Allow,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
assert_eq!(fw.evaluate(Direction::In, 6, 22, &test_id(1)), Action::Deny);
assert_eq!(
fw.evaluate(Direction::In, 6, 80, &test_id(1)),
Action::Allow
);
assert_eq!(
fw.evaluate(Direction::Out, 6, 22, &test_id(1)),
Action::Allow
);
}
#[test]
fn rule_scoped_to_arrival_network() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Allow,
default_outbound: Action::Allow,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: Some("db".to_string()),
origin: RuleOrigin::Local,
}],
});
let info = PacketInfo {
src_ip: std::net::IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
dst_ip: std::net::IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
protocol: 6,
src_port: 40000,
dst_port: 22,
tcp_flags: 0,
icmp_type: 0,
icmp_id: 0,
};
let peer = test_id(1);
assert_eq!(
fw.evaluate_packet(Direction::In, &info, &peer, Some("db")),
Action::Deny
);
assert_eq!(
fw.evaluate_packet(Direction::In, &info, &peer, Some("dev")),
Action::Allow
);
assert_eq!(
fw.evaluate_packet(Direction::In, &info, &peer, None),
Action::Allow
);
}
#[test]
fn evaluate_port_range() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Any,
port: Some(PortRange {
start: 80,
end: 443,
}),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
assert_eq!(
fw.evaluate(Direction::In, 6, 80, &test_id(1)),
Action::Allow
);
assert_eq!(
fw.evaluate(Direction::In, 17, 443, &test_id(1)),
Action::Allow
);
assert_eq!(fw.evaluate(Direction::In, 6, 22, &test_id(1)), Action::Deny);
}
#[test]
fn evaluate_peer_filter() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Any,
port: None,
peer: PeerFilter::Identity(test_id(1)),
network: None,
origin: RuleOrigin::Local,
}],
});
assert_eq!(
fw.evaluate(Direction::In, 6, 22, &test_id(1)),
Action::Allow
);
assert_eq!(fw.evaluate(Direction::In, 6, 22, &test_id(2)), Action::Deny);
}
#[test]
fn evaluate_first_match_wins() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![
FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
},
FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Any,
port: None,
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
},
],
});
assert_eq!(fw.evaluate(Direction::In, 6, 22, &test_id(1)), Action::Deny);
assert_eq!(
fw.evaluate(Direction::In, 6, 80, &test_id(1)),
Action::Allow
);
}
#[test]
fn port_range_parsing() {
let r = parse_port_range("80").unwrap();
assert_eq!(r, PortRange { start: 80, end: 80 });
let r = parse_port_range("80-443").unwrap();
assert_eq!(
r,
PortRange {
start: 80,
end: 443
}
);
assert!(parse_port_range("443-80").is_err());
assert!(parse_port_range("abc").is_err());
assert_eq!(
parse_port_range("*").unwrap(),
PortRange {
start: 0,
end: u16::MAX
}
);
}
#[test]
fn port_list_parsing() {
assert_eq!(
parse_port_list("80").unwrap(),
vec![PortRange { start: 80, end: 80 }]
);
assert_eq!(
parse_port_list("80,443,8000-9000").unwrap(),
vec![
PortRange { start: 80, end: 80 },
PortRange {
start: 443,
end: 443
},
PortRange {
start: 8000,
end: 9000
},
]
);
assert_eq!(
parse_port_list(" 22 , 80 ,").unwrap(),
vec![
PortRange { start: 22, end: 22 },
PortRange { start: 80, end: 80 },
]
);
assert!(parse_port_list("80,abc").is_err());
assert!(parse_port_list(",").is_err());
}
#[test]
fn config_serialization_roundtrip() {
let config = FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange {
start: 443,
end: 443,
}),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
};
let toml_str = toml::to_string_pretty(&config).unwrap();
let decoded: FirewallConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(decoded.default_inbound, Action::Deny);
assert_eq!(decoded.default_outbound, Action::Deny);
assert_eq!(decoded.rules.len(), 1);
assert_eq!(decoded.rules[0].port.as_ref().unwrap().start, 443);
}
const SYN: u8 = 0x02;
const ACK: u8 = 0x10;
const FIN: u8 = 0x01;
const RST: u8 = 0x04;
fn tcp_pkt(
src: Ipv4Addr,
src_port: u16,
dst: Ipv4Addr,
dst_port: u16,
flags: u8,
) -> PacketInfo {
let mut p = vec![0u8; 40];
p[0] = 0x45; p[9] = 6; p[12..16].copy_from_slice(&src.octets());
p[16..20].copy_from_slice(&dst.octets());
p[20] = (src_port >> 8) as u8;
p[21] = src_port as u8;
p[22] = (dst_port >> 8) as u8;
p[23] = dst_port as u8;
p[32] = 0x50; p[33] = flags;
parse_packet_info(&p).unwrap()
}
fn udp_pkt(src: Ipv4Addr, src_port: u16, dst: Ipv4Addr, dst_port: u16) -> PacketInfo {
let mut p = vec![0u8; 28];
p[0] = 0x45;
p[9] = 17; p[12..16].copy_from_slice(&src.octets());
p[16..20].copy_from_slice(&dst.octets());
p[20] = (src_port >> 8) as u8;
p[21] = src_port as u8;
p[22] = (dst_port >> 8) as u8;
p[23] = dst_port as u8;
parse_packet_info(&p).unwrap()
}
#[test]
fn default_allow_plus_deny_in_22_blocks_ssh_but_allows_return() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Allow,
default_outbound: Action::Allow,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let inbound_ssh = tcp_pkt(peer, 51000, me, 22, SYN);
assert_eq!(
fw.evaluate_packet(Direction::In, &inbound_ssh, &peer_id, None),
Action::Deny
);
let outbound_ssh = tcp_pkt(me, 54321, peer, 22, SYN);
assert_eq!(
fw.evaluate_packet(Direction::Out, &outbound_ssh, &peer_id, None),
Action::Allow
);
let ret = tcp_pkt(peer, 22, me, 54321, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &ret, &peer_id, None),
Action::Allow
);
}
#[test]
fn default_deny_allows_return_traffic_for_initiated_connections() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::Out,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange {
start: 443,
end: 443,
}),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let syn = tcp_pkt(me, 50000, peer, 443, SYN);
assert_eq!(
fw.evaluate_packet(Direction::Out, &syn, &peer_id, None),
Action::Allow
);
let ret = tcp_pkt(peer, 443, me, 50000, SYN | ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &ret, &peer_id, None),
Action::Allow
);
let unsolicited = tcp_pkt(peer, 1234, me, 8080, SYN);
assert_eq!(
fw.evaluate_packet(Direction::In, &unsolicited, &peer_id, None),
Action::Deny
);
let blocked_out = tcp_pkt(me, 40000, peer, 6667, SYN);
assert_eq!(
fw.evaluate_packet(Direction::Out, &blocked_out, &peer_id, None),
Action::Deny
);
let blocked_ret = tcp_pkt(peer, 6667, me, 40000, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &blocked_ret, &peer_id, None),
Action::Deny
);
}
#[test]
fn tcp_fin_evicts_flow_so_return_traffic_stops() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::Out,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange {
start: 443,
end: 443,
}),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(2);
let syn = tcp_pkt(me, 50000, peer, 443, SYN);
assert_eq!(
fw.evaluate_packet(Direction::Out, &syn, &peer_id, None),
Action::Allow
);
let ret = tcp_pkt(peer, 443, me, 50000, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &ret, &peer_id, None),
Action::Allow
);
let fin = tcp_pkt(me, 50000, peer, 443, FIN | ACK);
assert_eq!(
fw.evaluate_packet(Direction::Out, &fin, &peer_id, None),
Action::Allow
);
let after = tcp_pkt(peer, 443, me, 50000, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &after, &peer_id, None),
Action::Deny
);
}
#[test]
fn udp_return_traffic_tracked_within_flow() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::Out,
action: Action::Allow,
protocol: Protocol::Udp,
port: Some(PortRange { start: 53, end: 53 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(3);
let q = udp_pkt(me, 53000, peer, 53);
assert_eq!(
fw.evaluate_packet(Direction::Out, &q, &peer_id, None),
Action::Allow
);
let resp = udp_pkt(peer, 53, me, 53000);
assert_eq!(
fw.evaluate_packet(Direction::In, &resp, &peer_id, None),
Action::Allow
);
let unsolicited = udp_pkt(peer, 9999, me, 53);
assert_eq!(
fw.evaluate_packet(Direction::In, &unsolicited, &peer_id, None),
Action::Deny
);
}
#[test]
fn explicit_inbound_rule_still_wins_over_established() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Allow,
default_outbound: Action::Allow,
reject: false,
rules: vec![FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Tcp,
port: None,
peer: PeerFilter::Identity(test_id(9)),
network: None,
origin: RuleOrigin::Local,
}],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let bad_peer = test_id(9);
let syn = tcp_pkt(me, 50000, peer, 443, SYN);
fw.evaluate_packet(Direction::Out, &syn, &bad_peer, None); let ret = tcp_pkt(peer, 443, me, 50000, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &ret, &bad_peer, None),
Action::Deny
);
}
#[test]
fn parse_packet_extracts_tcp_flags() {
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let syn = tcp_pkt(me, 1000, peer, 443, SYN);
assert_eq!(syn.tcp_flags & SYN, SYN);
assert_eq!(syn.tcp_flags & ACK, 0);
let synack = tcp_pkt(peer, 443, me, 1000, SYN | ACK);
assert_eq!(synack.tcp_flags & (SYN | ACK), SYN | ACK);
}
#[test]
fn tcp_rst_evicts_flow() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Deny,
reject: false,
rules: vec![FirewallRule {
direction: Direction::Out,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange {
start: 443,
end: 443,
}),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
}],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(4);
let syn = tcp_pkt(me, 50000, peer, 443, SYN);
assert_eq!(
fw.evaluate_packet(Direction::Out, &syn, &peer_id, None),
Action::Allow
);
let ret = tcp_pkt(peer, 443, me, 50000, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &ret, &peer_id, None),
Action::Allow
);
let rst = tcp_pkt(me, 50000, peer, 443, RST | ACK);
assert_eq!(
fw.evaluate_packet(Direction::Out, &rst, &peer_id, None),
Action::Allow
);
let after = tcp_pkt(peer, 443, me, 50000, ACK);
assert_eq!(
fw.evaluate_packet(Direction::In, &after, &peer_id, None),
Action::Deny
);
}
const ECHO_REQUEST_V4: u8 = 8;
const ECHO_REPLY_V4: u8 = 0;
const ECHO_REQUEST_V6: u8 = 128;
const ECHO_REPLY_V6: u8 = 129;
fn icmp_pkt(src: Ipv4Addr, dst: Ipv4Addr, icmp_type: u8, id: u16) -> PacketInfo {
let mut p = vec![0u8; 28];
p[0] = 0x45; p[9] = 1; p[12..16].copy_from_slice(&src.octets());
p[16..20].copy_from_slice(&dst.octets());
p[20] = icmp_type; p[24] = (id >> 8) as u8; p[25] = id as u8;
parse_packet_info(&p).unwrap()
}
fn icmp6_pkt(src: Ipv6Addr, dst: Ipv6Addr, icmp_type: u8, id: u16) -> PacketInfo {
let mut p = vec![0u8; 48];
p[0] = 0x60; p[6] = 58; p[8..24].copy_from_slice(&src.octets());
p[24..40].copy_from_slice(&dst.octets());
p[40] = icmp_type; p[44] = (id >> 8) as u8; p[45] = id as u8;
parse_packet_info(&p).unwrap()
}
#[test]
fn inbound_echo_request_not_masked_by_prior_outbound_ping() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Allow,
reject: false,
rules: vec![], });
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let out_req = icmp_pkt(me, peer, ECHO_REQUEST_V4, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::Out, &out_req, &peer_id, None),
Action::Allow
);
let in_req = icmp_pkt(peer, me, ECHO_REQUEST_V4, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::In, &in_req, &peer_id, None),
Action::Deny
);
}
#[test]
fn inbound_echo_reply_allowed_for_our_outbound_ping() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Allow,
reject: false,
rules: vec![],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let out_req = icmp_pkt(me, peer, ECHO_REQUEST_V4, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::Out, &out_req, &peer_id, None),
Action::Allow
);
let in_reply = icmp_pkt(peer, me, ECHO_REPLY_V4, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::In, &in_reply, &peer_id, None),
Action::Allow
);
}
#[test]
fn inbound_echo_reply_for_unrelated_id_is_not_return_traffic() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Allow,
reject: false,
rules: vec![],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let out_req = icmp_pkt(me, peer, ECHO_REQUEST_V4, 0x1111);
fw.evaluate_packet(Direction::Out, &out_req, &peer_id, None);
let in_reply = icmp_pkt(peer, me, ECHO_REPLY_V4, 0x2222);
assert_eq!(
fw.evaluate_packet(Direction::In, &in_reply, &peer_id, None),
Action::Deny
);
}
#[test]
fn replying_to_a_ping_does_not_whitelist_the_pinger() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Allow,
reject: false,
rules: vec![],
});
let me = Ipv4Addr::new(100, 64, 0, 2);
let peer = Ipv4Addr::new(100, 64, 0, 3);
let peer_id = test_id(1);
let out_reply = icmp_pkt(me, peer, ECHO_REPLY_V4, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::Out, &out_reply, &peer_id, None),
Action::Allow
);
let in_req = icmp_pkt(peer, me, ECHO_REQUEST_V4, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::In, &in_req, &peer_id, None),
Action::Deny
);
}
#[test]
fn icmpv6_echo_request_not_masked_by_prior_outbound_ping() {
let fw = SharedFirewall::new(FirewallConfig {
default_inbound: Action::Deny,
default_outbound: Action::Allow,
reject: false,
rules: vec![],
});
let me = Ipv6Addr::new(0x2, 0, 0, 0, 0, 0, 0, 2);
let peer = Ipv6Addr::new(0x2, 0, 0, 0, 0, 0, 0, 3);
let peer_id = test_id(1);
let out_req = icmp6_pkt(me, peer, ECHO_REQUEST_V6, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::Out, &out_req, &peer_id, None),
Action::Allow
);
let in_req = icmp6_pkt(peer, me, ECHO_REQUEST_V6, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::In, &in_req, &peer_id, None),
Action::Deny
);
let in_reply = icmp6_pkt(peer, me, ECHO_REPLY_V6, 0x1234);
assert_eq!(
fw.evaluate_packet(Direction::In, &in_reply, &peer_id, None),
Action::Allow
);
}
use ray_proto::{HostSuggestions, SuggestedFirewall};
fn suggest(subject: &str, allows: &[(&str, &str)]) -> SuggestedFirewall {
let mut entry = HostSuggestions::default();
for (peer, ports) in allows {
entry
.allows
.insert((*peer).to_string(), (*ports).to_string());
}
let mut map = SuggestedFirewall::new();
map.insert(subject.to_string(), entry);
map
}
#[test]
fn materialize_resolves_peer_hostnames_and_expands_comma_ports() {
let me = test_id(1);
let peer = test_id(2);
let resolve = |h: &str| match h {
"me" => Some(me),
"peer" => Some(peer),
_ => None,
};
let suggestions = suggest("me", &[("peer", "tcp:9000,tcp:8123")]);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
let allows: Vec<_> = rules.iter().filter(|r| r.action == Action::Allow).collect();
assert_eq!(allows.len(), 2);
for r in &allows {
assert_eq!(r.direction, Direction::In);
assert_eq!(r.network.as_deref(), Some("prod"));
assert_eq!(r.origin, RuleOrigin::Network("prod".to_string()));
assert_eq!(r.peer, PeerFilter::Identity(peer));
assert_eq!(r.protocol, Protocol::Tcp);
}
let ports: Vec<u16> = allows
.iter()
.map(|r| r.port.as_ref().unwrap().start)
.collect();
assert!(ports.contains(&9000));
assert!(ports.contains(&8123));
}
#[test]
fn materialize_allow_list_is_additive_no_catch_all() {
let me = test_id(1);
let peer = test_id(2);
let resolve = |h: &str| match h {
"me" => Some(me),
"peer" => Some(peer),
_ => None,
};
let suggestions = suggest("me", &[("peer", "tcp:9000")]);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
assert_eq!(rules.len(), 1, "only the suggested allow rule");
assert_eq!(rules[0].action, Action::Allow);
assert_eq!(rules[0].peer, PeerFilter::Identity(peer));
assert!(
rules.iter().all(|r| r.action != Action::Deny),
"no catch-all deny should be appended"
);
}
#[test]
fn materialize_deny_only_blacklist_no_catch_all() {
let me = test_id(1);
let eve = test_id(3);
let resolve = |h: &str| match h {
"me" => Some(me),
"eve" => Some(eve),
_ => None,
};
let entry = HostSuggestions {
denies: [("eve".to_string(), "any".to_string())].into(),
..Default::default()
};
let mut suggestions = SuggestedFirewall::new();
suggestions.insert("me".to_string(), entry);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
assert_eq!(rules.len(), 1);
assert_eq!(rules[0].action, Action::Deny);
assert_eq!(rules[0].peer, PeerFilter::Identity(eve));
assert!(rules.iter().all(|r| r.peer != PeerFilter::Any));
}
#[test]
fn materialize_no_allow_list_no_default_keeps_open() {
let me = test_id(1);
let resolve = |h: &str| match h {
"me" => Some(me),
_ => None,
};
let mut suggestions = SuggestedFirewall::new();
suggestions.insert("me".to_string(), HostSuggestions::default());
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
assert!(rules.is_empty(), "expected no rules for an open subject");
}
#[test]
fn materialize_skips_unresolved_peers() {
let resolve = |_: &str| None;
let suggestions = suggest("me", &[("ghost", "tcp:9000")]);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
assert!(rules.is_empty());
}
#[test]
fn materialize_no_rules_for_unknown_subject() {
let me = test_id(1);
let resolve = |h: &str| match h {
"me" => Some(me),
_ => None,
};
let suggestions = suggest("other", &[("me", "tcp:9000")]);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
assert!(rules.is_empty());
}
#[test]
fn materialize_icmp_udp_and_wildcard_tokens() {
let me = test_id(1);
let peer = test_id(2);
let resolve = |h: &str| match h {
"me" => Some(me),
"peer" => Some(peer),
_ => None,
};
let suggestions = suggest("me", &[("peer", "icmp,tcp:*,udp:53,any")]);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
let allows: Vec<_> = rules.iter().filter(|r| r.action == Action::Allow).collect();
assert_eq!(allows.len(), 4);
let icmp = allows
.iter()
.find(|r| r.protocol == Protocol::Icmp)
.unwrap();
assert!(icmp.port.is_none(), "icmp rule must be port-less");
let tcp_any = allows.iter().find(|r| r.protocol == Protocol::Tcp).unwrap();
assert_eq!(tcp_any.port.as_ref().unwrap().end, u16::MAX);
let udp = allows.iter().find(|r| r.protocol == Protocol::Udp).unwrap();
assert_eq!(udp.port.as_ref().unwrap().start, 53);
let any = allows.iter().find(|r| r.protocol == Protocol::Any).unwrap();
assert!(any.port.is_none());
}
#[test]
fn materialize_wildcard_subject_applies_to_any_host() {
let me = test_id(1);
let resolve = |h: &str| match h {
"me" => Some(me),
_ => None,
};
let mut suggestions = SuggestedFirewall::new();
let mut entry = HostSuggestions::default();
entry.allows.insert("*".to_string(), "tcp:6969".to_string());
suggestions.insert("*".to_string(), entry);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
let allow = rules
.iter()
.find(|r| r.action == Action::Allow)
.expect("wildcard subject should materialize for any host");
assert_eq!(allow.peer, PeerFilter::Any, "`*` peer ⇒ any peer");
assert_eq!(allow.protocol, Protocol::Tcp);
assert_eq!(allow.port.as_ref().unwrap().start, 6969);
assert_eq!(allow.network.as_deref(), Some("prod"));
assert!(rules.iter().all(|r| r.action != Action::Deny));
}
#[test]
fn materialize_any_peer_star_bypasses_resolution() {
let resolve = |_: &str| None; let suggestions = suggest("me", &[("*", "udp:53")]);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
let allow = rules
.iter()
.find(|r| r.action == Action::Allow)
.expect("`*` peer must not be dropped by the resolver");
assert_eq!(allow.peer, PeerFilter::Any);
assert_eq!(allow.protocol, Protocol::Udp);
assert_eq!(allow.port.as_ref().unwrap().start, 53);
}
#[test]
fn materialize_merges_own_subject_and_wildcard() {
let me = test_id(1);
let peer = test_id(2);
let resolve = |h: &str| match h {
"me" => Some(me),
"peer" => Some(peer),
_ => None,
};
let mut suggestions = suggest("me", &[("peer", "tcp:22")]);
let mut wild = HostSuggestions::default();
wild.allows.insert("*".to_string(), "tcp:6969".to_string());
suggestions.insert("*".to_string(), wild);
let rules = materialize_suggestions("prod", "me", &suggestions, &resolve);
let allows: Vec<_> = rules.iter().filter(|r| r.action == Action::Allow).collect();
assert_eq!(allows.len(), 2, "own subject + wildcard subject");
assert!(
allows
.iter()
.any(|r| r.peer == PeerFilter::Identity(peer)
&& r.port.as_ref().unwrap().start == 22)
);
assert!(
allows
.iter()
.any(|r| r.peer == PeerFilter::Any && r.port.as_ref().unwrap().start == 6969)
);
}
#[test]
fn parse_spec_token_grammar() {
let (p, r) = parse_spec_token("tcp:22").unwrap();
assert_eq!(p, Protocol::Tcp);
assert_eq!(r.unwrap().start, 22);
let (p, r) = parse_spec_token("tcp:80-443").unwrap();
assert_eq!(p, Protocol::Tcp);
assert_eq!(r.unwrap().end, 443);
let (p, r) = parse_spec_token("udp:*").unwrap();
assert_eq!(p, Protocol::Udp);
assert_eq!(r.unwrap().end, u16::MAX);
let (p, r) = parse_spec_token("icmp").unwrap();
assert_eq!(p, Protocol::Icmp);
assert!(r.is_none());
let (p, r) = parse_spec_token("icmp:*").unwrap();
assert_eq!(p, Protocol::Icmp);
assert!(r.is_none());
let (p, r) = parse_spec_token("any").unwrap();
assert_eq!(p, Protocol::Any);
assert!(r.is_none());
let (p, r) = parse_spec_token("tcp").unwrap();
assert_eq!(p, Protocol::Tcp);
assert_eq!(r.unwrap().end, u16::MAX);
assert!(parse_spec_token("22").is_err());
assert!(parse_spec_token("foo:22").is_err());
assert!(parse_spec_token("tcp:").is_err());
}
#[test]
fn replace_network_rules_swaps_network_set_keeps_local() {
let local_rule = FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
};
let stale_net = FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange {
start: 9000,
end: 9000,
}),
peer: PeerFilter::Identity(test_id(9)),
network: Some("prod".to_string()),
origin: RuleOrigin::Network("prod".to_string()),
};
let other_net = FirewallRule {
direction: Direction::In,
action: Action::Allow,
protocol: Protocol::Tcp,
port: Some(PortRange {
start: 8080,
end: 8080,
}),
peer: PeerFilter::Any,
network: Some("dev".to_string()),
origin: RuleOrigin::Network("dev".to_string()),
};
let config = FirewallConfig {
default_inbound: Action::Allow,
default_outbound: Action::Allow,
reject: false,
rules: vec![local_rule.clone(), stale_net, other_net.clone()],
};
let fw = SharedFirewall::new(config);
let fresh = FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Any,
port: None,
peer: PeerFilter::Any,
network: Some("prod".to_string()),
origin: RuleOrigin::Network("prod".to_string()),
};
let updated = fw.replace_network_rules("prod", vec![fresh.clone()]);
assert!(updated.rules.iter().any(|r| r.origin == RuleOrigin::Local));
assert!(
updated
.rules
.iter()
.any(|r| matches!(&r.origin, RuleOrigin::Network(n) if n == "dev"))
);
let prod: Vec<_> = updated
.rules
.iter()
.filter(|r| matches!(&r.origin, RuleOrigin::Network(n) if n == "prod"))
.collect();
assert_eq!(prod.len(), 1);
assert_eq!(prod[0].action, Action::Deny);
assert_eq!(prod[0].peer, PeerFilter::Any);
}
#[test]
fn ssh_passthrough_toggles_a_single_managed_rule() {
let fw = SharedFirewall::new(FirewallConfig::default());
let local_22 = FirewallRule {
direction: Direction::In,
action: Action::Deny,
protocol: Protocol::Tcp,
port: Some(PortRange { start: 22, end: 22 }),
peer: PeerFilter::Any,
network: None,
origin: RuleOrigin::Local,
};
let mut cfg = (*fw.get_config()).clone();
cfg.rules.push(local_22.clone());
fw.update(cfg);
fw.set_ssh_passthrough(true);
let cfg = fw.set_ssh_passthrough(true);
let ssh_rules: Vec<_> = cfg
.rules
.iter()
.filter(|r| r.origin == RuleOrigin::Ssh)
.collect();
assert_eq!(ssh_rules.len(), 1);
assert_eq!(ssh_rules[0].action, Action::Allow);
assert_eq!(ssh_rules[0].port, Some(PortRange { start: 22, end: 22 }));
assert_eq!(cfg.rules[0].origin, RuleOrigin::Ssh);
assert!(cfg.rules.contains(&local_22));
let cfg = fw.set_ssh_passthrough(false);
assert!(!cfg.rules.iter().any(|r| r.origin == RuleOrigin::Ssh));
assert!(cfg.rules.contains(&local_22));
}
}