use alloc::string::String;
use alloc::vec::Vec;
use core::net::IpAddr;
use zerodds_security_crypto::Suite;
use zerodds_security_permissions::ProtectionKind;
use crate::caps::PeerCapabilities;
use crate::shared::PeerKey;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub enum ProtectionLevel {
#[default]
None,
Sign,
Encrypt,
}
impl ProtectionLevel {
#[must_use]
pub fn from_protection_kind(kind: ProtectionKind) -> Self {
match kind {
ProtectionKind::None => Self::None,
ProtectionKind::Sign | ProtectionKind::SignWithOriginAuthentication => Self::Sign,
ProtectionKind::Encrypt | ProtectionKind::EncryptWithOriginAuthentication => {
Self::Encrypt
}
}
}
#[must_use]
pub fn to_protection_kind(self) -> ProtectionKind {
match self {
Self::None => ProtectionKind::None,
Self::Sign => ProtectionKind::Sign,
Self::Encrypt => ProtectionKind::Encrypt,
}
}
#[must_use]
pub fn stronger(self, other: Self) -> Self {
if self >= other { self } else { other }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SuiteHint {
Aes128Gcm,
Aes256Gcm,
HmacSha256,
}
impl SuiteHint {
#[must_use]
pub fn to_suite(self) -> Suite {
match self {
Self::Aes128Gcm => Suite::Aes128Gcm,
Self::Aes256Gcm => Suite::Aes256Gcm,
Self::HmacSha256 => Suite::HmacSha256,
}
}
#[must_use]
pub fn from_suite(suite: Suite) -> Self {
match suite {
Suite::Aes128Gcm => Self::Aes128Gcm,
Suite::Aes256Gcm => Self::Aes256Gcm,
Suite::HmacSha256 => Self::HmacSha256,
}
}
#[must_use]
pub fn protection_level(self) -> ProtectionLevel {
match self {
Self::Aes128Gcm | Self::Aes256Gcm => ProtectionLevel::Encrypt,
Self::HmacSha256 => ProtectionLevel::Sign,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct IpRange {
pub base: IpAddr,
pub prefix_len: u8,
}
impl IpRange {
#[must_use]
pub fn contains(&self, addr: &IpAddr) -> bool {
match (self.base, addr) {
(IpAddr::V4(base), IpAddr::V4(a)) => {
if self.prefix_len > 32 {
return false;
}
let shift = 32 - u32::from(self.prefix_len);
let base_u = u32::from(base);
let a_u = u32::from(*a);
if self.prefix_len == 0 {
true
} else {
(base_u >> shift) == (a_u >> shift)
}
}
(IpAddr::V6(base), IpAddr::V6(a)) => {
if self.prefix_len > 128 {
return false;
}
let base_u = u128::from(base);
let a_u = u128::from(*a);
if self.prefix_len == 0 {
true
} else {
let shift = 128 - u32::from(self.prefix_len);
(base_u >> shift) == (a_u >> shift)
}
}
_ => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum NetInterface {
Loopback,
LocalHost,
LocalSubnet(IpRange),
Wan,
Named(String),
}
#[derive(Debug, Clone, Default)]
pub struct InterfaceConfig {
pub local_subnets: Vec<IpRange>,
pub named: Vec<(IpRange, String)>,
}
#[must_use]
pub fn classify_interface(addr: &IpAddr, config: &InterfaceConfig) -> NetInterface {
if addr.is_loopback() {
return NetInterface::Loopback;
}
for (range, name) in &config.named {
if range.contains(addr) {
return NetInterface::Named(name.clone());
}
}
for range in &config.local_subnets {
if range.contains(addr) {
return NetInterface::LocalSubnet(range.clone());
}
}
NetInterface::Wan
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PolicyDecision {
pub protection: ProtectionLevel,
pub suite: Option<SuiteHint>,
pub drop: bool,
}
impl PolicyDecision {
pub const PLAIN: Self = Self {
protection: ProtectionLevel::None,
suite: None,
drop: false,
};
pub const DROP: Self = Self {
protection: ProtectionLevel::None,
suite: None,
drop: true,
};
#[must_use]
pub fn with(protection: ProtectionLevel, suite: Option<SuiteHint>) -> Self {
let suite = if matches!(protection, ProtectionLevel::None) {
None
} else {
suite
};
Self {
protection,
suite,
drop: false,
}
}
}
#[derive(Debug)]
pub struct OutboundCtx<'a> {
pub domain_id: u32,
pub topic: &'a str,
pub partition: &'a [String],
pub interface: &'a NetInterface,
pub remote_peer: &'a PeerKey,
pub remote_caps: &'a PeerCapabilities,
}
#[derive(Debug)]
pub struct InboundCtx<'a> {
pub domain_id: u32,
pub source_peer: &'a PeerKey,
pub source_iface: &'a NetInterface,
pub source_caps: Option<&'a PeerCapabilities>,
pub is_sec_prefixed: bool,
}
pub trait PolicyEngine: Send + Sync {
fn outbound_decision(&self, ctx: OutboundCtx<'_>) -> PolicyDecision;
fn inbound_decision(&self, ctx: InboundCtx<'_>) -> PolicyDecision;
fn accept_peer(&self, caps: &PeerCapabilities) -> bool;
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use alloc::string::ToString;
use core::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn protection_level_orders_none_sign_encrypt() {
assert!(ProtectionLevel::None < ProtectionLevel::Sign);
assert!(ProtectionLevel::Sign < ProtectionLevel::Encrypt);
}
#[test]
fn protection_level_stronger_picks_max() {
assert_eq!(
ProtectionLevel::Sign.stronger(ProtectionLevel::Encrypt),
ProtectionLevel::Encrypt
);
assert_eq!(
ProtectionLevel::Encrypt.stronger(ProtectionLevel::None),
ProtectionLevel::Encrypt
);
assert_eq!(
ProtectionLevel::None.stronger(ProtectionLevel::None),
ProtectionLevel::None
);
}
#[test]
fn protection_level_from_kind_collapses_origin_auth() {
assert_eq!(
ProtectionLevel::from_protection_kind(ProtectionKind::None),
ProtectionLevel::None
);
assert_eq!(
ProtectionLevel::from_protection_kind(ProtectionKind::Sign),
ProtectionLevel::Sign
);
assert_eq!(
ProtectionLevel::from_protection_kind(ProtectionKind::SignWithOriginAuthentication),
ProtectionLevel::Sign
);
assert_eq!(
ProtectionLevel::from_protection_kind(ProtectionKind::Encrypt),
ProtectionLevel::Encrypt
);
assert_eq!(
ProtectionLevel::from_protection_kind(ProtectionKind::EncryptWithOriginAuthentication),
ProtectionLevel::Encrypt
);
}
#[test]
fn protection_level_to_kind_roundtrip_without_origin_auth() {
for lvl in [
ProtectionLevel::None,
ProtectionLevel::Sign,
ProtectionLevel::Encrypt,
] {
let kind = lvl.to_protection_kind();
assert_eq!(ProtectionLevel::from_protection_kind(kind), lvl);
}
}
#[test]
fn protection_level_default_is_none() {
assert_eq!(ProtectionLevel::default(), ProtectionLevel::None);
}
#[test]
fn suite_hint_roundtrip_suite() {
for s in [Suite::Aes128Gcm, Suite::Aes256Gcm, Suite::HmacSha256] {
assert_eq!(SuiteHint::from_suite(s).to_suite(), s);
}
}
#[test]
fn suite_hint_protection_level_matches_semantics() {
assert_eq!(
SuiteHint::Aes128Gcm.protection_level(),
ProtectionLevel::Encrypt
);
assert_eq!(
SuiteHint::Aes256Gcm.protection_level(),
ProtectionLevel::Encrypt
);
assert_eq!(
SuiteHint::HmacSha256.protection_level(),
ProtectionLevel::Sign
);
}
fn v4(a: u8, b: u8, c: u8, d: u8) -> IpAddr {
IpAddr::V4(Ipv4Addr::new(a, b, c, d))
}
#[test]
fn ip_range_v4_match_inside_prefix() {
let r = IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 24,
};
assert!(r.contains(&v4(10, 0, 0, 1)));
assert!(r.contains(&v4(10, 0, 0, 255)));
assert!(!r.contains(&v4(10, 0, 1, 0)));
assert!(!r.contains(&v4(11, 0, 0, 0)));
}
#[test]
fn ip_range_v4_prefix_zero_matches_all_v4() {
let r = IpRange {
base: v4(0, 0, 0, 0),
prefix_len: 0,
};
assert!(r.contains(&v4(1, 2, 3, 4)));
assert!(r.contains(&v4(255, 255, 255, 255)));
}
#[test]
fn ip_range_v4_prefix_32_is_exact_host() {
let r = IpRange {
base: v4(192, 168, 1, 5),
prefix_len: 32,
};
assert!(r.contains(&v4(192, 168, 1, 5)));
assert!(!r.contains(&v4(192, 168, 1, 6)));
}
#[test]
fn ip_range_v4_out_of_range_prefix_never_matches() {
let r = IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 40, };
assert!(!r.contains(&v4(10, 0, 0, 1)));
}
#[test]
fn ip_range_v6_basic_match() {
let base = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 0));
let r = IpRange {
base,
prefix_len: 8,
};
assert!(r.contains(&IpAddr::V6(Ipv6Addr::new(0xfd01, 2, 3, 4, 5, 6, 7, 8))));
assert!(!r.contains(&IpAddr::V6(Ipv6Addr::new(0xfe00, 0, 0, 0, 0, 0, 0, 0))));
}
#[test]
fn ip_range_v6_prefix_zero_matches_all_v6() {
let r = IpRange {
base: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
prefix_len: 0,
};
assert!(r.contains(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
}
#[test]
fn ip_range_v6_out_of_range_prefix_never_matches() {
let r = IpRange {
base: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
prefix_len: 200, };
assert!(!r.contains(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
}
#[test]
fn ip_range_mixed_family_never_matches() {
let r = IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 8,
};
assert!(!r.contains(&IpAddr::V6(Ipv6Addr::LOCALHOST)));
let r6 = IpRange {
base: IpAddr::V6(Ipv6Addr::UNSPECIFIED),
prefix_len: 0,
};
assert!(!r6.contains(&v4(10, 0, 0, 1)));
}
#[test]
fn classify_loopback_v4() {
let cfg = InterfaceConfig::default();
assert_eq!(
classify_interface(&v4(127, 0, 0, 1), &cfg),
NetInterface::Loopback
);
assert_eq!(
classify_interface(&v4(127, 1, 2, 3), &cfg),
NetInterface::Loopback
);
}
#[test]
fn classify_loopback_v6() {
let cfg = InterfaceConfig::default();
assert_eq!(
classify_interface(&IpAddr::V6(Ipv6Addr::LOCALHOST), &cfg),
NetInterface::Loopback
);
}
#[test]
fn classify_local_subnet_after_loopback() {
let cfg = InterfaceConfig {
local_subnets: alloc::vec![IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 24,
}],
..InterfaceConfig::default()
};
match classify_interface(&v4(10, 0, 0, 5), &cfg) {
NetInterface::LocalSubnet(r) => {
assert_eq!(r.prefix_len, 24);
}
other => panic!("expected LocalSubnet, got {other:?}"),
}
}
#[test]
fn classify_wan_fallback() {
let cfg = InterfaceConfig::default();
assert_eq!(classify_interface(&v4(8, 8, 8, 8), &cfg), NetInterface::Wan);
}
#[test]
fn classify_named_wins_over_local_subnet() {
let vpn_range = IpRange {
base: v4(10, 8, 0, 0),
prefix_len: 16,
};
let mgmt_range = IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 8,
};
let cfg = InterfaceConfig {
local_subnets: alloc::vec![mgmt_range],
named: alloc::vec![(vpn_range, "vpn".to_string())],
};
assert_eq!(
classify_interface(&v4(10, 8, 1, 2), &cfg),
NetInterface::Named("vpn".to_string())
);
}
#[test]
fn classify_named_first_match_wins() {
let cfg = InterfaceConfig {
named: alloc::vec![
(
IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 8,
},
"first".to_string()
),
(
IpRange {
base: v4(10, 0, 0, 0),
prefix_len: 24,
},
"second".to_string()
),
],
..InterfaceConfig::default()
};
assert_eq!(
classify_interface(&v4(10, 0, 0, 5), &cfg),
NetInterface::Named("first".to_string())
);
}
#[test]
fn policy_decision_plain_constant() {
assert_eq!(
PolicyDecision::PLAIN,
PolicyDecision {
protection: ProtectionLevel::None,
suite: None,
drop: false,
}
);
}
#[test]
fn policy_decision_drop_constant() {
assert_eq!(
PolicyDecision::DROP,
PolicyDecision {
protection: ProtectionLevel::None,
suite: None,
drop: true,
}
);
assert_ne!(PolicyDecision::DROP, PolicyDecision::PLAIN);
}
#[test]
fn policy_decision_with_none_forces_suite_none() {
let d = PolicyDecision::with(ProtectionLevel::None, Some(SuiteHint::Aes128Gcm));
assert!(d.suite.is_none());
}
#[test]
fn policy_decision_with_encrypt_keeps_suite() {
let d = PolicyDecision::with(ProtectionLevel::Encrypt, Some(SuiteHint::Aes256Gcm));
assert_eq!(d.suite, Some(SuiteHint::Aes256Gcm));
assert_eq!(d.protection, ProtectionLevel::Encrypt);
assert!(!d.drop);
}
}