Skip to main content

rfham_geo/geoip/providers/
local.rs

1//! IPv4 CIDR-range type for local geo-IP lookup tables.
2//!
3//! [`IpNetwork`] represents an IPv4 network in CIDR notation (`address/prefix`).
4//! It can be constructed from a CIDR string, from an `(address, prefix_length)` pair,
5//! or from an `(address, netmask)` pair.
6//!
7//! # Examples
8//!
9//! ```rust
10//! use rfham_geo::geoip::providers::local::IpNetwork;
11//! use std::{net::Ipv4Addr, str::FromStr};
12//!
13//! let net: IpNetwork = "192.168.1.0/24".parse().unwrap();
14//! assert_eq!(net.prefix_length(), 24);
15//! assert!(net.contains(Ipv4Addr::from_str("192.168.1.42").unwrap()));
16//! assert!(!net.contains(Ipv4Addr::from_str("192.168.2.1").unwrap()));
17//! assert_eq!(net.to_string(), "192.168.1.0/24");
18//! ```
19
20use rfham_core::error::CoreError;
21use serde_with::{DeserializeFromStr, SerializeDisplay};
22use std::{fmt::Display, net::Ipv4Addr, str::FromStr};
23
24// ------------------------------------------------------------------------------------------------
25// Public Macros
26// ------------------------------------------------------------------------------------------------
27
28// ------------------------------------------------------------------------------------------------
29// Public Types
30// ------------------------------------------------------------------------------------------------
31
32#[derive(Clone, Debug, PartialEq, Eq, DeserializeFromStr, SerializeDisplay)]
33pub struct IpNetwork {
34    address: u32,
35    mask: u32,
36}
37
38// ------------------------------------------------------------------------------------------------
39// Public Functions
40// ------------------------------------------------------------------------------------------------
41
42// ------------------------------------------------------------------------------------------------
43// Private Macros
44// ------------------------------------------------------------------------------------------------
45
46// ------------------------------------------------------------------------------------------------
47// Private Types
48// ------------------------------------------------------------------------------------------------
49
50// ------------------------------------------------------------------------------------------------
51// Implementations
52// ------------------------------------------------------------------------------------------------
53
54impl Display for IpNetwork {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        write!(f, "{}/{}", self.address(), self.prefix_length())
57    }
58}
59
60impl FromStr for IpNetwork {
61    type Err = CoreError;
62
63    fn from_str(s: &str) -> Result<Self, Self::Err> {
64        let pair: Vec<&str> = s.split('/').collect::<Vec<_>>();
65        if pair.len() == 2 {
66            let address = Ipv4Addr::from_str(pair[0])
67                .map_err(|_| CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))?;
68            let prefix_length = u8::from_str(pair[1])
69                .map_err(|_| CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))?;
70            if prefix_length > 32 {
71                return Err(CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"));
72            }
73            Ok(IpNetwork::from_cidr(address, prefix_length))
74        } else {
75            Err(CoreError::InvalidValueFromStr(s.to_string(), "IpNetwork"))
76        }
77    }
78}
79
80impl IpNetwork {
81    pub fn from_cidr(address: Ipv4Addr, prefix_length: u8) -> Self {
82        assert!(prefix_length <= 32);
83        // Network mask: prefix_length leading 1-bits followed by trailing 0-bits.
84        // Shifts by 32 overflow for u32, so /0 and /32 are special-cased.
85        let mask = match prefix_length {
86            0 => 0u32,
87            32 => u32::MAX,
88            n => !(u32::MAX >> n),
89        };
90        Self {
91            address: address.to_bits(),
92            mask,
93        }
94    }
95
96    pub fn from_mask(address: Ipv4Addr, net_mask: Ipv4Addr) -> Self {
97        Self {
98            address: address.to_bits(),
99            mask: net_mask.to_bits(),
100        }
101    }
102
103    pub fn address(&self) -> Ipv4Addr {
104        self.address.into()
105    }
106
107    pub fn address_u32(&self) -> u32 {
108        self.address
109    }
110
111    pub fn mask(&self) -> Ipv4Addr {
112        self.mask.into()
113    }
114
115    pub fn mask_u32(&self) -> u32 {
116        self.mask
117    }
118
119    pub fn prefix_length(&self) -> u8 {
120        self.mask_u32().leading_ones() as u8
121    }
122
123    pub fn contains(&self, address: Ipv4Addr) -> bool {
124        (self.address & self.mask) == (address.to_bits() & self.mask)
125    }
126}
127
128// ------------------------------------------------------------------------------------------------
129// Private Functions
130// ------------------------------------------------------------------------------------------------
131
132// ------------------------------------------------------------------------------------------------
133// Sub-Modules
134// ------------------------------------------------------------------------------------------------
135
136// ------------------------------------------------------------------------------------------------
137// Unit Tests
138// ------------------------------------------------------------------------------------------------
139
140#[cfg(test)]
141mod tests {
142    use super::IpNetwork;
143    use pretty_assertions::assert_eq;
144    use std::{net::Ipv4Addr, str::FromStr};
145
146    #[test]
147    fn test_cidr_parse_and_display() {
148        let net: IpNetwork = "192.168.1.0/24".parse().unwrap();
149        assert_eq!(net.to_string(), "192.168.1.0/24");
150        assert_eq!(net.prefix_length(), 24);
151    }
152
153    #[test]
154    fn test_cidr_contains() {
155        let net: IpNetwork = "10.0.0.0/8".parse().unwrap();
156        assert!(net.contains(Ipv4Addr::from_str("10.1.2.3").unwrap()));
157        assert!(net.contains(Ipv4Addr::from_str("10.255.255.255").unwrap()));
158        assert!(!net.contains(Ipv4Addr::from_str("11.0.0.1").unwrap()));
159    }
160
161    #[test]
162    fn test_cidr_host_route() {
163        let net: IpNetwork = "203.0.113.5/32".parse().unwrap();
164        assert_eq!(net.prefix_length(), 32);
165        assert!(net.contains(Ipv4Addr::from_str("203.0.113.5").unwrap()));
166        assert!(!net.contains(Ipv4Addr::from_str("203.0.113.6").unwrap()));
167    }
168
169    #[test]
170    fn test_from_mask() {
171        let addr = Ipv4Addr::from_str("192.168.0.0").unwrap();
172        let mask = Ipv4Addr::from_str("255.255.255.0").unwrap();
173        let net = IpNetwork::from_mask(addr, mask);
174        assert_eq!(net.prefix_length(), 24);
175        assert!(net.contains(Ipv4Addr::from_str("192.168.0.99").unwrap()));
176    }
177
178    #[test]
179    fn test_invalid_cidr_returns_error() {
180        assert!("notanip/24".parse::<IpNetwork>().is_err());
181        assert!("192.168.1.0/33".parse::<IpNetwork>().is_err()); // prefix > 32 panics in from_cidr
182        assert!("192.168.1.0".parse::<IpNetwork>().is_err()); // no prefix
183    }
184}