use std::net::SocketAddr;
use ipnetwork::IpNetwork;
use serde::{Deserialize, Serialize};
use super::destination::{matches_cidr, matches_group};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NetworkPolicy {
#[serde(default)]
pub default_action: Action,
#[serde(default)]
pub rules: Vec<Rule>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum Action {
#[default]
Allow,
Deny,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Rule {
pub direction: Direction,
pub destination: Destination,
#[serde(default)]
pub protocol: Option<Protocol>,
#[serde(default)]
pub ports: Option<PortRange>,
pub action: Action,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Direction {
Outbound,
Inbound,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Destination {
Any,
Cidr(IpNetwork),
Domain(String),
DomainSuffix(String),
Group(DestinationGroup),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DestinationGroup {
Loopback,
Private,
LinkLocal,
Metadata,
Multicast,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum Protocol {
Tcp,
Udp,
Icmpv4,
Icmpv6,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct PortRange {
pub start: u16,
pub end: u16,
}
impl NetworkPolicy {
pub fn none() -> Self {
Self {
default_action: Action::Deny,
rules: vec![],
}
}
pub fn allow_all() -> Self {
Self {
default_action: Action::Allow,
rules: vec![],
}
}
pub fn public_only() -> Self {
Self {
default_action: Action::Allow,
rules: vec![
Rule::deny_outbound(Destination::Group(DestinationGroup::Loopback)),
Rule::deny_outbound(Destination::Group(DestinationGroup::Private)),
Rule::deny_outbound(Destination::Group(DestinationGroup::LinkLocal)),
Rule::deny_outbound(Destination::Group(DestinationGroup::Metadata)),
],
}
}
pub fn non_local() -> Self {
Self {
default_action: Action::Allow,
rules: vec![
Rule::deny_outbound(Destination::Group(DestinationGroup::Loopback)),
Rule::deny_outbound(Destination::Group(DestinationGroup::LinkLocal)),
Rule::deny_outbound(Destination::Group(DestinationGroup::Metadata)),
],
}
}
pub fn evaluate_egress(&self, dst: SocketAddr, protocol: Protocol) -> Action {
for rule in &self.rules {
if rule.direction != Direction::Outbound {
continue;
}
if let Some(ref rule_proto) = rule.protocol
&& *rule_proto != protocol
{
continue;
}
if let Some(ref ports) = rule.ports
&& !ports.contains(dst.port())
{
continue;
}
if !matches_destination(&rule.destination, dst.ip()) {
continue;
}
return rule.action;
}
self.default_action
}
pub fn evaluate_egress_ip(&self, dst: std::net::IpAddr, protocol: Protocol) -> Action {
for rule in &self.rules {
if rule.direction != Direction::Outbound {
continue;
}
if let Some(ref rule_proto) = rule.protocol
&& *rule_proto != protocol
{
continue;
}
if rule.ports.is_some() {
continue;
}
if !matches_destination(&rule.destination, dst) {
continue;
}
return rule.action;
}
self.default_action
}
}
impl Action {
pub fn is_allow(self) -> bool {
matches!(self, Action::Allow)
}
pub fn is_deny(self) -> bool {
matches!(self, Action::Deny)
}
}
impl Default for NetworkPolicy {
fn default() -> Self {
Self::public_only()
}
}
impl Rule {
pub fn allow_outbound(destination: Destination) -> Self {
Self {
direction: Direction::Outbound,
destination,
protocol: None,
ports: None,
action: Action::Allow,
}
}
pub fn deny_outbound(destination: Destination) -> Self {
Self {
direction: Direction::Outbound,
destination,
protocol: None,
ports: None,
action: Action::Deny,
}
}
}
impl PortRange {
pub fn single(port: u16) -> Self {
Self {
start: port,
end: port,
}
}
pub fn range(start: u16, end: u16) -> Self {
Self { start, end }
}
pub fn contains(&self, port: u16) -> bool {
port >= self.start && port <= self.end
}
}
fn matches_destination(dest: &Destination, addr: std::net::IpAddr) -> bool {
match dest {
Destination::Any => true,
Destination::Cidr(network) => matches_cidr(network, addr),
Destination::Group(group) => matches_group(*group, addr),
Destination::Domain(_) | Destination::DomainSuffix(_) => false,
}
}