Skip to main content

push_packet/rules/
mod.rs

1//! Rule definitions and builders.
2
3mod action;
4mod error;
5mod net;
6mod port;
7
8use std::{fmt::Display, ops::RangeInclusive};
9
10pub use action::Action;
11pub use error::RuleError;
12use ipnet::IpNet;
13use net::IntoIpNet;
14use port::IntoPortRange;
15pub use push_packet_common::Protocol;
16
17#[non_exhaustive]
18pub(crate) enum AddressFamily {
19    Any,
20    Ipv4,
21    Ipv6,
22}
23
24/// ID for a [`Rule`]. This can be used to track rules for dynamic removal. Removed [`RuleId`]s are
25/// reclaimed.
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
27pub struct RuleId(pub(crate) u32);
28
29impl Display for RuleId {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        self.0.fmt(f)
32    }
33}
34
35/// A rule for controlling packet routing. Rules can be build with [`RuleBuilder`] and require at
36/// least one filter constraint, and an action.
37pub struct Rule {
38    pub(crate) action: Action,
39    pub(crate) protocol: Option<Protocol>,
40    pub(crate) source_cidr: Option<IpNet>,
41    pub(crate) source_port: Option<RangeInclusive<u16>>,
42    pub(crate) destination_cidr: Option<IpNet>,
43    pub(crate) destination_port: Option<RangeInclusive<u16>>,
44}
45
46impl TryFrom<RuleBuilder> for Rule {
47    type Error = RuleError;
48    fn try_from(value: RuleBuilder) -> Result<Self, Self::Error> {
49        value.build()
50    }
51}
52
53impl Rule {
54    pub(crate) fn address_family(&self) -> AddressFamily {
55        match (self.source_cidr, self.destination_cidr) {
56            (Some(net), _) | (_, Some(net)) => match net {
57                IpNet::V4(_) => AddressFamily::Ipv4,
58                IpNet::V6(_) => AddressFamily::Ipv6,
59            },
60            _ => AddressFamily::Any,
61        }
62    }
63    /// Creates a [`RuleBuilder`]
64    #[must_use]
65    pub fn builder() -> RuleBuilder {
66        RuleBuilder::default()
67    }
68
69    /// Creates a [`RuleBuilder`] and sets the rule's [`Action`]
70    ///
71    /// The action applies to all packets matching the rule, unless overridden by successive rules.
72    #[must_use]
73    pub fn action(action: Action) -> RuleBuilder {
74        Rule::builder().action(action)
75    }
76
77    /// Creates a [`RuleBuilder`] and sets the rule's [`Protocol`]
78    #[must_use]
79    pub fn protocol(protocol: Protocol) -> RuleBuilder {
80        Rule::builder().protocol(protocol)
81    }
82
83    /// Creates a [`RuleBuilder`] and sets the source CIDR.
84    ///
85    /// Accepts any IP address or CIDR notation:
86    /// - `"127.0.0.1"`: matches a single IP
87    /// - `"10.0.0.0/24"`: matches a CIDR
88    ///
89    /// This additionally accepts [`std::net::IpAddr`], [`std::net::Ipv4Addr`],
90    /// [`std::net::Ipv6Addr`], [`ipnet::IpNet`], [`ipnet::Ipv4Net`], and [`ipnet::Ipv6Net`].
91    pub fn source_cidr(cidr_range: impl IntoIpNet) -> RuleBuilder {
92        Rule::builder().source_cidr(cidr_range)
93    }
94
95    /// Creates a [`RuleBuilder`] and sets the source port
96    ///
97    /// Accepts a [`u16`] or range
98    pub fn source_port(port: impl IntoPortRange) -> RuleBuilder {
99        Rule::builder().source_port(port)
100    }
101
102    /// Creates a [`RuleBuilder`] and sets the destination CIDR.
103    ///
104    /// Accepts any IP address or CIDR notation:
105    /// - `"127.0.0.1"`: matches a single IP
106    /// - `"10.0.0.0/24"`: matches a CIDR
107    ///
108    /// This additionally accepts [`std::net::IpAddr`], [`std::net::Ipv4Addr`],
109    /// [`std::net::Ipv6Addr`], [`ipnet::IpNet`], [`ipnet::Ipv4Net`], and [`ipnet::Ipv6Net`].
110    pub fn destination_cidr(cidr_range: impl IntoIpNet) -> RuleBuilder {
111        Rule::builder().destination_cidr(cidr_range)
112    }
113
114    /// Creates a [`RuleBuilder`] and sets the destination port
115    ///
116    /// Accepts a [`u16`] or range
117    pub fn destination_port(port: impl IntoPortRange) -> RuleBuilder {
118        Rule::builder().destination_port(port)
119    }
120}
121
122/// A builder struct for [`Rule`]s. This should generally be constructed with [`Rule::builder`], or
123/// a shortcut such as [`Rule::source_cidr`].
124#[derive(Default)]
125pub struct RuleBuilder {
126    action: Option<Action>,
127    protocol: Option<Protocol>,
128    source_cidr: Option<Result<IpNet, RuleError>>,
129    source_port: Option<RangeInclusive<u16>>,
130    destination_cidr: Option<Result<IpNet, RuleError>>,
131    destination_port: Option<RangeInclusive<u16>>,
132}
133
134impl RuleBuilder {
135    /// Sets the rule's [`Action`]
136    ///
137    /// The action applies to all packets matching the rule, unless overridden by successive rules.
138    #[must_use]
139    pub fn action(mut self, action: Action) -> Self {
140        self.action = Some(action);
141        self
142    }
143
144    /// Sets the rule's [`Protocol`]
145    #[must_use]
146    pub fn protocol(mut self, protocol: Protocol) -> Self {
147        self.protocol = Some(protocol);
148        self
149    }
150
151    /// Sets the source CIDR.
152    ///
153    /// Accepts any IP address or CIDR notation:
154    /// - `"127.0.0.1"`: matches a single IP
155    /// - `"10.0.0.0/24"`: matches a CIDR
156    ///
157    /// This additionally accepts [`std::net::IpAddr`], [`std::net::Ipv4Addr`],
158    /// [`std::net::Ipv6Addr`], [`ipnet::IpNet`], [`ipnet::Ipv4Net`], and [`ipnet::Ipv6Net`].
159    #[must_use]
160    pub fn source_cidr(mut self, cidr_range: impl IntoIpNet) -> Self {
161        let source_cidr = cidr_range.into_ip_net();
162        self.source_cidr = Some(source_cidr);
163        self
164    }
165
166    /// Sets the source port
167    ///
168    /// Accepts a [`u16`] or range
169    #[must_use]
170    pub fn source_port(mut self, port: impl IntoPortRange) -> Self {
171        self.source_port = Some(port.into_port_range());
172        self
173    }
174
175    /// Sets the destination CIDR.
176    ///
177    /// Accepts any IP address or CIDR notation:
178    /// - `"127.0.0.1"`: matches a single IP
179    /// - `"10.0.0.0/24"`: matches a CIDR
180    ///
181    /// This additionally accepts [`std::net::IpAddr`], [`std::net::Ipv4Addr`],
182    /// [`std::net::Ipv6Addr`], [`ipnet::IpNet`], [`ipnet::Ipv4Net`], and [`ipnet::Ipv6Net`].
183    #[must_use]
184    pub fn destination_cidr(mut self, cidr_range: impl IntoIpNet) -> Self {
185        let destination_cidr = cidr_range.into_ip_net();
186        self.destination_cidr = Some(destination_cidr);
187        self
188    }
189
190    /// Sets the destination port
191    ///
192    /// Accepts a [`u16`] or range
193    #[must_use]
194    pub fn destination_port(mut self, port: impl IntoPortRange) -> Self {
195        self.destination_port = Some(port.into_port_range());
196        self
197    }
198
199    /// Builds the [`Rule`].
200    ///
201    /// # Errors
202    ///
203    /// Returns an [`RuleError`] if there is a missing action, invalid cidr, or no constraints (ips,
204    /// ports, protocols).
205    pub fn build(self) -> Result<Rule, RuleError> {
206        let Self {
207            action,
208            protocol,
209            source_cidr,
210            source_port,
211            destination_cidr,
212            destination_port,
213        } = self;
214
215        let action = action.ok_or(RuleError::MissingAction)?;
216
217        if protocol.is_none()
218            && source_cidr.is_none()
219            && source_port.is_none()
220            && destination_cidr.is_none()
221            && destination_port.is_none()
222        {
223            return Err(RuleError::MissingConstraint);
224        }
225
226        let (source_cidr, destination_cidr) = match (source_cidr, destination_cidr) {
227            (Some(src), Some(dst)) => {
228                let (src, dst) = (src?, dst?);
229                match (&src, &dst) {
230                    (IpNet::V4(_), IpNet::V6(_)) | (IpNet::V6(_), IpNet::V4(_)) => {
231                        return Err(RuleError::IncompatibleAddresses);
232                    }
233                    _ => (Some(src), Some(dst)),
234                }
235            }
236            (Some(src), None) => (Some(src?), None),
237            (None, Some(src)) => (None, Some(src?)),
238            _ => (None, None),
239        };
240
241        Ok(Rule {
242            action,
243            protocol,
244            source_cidr,
245            source_port,
246            destination_cidr,
247            destination_port,
248        })
249    }
250}
251
252#[cfg(test)]
253mod tests {
254
255    use crate::rules::{Rule, action::Action, error::RuleError};
256
257    #[test]
258    fn rule_builder_requires_action() {
259        assert!(matches!(
260            Rule::source_cidr("127.0.0.1").build(),
261            Err(RuleError::MissingAction)
262        ));
263    }
264
265    #[test]
266    fn rule_builder_requires_a_constraint() {
267        assert!(matches!(
268            Rule::action(Action::Pass).build(),
269            Err(RuleError::MissingConstraint)
270        ));
271    }
272
273    #[test]
274    fn rule_builder_builds_with_one_constraint_and_action() {
275        assert!(
276            Rule::protocol(push_packet_common::Protocol::Tcp)
277                .action(Action::Pass)
278                .build()
279                .is_ok()
280        );
281    }
282
283    #[test]
284    fn rule_builder_builds_with_all_constraints_and_action() {
285        let rule = Rule::builder()
286            .protocol(push_packet_common::Protocol::Tcp)
287            .source_cidr("127.0.0.1")
288            .destination_cidr("127.0.0.1")
289            .source_port(3000)
290            .destination_port(80)
291            .action(Action::Route)
292            .build();
293        assert!(rule.is_ok());
294    }
295
296    #[test]
297    fn rule_builder_bad_ip_propagates() {
298        assert!(
299            Rule::builder()
300                .source_cidr("badip")
301                .action(Action::Pass)
302                .build()
303                .is_err_and(|e| matches!(e, RuleError::InvalidAddress { .. }))
304        );
305    }
306}