1use std::fmt;
2use std::str::FromStr;
3
4use mm1_proto::message;
5
6use crate::address::{Address, AddressParseError};
7
8mod net_mask;
9
10const ADDRESS_BITS: u8 = u64::BITS as u8;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
13#[message(base_path = ::mm1_proto, derive_deserialize = false)]
14pub struct NetMask(u8);
15
16impl<'de> serde::Deserialize<'de> for NetMask {
17 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
18 where
19 D: serde::Deserializer<'de>,
20 {
21 let value = u8::deserialize(deserializer)?;
24 NetMask::try_from(value).map_err(<D::Error as serde::de::Error>::custom)
25 }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
29#[message(base_path = ::mm1_proto, derive_serialize = false, derive_deserialize = false)]
30pub struct NetAddress {
31 pub address: Address,
32 pub mask: NetMask,
33}
34
35#[derive(Debug, thiserror::Error)]
36pub enum MaskParseError {
37 #[error("invalid mask")]
38 InvalidMask(InvalidMask),
39 #[error("parse int error")]
40 ParseIntError(<u64 as FromStr>::Err),
41}
42
43#[derive(Debug, thiserror::Error)]
44#[error("invalid mask: {}", _0)]
45#[message(base_path = ::mm1_proto)]
46pub struct InvalidMask(u8);
47
48#[derive(Debug, thiserror::Error)]
49pub enum NetAddressParseError {
50 #[error("no slash")]
51 NoSlash,
52 #[error("parse addr error")]
53 ParseAddrError(AddressParseError),
54 #[error("parse mask error")]
55 ParseMaskError(MaskParseError),
56}
57
58impl NetMask {
59 pub fn bits_fixed(&self) -> u8 {
60 self.0
61 }
62
63 pub fn bits_available(&self) -> u8 {
64 ADDRESS_BITS - self.bits_fixed()
65 }
66
67 pub fn into_u64(self) -> u64 {
68 match self.0 as u32 {
69 0 => 0u64,
70 zero_to_63 => u64::MAX << (u64::BITS - zero_to_63),
71 }
72 }
73}
74
75impl From<Address> for NetAddress {
76 fn from(address: Address) -> Self {
77 (address, NetMask::MAX).into()
78 }
79}
80
81impl From<(Address, NetMask)> for NetAddress {
82 fn from((address, mask): (Address, NetMask)) -> Self {
83 Self { address, mask }
84 }
85}
86
87impl TryFrom<u8> for NetMask {
88 type Error = InvalidMask;
89
90 fn try_from(value: u8) -> Result<Self, Self::Error> {
91 match value {
92 0..=ADDRESS_BITS => Ok(Self(value)),
93 invalid => Err(InvalidMask(invalid)),
94 }
95 }
96}
97
98impl From<NetMask> for u8 {
99 fn from(value: NetMask) -> Self {
100 value.0
101 }
102}
103
104impl FromStr for NetMask {
105 type Err = MaskParseError;
106
107 fn from_str(s: &str) -> Result<Self, Self::Err> {
108 s.parse::<u8>()
109 .map_err(MaskParseError::ParseIntError)?
110 .try_into()
111 .map_err(MaskParseError::InvalidMask)
112 }
113}
114
115impl FromStr for NetAddress {
116 type Err = NetAddressParseError;
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 let (addr, mask) = s.split_once('/').ok_or(NetAddressParseError::NoSlash)?;
120 let address = addr.parse().map_err(NetAddressParseError::ParseAddrError)?;
121 let mask = mask.parse().map_err(NetAddressParseError::ParseMaskError)?;
122
123 Ok(Self { address, mask })
124 }
125}
126
127impl fmt::Display for NetMask {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 fmt::Display::fmt(&self.0, f)
130 }
131}
132
133impl fmt::Display for NetAddress {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 write!(f, "{}/{}", self.address, self.mask)
136 }
137}
138
139impl serde::Serialize for NetAddress {
140 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
141 where
142 S: serde::Serializer,
143 {
144 if serializer.is_human_readable() {
145 self.to_string().serialize(serializer)
146 } else {
147 let Self { address, mask } = self;
148 (address, mask).serialize(serializer)
149 }
150 }
151}
152
153impl<'de> serde::Deserialize<'de> for NetAddress {
154 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
155 where
156 D: serde::Deserializer<'de>,
157 {
158 if deserializer.is_human_readable() {
159 String::deserialize(deserializer)?
160 .parse()
161 .map_err(<D::Error as serde::de::Error>::custom)
162 } else {
163 let (address, mask) = serde::Deserialize::deserialize(deserializer)?;
164 let net_address = Self { address, mask };
165 Ok(net_address)
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[test]
175 fn parse_mask() {
176 for i in 0..=64 {
177 let _: NetMask = i.to_string().parse().expect("parse mask");
178 }
179 for i in 65..128 {
180 assert!(NetMask::from_str(&i.to_string()).is_err());
181 }
182 }
183
184 #[test]
185 fn net_address_parse() {
186 for i in 0..=64 {
187 let address = Address::from_u64(i);
188 let mask = i.to_string().parse().unwrap();
189 eprintln!("{}", NetAddress { address, mask });
190 }
191 }
192
193 #[test]
194 fn mask_to_u64() {
195 for (input, expected) in [
196 (64, u64::MAX),
197 (60, 0xFFFF_FFFF_FFFF_FFF0),
198 (56, 0xFFFF_FFFF_FFFF_FF00),
199 (52, 0xFFFF_FFFF_FFFF_F000),
200 (48, 0xFFFF_FFFF_FFFF_0000),
201 (44, 0xFFFF_FFFF_FFF0_0000),
202 (40, 0xFFFF_FFFF_FF00_0000),
203 (36, 0xFFFF_FFFF_F000_0000),
204 (32, 0xFFFF_FFFF_0000_0000),
205 (28, 0xFFFF_FFF0_0000_0000),
206 (24, 0xFFFF_FF00_0000_0000),
207 (20, 0xFFFF_F000_0000_0000),
208 (16, 0xFFFF_0000_0000_0000),
209 (12, 0xFFF0_0000_0000_0000),
210 (8, 0xFF00_0000_0000_0000),
211 (4, 0xF000_0000_0000_0000),
212 (0, 0x0000_0000_0000_0000),
213 ] {
214 assert_eq!(NetMask(input).into_u64(), expected, "/{input}");
215 }
216 }
217
218 #[test]
221 fn netmask_deserialize_rejects_out_of_range() {
222 assert!(serde_json::from_str::<NetMask>("200").is_err());
223 assert!(serde_json::from_str::<NetMask>("65").is_err());
224 assert_eq!(serde_json::from_str::<NetMask>("0").unwrap(), NetMask::MIN);
226 assert_eq!(serde_json::from_str::<NetMask>("64").unwrap(), NetMask::MAX);
227 }
228
229 #[test]
230 fn net_address_wire_deserialize_rejects_out_of_range_mask() {
231 let bytes = rmp_serde::to_vec(&(Address::from_u64(1), 200u8)).unwrap();
234 assert!(rmp_serde::from_slice::<NetAddress>(&bytes).is_err());
235 }
236}