maxminddb_writer/
paths.rs

1use std::{
2    net::{IpAddr, Ipv4Addr, Ipv6Addr},
3    str::FromStr,
4};
5
6use thiserror::Error;
7
8fn trailing_zeros(s: &[u8]) -> usize {
9    let mut count = 0;
10    for &b in s.iter().rev() {
11        if b == 0 {
12            count += 8;
13        } else {
14            count += b.trailing_zeros();
15            break;
16        }
17    }
18    count as usize
19}
20
21fn octets_with_mask_from_range<const N: usize>(
22    start: [u8; N],
23    stop: [u8; N],
24) -> Vec<([u8; N], u8)> {
25    let mut count = 0usize;
26    for (from, to) in start.iter().zip(stop.iter()) {
27        count <<= 8;
28        count |= (to - from) as usize;
29    }
30    count += 1;
31    octets_with_mask(start, count)
32}
33
34fn octets_with_mask<const N: usize>(mut start: [u8; N], mut count: usize) -> Vec<([u8; N], u8)> {
35    let mut result = Vec::new();
36    while count > 0 {
37        // calculate the biggest possible mask
38        let zeros = trailing_zeros(&start);
39        let suffix = zeros.min(count.ilog2() as usize);
40        let mask = N * 8 - suffix;
41        result.push((start, mask as u8));
42
43        // increment start
44        let mut byte_to_change = N - suffix / 8 - 1;
45        let mut bit_to_change = suffix % 8;
46        loop {
47            if let Some(new_val) = start[byte_to_change].checked_add(1 << bit_to_change) {
48                start[byte_to_change] = new_val;
49                break;
50            } else {
51                start[byte_to_change] = 0;
52                if byte_to_change == 0 {
53                    break;
54                }
55                byte_to_change -= 1;
56                bit_to_change = 0;
57            }
58        }
59
60        // subtract the block from count
61        count -= 1 << suffix;
62    }
63    result
64}
65
66pub trait IntoBitPath {
67    type Output: Iterator<Item = bool>;
68
69    fn into_bit_path(self) -> Self::Output;
70}
71
72impl<T> IntoBitPath for T
73where
74    T: Iterator<Item = bool>,
75{
76    type Output = T;
77
78    fn into_bit_path(self) -> Self::Output {
79        self
80    }
81}
82
83#[derive(Clone, Copy, Debug, Eq, PartialEq)]
84pub struct IpAddrWithMask {
85    pub addr: IpAddr,
86    pub mask: u8,
87}
88
89impl IpAddrWithMask {
90    pub fn new(addr: IpAddr, mask: u8) -> Self {
91        Self { addr, mask }
92    }
93
94    pub fn from_count(addr: IpAddr, count: usize) -> Vec<Self> {
95        match addr {
96            IpAddr::V4(addr) => octets_with_mask(addr.octets(), count)
97                .into_iter()
98                .map(|(octets, mask)| {
99                    let addr = Ipv4Addr::from(octets);
100                    Self::new(IpAddr::V4(addr), mask)
101                })
102                .collect(),
103            IpAddr::V6(addr) => octets_with_mask(addr.octets(), count)
104                .into_iter()
105                .map(|(octets, mask)| {
106                    let addr = Ipv6Addr::from(octets);
107                    Self::new(IpAddr::V6(addr), mask)
108                })
109                .collect(),
110        }
111    }
112
113    pub fn from_ip_range(first: IpAddr, last: IpAddr) -> Vec<Self> {
114        match (first, last) {
115            (IpAddr::V4(first), IpAddr::V4(last)) => {
116                octets_with_mask_from_range(first.octets(), last.octets())
117                    .into_iter()
118                    .map(|(octets, mask)| {
119                        let addr = Ipv4Addr::from(octets);
120                        Self::new(IpAddr::V4(addr), mask)
121                    })
122                    .collect()
123            }
124            (IpAddr::V6(first), IpAddr::V6(last)) => {
125                octets_with_mask_from_range(first.octets(), last.octets())
126                    .into_iter()
127                    .map(|(octets, mask)| {
128                        let addr = Ipv6Addr::from(octets);
129                        Self::new(IpAddr::V6(addr), mask)
130                    })
131                    .collect()
132            }
133            _ => panic!("IP version mismatch"),
134        }
135    }
136}
137
138impl From<IpAddr> for IpAddrWithMask {
139    fn from(addr: IpAddr) -> Self {
140        match addr {
141            IpAddr::V4(addr) => Self::from(addr),
142            IpAddr::V6(addr) => Self::from(addr),
143        }
144    }
145}
146
147impl From<Ipv4Addr> for IpAddrWithMask {
148    fn from(addr: Ipv4Addr) -> Self {
149        Self {
150            addr: IpAddr::V4(addr),
151            mask: 32,
152        }
153    }
154}
155
156impl From<Ipv6Addr> for IpAddrWithMask {
157    fn from(addr: Ipv6Addr) -> Self {
158        Self {
159            addr: IpAddr::V6(addr),
160            mask: 128,
161        }
162    }
163}
164
165#[derive(Debug, Error)]
166pub enum IpAddrWithMaskParseError {
167    #[error("address parse error")]
168    AddrParseError(#[from] std::net::AddrParseError),
169    #[error("mask parse error")]
170    MaskParseError(#[from] std::num::ParseIntError),
171}
172
173impl FromStr for IpAddrWithMask {
174    type Err = IpAddrWithMaskParseError;
175
176    fn from_str(s: &str) -> Result<Self, Self::Err> {
177        let mut parts = s.split('/');
178        let addr = parts.next().unwrap_or(s);
179        let mask = parts.next();
180        let addr = IpAddr::from_str(addr)?;
181        if let Some(mask) = mask {
182            Ok(Self {
183                addr,
184                mask: mask.parse()?,
185            })
186        } else {
187            Ok(Self::from(addr))
188        }
189    }
190}
191
192impl IntoBitPath for IpAddrWithMask {
193    type Output = IpAddrWithMaskBitPath;
194
195    fn into_bit_path(self) -> Self::Output {
196        IpAddrWithMaskBitPath { addr: self, bit: 0 }
197    }
198}
199
200pub struct IpAddrWithMaskBitPath {
201    addr: IpAddrWithMask,
202    bit: u8,
203}
204
205impl Iterator for IpAddrWithMaskBitPath {
206    type Item = bool;
207
208    fn next(&mut self) -> Option<Self::Item> {
209        if self.bit >= self.addr.mask {
210            return None;
211        }
212        let result = match self.addr.addr {
213            IpAddr::V4(addr) => {
214                addr.octets()[self.bit as usize / 8] & (1 << (7 - self.bit % 8)) != 0
215            }
216            IpAddr::V6(addr) => {
217                addr.octets()[self.bit as usize / 8] & (1 << (7 - self.bit % 8)) != 0
218            }
219        };
220        self.bit += 1;
221        Some(result)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use std::vec;
228
229    use super::*;
230
231    #[test]
232    fn test_trailing_zeros() {
233        assert_eq!(trailing_zeros(&[0, 0, 0, 0]), 32);
234        assert_eq!(trailing_zeros(&[0, 0, 0, 1]), 0);
235        assert_eq!(trailing_zeros(&[0, 0, 1, 0]), 8);
236        assert_eq!(trailing_zeros(&[0, 1, 0, 0]), 16);
237        assert_eq!(trailing_zeros(&[1, 0, 0, 0]), 24);
238        assert_eq!(trailing_zeros(&[1, 0, 0, 1]), 0);
239    }
240
241    #[test]
242    fn test_octets_with_mask() {
243        assert_eq!(
244            octets_with_mask([1, 0, 0, 0], 255),
245            vec![
246                ([1, 0, 0, 0], 25),
247                ([1, 0, 0, 128], 26),
248                ([1, 0, 0, 192], 27),
249                ([1, 0, 0, 224], 28),
250                ([1, 0, 0, 240], 29),
251                ([1, 0, 0, 248], 30),
252                ([1, 0, 0, 252], 31),
253                ([1, 0, 0, 254], 32),
254            ],
255        );
256        assert_eq!(
257            octets_with_mask([1, 0, 0, 240], 32),
258            vec![([1, 0, 0, 240], 28), ([1, 0, 1, 0], 28),],
259        );
260        assert_eq!(
261            octets_with_mask([196, 11, 105, 0], 256),
262            vec![([196, 11, 105, 0], 24),],
263        );
264        assert_eq!(
265            octets_with_mask([196, 11, 105, 0], 1024),
266            vec![
267                ([196, 11, 105, 0], 24),
268                ([196, 11, 106, 0], 23),
269                ([196, 11, 108, 0], 24),
270            ],
271        );
272    }
273
274    #[test]
275    fn test_octets_with_mask_from_range() {
276        assert_eq!(
277            octets_with_mask_from_range([196, 11, 105, 0], [196, 11, 105, 255]),
278            vec![([196, 11, 105, 0], 24),],
279        );
280        assert_eq!(
281            octets_with_mask_from_range([0, 0, 0, 0], [1, 0, 0, 255]),
282            vec![([0, 0, 0, 0], 8), ([1, 0, 0, 0], 24)],
283        );
284    }
285
286    #[test]
287    fn test_ip_addr_with_mask() {
288        let addr = "196.11.105.0".parse();
289        let count = 1024;
290        let addrs = IpAddrWithMask::from_count(addr.unwrap(), count);
291        assert_eq!(
292            addrs,
293            vec![
294                IpAddrWithMask {
295                    addr: IpAddr::V4(Ipv4Addr::new(196, 11, 105, 0)),
296                    mask: 24,
297                },
298                IpAddrWithMask {
299                    addr: IpAddr::V4(Ipv4Addr::new(196, 11, 106, 0)),
300                    mask: 23,
301                },
302                IpAddrWithMask {
303                    addr: IpAddr::V4(Ipv4Addr::new(196, 11, 108, 0)),
304                    mask: 24,
305                },
306            ]
307        );
308    }
309}