1use std::net::IpAddr;
16
17use anyhow::{Context, Result};
18
19use crate::config::AccessCfg;
20
21#[derive(Debug, Clone, Copy)]
24enum Cidr {
25 V4 { base: u32, prefix: u8 },
26 V6 { base: u128, prefix: u8 },
27}
28
29impl Cidr {
30 fn parse(s: &str) -> Result<Cidr> {
33 let s = s.trim();
34 let (addr_part, prefix_part) = match s.split_once('/') {
35 Some((a, p)) => (a, Some(p)),
36 None => (s, None),
37 };
38 let addr: IpAddr = addr_part
39 .parse()
40 .with_context(|| format!("invalid IP/CIDR address {s:?}"))?;
41 match addr {
42 IpAddr::V4(v4) => {
43 let prefix = match prefix_part {
44 Some(p) => p
45 .parse::<u8>()
46 .ok()
47 .filter(|p| *p <= 32)
48 .with_context(|| format!("invalid IPv4 CIDR prefix in {s:?} (0-32)"))?,
49 None => 32,
50 };
51 let base = u32::from(v4) & mask_v4(prefix);
52 Ok(Cidr::V4 { base, prefix })
53 }
54 IpAddr::V6(v6) => {
55 let prefix = match prefix_part {
56 Some(p) => p
57 .parse::<u8>()
58 .ok()
59 .filter(|p| *p <= 128)
60 .with_context(|| format!("invalid IPv6 CIDR prefix in {s:?} (0-128)"))?,
61 None => 128,
62 };
63 let base = u128::from(v6) & mask_v6(prefix);
64 Ok(Cidr::V6 { base, prefix })
65 }
66 }
67 }
68
69 fn contains(&self, ip: IpAddr) -> bool {
70 match (self, ip) {
71 (Cidr::V4 { base, prefix }, IpAddr::V4(v4)) => {
72 u32::from(v4) & mask_v4(*prefix) == *base
73 }
74 (Cidr::V6 { base, prefix }, IpAddr::V6(v6)) => {
75 u128::from(v6) & mask_v6(*prefix) == *base
76 }
77 _ => false,
79 }
80 }
81}
82
83fn mask_v4(prefix: u8) -> u32 {
86 if prefix == 0 {
87 0
88 } else {
89 u32::MAX << (32 - prefix)
90 }
91}
92
93fn mask_v6(prefix: u8) -> u128 {
94 if prefix == 0 {
95 0
96 } else {
97 u128::MAX << (128 - prefix)
98 }
99}
100
101pub struct AccessPolicy {
103 allow: Vec<Cidr>,
104 deny: Vec<Cidr>,
105}
106
107impl AccessPolicy {
108 pub fn build(cfg: &AccessCfg) -> Result<Option<AccessPolicy>> {
112 if cfg.allow.is_empty() && cfg.deny.is_empty() {
113 return Ok(None);
114 }
115 let allow = cfg
116 .allow
117 .iter()
118 .map(|s| Cidr::parse(s))
119 .collect::<Result<_>>()?;
120 let deny = cfg
121 .deny
122 .iter()
123 .map(|s| Cidr::parse(s))
124 .collect::<Result<_>>()?;
125 Ok(Some(AccessPolicy { allow, deny }))
126 }
127
128 pub fn allowed(&self, ip: IpAddr) -> bool {
131 if self.deny.iter().any(|c| c.contains(ip)) {
132 return false;
133 }
134 if self.allow.is_empty() {
135 return true;
136 }
137 self.allow.iter().any(|c| c.contains(ip))
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 fn ip(s: &str) -> IpAddr {
146 s.parse().unwrap()
147 }
148
149 fn policy(allow: &[&str], deny: &[&str]) -> AccessPolicy {
150 AccessPolicy::build(&AccessCfg {
151 allow: allow.iter().map(|s| s.to_string()).collect(),
152 deny: deny.iter().map(|s| s.to_string()).collect(),
153 })
154 .unwrap()
155 .unwrap()
156 }
157
158 #[test]
159 fn empty_lists_build_to_none() {
160 assert!(AccessPolicy::build(&AccessCfg::default())
161 .unwrap()
162 .is_none());
163 }
164
165 #[test]
166 fn allowlist_is_a_whitelist() {
167 let p = policy(&["10.0.0.0/8", "203.0.113.7"], &[]);
168 assert!(p.allowed(ip("10.1.2.3")));
169 assert!(p.allowed(ip("203.0.113.7")));
170 assert!(!p.allowed(ip("8.8.8.8")));
171 }
172
173 #[test]
174 fn deny_wins_over_allow() {
175 let p = policy(&["10.0.0.0/8"], &["10.0.0.5"]);
176 assert!(p.allowed(ip("10.0.0.6")));
177 assert!(!p.allowed(ip("10.0.0.5")));
178 }
179
180 #[test]
181 fn deny_only_blocks_listed_and_admits_rest() {
182 let p = policy(&[], &["192.168.0.0/16"]);
183 assert!(!p.allowed(ip("192.168.1.1")));
184 assert!(p.allowed(ip("203.0.113.1")));
185 }
186
187 #[test]
188 fn ipv6_and_cross_family() {
189 let p = policy(&["2001:db8::/32"], &[]);
190 assert!(p.allowed(ip("2001:db8::1")));
191 assert!(!p.allowed(ip("2001:dead::1")));
192 assert!(!p.allowed(ip("10.0.0.1")));
194 }
195
196 #[test]
197 fn rejects_bad_entries() {
198 assert!(AccessPolicy::build(&AccessCfg {
199 allow: vec!["not-an-ip".into()],
200 deny: vec![],
201 })
202 .is_err());
203 assert!(AccessPolicy::build(&AccessCfg {
204 allow: vec!["10.0.0.0/99".into()],
205 deny: vec![],
206 })
207 .is_err());
208 }
209}