1#![cfg_attr(not(target_os = "linux"), allow(dead_code))]
8
9use serde_json::Value;
10
11use crate::control::listening::Proto;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum FilterState {
15 NoFirewall,
16 Accept,
17 Drop,
18 Unknown,
19}
20
21impl FilterState {
22 pub fn as_str(self) -> &'static str {
23 match self {
24 FilterState::NoFirewall => "no_firewall",
25 FilterState::Accept => "accept",
26 FilterState::Drop => "drop",
27 FilterState::Unknown => "unknown",
28 }
29 }
30}
31
32pub struct FilterClassifier {
33 rules: Option<Vec<Rule>>,
34}
35
36#[derive(Debug, Clone)]
37struct Rule {
38 matches: Vec<MatchExpr>,
39 verdict: Verdict,
40}
41
42#[derive(Debug, Clone)]
43enum MatchExpr {
44 Iifname,
45 L4Proto(Proto),
46 Dport(Proto, PortMatch),
47 Unrecognized,
48}
49
50#[derive(Debug, Clone)]
51enum PortMatch {
52 Single(u16),
53 Set(Vec<u16>),
54 Range(u16, u16),
55}
56
57impl PortMatch {
58 fn matches(&self, port: u16) -> bool {
59 match self {
60 PortMatch::Single(p) => *p == port,
61 PortMatch::Set(ports) => ports.contains(&port),
62 PortMatch::Range(lo, hi) => *lo <= port && port <= *hi,
63 }
64 }
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68enum Verdict {
69 Accept,
70 Drop,
71 Other,
72}
73
74impl FilterClassifier {
75 pub fn no_firewall() -> Self {
76 Self { rules: None }
77 }
78
79 #[cfg(target_os = "linux")]
80 pub fn query() -> Self {
81 let Some(json) = run_nft_list() else {
82 return Self::no_firewall();
83 };
84 Self {
85 rules: Some(parse_inbound_rules(&json)),
86 }
87 }
88
89 #[cfg(not(target_os = "linux"))]
90 pub fn query() -> Self {
91 Self::no_firewall()
92 }
93
94 pub fn is_active(&self) -> bool {
95 self.rules.is_some()
96 }
97
98 pub fn classify(&self, proto: Proto, port: u16) -> FilterState {
99 let Some(rules) = &self.rules else {
100 return FilterState::NoFirewall;
101 };
102
103 let mut saw_unknown_for_port = false;
104
105 for rule in rules {
106 let mut references_port = false;
107 let mut canonical_for_port = true;
108 let mut has_proto_match = None;
109
110 for matcher in &rule.matches {
111 match matcher {
112 MatchExpr::Iifname => {}
113 MatchExpr::L4Proto(p) => {
114 has_proto_match = Some(*p);
115 if *p != proto {
116 canonical_for_port = false;
117 }
118 }
119 MatchExpr::Dport(p, port_match) => {
120 if *p == proto && port_match.matches(port) {
121 references_port = true;
122 } else if !port_match.matches(port) {
123 canonical_for_port = false;
124 }
125 }
126 MatchExpr::Unrecognized => {
127 if rule_might_reference_port(rule, proto, port) {
128 saw_unknown_for_port = true;
129 }
130 canonical_for_port = false;
131 }
132 }
133 }
134
135 if !references_port {
136 continue;
137 }
138 if !canonical_for_port {
139 saw_unknown_for_port = true;
140 continue;
141 }
142 if let Some(p) = has_proto_match
143 && p != proto
144 {
145 continue;
146 }
147
148 match rule.verdict {
149 Verdict::Accept => return FilterState::Accept,
150 Verdict::Drop => return FilterState::Drop,
151 Verdict::Other => saw_unknown_for_port = true,
152 }
153 }
154
155 if saw_unknown_for_port {
156 FilterState::Unknown
157 } else {
158 FilterState::Drop
159 }
160 }
161}
162
163fn rule_might_reference_port(rule: &Rule, proto: Proto, port: u16) -> bool {
164 rule.matches.iter().any(|matcher| match matcher {
165 MatchExpr::Dport(p, port_match) => *p == proto && port_match.matches(port),
166 _ => false,
167 })
168}
169
170#[cfg(target_os = "linux")]
171fn run_nft_list() -> Option<Value> {
172 use std::process::Command;
173
174 let output = Command::new("nft")
175 .args(["-j", "list", "table", "inet", "fips"])
176 .output()
177 .ok()?;
178
179 if !output.status.success() {
180 return None;
181 }
182
183 serde_json::from_slice::<Value>(&output.stdout).ok()
184}
185
186fn parse_inbound_rules(json: &Value) -> Vec<Rule> {
187 let Some(entries) = json.get("nftables").and_then(|v| v.as_array()) else {
188 return Vec::new();
189 };
190
191 entries
192 .iter()
193 .filter_map(|entry| entry.get("rule"))
194 .filter(|rule| {
195 rule.get("table").and_then(|v| v.as_str()) == Some("fips")
196 && rule.get("chain").and_then(|v| v.as_str()) == Some("inbound")
197 })
198 .map(parse_rule)
199 .collect()
200}
201
202fn parse_rule(rule: &Value) -> Rule {
203 let exprs = rule
204 .get("expr")
205 .and_then(|v| v.as_array())
206 .cloned()
207 .unwrap_or_default();
208
209 let mut matches = Vec::new();
210 let mut verdict = Verdict::Other;
211
212 for expr in &exprs {
213 if let Some(matcher) = expr.get("match") {
214 matches.push(parse_match(matcher));
215 } else if expr.get("accept").is_some() {
216 verdict = Verdict::Accept;
217 } else if expr.get("drop").is_some() {
218 verdict = Verdict::Drop;
219 } else if expr.get("return").is_some()
220 || expr.get("jump").is_some()
221 || expr.get("goto").is_some()
222 || expr.get("continue").is_some()
223 || expr.get("reject").is_some()
224 || expr.get("queue").is_some()
225 {
226 verdict = Verdict::Other;
227 }
228 }
229
230 Rule { matches, verdict }
231}
232
233fn parse_match(matcher: &Value) -> MatchExpr {
234 let op = matcher
235 .get("op")
236 .and_then(|value| value.as_str())
237 .unwrap_or("==");
238 let left = matcher.get("left").cloned().unwrap_or(Value::Null);
239 let right = matcher.get("right").cloned().unwrap_or(Value::Null);
240
241 if let Some(meta) = left.get("meta")
242 && meta.get("key").and_then(|v| v.as_str()) == Some("iifname")
243 && right.as_str().is_some()
244 {
245 let _ = op;
246 return MatchExpr::Iifname;
247 }
248
249 if let Some(meta) = left.get("meta")
250 && meta.get("key").and_then(|v| v.as_str()) == Some("l4proto")
251 && let Some(proto_str) = right.as_str()
252 && let Some(proto) = parse_proto(proto_str)
253 && op == "=="
254 {
255 return MatchExpr::L4Proto(proto);
256 }
257
258 if let Some(payload) = left.get("payload")
259 && payload.get("field").and_then(|v| v.as_str()) == Some("dport")
260 && let Some(proto_str) = payload.get("protocol").and_then(|v| v.as_str())
261 && let Some(proto) = parse_proto(proto_str)
262 && op == "=="
263 {
264 if let Some(port) = right.as_u64() {
265 return MatchExpr::Dport(proto, PortMatch::Single(port as u16));
266 }
267 if let Some(set) = right.get("set").and_then(|v| v.as_array()) {
268 let ports: Vec<u16> = set
269 .iter()
270 .filter_map(|v| v.as_u64().map(|port| port as u16))
271 .collect();
272 if ports.len() == set.len() {
273 return MatchExpr::Dport(proto, PortMatch::Set(ports));
274 }
275 }
276 if let Some(range) = right.get("range").and_then(|v| v.as_array())
277 && range.len() == 2
278 && let (Some(lo), Some(hi)) = (range[0].as_u64(), range[1].as_u64())
279 {
280 return MatchExpr::Dport(proto, PortMatch::Range(lo as u16, hi as u16));
281 }
282 }
283
284 MatchExpr::Unrecognized
285}
286
287fn parse_proto(value: &str) -> Option<Proto> {
288 match value {
289 "tcp" => Some(Proto::Tcp),
290 "udp" => Some(Proto::Udp),
291 _ => None,
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use serde_json::json;
299
300 fn make_classifier(rules_json: Value) -> FilterClassifier {
301 let nft_json = json!({
302 "nftables": rules_json
303 .as_array()
304 .unwrap()
305 .iter()
306 .map(|rule| json!({"rule": {
307 "family": "inet",
308 "table": "fips",
309 "chain": "inbound",
310 "expr": rule,
311 }}))
312 .collect::<Vec<_>>(),
313 });
314 FilterClassifier {
315 rules: Some(parse_inbound_rules(&nft_json)),
316 }
317 }
318
319 #[test]
320 fn no_firewall_means_no_firewall() {
321 let classifier = FilterClassifier::no_firewall();
322 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::NoFirewall);
323 assert_eq!(
324 classifier.classify(Proto::Udp, 5353),
325 FilterState::NoFirewall
326 );
327 }
328
329 #[test]
330 fn empty_chain_drops_everything() {
331 let classifier = make_classifier(json!([]));
332 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Drop);
333 assert_eq!(classifier.classify(Proto::Udp, 5353), FilterState::Drop);
334 }
335
336 #[test]
337 fn canonical_tcp_dport_accept() {
338 let classifier = make_classifier(json!([
339 [
340 {"match": {
341 "op": "==",
342 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
343 "right": 22
344 }},
345 {"accept": null}
346 ]
347 ]));
348 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Accept);
349 assert_eq!(classifier.classify(Proto::Tcp, 80), FilterState::Drop);
350 assert_eq!(classifier.classify(Proto::Udp, 22), FilterState::Drop);
351 }
352
353 #[test]
354 fn canonical_udp_dport_accept() {
355 let classifier = make_classifier(json!([
356 [
357 {"match": {
358 "op": "==",
359 "left": {"payload": {"protocol": "udp", "field": "dport"}},
360 "right": 5353
361 }},
362 {"accept": null}
363 ]
364 ]));
365 assert_eq!(classifier.classify(Proto::Udp, 5353), FilterState::Accept);
366 assert_eq!(classifier.classify(Proto::Tcp, 5353), FilterState::Drop);
367 }
368
369 #[test]
370 fn dport_set_accept() {
371 let classifier = make_classifier(json!([
372 [
373 {"match": {
374 "op": "==",
375 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
376 "right": {"set": [22, 80, 443]}
377 }},
378 {"accept": null}
379 ]
380 ]));
381 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Accept);
382 assert_eq!(classifier.classify(Proto::Tcp, 80), FilterState::Accept);
383 assert_eq!(classifier.classify(Proto::Tcp, 443), FilterState::Accept);
384 assert_eq!(classifier.classify(Proto::Tcp, 25), FilterState::Drop);
385 }
386
387 #[test]
388 fn dport_range_accept() {
389 let classifier = make_classifier(json!([
390 [
391 {"match": {
392 "op": "==",
393 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
394 "right": {"range": [22, 25]}
395 }},
396 {"accept": null}
397 ]
398 ]));
399 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Accept);
400 assert_eq!(classifier.classify(Proto::Tcp, 25), FilterState::Accept);
401 assert_eq!(classifier.classify(Proto::Tcp, 26), FilterState::Drop);
402 }
403
404 #[test]
405 fn saddr_restricted_is_unknown() {
406 let classifier = make_classifier(json!([
407 [
408 {"match": {
409 "op": "==",
410 "left": {"payload": {"protocol": "ip6", "field": "saddr"}},
411 "right": {"prefix": {"addr": "fd97::", "len": 64}}
412 }},
413 {"match": {
414 "op": "==",
415 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
416 "right": 22
417 }},
418 {"accept": null}
419 ]
420 ]));
421 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Unknown);
422 assert_eq!(classifier.classify(Proto::Tcp, 80), FilterState::Drop);
423 }
424
425 #[test]
426 fn jump_verdict_is_unknown() {
427 let classifier = make_classifier(json!([
428 [
429 {"match": {
430 "op": "==",
431 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
432 "right": 22
433 }},
434 {"jump": {"target": "some_chain"}}
435 ]
436 ]));
437 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Unknown);
438 }
439
440 #[test]
441 fn explicit_drop_classifies_as_drop() {
442 let classifier = make_classifier(json!([
443 [
444 {"match": {
445 "op": "==",
446 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
447 "right": 22
448 }},
449 {"drop": null}
450 ]
451 ]));
452 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Drop);
453 }
454
455 #[test]
456 fn unrelated_rules_do_not_affect_port() {
457 let classifier = make_classifier(json!([
458 [
459 {"match": {
460 "op": "!=",
461 "left": {"meta": {"key": "iifname"}},
462 "right": "fips0"
463 }},
464 {"return": null}
465 ],
466 [
467 {"match": {
468 "op": "in",
469 "left": {"ct": {"key": "state"}},
470 "right": ["established", "related"]
471 }},
472 {"accept": null}
473 ],
474 [
475 {"match": {
476 "op": "==",
477 "left": {"payload": {"protocol": "icmpv6", "field": "type"}},
478 "right": "echo-request"
479 }},
480 {"accept": null}
481 ],
482 ]));
483 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Drop);
484 assert_eq!(classifier.classify(Proto::Udp, 5353), FilterState::Drop);
485 }
486
487 #[test]
488 fn l4proto_then_dport_accept() {
489 let classifier = make_classifier(json!([
490 [
491 {"match": {
492 "op": "==",
493 "left": {"meta": {"key": "l4proto"}},
494 "right": "tcp"
495 }},
496 {"match": {
497 "op": "==",
498 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
499 "right": 22
500 }},
501 {"accept": null}
502 ]
503 ]));
504 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Accept);
505 assert_eq!(classifier.classify(Proto::Udp, 22), FilterState::Drop);
506 }
507
508 #[test]
509 fn first_accept_match_wins() {
510 let classifier = make_classifier(json!([
511 [
512 {"match": {
513 "op": "==",
514 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
515 "right": 22
516 }},
517 {"accept": null}
518 ],
519 [
520 {"match": {
521 "op": "==",
522 "left": {"payload": {"protocol": "ip6", "field": "saddr"}},
523 "right": "fd00::1"
524 }},
525 {"match": {
526 "op": "==",
527 "left": {"payload": {"protocol": "tcp", "field": "dport"}},
528 "right": 22
529 }},
530 {"drop": null}
531 ]
532 ]));
533 assert_eq!(classifier.classify(Proto::Tcp, 22), FilterState::Accept);
534 }
535}