1use std::fmt;
2use std::str::FromStr;
3
4use mm1_proto::message;
5
6use crate::address::{Address, AddressParseError};
7
8mod net_mask;
9
10const ADDRESS_BITS: u8 = u64::BITS as u8;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
13#[message(base_path = ::mm1_proto)]
14pub struct NetMask(u8);
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
17#[message(base_path = ::mm1_proto, derive_serialize = false, derive_deserialize = false)]
18pub struct NetAddress {
19 pub address: Address,
20 pub mask: NetMask,
21}
22
23#[derive(Debug, thiserror::Error)]
24pub enum MaskParseError {
25 #[error("invalid mask")]
26 InvalidMask(InvalidMask),
27 #[error("parse int error")]
28 ParseIntError(<u64 as FromStr>::Err),
29}
30
31#[derive(Debug, thiserror::Error)]
32#[error("invalid mask: {}", _0)]
33#[message(base_path = ::mm1_proto)]
34pub struct InvalidMask(u8);
35
36#[derive(Debug, thiserror::Error)]
37pub enum NetAddressParseError {
38 #[error("no slash")]
39 NoSlash,
40 #[error("parse addr error")]
41 ParseAddrError(AddressParseError),
42 #[error("parse mask error")]
43 ParseMaskError(MaskParseError),
44}
45
46impl NetMask {
47 pub fn bits_fixed(&self) -> u8 {
48 self.0
49 }
50
51 pub fn bits_available(&self) -> u8 {
52 ADDRESS_BITS - self.bits_fixed()
53 }
54
55 pub fn into_u64(self) -> u64 {
56 match self.0 as u32 {
57 0 => 0u64,
58 zero_to_63 => u64::MAX << (u64::BITS - zero_to_63),
59 }
60 }
61}
62
63impl From<Address> for NetAddress {
64 fn from(address: Address) -> Self {
65 (address, NetMask::M_64).into()
66 }
67}
68
69impl From<(Address, NetMask)> for NetAddress {
70 fn from((address, mask): (Address, NetMask)) -> Self {
71 Self { address, mask }
72 }
73}
74
75impl TryFrom<u8> for NetMask {
76 type Error = InvalidMask;
77
78 fn try_from(value: u8) -> Result<Self, Self::Error> {
79 match value {
80 0..=ADDRESS_BITS => Ok(Self(value)),
81 invalid => Err(InvalidMask(invalid)),
82 }
83 }
84}
85
86impl From<NetMask> for u8 {
87 fn from(value: NetMask) -> Self {
88 value.0
89 }
90}
91
92impl FromStr for NetMask {
93 type Err = MaskParseError;
94
95 fn from_str(s: &str) -> Result<Self, Self::Err> {
96 s.parse::<u8>()
97 .map_err(MaskParseError::ParseIntError)?
98 .try_into()
99 .map_err(MaskParseError::InvalidMask)
100 }
101}
102
103impl FromStr for NetAddress {
104 type Err = NetAddressParseError;
105
106 fn from_str(s: &str) -> Result<Self, Self::Err> {
107 let (addr, mask) = s.split_once('/').ok_or(NetAddressParseError::NoSlash)?;
108 let address = addr.parse().map_err(NetAddressParseError::ParseAddrError)?;
109 let mask = mask.parse().map_err(NetAddressParseError::ParseMaskError)?;
110
111 Ok(Self { address, mask })
112 }
113}
114
115impl fmt::Display for NetMask {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 fmt::Display::fmt(&self.0, f)
118 }
119}
120
121impl fmt::Display for NetAddress {
122 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
123 write!(f, "{}/{}", self.address, self.mask)
124 }
125}
126
127impl serde::Serialize for NetAddress {
128 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
129 where
130 S: serde::Serializer,
131 {
132 self.to_string().serialize(serializer)
133 }
134}
135
136impl<'de> serde::Deserialize<'de> for NetAddress {
137 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
138 where
139 D: serde::Deserializer<'de>,
140 {
141 String::deserialize(deserializer)?
142 .parse()
143 .map_err(<D::Error as serde::de::Error>::custom)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn parse_mask() {
153 for i in 0..=64 {
154 let _: NetMask = i.to_string().parse().expect("parse mask");
155 }
156 for i in 65..128 {
157 assert!(NetMask::from_str(&i.to_string()).is_err());
158 }
159 }
160
161 #[test]
162 fn net_address_parse() {
163 for i in 0..=64 {
164 let address = Address::from_u64(i);
165 let mask = i.to_string().parse().unwrap();
166 eprintln!("{}", NetAddress { address, mask });
167 }
168 }
169
170 #[test]
171 fn mask_to_u64() {
172 for (input, expected) in [
173 (64, u64::MAX),
174 (60, 0xFFFF_FFFF_FFFF_FFF0),
175 (56, 0xFFFF_FFFF_FFFF_FF00),
176 (52, 0xFFFF_FFFF_FFFF_F000),
177 (48, 0xFFFF_FFFF_FFFF_0000),
178 (44, 0xFFFF_FFFF_FFF0_0000),
179 (40, 0xFFFF_FFFF_FF00_0000),
180 (36, 0xFFFF_FFFF_F000_0000),
181 (32, 0xFFFF_FFFF_0000_0000),
182 (28, 0xFFFF_FFF0_0000_0000),
183 (24, 0xFFFF_FF00_0000_0000),
184 (20, 0xFFFF_F000_0000_0000),
185 (16, 0xFFFF_0000_0000_0000),
186 (12, 0xFFF0_0000_0000_0000),
187 (8, 0xFF00_0000_0000_0000),
188 (4, 0xF000_0000_0000_0000),
189 (0, 0x0000_0000_0000_0000),
190 ] {
191 assert_eq!(NetMask(input).into_u64(), expected, "/{input}");
192 }
193 }
194}