Skip to main content

edgesentry_rs/ingest/
network_policy.rs

1use std::net::IpAddr;
2
3use thiserror::Error;
4
5/// A single entry in the allowlist: either an exact IP address or a CIDR block.
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum AllowedSource {
8    /// Exact IP address match.
9    Ip(IpAddr),
10    /// CIDR block (IPv4 or IPv6).
11    Cidr { base: IpAddr, prefix_len: u8 },
12}
13
14impl AllowedSource {
15    /// Returns `true` if `addr` is covered by this entry.
16    pub fn contains(&self, addr: IpAddr) -> bool {
17        match self {
18            AllowedSource::Ip(allowed) => *allowed == addr,
19            AllowedSource::Cidr { base, prefix_len } => {
20                cidr_contains(*base, *prefix_len, addr)
21            }
22        }
23    }
24}
25
26fn cidr_contains(base: IpAddr, prefix_len: u8, addr: IpAddr) -> bool {
27    match (base, addr) {
28        (IpAddr::V4(base_v4), IpAddr::V4(addr_v4)) => {
29            if prefix_len == 0 {
30                return true;
31            }
32            if prefix_len > 32 {
33                return false;
34            }
35            let shift = 32 - prefix_len as u32;
36            let base_bits = u32::from(base_v4) >> shift;
37            let addr_bits = u32::from(addr_v4) >> shift;
38            base_bits == addr_bits
39        }
40        (IpAddr::V6(base_v6), IpAddr::V6(addr_v6)) => {
41            if prefix_len == 0 {
42                return true;
43            }
44            if prefix_len > 128 {
45                return false;
46            }
47            let shift = 128 - prefix_len as u32;
48            let base_bits = u128::from(base_v6) >> shift;
49            let addr_bits = u128::from(addr_v6) >> shift;
50            base_bits == addr_bits
51        }
52        // Mismatched families never match.
53        _ => false,
54    }
55}
56
57/// Errors produced by [`NetworkPolicy`].
58#[derive(Debug, Error, PartialEq, Eq)]
59pub enum NetworkPolicyError {
60    /// Source address is not in the allowlist.
61    #[error("source {addr} is not in the allowlist")]
62    Denied { addr: IpAddr },
63    /// The supplied CIDR string could not be parsed.
64    #[error("invalid CIDR '{0}': expected <addr>/<prefix_len>")]
65    InvalidCidr(String),
66}
67
68/// Deny-by-default IP/CIDR allowlist for ingest endpoints.
69///
70/// # Usage
71///
72/// Build a policy with [`NetworkPolicy::new`], populate it with
73/// [`allow_ip`](Self::allow_ip) / [`allow_cidr`](Self::allow_cidr), then call
74/// [`check`](Self::check) with the source address of each incoming connection
75/// **before** passing the payload to [`IngestService`](super::IngestService).
76///
77/// ```rust
78/// use std::net::IpAddr;
79/// use edgesentry_rs::NetworkPolicy;
80///
81/// let mut policy = NetworkPolicy::new();
82/// policy.allow_cidr("10.0.0.0/8").unwrap();
83///
84/// let trusted: IpAddr = "10.1.2.3".parse().unwrap();
85/// assert!(policy.check(trusted).is_ok());
86///
87/// let untrusted: IpAddr = "192.168.1.1".parse().unwrap();
88/// assert!(policy.check(untrusted).is_err());
89/// ```
90#[derive(Debug, Default, Clone)]
91pub struct NetworkPolicy {
92    allowed: Vec<AllowedSource>,
93}
94
95impl NetworkPolicy {
96    /// Create an empty policy (all sources denied until rules are added).
97    pub fn new() -> Self {
98        Self::default()
99    }
100
101    /// Permit a single IP address.
102    pub fn allow_ip(&mut self, addr: IpAddr) -> &mut Self {
103        self.allowed.push(AllowedSource::Ip(addr));
104        self
105    }
106
107    /// Permit all addresses within a CIDR block, e.g. `"10.0.0.0/8"` or `"fd00::/8"`.
108    pub fn allow_cidr(&mut self, cidr: &str) -> Result<&mut Self, NetworkPolicyError> {
109        let (base, prefix_len) = parse_cidr(cidr)?;
110        self.allowed.push(AllowedSource::Cidr { base, prefix_len });
111        Ok(self)
112    }
113
114    /// Returns `Ok(())` if `source` is covered by at least one allowlist entry,
115    /// or `Err(NetworkPolicyError::Denied)` if not.
116    pub fn check(&self, source: IpAddr) -> Result<(), NetworkPolicyError> {
117        if self.allowed.iter().any(|e| e.contains(source)) {
118            Ok(())
119        } else {
120            Err(NetworkPolicyError::Denied { addr: source })
121        }
122    }
123
124    /// Returns the list of configured allowlist entries.
125    pub fn entries(&self) -> &[AllowedSource] {
126        &self.allowed
127    }
128}
129
130fn parse_cidr(cidr: &str) -> Result<(IpAddr, u8), NetworkPolicyError> {
131    let err = || NetworkPolicyError::InvalidCidr(cidr.to_string());
132
133    let (addr_str, prefix_str) = cidr.split_once('/').ok_or_else(err)?;
134    let prefix_len: u8 = prefix_str.parse().map_err(|_| err())?;
135    let base: IpAddr = addr_str.parse().map_err(|_| err())?;
136
137    let max_prefix = match base {
138        IpAddr::V4(_) => 32,
139        IpAddr::V6(_) => 128,
140    };
141    if prefix_len > max_prefix {
142        return Err(err());
143    }
144
145    Ok((base, prefix_len))
146}