pingap_util/ip.rs
1// Copyright 2024-2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use ahash::AHashSet;
17use ipnet::IpNet;
18use std::net::{AddrParseError, IpAddr};
19use std::str::FromStr;
20
21/// IpRules stores pre-parsed IP addresses and networks for efficient access control.
22#[derive(Clone, Debug)]
23pub struct IpRules {
24 ip_net_list: Vec<IpNet>,
25 // Use a HashSet for O(1) average time complexity for individual IP lookups.
26 ip_set: AHashSet<IpAddr>,
27}
28
29impl IpRules {
30 /// Creates a new IpRules instance from a list of IP addresses and/or CIDR networks.
31 ///
32 /// The input values are parsed and stored in optimized data structures for fast lookups.
33 /// Invalid entries are ignored and a warning is logged.
34 pub fn new<T: AsRef<str>>(values: &[T]) -> Self {
35 let mut ip_net_list = vec![];
36 let mut ip_set = AHashSet::new();
37
38 for item in values {
39 let item_str = item.as_ref();
40 // Try parsing as a CIDR network first.
41 if let Ok(value) = IpNet::from_str(item_str) {
42 ip_net_list.push(value);
43 // If not a network, try parsing as a single IP address.
44 } else if let Ok(value) = IpAddr::from_str(item_str) {
45 ip_set.insert(value);
46 } else {
47 // If it's neither, warn about the invalid entry.
48 }
49 }
50 Self {
51 ip_net_list,
52 ip_set,
53 }
54 }
55
56 /// Checks if a given IP address matches any of the stored rules.
57 ///
58 /// This is the primary method for checking access. It parses the string
59 /// and then performs the efficient matching logic.
60 pub fn is_match(&self, ip: &str) -> Result<bool, AddrParseError> {
61 let addr = ip.parse::<IpAddr>()?;
62 Ok(self.is_match_addr(&addr))
63 }
64
65 /// A more performant version of `is_match` that accepts a pre-parsed `IpAddr`.
66 ///
67 /// This allows callers to avoid re-parsing the IP address if they already
68 /// have it in `IpAddr` form.
69 pub fn is_match_addr(&self, ip_addr: &IpAddr) -> bool {
70 // First, perform a fast O(1) lookup in the HashSet.
71 if self.ip_set.contains(ip_addr) {
72 return true;
73 }
74 // If not found, iterate through the network ranges.
75 self.ip_net_list.iter().any(|net| net.contains(ip_addr))
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use pretty_assertions::assert_eq;
83
84 #[test]
85 fn test_ip_rules() {
86 let rules = IpRules::new(&[
87 "192.168.1.0/24", // A network
88 "10.0.0.1", // A single IP
89 "2001:db8::/32", // An IPv6 network
90 "2001:db8:a::1", // A single IPv6
91 "not-an-ip", // An invalid entry that should be ignored
92 ]);
93
94 // Check that the constructor correctly parsed and stored the rules.
95 assert_eq!(rules.ip_net_list.len(), 2);
96 assert_eq!(rules.ip_set.len(), 2);
97
98 // --- Test is_match_addr for performance-critical paths ---
99 let ip_in_net_v4 = "192.168.1.100".parse().unwrap();
100 let exact_ip_v4 = "10.0.0.1".parse().unwrap();
101 let outside_ip_v4 = "192.168.2.1".parse().unwrap();
102
103 let ip_in_net_v6 = "2001:db8:dead:beef::1".parse().unwrap();
104 let exact_ip_v6 = "2001:db8:a::1".parse().unwrap();
105 let outside_ip_v6 = "2001:db9::1".parse().unwrap();
106
107 assert!(rules.is_match_addr(&ip_in_net_v4));
108 assert!(rules.is_match_addr(&exact_ip_v4));
109 assert!(!rules.is_match_addr(&outside_ip_v4));
110
111 assert!(rules.is_match_addr(&ip_in_net_v6));
112 assert!(rules.is_match_addr(&exact_ip_v6));
113 assert!(!rules.is_match_addr(&outside_ip_v6));
114
115 // --- Test is_match for user-facing convenience ---
116 assert_eq!(rules.is_match("192.168.1.1"), Ok(true));
117 assert_eq!(rules.is_match("10.0.0.1"), Ok(true));
118 assert_eq!(rules.is_match("192.168.3.1"), Ok(false));
119 // Test invalid IP string input for is_match
120 assert!(rules.is_match("999.999.999.999").is_err());
121 }
122}