firewall_objects/ip/
network.rs1use ipnet::IpNet;
4use std::cmp::Ordering;
5use std::collections::BTreeSet;
6use std::fmt;
7use std::hash::{Hash, Hasher};
8use std::net::IpAddr;
9use std::str::FromStr;
10
11use crate::ip::fqdn::Fqdn;
12use crate::ip::range::IpRange;
13
14#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
15#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
16pub enum Network {
17 Host(IpAddr),
18 Network(IpNet),
19 Range(IpRange),
20 Fqdn(Fqdn),
21}
22
23impl FromStr for Network {
24 type Err = String;
25
26 fn from_str(s: &str) -> Result<Self, Self::Err> {
36 Network::new(s)
37 }
38}
39
40impl fmt::Display for Network {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 Network::Host(ip) => write!(f, "{}", ip),
44 Network::Network(net) => write!(f, "{}", net),
45 Network::Range(r) => write!(f, "{}", r),
46 Network::Fqdn(d) => write!(f, "{}", d),
47 }
48 }
49}
50
51impl Network {
52 pub fn new(input: &str) -> Result<Self, String> {
97 let s = input.trim();
98
99 if s.is_empty() {
100 return Err("network value cannot be empty".into());
101 }
102
103 if let Ok(ip) = s.parse::<IpAddr>() {
105 return Ok(Network::Host(ip));
106 }
107
108 if let Ok(net) = s.parse::<IpNet>() {
110 return Ok(Network::Network(net));
111 }
112
113 if s.contains('-')
115 && let Ok(range) = IpRange::parse(s)
116 {
117 return Ok(Network::Range(range));
118 }
119
120 if let Ok(fqdn) = Fqdn::new(s) {
122 return Ok(Network::Fqdn(fqdn));
123 }
124
125 Err(format!("invalid network value: {input}"))
126 }
127}
128
129impl fmt::Display for NetworkObj {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 write!(f, "{}={}", self.name, self.value)
132 }
133}
134
135#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[derive(Debug, Clone)]
150pub struct NetworkObj {
151 pub name: String,
152 pub value: Network,
153}
154
155impl PartialEq for NetworkObj {
156 fn eq(&self, other: &Self) -> bool {
157 self.name.eq(&other.name)
158 }
159}
160
161impl Eq for NetworkObj {}
162
163impl PartialOrd for NetworkObj {
164 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
165 Some(self.cmp(other))
166 }
167}
168
169impl Ord for NetworkObj {
170 fn cmp(&self, other: &Self) -> Ordering {
171 self.name.cmp(&other.name)
172 }
173}
174
175impl Hash for NetworkObj {
176 fn hash<H: Hasher>(&self, state: &mut H) {
177 self.name.hash(state);
178 }
179}
180
181impl fmt::Display for NetworkObjGroup {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 writeln!(f, "{}:", self.name)?;
184 for m in &self.value {
185 writeln!(f, " {}", m)?;
186 }
187 Ok(())
188 }
189}
190
191impl NetworkObj {
192 pub fn new(name: String, value: Network) -> Self {
193 Self { name, value }
194 }
195}
196
197impl TryFrom<(&str, &str)> for NetworkObj {
198 type Error = String;
199
200 fn try_from(v: (&str, &str)) -> Result<Self, Self::Error> {
201 let (name, value) = v;
202 Ok(Self::new(name.to_string(), Network::new(value)?))
203 }
204}
205
206#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
207#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
208pub struct NetworkObjGroup {
209 pub name: String,
210 pub value: BTreeSet<NetworkObj>,
211}
212
213impl NetworkObjGroup {
214 pub fn new(name: &str, value: BTreeSet<NetworkObj>) -> Result<Self, String> {
216 let name = name.trim();
217
218 if name.is_empty() {
219 return Err("group name cannot be empty".into());
220 }
221
222 Ok(Self {
223 name: name.to_string(),
224 value,
225 })
226 }
227
228 pub fn add(&mut self, obj: NetworkObj) -> Result<(), String> {
230 if self.value.iter().any(|existing| existing.name == obj.name) {
231 return Err(format!(
232 "network object name '{}' already exists in group '{}'",
233 obj.name, self.name
234 ));
235 }
236
237 self.value.insert(obj);
238 Ok(())
239 }
240
241 pub fn remove(&mut self, obj: &NetworkObj) -> bool {
243 self.value.remove(obj)
244 }
245
246 pub fn len(&self) -> usize {
248 self.value.len()
249 }
250
251 pub fn is_empty(&self) -> bool {
253 self.value.is_empty()
254 }
255
256 pub fn iter(&self) -> impl Iterator<Item = &NetworkObj> {
258 self.value.iter()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use std::collections::BTreeSet;
266 use std::str::FromStr;
267
268 #[test]
269 fn network_parses_range_before_fqdn() {
270 let n = Network::new("192.0.2.10 - 192.0.2.20").unwrap();
271 assert!(matches!(n, Network::Range(_)));
272 }
273
274 #[test]
275 fn network_parses_fqdn_with_dash() {
276 let n = Network::new("dash-host.example.com").unwrap();
277 assert!(matches!(n, Network::Fqdn(_)));
278 }
279
280 #[test]
281 fn group_rejects_duplicate_names() {
282 let mut group = NetworkObjGroup::new("critical", BTreeSet::new()).unwrap();
283 group
284 .add(NetworkObj::new(
285 "db1".into(),
286 Network::from_str("192.0.2.10").unwrap(),
287 ))
288 .unwrap();
289
290 let err = group
291 .add(NetworkObj::new(
292 "db1".into(),
293 Network::from_str("192.0.2.11").unwrap(),
294 ))
295 .unwrap_err();
296
297 assert!(err.contains("db1"));
298 }
299}