use crate::{Result, docker::config::ConfigVerdict};
use nftables::{
expr::{Expression, Meta, MetaKey, NamedExpression, Payload, PayloadField},
schema::Rule,
stmt::{Counter, JumpTarget, Log, LogLevel, Match, Operator, Queue, Statement},
types::NfFamily,
};
use std::borrow::Cow;
use std::net::IpAddr;
pub struct RuleContext<'a> {
pub container_id: &'a str,
pub container_name: &'a str,
pub container_ips: &'a [IpAddr],
pub container_ports: &'a [(u16, String)], pub chain_name: &'a str,
pub table_name: &'a str,
pub family: NfFamily,
}
pub trait ToNftablesRule {
fn to_nftables_statements(&self) -> Result<Vec<Statement<'static>>>;
fn to_nftables_rule(
&self,
ctx: &RuleContext,
comment: Option<String>,
) -> Result<Rule<'static>> {
let statements = self.to_nftables_statements()?;
Ok(Rule {
family: ctx.family,
table: Cow::Owned(ctx.table_name.to_string()),
chain: Cow::Owned(ctx.chain_name.to_string()),
expr: Cow::Owned(statements),
handle: None,
index: None,
comment: comment.map(Cow::Owned),
})
}
fn match_protocol(protocol: &str) -> Statement<'static> {
Statement::Match(Match {
left: Expression::Named(NamedExpression::Meta(Meta {
key: MetaKey::L4proto,
})),
right: Expression::Number(match protocol.to_lowercase().as_str() {
"tcp" => 6,
"udp" => 17,
"icmp" => 1,
"icmpv6" => 58,
_ => 6, }),
op: Operator::EQ,
})
}
fn match_dst_port(protocol: &str, port: u16) -> Statement<'static> {
Statement::Match(Match {
left: Expression::Named(NamedExpression::Payload(Payload::PayloadField(
PayloadField {
protocol: Cow::Owned(protocol.to_string()),
field: Cow::Borrowed("dport"),
},
))),
right: Expression::Number(port as u32),
op: Operator::EQ,
})
}
fn match_dst_port_range(protocol: &str, start: u16, end: u16) -> Statement<'static> {
Statement::Match(Match {
left: Expression::Named(NamedExpression::Payload(Payload::PayloadField(
PayloadField {
protocol: Cow::Owned(protocol.to_string()),
field: Cow::Borrowed("dport"),
},
))),
right: Expression::Range(Box::new(nftables::expr::Range {
range: [
Expression::Number(start as u32),
Expression::Number(end as u32),
],
})),
op: Operator::EQ,
})
}
fn match_src_port(protocol: &str, port: u16) -> Statement<'static> {
Statement::Match(Match {
left: Expression::Named(NamedExpression::Payload(Payload::PayloadField(
PayloadField {
protocol: Cow::Owned(protocol.to_string()),
field: Cow::Borrowed("sport"),
},
))),
right: Expression::Number(port as u32),
op: Operator::EQ,
})
}
fn match_src_ip(ip: &str) -> Statement<'static> {
let protocol = if ip.contains(':') { "ip6" } else { "ip" };
Statement::Match(Match {
left: Expression::Named(NamedExpression::Payload(Payload::PayloadField(
PayloadField {
protocol: Cow::Borrowed(protocol),
field: Cow::Borrowed("saddr"),
},
))),
right: Expression::String(Cow::Owned(ip.to_string())),
op: Operator::EQ,
})
}
fn match_dst_ip(ip: &str) -> Statement<'static> {
let protocol = if ip.contains(':') { "ip6" } else { "ip" };
Statement::Match(Match {
left: Expression::Named(NamedExpression::Payload(Payload::PayloadField(
PayloadField {
protocol: Cow::Borrowed(protocol),
field: Cow::Borrowed("daddr"),
},
))),
right: Expression::String(Cow::Owned(ip.to_string())),
op: Operator::EQ,
})
}
fn log_statement(prefix: Option<&str>) -> Statement<'static> {
Statement::Log(Some(Log {
prefix: prefix.map(|p| Cow::Owned(p.to_string())),
group: None,
snaplen: None,
queue_threshold: None,
level: Some(LogLevel::Info),
flags: None,
}))
}
fn counter_statement() -> Statement<'static> {
Statement::Counter(Counter::Anonymous(None))
}
fn verdict_to_statement(verdict: &ConfigVerdict) -> Statement<'static> {
if !verdict.chain.is_empty() {
Statement::Jump(JumpTarget {
target: Cow::Owned(verdict.chain.clone()),
})
} else if verdict.queue > 0 {
Statement::Queue(Queue {
num: Expression::Number(verdict.queue as u32),
flags: None,
})
} else if verdict.drop {
Statement::Drop(None)
} else {
Statement::Accept(None)
}
}
}