Skip to main content

mm1_address/
subnet.rs

1use std::fmt;
2use std::str::FromStr;
3
4use mm1_proto::message;
5
6use crate::address::{Address, AddressParseError};
7
8mod net_mask;
9
10const ADDRESS_BITS: u8 = u64::BITS as u8;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
13#[message(base_path = ::mm1_proto, derive_deserialize = false)]
14pub struct NetMask(u8);
15
16impl<'de> serde::Deserialize<'de> for NetMask {
17    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
18    where
19        D: serde::Deserializer<'de>,
20    {
21        // Route through the validating TryFrom so an out-of-range mask is
22        // rejected on every format instead of panicking later.
23        let value = u8::deserialize(deserializer)?;
24        NetMask::try_from(value).map_err(<D::Error as serde::de::Error>::custom)
25    }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
29#[message(base_path = ::mm1_proto, derive_serialize = false, derive_deserialize = false)]
30pub struct NetAddress {
31    pub address: Address,
32    pub mask:    NetMask,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum MaskParseError {
37    #[error("invalid mask")]
38    InvalidMask(InvalidMask),
39    #[error("parse int error")]
40    ParseIntError(<u64 as FromStr>::Err),
41}
42
43#[derive(Debug, thiserror::Error)]
44#[error("invalid mask: {}", _0)]
45#[message(base_path = ::mm1_proto)]
46pub struct InvalidMask(u8);
47
48#[derive(Debug, thiserror::Error)]
49pub enum NetAddressParseError {
50    #[error("no slash")]
51    NoSlash,
52    #[error("parse addr error")]
53    ParseAddrError(AddressParseError),
54    #[error("parse mask error")]
55    ParseMaskError(MaskParseError),
56}
57
58impl NetMask {
59    pub fn bits_fixed(&self) -> u8 {
60        self.0
61    }
62
63    pub fn bits_available(&self) -> u8 {
64        ADDRESS_BITS - self.bits_fixed()
65    }
66
67    pub fn into_u64(self) -> u64 {
68        match self.0 as u32 {
69            0 => 0u64,
70            zero_to_63 => u64::MAX << (u64::BITS - zero_to_63),
71        }
72    }
73}
74
75impl From<Address> for NetAddress {
76    fn from(address: Address) -> Self {
77        (address, NetMask::MAX).into()
78    }
79}
80
81impl From<(Address, NetMask)> for NetAddress {
82    fn from((address, mask): (Address, NetMask)) -> Self {
83        Self { address, mask }
84    }
85}
86
87impl TryFrom<u8> for NetMask {
88    type Error = InvalidMask;
89
90    fn try_from(value: u8) -> Result<Self, Self::Error> {
91        match value {
92            0..=ADDRESS_BITS => Ok(Self(value)),
93            invalid => Err(InvalidMask(invalid)),
94        }
95    }
96}
97
98impl From<NetMask> for u8 {
99    fn from(value: NetMask) -> Self {
100        value.0
101    }
102}
103
104impl FromStr for NetMask {
105    type Err = MaskParseError;
106
107    fn from_str(s: &str) -> Result<Self, Self::Err> {
108        s.parse::<u8>()
109            .map_err(MaskParseError::ParseIntError)?
110            .try_into()
111            .map_err(MaskParseError::InvalidMask)
112    }
113}
114
115impl FromStr for NetAddress {
116    type Err = NetAddressParseError;
117
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        let (addr, mask) = s.split_once('/').ok_or(NetAddressParseError::NoSlash)?;
120        let address = addr.parse().map_err(NetAddressParseError::ParseAddrError)?;
121        let mask = mask.parse().map_err(NetAddressParseError::ParseMaskError)?;
122
123        Ok(Self { address, mask })
124    }
125}
126
127impl fmt::Display for NetMask {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        fmt::Display::fmt(&self.0, f)
130    }
131}
132
133impl fmt::Display for NetAddress {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        write!(f, "{}/{}", self.address, self.mask)
136    }
137}
138
139impl serde::Serialize for NetAddress {
140    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
141    where
142        S: serde::Serializer,
143    {
144        if serializer.is_human_readable() {
145            self.to_string().serialize(serializer)
146        } else {
147            let Self { address, mask } = self;
148            (address, mask).serialize(serializer)
149        }
150    }
151}
152
153impl<'de> serde::Deserialize<'de> for NetAddress {
154    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
155    where
156        D: serde::Deserializer<'de>,
157    {
158        if deserializer.is_human_readable() {
159            String::deserialize(deserializer)?
160                .parse()
161                .map_err(<D::Error as serde::de::Error>::custom)
162        } else {
163            let (address, mask) = serde::Deserialize::deserialize(deserializer)?;
164            let net_address = Self { address, mask };
165            Ok(net_address)
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn parse_mask() {
176        for i in 0..=64 {
177            let _: NetMask = i.to_string().parse().expect("parse mask");
178        }
179        for i in 65..128 {
180            assert!(NetMask::from_str(&i.to_string()).is_err());
181        }
182    }
183
184    #[test]
185    fn net_address_parse() {
186        for i in 0..=64 {
187            let address = Address::from_u64(i);
188            let mask = i.to_string().parse().unwrap();
189            eprintln!("{}", NetAddress { address, mask });
190        }
191    }
192
193    #[test]
194    fn mask_to_u64() {
195        for (input, expected) in [
196            (64, u64::MAX),
197            (60, 0xFFFF_FFFF_FFFF_FFF0),
198            (56, 0xFFFF_FFFF_FFFF_FF00),
199            (52, 0xFFFF_FFFF_FFFF_F000),
200            (48, 0xFFFF_FFFF_FFFF_0000),
201            (44, 0xFFFF_FFFF_FFF0_0000),
202            (40, 0xFFFF_FFFF_FF00_0000),
203            (36, 0xFFFF_FFFF_F000_0000),
204            (32, 0xFFFF_FFFF_0000_0000),
205            (28, 0xFFFF_FFF0_0000_0000),
206            (24, 0xFFFF_FF00_0000_0000),
207            (20, 0xFFFF_F000_0000_0000),
208            (16, 0xFFFF_0000_0000_0000),
209            (12, 0xFFF0_0000_0000_0000),
210            (8, 0xFF00_0000_0000_0000),
211            (4, 0xF000_0000_0000_0000),
212            (0, 0x0000_0000_0000_0000),
213        ] {
214            assert_eq!(NetMask(input).into_u64(), expected, "/{input}");
215        }
216    }
217
218    // Regression tests for #137: deserialization must reject out-of-range masks
219    // instead of accepting them and panicking later.
220    #[test]
221    fn netmask_deserialize_rejects_out_of_range() {
222        assert!(serde_json::from_str::<NetMask>("200").is_err());
223        assert!(serde_json::from_str::<NetMask>("65").is_err());
224        // Valid masks still deserialize.
225        assert_eq!(serde_json::from_str::<NetMask>("0").unwrap(), NetMask::MIN);
226        assert_eq!(serde_json::from_str::<NetMask>("64").unwrap(), NetMask::MAX);
227    }
228
229    #[test]
230    fn net_address_wire_deserialize_rejects_out_of_range_mask() {
231        // The non-human-readable path reads (Address, NetMask) as a tuple; an
232        // out-of-range mask on the wire must be rejected there too.
233        let bytes = rmp_serde::to_vec(&(Address::from_u64(1), 200u8)).unwrap();
234        assert!(rmp_serde::from_slice::<NetAddress>(&bytes).is_err());
235    }
236}