Skip to main content

ipstuff/
masked.rs

1use crate::IpBitwiseExt;
2
3use std::net::Ipv4Addr;
4use std::str::FromStr;
5
6use std::ops::Not;
7
8use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
9
10/// A 4-byte type representing a subnet mask. This type can only be a valid subnet mask.
11#[repr(align(4))]
12#[derive(Copy, Clone, Eq, PartialEq, Hash)]
13pub struct Ipv4Mask {
14    mask: [u8; 4],
15}
16
17impl Ipv4Mask {
18    /// Returns a mask with the specified length.
19    ///
20    /// # Panics
21    ///
22    /// Will panic if provided length is > 32
23    pub const fn new(len: u8) -> Self {
24        #[rustfmt::skip]
25        const MASKS: [[u8; 4]; 33] = [
26            // /0
27            [0, 0, 0, 0],
28            // /1 - /8
29            [128, 0, 0, 0], [192, 0, 0, 0], [224, 0, 0, 0], [240, 0, 0, 0],
30            [248, 0, 0, 0], [252, 0, 0, 0], [254, 0, 0, 0], [255, 0, 0, 0],
31            // /9 - /16
32            [255, 128, 0, 0], [255, 192, 0, 0], [255, 224, 0, 0], [255, 240, 0, 0],
33            [255, 248, 0, 0], [255, 252, 0, 0], [255, 254, 0, 0], [255, 255, 0, 0],
34            // /17 - /24
35            [255, 255, 128, 0], [255, 255, 192, 0], [255, 255, 224, 0], [255, 255, 240, 0],
36            [255, 255, 248, 0], [255, 255, 252, 0], [255, 255, 254, 0], [255, 255, 255, 0],
37            // /25 - /32
38            [255, 255, 255, 128], [255, 255, 255, 192], [255, 255, 255, 224], [255, 255, 255, 240],
39            [255, 255, 255, 248], [255, 255, 255, 252], [255, 255, 255, 254], [255, 255, 255, 255],
40        ];
41        let mask = MASKS[len as usize];
42        Self { mask }
43    }
44    /// Constructs a subnet mask from the provided bytes, if they represent a valid mask.
45    pub fn from_bytes(bytes: [u8; 4]) -> Option<Self> {
46        Self::from_u32(u32::from_be_bytes(bytes))
47    }
48    /// Constructs a subnet mask from the provided u32, if it represents a valid mask.
49    pub fn from_u32(x: u32) -> Option<Self> {
50        let ones = if cfg!(target_feature = "popcnt") {
51            x.count_ones() as u8 // popcnt;
52        } else {
53            (!x).leading_zeros() as u8 // not; bsr;
54        };
55        let zeros = x.trailing_zeros() as u8; // tzcnt / bsf;
56        // add; test; sete;
57        if ones + zeros == 32 {
58            let mask = x.to_be_bytes();
59            Some(Self { mask })
60        } else {
61            None
62        }
63    }
64    /// Returns the subnet mask as an array of bytes.
65    pub const fn octets(self) -> [u8; 4] {
66        self.mask
67    }
68    /// Returns this mask in u32 representation
69    pub const fn as_u32(self) -> u32 {
70        let bytes = self.octets();
71        (bytes[0] as u32) << 24 | (bytes[1] as u32) << 16 | (bytes[2] as u32) << 8 | bytes[3] as u32
72    }
73    /// Returns the length of the mask. That is, the number of 1 bits in this mask.
74    pub const fn len(self) -> u8 {
75        let x = self.as_u32();
76        #[cfg(target_feature = "popcnt")]
77        let len = x.count_ones() as u8;
78        #[cfg(not(target_feature = "popcnt"))]
79        let len = (!x).leading_zeros() as u8;
80        len
81    }
82}
83
84impl Display for Ipv4Mask {
85    fn fmt(&self, f: &mut Formatter) -> FmtResult {
86        if f.alternate() {
87            write!(f, "/{}", self.len())
88        } else {
89            let bytes = self.octets();
90            write!(f, "{}.{}.{}.{}", bytes[0], bytes[1], bytes[2], bytes[3])
91        }
92    }
93}
94
95impl Debug for Ipv4Mask {
96    fn fmt(&self, f: &mut Formatter) -> FmtResult {
97        Display::fmt(self, f)
98    }
99}
100
101impl Not for Ipv4Mask {
102    type Output = [u8; 4];
103    fn not(self) -> [u8; 4] {
104        let x = u32::from_ne_bytes(self.octets());
105        (!x).to_ne_bytes()
106    }
107}
108
109impl FromStr for Ipv4Mask {
110    type Err = InvalidIpv4Mask;
111    fn from_str(s: &str) -> Result<Self, InvalidIpv4Mask> {
112        let bytes = s.parse::<Ipv4Addr>().map_err(|_| InvalidIpv4Mask)?.octets();
113        Self::from_bytes(bytes).ok_or(InvalidIpv4Mask)
114    }
115}
116/// Error when failing to parse an Ipv4Mask.
117#[derive(Debug)]
118pub struct InvalidIpv4Mask;
119/// An 8-byte type representing an IPv4 address and subnet mask pair. The IP may be any ip
120/// within the represented network, and the mask may be any valid subnet mask.
121#[derive(Copy, Clone, Eq, PartialEq, Hash)]
122pub struct MaskedIpv4 {
123    /// The IP address
124    pub ip: Ipv4Addr,
125    /// The subnet mask
126    pub mask: Ipv4Mask,
127}
128
129impl MaskedIpv4 {
130    /// Constructs a MaskedIpv4 from the provided ip and mask.
131    pub const fn new(ip: Ipv4Addr, mask: Ipv4Mask) -> Self {
132        Self { ip, mask }
133    }
134    /// Constructs a MaskedIpv4 from the provided ip and mask length.
135    ///
136    /// # Panics
137    ///
138    /// Will panic if provided length > 32
139    pub const fn cidr(ip: Ipv4Addr, mask_len: u8) -> Self {
140        let mask = Ipv4Mask::new(mask_len);
141        Self::new(ip, mask)
142    }
143    /// Constructs a new MaskedIpv4 from the provided CIDR string.
144    pub fn from_cidr_str(s: &str) -> Option<Self> {
145        let mut parts = s.split("/");
146        let ip = parts.next()?.parse::<Ipv4Addr>().ok()?;
147        let mask_len = parts.next()?.parse::<u8>().ok()?;
148        if mask_len > 32 {
149            None
150        } else {
151            Some(Self::cidr(ip, mask_len))
152        }
153    }
154    /// Constructs a new MaskedIpv4 from the provided IP and subnet mask. There must be exactly one space between the IP and mask.
155    pub fn from_network_str(s: &str) -> Option<Self> {
156        let mut parts = s.split(" ");
157        let ip = parts.next()?.parse().ok()?;
158        let mask = parts.next()?.parse().ok()?;
159        Some(Self::new(ip, mask))
160    }
161    /// Returns a String with the IP and mask in CIDR format. Shortcut for `format!("{:#}", self)`
162    pub fn to_cidr_string(&self) -> String {
163        format!("{:#}", self)
164    }
165    /// Returns a String with the IP and mask in dotted decimal format. Shortcut for `format!("{}", self)`
166    pub fn to_network_string(&self) -> String {
167        format!("{}", self)
168    }
169    /// Returns the network adderss by setting all host bits to 0.
170    pub fn network_address(&self) -> Ipv4Addr {
171        self.ip.bitand(self.mask)
172    }
173    /// Constructs a new MaskedIpv4 using the network address and mask of this MaskedIpv4.
174    pub fn network(&self) -> MaskedIpv4 {
175        Self::new(self.network_address(), self.mask)
176    }
177    /// Returns true if all host bits in the IP are 0. Always returns false if the mask length is 31 or 32.
178    pub fn is_network_address(&self) -> bool {
179        self.mask.len() <= 30 && self.ip == self.network_address()
180    }
181    /// Returns the broadcast address by setting all host bits to 1.
182    pub fn broadcast_address(&self) -> Ipv4Addr {
183        self.ip.bitor(!self.mask)
184    }
185    /// Returns true if all host bits in the IP are 1. Always returns false if the mask length is 31 or 32.
186    pub fn is_broadcast_address(&self) -> bool {
187        self.mask.len() <= 30 && self.ip == self.broadcast_address()
188    }
189    /// Returns the number of network bits. That is, the length of the mask.
190    pub fn network_bits(&self) -> u8 {
191        self.mask.len()
192    }
193    /// Returns the number of host bits. That is, the number of 0 bits in the mask.
194    pub fn host_bits(&self) -> u8 {
195        32 - self.network_bits()
196    }
197    /// Returns the number of host addresses in the network.
198    ///
199    /// # Panics
200    ///
201    /// Will panic if usize is not large enough to hold the host count.
202    pub fn host_count(&self) -> usize {
203        let host_bits = self.host_bits();
204        match host_bits {
205            0 => 1,
206            1 => 2,
207            _ => 2usize.checked_shl(host_bits as u32).unwrap() - 2,
208        }
209    }
210    /// Returns the number of host addresses in the network as u64. Unlike host_count, this will never panic.
211    pub fn host_count_u64(&self) -> u64 {
212        let host_bits = self.host_bits();
213        match host_bits {
214            0 => 1,
215            1 => 2,
216            _ => (2 << host_bits) - 2,
217        }
218    }
219    /// Returns the number of networks of the provided mask length will fit in this network.
220    ///
221    /// # Panics
222    ///
223    /// Will panic if the provided length is > 32, or if the number of networks does not fit in usize
224    pub fn network_count(&self, len: u8) -> usize {
225        if len < self.mask.len() {
226            0
227        } else if len > 32 {
228            panic!("Invalid mask length > 32")
229        } else {
230            let borrowed_bits = len - self.mask.len();
231            2usize.checked_shl(borrowed_bits as u32).unwrap()
232        }
233    }
234    /// Returns the number of networks of the provided mask length will fit in this network as u64. Unlike network_count, this will not panic
235    /// due to overflow. May still panic if the provided length is too long.
236    ///
237    /// # Panics
238    ///
239    /// Will panic if the provided length is > 32
240    pub fn network_count_u64(&self, len: u8) -> u64 {
241        if len < self.mask.len() {
242            0
243        } else if len > 32 {
244            panic!("Invalid mask length > 32")
245        } else {
246            let borrowed_bits = len - self.mask.len();
247            2 << borrowed_bits
248        }
249    }
250    /// Returns true if this network contains the provided IP address, even if the provided IP is the network or broadcast address.
251    pub fn contains(&self, ip: Ipv4Addr) -> bool {
252        self.ip.bitand(self.mask) == ip.bitand(self.mask)
253    }
254}
255
256impl Display for MaskedIpv4 {
257    fn fmt(&self, f: &mut Formatter) -> FmtResult {
258        if f.alternate() {
259            write!(f, "{}/{}", self.ip, self.mask.len())
260        } else {
261            write!(f, "{} {}", self.ip, self.mask)
262        }
263    }
264}
265
266impl Debug for MaskedIpv4 {
267    fn fmt(&self, f: &mut Formatter) -> FmtResult {
268        Display::fmt(self, f)
269    }
270}
271
272impl FromStr for MaskedIpv4 {
273    type Err = InvalidMaskedIpv4;
274    fn from_str(s: &str) -> Result<Self, InvalidMaskedIpv4> {
275        Self::from_cidr_str(s)
276            .or_else(|| Self::from_network_str(s))
277            .ok_or(InvalidMaskedIpv4)
278    }
279}
280/// Error when failing to parse a MaskedIpv4.
281#[derive(Debug)]
282pub struct InvalidMaskedIpv4;
283
284
285
286// impl FromStr for MaskedIpv4 {
287//     type Err = ();
288//     fn from_str(s: &str) -> Result<Self, ()> {
289//         Self::from_cidr_str(s)
290//             .or_else(||Self::from_ip_mask_str(s))
291//             .ok_or(())
292//     }
293// }
294
295// pub fn mask_from_len(len: u8) -> Option<[u8; 4]> {
296//     if len == 0 { return Some([0,0,0,0]) }
297//     // println!("len is {}", len);
298//     let zeroes = 32u8.checked_sub(len)?;
299//     assert!(zeroes < 32);
300//     let x = 0xFFFF_FFFFu32 << zeroes;
301//     Some(x.to_be_bytes())
302// }
303// #[allow(dead_code)] // fix me
304// fn net_addr(ip: Ipv4Addr, mask: [u8; 4]) -> Ipv4Addr {
305//     let bytes = ip.octets();
306//     Ipv4Addr::new(
307//         bytes[0] & mask[0],
308//         bytes[1] & mask[1],
309//         bytes[2] & mask[2],
310//         bytes[3] & mask[3]
311//     )
312// }
313
314// fn validate_mask(mask: [u8; 4]) -> Option<u8> {
315//     let x = u32::from_be_bytes(mask);
316//     match (x.count_zeros(), x.trailing_zeros()) {
317//         (a, b) if a == b => Some(32 - a as u8),
318//         _ => None
319//     }
320// }