1mod action;
4mod error;
5mod net;
6mod port;
7
8use std::{fmt::Display, ops::RangeInclusive};
9
10pub use action::Action;
11pub use error::RuleError;
12use ipnet::IpNet;
13use net::IntoIpNet;
14use port::IntoPortRange;
15pub use push_packet_common::Protocol;
16
17#[non_exhaustive]
18pub(crate) enum AddressFamily {
19 Any,
20 Ipv4,
21 Ipv6,
22}
23
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
27pub struct RuleId(pub(crate) u32);
28
29impl Display for RuleId {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 self.0.fmt(f)
32 }
33}
34
35pub struct Rule {
38 pub(crate) action: Action,
39 pub(crate) protocol: Option<Protocol>,
40 pub(crate) source_cidr: Option<IpNet>,
41 pub(crate) source_port: Option<RangeInclusive<u16>>,
42 pub(crate) destination_cidr: Option<IpNet>,
43 pub(crate) destination_port: Option<RangeInclusive<u16>>,
44}
45
46impl TryFrom<RuleBuilder> for Rule {
47 type Error = RuleError;
48 fn try_from(value: RuleBuilder) -> Result<Self, Self::Error> {
49 value.build()
50 }
51}
52
53impl Rule {
54 pub(crate) fn address_family(&self) -> AddressFamily {
55 match (self.source_cidr, self.destination_cidr) {
56 (Some(net), _) | (_, Some(net)) => match net {
57 IpNet::V4(_) => AddressFamily::Ipv4,
58 IpNet::V6(_) => AddressFamily::Ipv6,
59 },
60 _ => AddressFamily::Any,
61 }
62 }
63 #[must_use]
65 pub fn builder() -> RuleBuilder {
66 RuleBuilder::default()
67 }
68
69 #[must_use]
73 pub fn action(action: Action) -> RuleBuilder {
74 Rule::builder().action(action)
75 }
76
77 #[must_use]
79 pub fn protocol(protocol: Protocol) -> RuleBuilder {
80 Rule::builder().protocol(protocol)
81 }
82
83 pub fn source_cidr(cidr_range: impl IntoIpNet) -> RuleBuilder {
92 Rule::builder().source_cidr(cidr_range)
93 }
94
95 pub fn source_port(port: impl IntoPortRange) -> RuleBuilder {
99 Rule::builder().source_port(port)
100 }
101
102 pub fn destination_cidr(cidr_range: impl IntoIpNet) -> RuleBuilder {
111 Rule::builder().destination_cidr(cidr_range)
112 }
113
114 pub fn destination_port(port: impl IntoPortRange) -> RuleBuilder {
118 Rule::builder().destination_port(port)
119 }
120}
121
122#[derive(Default)]
125pub struct RuleBuilder {
126 action: Option<Action>,
127 protocol: Option<Protocol>,
128 source_cidr: Option<Result<IpNet, RuleError>>,
129 source_port: Option<RangeInclusive<u16>>,
130 destination_cidr: Option<Result<IpNet, RuleError>>,
131 destination_port: Option<RangeInclusive<u16>>,
132}
133
134impl RuleBuilder {
135 #[must_use]
139 pub fn action(mut self, action: Action) -> Self {
140 self.action = Some(action);
141 self
142 }
143
144 #[must_use]
146 pub fn protocol(mut self, protocol: Protocol) -> Self {
147 self.protocol = Some(protocol);
148 self
149 }
150
151 #[must_use]
160 pub fn source_cidr(mut self, cidr_range: impl IntoIpNet) -> Self {
161 let source_cidr = cidr_range.into_ip_net();
162 self.source_cidr = Some(source_cidr);
163 self
164 }
165
166 #[must_use]
170 pub fn source_port(mut self, port: impl IntoPortRange) -> Self {
171 self.source_port = Some(port.into_port_range());
172 self
173 }
174
175 #[must_use]
184 pub fn destination_cidr(mut self, cidr_range: impl IntoIpNet) -> Self {
185 let destination_cidr = cidr_range.into_ip_net();
186 self.destination_cidr = Some(destination_cidr);
187 self
188 }
189
190 #[must_use]
194 pub fn destination_port(mut self, port: impl IntoPortRange) -> Self {
195 self.destination_port = Some(port.into_port_range());
196 self
197 }
198
199 pub fn build(self) -> Result<Rule, RuleError> {
206 let Self {
207 action,
208 protocol,
209 source_cidr,
210 source_port,
211 destination_cidr,
212 destination_port,
213 } = self;
214
215 let action = action.ok_or(RuleError::MissingAction)?;
216
217 if protocol.is_none()
218 && source_cidr.is_none()
219 && source_port.is_none()
220 && destination_cidr.is_none()
221 && destination_port.is_none()
222 {
223 return Err(RuleError::MissingConstraint);
224 }
225
226 let (source_cidr, destination_cidr) = match (source_cidr, destination_cidr) {
227 (Some(src), Some(dst)) => {
228 let (src, dst) = (src?, dst?);
229 match (&src, &dst) {
230 (IpNet::V4(_), IpNet::V6(_)) | (IpNet::V6(_), IpNet::V4(_)) => {
231 return Err(RuleError::IncompatibleAddresses);
232 }
233 _ => (Some(src), Some(dst)),
234 }
235 }
236 (Some(src), None) => (Some(src?), None),
237 (None, Some(src)) => (None, Some(src?)),
238 _ => (None, None),
239 };
240
241 Ok(Rule {
242 action,
243 protocol,
244 source_cidr,
245 source_port,
246 destination_cidr,
247 destination_port,
248 })
249 }
250}
251
252#[cfg(test)]
253mod tests {
254
255 use crate::rules::{Rule, action::Action, error::RuleError};
256
257 #[test]
258 fn rule_builder_requires_action() {
259 assert!(matches!(
260 Rule::source_cidr("127.0.0.1").build(),
261 Err(RuleError::MissingAction)
262 ));
263 }
264
265 #[test]
266 fn rule_builder_requires_a_constraint() {
267 assert!(matches!(
268 Rule::action(Action::Pass).build(),
269 Err(RuleError::MissingConstraint)
270 ));
271 }
272
273 #[test]
274 fn rule_builder_builds_with_one_constraint_and_action() {
275 assert!(
276 Rule::protocol(push_packet_common::Protocol::Tcp)
277 .action(Action::Pass)
278 .build()
279 .is_ok()
280 );
281 }
282
283 #[test]
284 fn rule_builder_builds_with_all_constraints_and_action() {
285 let rule = Rule::builder()
286 .protocol(push_packet_common::Protocol::Tcp)
287 .source_cidr("127.0.0.1")
288 .destination_cidr("127.0.0.1")
289 .source_port(3000)
290 .destination_port(80)
291 .action(Action::Route)
292 .build();
293 assert!(rule.is_ok());
294 }
295
296 #[test]
297 fn rule_builder_bad_ip_propagates() {
298 assert!(
299 Rule::builder()
300 .source_cidr("badip")
301 .action(Action::Pass)
302 .build()
303 .is_err_and(|e| matches!(e, RuleError::InvalidAddress { .. }))
304 );
305 }
306}