1use std::{
2 net::{IpAddr, Ipv4Addr, Ipv6Addr},
3 str::FromStr,
4};
5
6use thiserror::Error;
7
8fn trailing_zeros(s: &[u8]) -> usize {
9 let mut count = 0;
10 for &b in s.iter().rev() {
11 if b == 0 {
12 count += 8;
13 } else {
14 count += b.trailing_zeros();
15 break;
16 }
17 }
18 count as usize
19}
20
21fn octets_with_mask_from_range<const N: usize>(
22 start: [u8; N],
23 stop: [u8; N],
24) -> Vec<([u8; N], u8)> {
25 let mut count = 0usize;
26 for (from, to) in start.iter().zip(stop.iter()) {
27 count <<= 8;
28 count |= (to - from) as usize;
29 }
30 count += 1;
31 octets_with_mask(start, count)
32}
33
34fn octets_with_mask<const N: usize>(mut start: [u8; N], mut count: usize) -> Vec<([u8; N], u8)> {
35 let mut result = Vec::new();
36 while count > 0 {
37 let zeros = trailing_zeros(&start);
39 let suffix = zeros.min(count.ilog2() as usize);
40 let mask = N * 8 - suffix;
41 result.push((start, mask as u8));
42
43 let mut byte_to_change = N - suffix / 8 - 1;
45 let mut bit_to_change = suffix % 8;
46 loop {
47 if let Some(new_val) = start[byte_to_change].checked_add(1 << bit_to_change) {
48 start[byte_to_change] = new_val;
49 break;
50 } else {
51 start[byte_to_change] = 0;
52 if byte_to_change == 0 {
53 break;
54 }
55 byte_to_change -= 1;
56 bit_to_change = 0;
57 }
58 }
59
60 count -= 1 << suffix;
62 }
63 result
64}
65
66pub trait IntoBitPath {
67 type Output: Iterator<Item = bool>;
68
69 fn into_bit_path(self) -> Self::Output;
70}
71
72impl<T> IntoBitPath for T
73where
74 T: Iterator<Item = bool>,
75{
76 type Output = T;
77
78 fn into_bit_path(self) -> Self::Output {
79 self
80 }
81}
82
83#[derive(Clone, Copy, Debug, Eq, PartialEq)]
84pub struct IpAddrWithMask {
85 pub addr: IpAddr,
86 pub mask: u8,
87}
88
89impl IpAddrWithMask {
90 pub fn new(addr: IpAddr, mask: u8) -> Self {
91 Self { addr, mask }
92 }
93
94 pub fn from_count(addr: IpAddr, count: usize) -> Vec<Self> {
95 match addr {
96 IpAddr::V4(addr) => octets_with_mask(addr.octets(), count)
97 .into_iter()
98 .map(|(octets, mask)| {
99 let addr = Ipv4Addr::from(octets);
100 Self::new(IpAddr::V4(addr), mask)
101 })
102 .collect(),
103 IpAddr::V6(addr) => octets_with_mask(addr.octets(), count)
104 .into_iter()
105 .map(|(octets, mask)| {
106 let addr = Ipv6Addr::from(octets);
107 Self::new(IpAddr::V6(addr), mask)
108 })
109 .collect(),
110 }
111 }
112
113 pub fn from_ip_range(first: IpAddr, last: IpAddr) -> Vec<Self> {
114 match (first, last) {
115 (IpAddr::V4(first), IpAddr::V4(last)) => {
116 octets_with_mask_from_range(first.octets(), last.octets())
117 .into_iter()
118 .map(|(octets, mask)| {
119 let addr = Ipv4Addr::from(octets);
120 Self::new(IpAddr::V4(addr), mask)
121 })
122 .collect()
123 }
124 (IpAddr::V6(first), IpAddr::V6(last)) => {
125 octets_with_mask_from_range(first.octets(), last.octets())
126 .into_iter()
127 .map(|(octets, mask)| {
128 let addr = Ipv6Addr::from(octets);
129 Self::new(IpAddr::V6(addr), mask)
130 })
131 .collect()
132 }
133 _ => panic!("IP version mismatch"),
134 }
135 }
136}
137
138impl From<IpAddr> for IpAddrWithMask {
139 fn from(addr: IpAddr) -> Self {
140 match addr {
141 IpAddr::V4(addr) => Self::from(addr),
142 IpAddr::V6(addr) => Self::from(addr),
143 }
144 }
145}
146
147impl From<Ipv4Addr> for IpAddrWithMask {
148 fn from(addr: Ipv4Addr) -> Self {
149 Self {
150 addr: IpAddr::V4(addr),
151 mask: 32,
152 }
153 }
154}
155
156impl From<Ipv6Addr> for IpAddrWithMask {
157 fn from(addr: Ipv6Addr) -> Self {
158 Self {
159 addr: IpAddr::V6(addr),
160 mask: 128,
161 }
162 }
163}
164
165#[derive(Debug, Error)]
166pub enum IpAddrWithMaskParseError {
167 #[error("address parse error")]
168 AddrParseError(#[from] std::net::AddrParseError),
169 #[error("mask parse error")]
170 MaskParseError(#[from] std::num::ParseIntError),
171}
172
173impl FromStr for IpAddrWithMask {
174 type Err = IpAddrWithMaskParseError;
175
176 fn from_str(s: &str) -> Result<Self, Self::Err> {
177 let mut parts = s.split('/');
178 let addr = parts.next().unwrap_or(s);
179 let mask = parts.next();
180 let addr = IpAddr::from_str(addr)?;
181 if let Some(mask) = mask {
182 Ok(Self {
183 addr,
184 mask: mask.parse()?,
185 })
186 } else {
187 Ok(Self::from(addr))
188 }
189 }
190}
191
192impl IntoBitPath for IpAddrWithMask {
193 type Output = IpAddrWithMaskBitPath;
194
195 fn into_bit_path(self) -> Self::Output {
196 IpAddrWithMaskBitPath { addr: self, bit: 0 }
197 }
198}
199
200pub struct IpAddrWithMaskBitPath {
201 addr: IpAddrWithMask,
202 bit: u8,
203}
204
205impl Iterator for IpAddrWithMaskBitPath {
206 type Item = bool;
207
208 fn next(&mut self) -> Option<Self::Item> {
209 if self.bit >= self.addr.mask {
210 return None;
211 }
212 let result = match self.addr.addr {
213 IpAddr::V4(addr) => {
214 addr.octets()[self.bit as usize / 8] & (1 << (7 - self.bit % 8)) != 0
215 }
216 IpAddr::V6(addr) => {
217 addr.octets()[self.bit as usize / 8] & (1 << (7 - self.bit % 8)) != 0
218 }
219 };
220 self.bit += 1;
221 Some(result)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use std::vec;
228
229 use super::*;
230
231 #[test]
232 fn test_trailing_zeros() {
233 assert_eq!(trailing_zeros(&[0, 0, 0, 0]), 32);
234 assert_eq!(trailing_zeros(&[0, 0, 0, 1]), 0);
235 assert_eq!(trailing_zeros(&[0, 0, 1, 0]), 8);
236 assert_eq!(trailing_zeros(&[0, 1, 0, 0]), 16);
237 assert_eq!(trailing_zeros(&[1, 0, 0, 0]), 24);
238 assert_eq!(trailing_zeros(&[1, 0, 0, 1]), 0);
239 }
240
241 #[test]
242 fn test_octets_with_mask() {
243 assert_eq!(
244 octets_with_mask([1, 0, 0, 0], 255),
245 vec![
246 ([1, 0, 0, 0], 25),
247 ([1, 0, 0, 128], 26),
248 ([1, 0, 0, 192], 27),
249 ([1, 0, 0, 224], 28),
250 ([1, 0, 0, 240], 29),
251 ([1, 0, 0, 248], 30),
252 ([1, 0, 0, 252], 31),
253 ([1, 0, 0, 254], 32),
254 ],
255 );
256 assert_eq!(
257 octets_with_mask([1, 0, 0, 240], 32),
258 vec![([1, 0, 0, 240], 28), ([1, 0, 1, 0], 28),],
259 );
260 assert_eq!(
261 octets_with_mask([196, 11, 105, 0], 256),
262 vec![([196, 11, 105, 0], 24),],
263 );
264 assert_eq!(
265 octets_with_mask([196, 11, 105, 0], 1024),
266 vec![
267 ([196, 11, 105, 0], 24),
268 ([196, 11, 106, 0], 23),
269 ([196, 11, 108, 0], 24),
270 ],
271 );
272 }
273
274 #[test]
275 fn test_octets_with_mask_from_range() {
276 assert_eq!(
277 octets_with_mask_from_range([196, 11, 105, 0], [196, 11, 105, 255]),
278 vec![([196, 11, 105, 0], 24),],
279 );
280 assert_eq!(
281 octets_with_mask_from_range([0, 0, 0, 0], [1, 0, 0, 255]),
282 vec![([0, 0, 0, 0], 8), ([1, 0, 0, 0], 24)],
283 );
284 }
285
286 #[test]
287 fn test_ip_addr_with_mask() {
288 let addr = "196.11.105.0".parse();
289 let count = 1024;
290 let addrs = IpAddrWithMask::from_count(addr.unwrap(), count);
291 assert_eq!(
292 addrs,
293 vec![
294 IpAddrWithMask {
295 addr: IpAddr::V4(Ipv4Addr::new(196, 11, 105, 0)),
296 mask: 24,
297 },
298 IpAddrWithMask {
299 addr: IpAddr::V4(Ipv4Addr::new(196, 11, 106, 0)),
300 mask: 23,
301 },
302 IpAddrWithMask {
303 addr: IpAddr::V4(Ipv4Addr::new(196, 11, 108, 0)),
304 mask: 24,
305 },
306 ]
307 );
308 }
309}