mod action;
mod error;
mod net;
mod port;
use std::{fmt::Display, ops::RangeInclusive};
pub use action::Action;
pub use error::RuleError;
use ipnet::IpNet;
use net::IntoIpNet;
use port::IntoPortRange;
pub use push_packet_common::Protocol;
#[non_exhaustive]
pub(crate) enum AddressFamily {
Any,
Ipv4,
Ipv6,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct RuleId(pub(crate) u32);
impl Display for RuleId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
pub struct Rule {
pub(crate) action: Action,
pub(crate) protocol: Option<Protocol>,
pub(crate) source_cidr: Option<IpNet>,
pub(crate) source_port: Option<RangeInclusive<u16>>,
pub(crate) destination_cidr: Option<IpNet>,
pub(crate) destination_port: Option<RangeInclusive<u16>>,
}
impl TryFrom<RuleBuilder> for Rule {
type Error = RuleError;
fn try_from(value: RuleBuilder) -> Result<Self, Self::Error> {
value.build()
}
}
impl Rule {
pub(crate) fn address_family(&self) -> AddressFamily {
match (self.source_cidr, self.destination_cidr) {
(Some(net), _) | (_, Some(net)) => match net {
IpNet::V4(_) => AddressFamily::Ipv4,
IpNet::V6(_) => AddressFamily::Ipv6,
},
_ => AddressFamily::Any,
}
}
#[must_use]
pub fn builder() -> RuleBuilder {
RuleBuilder::default()
}
#[must_use]
pub fn action(action: Action) -> RuleBuilder {
Rule::builder().action(action)
}
#[must_use]
pub fn protocol(protocol: Protocol) -> RuleBuilder {
Rule::builder().protocol(protocol)
}
pub fn source_cidr(cidr_range: impl IntoIpNet) -> RuleBuilder {
Rule::builder().source_cidr(cidr_range)
}
pub fn source_port(port: impl IntoPortRange) -> RuleBuilder {
Rule::builder().source_port(port)
}
pub fn destination_cidr(cidr_range: impl IntoIpNet) -> RuleBuilder {
Rule::builder().destination_cidr(cidr_range)
}
pub fn destination_port(port: impl IntoPortRange) -> RuleBuilder {
Rule::builder().destination_port(port)
}
}
#[derive(Default)]
pub struct RuleBuilder {
action: Option<Action>,
protocol: Option<Protocol>,
source_cidr: Option<Result<IpNet, RuleError>>,
source_port: Option<RangeInclusive<u16>>,
destination_cidr: Option<Result<IpNet, RuleError>>,
destination_port: Option<RangeInclusive<u16>>,
}
impl RuleBuilder {
#[must_use]
pub fn action(mut self, action: Action) -> Self {
self.action = Some(action);
self
}
#[must_use]
pub fn protocol(mut self, protocol: Protocol) -> Self {
self.protocol = Some(protocol);
self
}
#[must_use]
pub fn source_cidr(mut self, cidr_range: impl IntoIpNet) -> Self {
let source_cidr = cidr_range.into_ip_net();
self.source_cidr = Some(source_cidr);
self
}
#[must_use]
pub fn source_port(mut self, port: impl IntoPortRange) -> Self {
self.source_port = Some(port.into_port_range());
self
}
#[must_use]
pub fn destination_cidr(mut self, cidr_range: impl IntoIpNet) -> Self {
let destination_cidr = cidr_range.into_ip_net();
self.destination_cidr = Some(destination_cidr);
self
}
#[must_use]
pub fn destination_port(mut self, port: impl IntoPortRange) -> Self {
self.destination_port = Some(port.into_port_range());
self
}
pub fn build(self) -> Result<Rule, RuleError> {
let Self {
action,
protocol,
source_cidr,
source_port,
destination_cidr,
destination_port,
} = self;
let action = action.ok_or(RuleError::MissingAction)?;
if protocol.is_none()
&& source_cidr.is_none()
&& source_port.is_none()
&& destination_cidr.is_none()
&& destination_port.is_none()
{
return Err(RuleError::MissingConstraint);
}
let (source_cidr, destination_cidr) = match (source_cidr, destination_cidr) {
(Some(src), Some(dst)) => {
let (src, dst) = (src?, dst?);
match (&src, &dst) {
(IpNet::V4(_), IpNet::V6(_)) | (IpNet::V6(_), IpNet::V4(_)) => {
return Err(RuleError::IncompatibleAddresses);
}
_ => (Some(src), Some(dst)),
}
}
(Some(src), None) => (Some(src?), None),
(None, Some(src)) => (None, Some(src?)),
_ => (None, None),
};
Ok(Rule {
action,
protocol,
source_cidr,
source_port,
destination_cidr,
destination_port,
})
}
}
#[cfg(test)]
mod tests {
use crate::rules::{Rule, action::Action, error::RuleError};
#[test]
fn rule_builder_requires_action() {
assert!(matches!(
Rule::source_cidr("127.0.0.1").build(),
Err(RuleError::MissingAction)
));
}
#[test]
fn rule_builder_requires_a_constraint() {
assert!(matches!(
Rule::action(Action::Pass).build(),
Err(RuleError::MissingConstraint)
));
}
#[test]
fn rule_builder_builds_with_one_constraint_and_action() {
assert!(
Rule::protocol(push_packet_common::Protocol::Tcp)
.action(Action::Pass)
.build()
.is_ok()
);
}
#[test]
fn rule_builder_builds_with_all_constraints_and_action() {
let rule = Rule::builder()
.protocol(push_packet_common::Protocol::Tcp)
.source_cidr("127.0.0.1")
.destination_cidr("127.0.0.1")
.source_port(3000)
.destination_port(80)
.action(Action::Route)
.build();
assert!(rule.is_ok());
}
#[test]
fn rule_builder_bad_ip_propagates() {
assert!(
Rule::builder()
.source_cidr("badip")
.action(Action::Pass)
.build()
.is_err_and(|e| matches!(e, RuleError::InvalidAddress { .. }))
);
}
}