edgesentry_rs/ingest/
network_policy.rs1use std::net::IpAddr;
2
3use thiserror::Error;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum AllowedSource {
8 Ip(IpAddr),
10 Cidr { base: IpAddr, prefix_len: u8 },
12}
13
14impl AllowedSource {
15 pub fn contains(&self, addr: IpAddr) -> bool {
17 match self {
18 AllowedSource::Ip(allowed) => *allowed == addr,
19 AllowedSource::Cidr { base, prefix_len } => {
20 cidr_contains(*base, *prefix_len, addr)
21 }
22 }
23 }
24}
25
26fn cidr_contains(base: IpAddr, prefix_len: u8, addr: IpAddr) -> bool {
27 match (base, addr) {
28 (IpAddr::V4(base_v4), IpAddr::V4(addr_v4)) => {
29 if prefix_len == 0 {
30 return true;
31 }
32 if prefix_len > 32 {
33 return false;
34 }
35 let shift = 32 - prefix_len as u32;
36 let base_bits = u32::from(base_v4) >> shift;
37 let addr_bits = u32::from(addr_v4) >> shift;
38 base_bits == addr_bits
39 }
40 (IpAddr::V6(base_v6), IpAddr::V6(addr_v6)) => {
41 if prefix_len == 0 {
42 return true;
43 }
44 if prefix_len > 128 {
45 return false;
46 }
47 let shift = 128 - prefix_len as u32;
48 let base_bits = u128::from(base_v6) >> shift;
49 let addr_bits = u128::from(addr_v6) >> shift;
50 base_bits == addr_bits
51 }
52 _ => false,
54 }
55}
56
57#[derive(Debug, Error, PartialEq, Eq)]
59pub enum NetworkPolicyError {
60 #[error("source {addr} is not in the allowlist")]
62 Denied { addr: IpAddr },
63 #[error("invalid CIDR '{0}': expected <addr>/<prefix_len>")]
65 InvalidCidr(String),
66}
67
68#[derive(Debug, Default, Clone)]
91pub struct NetworkPolicy {
92 allowed: Vec<AllowedSource>,
93}
94
95impl NetworkPolicy {
96 pub fn new() -> Self {
98 Self::default()
99 }
100
101 pub fn allow_ip(&mut self, addr: IpAddr) -> &mut Self {
103 self.allowed.push(AllowedSource::Ip(addr));
104 self
105 }
106
107 pub fn allow_cidr(&mut self, cidr: &str) -> Result<&mut Self, NetworkPolicyError> {
109 let (base, prefix_len) = parse_cidr(cidr)?;
110 self.allowed.push(AllowedSource::Cidr { base, prefix_len });
111 Ok(self)
112 }
113
114 pub fn check(&self, source: IpAddr) -> Result<(), NetworkPolicyError> {
117 if self.allowed.iter().any(|e| e.contains(source)) {
118 Ok(())
119 } else {
120 Err(NetworkPolicyError::Denied { addr: source })
121 }
122 }
123
124 pub fn entries(&self) -> &[AllowedSource] {
126 &self.allowed
127 }
128}
129
130fn parse_cidr(cidr: &str) -> Result<(IpAddr, u8), NetworkPolicyError> {
131 let err = || NetworkPolicyError::InvalidCidr(cidr.to_string());
132
133 let (addr_str, prefix_str) = cidr.split_once('/').ok_or_else(err)?;
134 let prefix_len: u8 = prefix_str.parse().map_err(|_| err())?;
135 let base: IpAddr = addr_str.parse().map_err(|_| err())?;
136
137 let max_prefix = match base {
138 IpAddr::V4(_) => 32,
139 IpAddr::V6(_) => 128,
140 };
141 if prefix_len > max_prefix {
142 return Err(err());
143 }
144
145 Ok((base, prefix_len))
146}