Skip to main content

pingap_util/
ip.rs

1// Copyright 2024-2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use ahash::AHashSet;
17use ipnet::IpNet;
18use std::net::{AddrParseError, IpAddr};
19use std::str::FromStr;
20
21/// IpRules stores pre-parsed IP addresses and networks for efficient access control.
22#[derive(Clone, Debug)]
23pub struct IpRules {
24    ip_net_list: Vec<IpNet>,
25    // Use a HashSet for O(1) average time complexity for individual IP lookups.
26    ip_set: AHashSet<IpAddr>,
27}
28
29impl IpRules {
30    /// Creates a new IpRules instance from a list of IP addresses and/or CIDR networks.
31    ///
32    /// The input values are parsed and stored in optimized data structures for fast lookups.
33    /// Invalid entries are ignored and a warning is logged.
34    pub fn new<T: AsRef<str>>(values: &[T]) -> Self {
35        let mut ip_net_list = vec![];
36        let mut ip_set = AHashSet::new();
37
38        for item in values {
39            let item_str = item.as_ref();
40            // Try parsing as a CIDR network first.
41            if let Ok(value) = IpNet::from_str(item_str) {
42                ip_net_list.push(value);
43            // If not a network, try parsing as a single IP address.
44            } else if let Ok(value) = IpAddr::from_str(item_str) {
45                ip_set.insert(value);
46            } else {
47                // If it's neither, warn about the invalid entry.
48            }
49        }
50        Self {
51            ip_net_list,
52            ip_set,
53        }
54    }
55
56    /// Checks if a given IP address matches any of the stored rules.
57    ///
58    /// This is the primary method for checking access. It parses the string
59    /// and then performs the efficient matching logic.
60    pub fn is_match(&self, ip: &str) -> Result<bool, AddrParseError> {
61        let addr = ip.parse::<IpAddr>()?;
62        Ok(self.is_match_addr(&addr))
63    }
64
65    /// A more performant version of `is_match` that accepts a pre-parsed `IpAddr`.
66    ///
67    /// This allows callers to avoid re-parsing the IP address if they already
68    /// have it in `IpAddr` form.
69    pub fn is_match_addr(&self, ip_addr: &IpAddr) -> bool {
70        // First, perform a fast O(1) lookup in the HashSet.
71        if self.ip_set.contains(ip_addr) {
72            return true;
73        }
74        // If not found, iterate through the network ranges.
75        self.ip_net_list.iter().any(|net| net.contains(ip_addr))
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use pretty_assertions::assert_eq;
83
84    #[test]
85    fn test_ip_rules() {
86        let rules = IpRules::new(&[
87            "192.168.1.0/24", // A network
88            "10.0.0.1",       // A single IP
89            "2001:db8::/32",  // An IPv6 network
90            "2001:db8:a::1",  // A single IPv6
91            "not-an-ip",      // An invalid entry that should be ignored
92        ]);
93
94        // Check that the constructor correctly parsed and stored the rules.
95        assert_eq!(rules.ip_net_list.len(), 2);
96        assert_eq!(rules.ip_set.len(), 2);
97
98        // --- Test is_match_addr for performance-critical paths ---
99        let ip_in_net_v4 = "192.168.1.100".parse().unwrap();
100        let exact_ip_v4 = "10.0.0.1".parse().unwrap();
101        let outside_ip_v4 = "192.168.2.1".parse().unwrap();
102
103        let ip_in_net_v6 = "2001:db8:dead:beef::1".parse().unwrap();
104        let exact_ip_v6 = "2001:db8:a::1".parse().unwrap();
105        let outside_ip_v6 = "2001:db9::1".parse().unwrap();
106
107        assert!(rules.is_match_addr(&ip_in_net_v4));
108        assert!(rules.is_match_addr(&exact_ip_v4));
109        assert!(!rules.is_match_addr(&outside_ip_v4));
110
111        assert!(rules.is_match_addr(&ip_in_net_v6));
112        assert!(rules.is_match_addr(&exact_ip_v6));
113        assert!(!rules.is_match_addr(&outside_ip_v6));
114
115        // --- Test is_match for user-facing convenience ---
116        assert_eq!(rules.is_match("192.168.1.1"), Ok(true));
117        assert_eq!(rules.is_match("10.0.0.1"), Ok(true));
118        assert_eq!(rules.is_match("192.168.3.1"), Ok(false));
119        // Test invalid IP string input for is_match
120        assert!(rules.is_match("999.999.999.999").is_err());
121    }
122}