rfham_geo/geoip/providers/
local.rs1use rfham_core::error::CoreError;
21use serde_with::{DeserializeFromStr, SerializeDisplay};
22use std::{fmt::Display, net::Ipv4Addr, str::FromStr};
23
24#[derive(Clone, Debug, PartialEq, Eq, DeserializeFromStr, SerializeDisplay)]
33pub struct IpNetwork {
34 address: u32,
35 mask: u32,
36}
37
38impl Display for IpNetwork {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 write!(f, "{}/{}", self.address(), self.prefix_length())
57 }
58}
59
60impl FromStr for IpNetwork {
61 type Err = CoreError;
62
63 fn from_str(s: &str) -> Result<Self, Self::Err> {
64 let pair: Vec<&str> = s.split('/').collect::<Vec<_>>();
65 if pair.len() == 2 {
66 let address = Ipv4Addr::from_str(pair[0])
67 .map_err(|_| CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))?;
68 let prefix_length = u8::from_str(pair[1])
69 .map_err(|_| CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))?;
70 if prefix_length > 32 {
71 return Err(CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"));
72 }
73 Ok(IpNetwork::from_cidr(address, prefix_length))
74 } else {
75 Err(CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))
76 }
77 }
78}
79
80impl IpNetwork {
81 pub fn from_cidr(address: Ipv4Addr, prefix_length: u8) -> Self {
82 assert!(prefix_length <= 32);
83 let mask = match prefix_length {
86 0 => 0u32,
87 32 => u32::MAX,
88 n => !(u32::MAX >> n),
89 };
90 Self {
91 address: address.to_bits(),
92 mask,
93 }
94 }
95
96 pub fn from_mask(address: Ipv4Addr, net_mask: Ipv4Addr) -> Self {
97 Self {
98 address: address.to_bits(),
99 mask: net_mask.to_bits(),
100 }
101 }
102
103 pub fn address(&self) -> Ipv4Addr {
104 self.address.into()
105 }
106
107 pub fn address_u32(&self) -> u32 {
108 self.address
109 }
110
111 pub fn mask(&self) -> Ipv4Addr {
112 self.mask.into()
113 }
114
115 pub fn mask_u32(&self) -> u32 {
116 self.mask
117 }
118
119 pub fn prefix_length(&self) -> u8 {
120 self.mask_u32().leading_ones() as u8
121 }
122
123 pub fn contains(&self, address: Ipv4Addr) -> bool {
124 (self.address & self.mask) == (address.to_bits() & self.mask)
125 }
126}
127
128#[cfg(test)]
141mod tests {
142 use super::IpNetwork;
143 use pretty_assertions::assert_eq;
144 use std::{net::Ipv4Addr, str::FromStr};
145
146 #[test]
147 fn test_cidr_parse_and_display() {
148 let net: IpNetwork = "192.168.1.0/24".parse().unwrap();
149 assert_eq!(net.to_string(), "192.168.1.0/24");
150 assert_eq!(net.prefix_length(), 24);
151 }
152
153 #[test]
154 fn test_cidr_contains() {
155 let net: IpNetwork = "10.0.0.0/8".parse().unwrap();
156 assert!(net.contains(Ipv4Addr::from_str("10.1.2.3").unwrap()));
157 assert!(net.contains(Ipv4Addr::from_str("10.255.255.255").unwrap()));
158 assert!(!net.contains(Ipv4Addr::from_str("11.0.0.1").unwrap()));
159 }
160
161 #[test]
162 fn test_cidr_host_route() {
163 let net: IpNetwork = "203.0.113.5/32".parse().unwrap();
164 assert_eq!(net.prefix_length(), 32);
165 assert!(net.contains(Ipv4Addr::from_str("203.0.113.5").unwrap()));
166 assert!(!net.contains(Ipv4Addr::from_str("203.0.113.6").unwrap()));
167 }
168
169 #[test]
170 fn test_from_mask() {
171 let addr = Ipv4Addr::from_str("192.168.0.0").unwrap();
172 let mask = Ipv4Addr::from_str("255.255.255.0").unwrap();
173 let net = IpNetwork::from_mask(addr, mask);
174 assert_eq!(net.prefix_length(), 24);
175 assert!(net.contains(Ipv4Addr::from_str("192.168.0.99").unwrap()));
176 }
177
178 #[test]
179 fn test_invalid_cidr_returns_error() {
180 assert!("notanip/24".parse::<IpNetwork>().is_err());
181 assert!("192.168.1.0/33".parse::<IpNetwork>().is_err()); assert!("192.168.1.0".parse::<IpNetwork>().is_err()); }
184}