1use sha2::{Digest, Sha256};
7use std::net::IpAddr;
8use std::str::FromStr;
9use subtle::ConstantTimeEq;
10
11#[derive(Debug, Clone)]
13pub struct IpFilter {
14 allowed: Vec<IpFilterEntry>,
16}
17
18#[derive(Debug, Clone)]
19enum IpFilterEntry {
20 Single(IpAddr),
22 Cidr { network: IpAddr, prefix_len: u8 },
24}
25
26impl IpFilter {
27 pub fn new() -> Self {
29 Self {
30 allowed: Vec::new(),
31 }
32 }
33
34 pub fn allow(&mut self, ip_or_cidr: &str) -> Result<(), String> {
45 if let Some(slash_pos) = ip_or_cidr.find('/') {
47 let (network_part, prefix_part) = ip_or_cidr.split_at(slash_pos);
48 let prefix_str = &prefix_part[1..]; let network = IpAddr::from_str(network_part)
51 .map_err(|e| format!("Invalid network address: {}", e))?;
52
53 let prefix_len: u8 = prefix_str
54 .parse()
55 .map_err(|_| format!("Invalid CIDR prefix length: {}", prefix_str))?;
56
57 let max_prefix = match network {
59 IpAddr::V4(_) => 32,
60 IpAddr::V6(_) => 128,
61 };
62
63 if prefix_len > max_prefix {
64 return Err(format!(
65 "CIDR prefix length {} exceeds maximum {} for {:?}",
66 prefix_len, max_prefix, network
67 ));
68 }
69
70 self.allowed.push(IpFilterEntry::Cidr {
71 network,
72 prefix_len,
73 });
74 Ok(())
75 } else {
76 let ip =
78 IpAddr::from_str(ip_or_cidr).map_err(|e| format!("Invalid IP address: {}", e))?;
79 self.allowed.push(IpFilterEntry::Single(ip));
80 Ok(())
81 }
82 }
83
84 pub fn is_allowed(&self, ip: IpAddr) -> bool {
86 if self.allowed.is_empty() {
88 return true;
89 }
90
91 self.allowed.iter().any(|entry| self.matches(ip, entry))
93 }
94
95 fn matches(&self, ip: IpAddr, entry: &IpFilterEntry) -> bool {
97 match entry {
98 IpFilterEntry::Single(allowed_ip) => ip == *allowed_ip,
99 IpFilterEntry::Cidr {
100 network,
101 prefix_len,
102 } => self.ip_in_cidr(ip, *network, *prefix_len),
103 }
104 }
105
106 fn ip_in_cidr(&self, ip: IpAddr, network: IpAddr, prefix_len: u8) -> bool {
108 match (ip, network) {
109 (IpAddr::V4(ip), IpAddr::V4(net)) => {
110 let ip_bits = u32::from(ip);
111 let net_bits = u32::from(net);
112 let mask = if prefix_len == 0 {
113 0
114 } else {
115 0xFFFFFFFFu32 << (32 - prefix_len)
116 };
117 (ip_bits & mask) == (net_bits & mask)
118 }
119 (IpAddr::V6(ip), IpAddr::V6(net)) => {
120 let ip_bits = u128::from(ip);
121 let net_bits = u128::from(net);
122 let mask = if prefix_len == 0 {
123 0
124 } else {
125 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFu128 << (128 - prefix_len)
126 };
127 (ip_bits & mask) == (net_bits & mask)
128 }
129 _ => false, }
131 }
132}
133
134impl Default for IpFilter {
135 fn default() -> Self {
136 Self::new()
137 }
138}
139
140#[derive(Debug, Clone)]
142pub enum AuthConfig {
143 None,
145 Bearer(String),
147 Basic {
149 username: String,
151 password: String,
153 },
154}
155
156impl AuthConfig {
157 pub fn bearer(token: impl Into<String>) -> Self {
159 Self::Bearer(token.into())
160 }
161
162 pub fn basic(username: impl Into<String>, password: impl Into<String>) -> Self {
164 Self::Basic {
165 username: username.into(),
166 password: password.into(),
167 }
168 }
169
170 pub fn validate(&self, header: &str) -> bool {
180 match self {
181 AuthConfig::None => true,
182 AuthConfig::Bearer(token) => {
183 if let Some(token_part) = header.strip_prefix("Bearer ") {
184 let expected_hash = Sha256::digest(token.as_bytes());
187 let provided_hash = Sha256::digest(token_part.as_bytes());
188 expected_hash.ct_eq(&provided_hash).into()
189 } else {
190 false
191 }
192 }
193 AuthConfig::Basic { username, password } => {
194 if let Some(creds_part) = header.strip_prefix("Basic ") {
195 if let Ok(decoded) = base64_decode(creds_part) {
197 let expected = format!("{}:{}", username, password);
198 let expected_hash = Sha256::digest(expected.as_bytes());
200 let decoded_hash = Sha256::digest(decoded.as_bytes());
201 expected_hash.ct_eq(&decoded_hash).into()
202 } else {
203 false
204 }
205 } else {
206 false
207 }
208 }
209 }
210 }
211}
212
213fn base64_decode(s: &str) -> Result<String, String> {
215 const BASE64_TABLE: &[u8; 64] =
217 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
218 let mut table = [255u8; 256];
219 for (i, &c) in BASE64_TABLE.iter().enumerate() {
220 table[c as usize] = i as u8;
221 }
222
223 let input = s.trim_end_matches('=');
224 let mut output = Vec::new();
225 let bytes = input.as_bytes();
226
227 for chunk in bytes.chunks(4) {
228 if chunk.len() < 2 {
229 break;
230 }
231
232 let mut buf = [0u8; 4];
233 for (i, &c) in chunk.iter().enumerate() {
234 if c == b'=' {
235 break;
236 }
237 buf[i] = table[c as usize];
238 if buf[i] == 255 {
239 return Err("Invalid base64 character".to_string());
240 }
241 }
242
243 let b1 = (buf[0] << 2) | (buf[1] >> 4);
244 output.push(b1);
245
246 if chunk.len() > 2 && chunk[2] != b'=' {
247 let b2 = ((buf[1] & 0x0F) << 4) | (buf[2] >> 2);
248 output.push(b2);
249 }
250
251 if chunk.len() > 3 && chunk[3] != b'=' {
252 let b3 = ((buf[2] & 0x03) << 6) | buf[3];
253 output.push(b3);
254 }
255 }
256
257 String::from_utf8(output).map_err(|e| e.to_string())
258}