firewall_objects/service/
transport.rs1use crate::service::icmp::{IcmpSpec, IcmpVersion, parse_descriptor};
2use crate::service::registry;
3use std::cmp::Ordering;
4use std::collections::BTreeSet;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use std::str::FromStr;
8
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum TransportService {
16 Tcp(u16),
17 Udp(u16),
18 Icmp(IcmpSpec),
19 IpProtocol(u8),
20 Any,
21}
22
23impl TransportService {
24 pub const fn tcp(port: u16) -> Self {
25 TransportService::Tcp(port)
26 }
27
28 pub const fn udp(port: u16) -> Self {
29 TransportService::Udp(port)
30 }
31
32 pub const fn icmp(version: IcmpVersion, ty: u8, code: Option<u8>) -> Self {
33 TransportService::Icmp(IcmpSpec::new(version, ty, code))
34 }
35
36 pub const fn ip_protocol(number: u8) -> Self {
37 TransportService::IpProtocol(number)
38 }
39}
40
41impl fmt::Display for TransportService {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 match self {
44 TransportService::Tcp(port) => write!(f, "tcp/{port}"),
45 TransportService::Udp(port) => write!(f, "udp/{port}"),
46 TransportService::Icmp(spec) => write!(f, "{spec}"),
47 TransportService::IpProtocol(proto) => write!(f, "ip/{proto}"),
48 TransportService::Any => write!(f, "any"),
49 }
50 }
51}
52
53impl FromStr for TransportService {
54 type Err = String;
55
56 fn from_str(s: &str) -> Result<Self, Self::Err> {
63 let input = s.trim();
64
65 if input.is_empty() {
66 return Err("service definition cannot be empty".into());
67 }
68
69 if let Some(alias) = registry::lookup(input) {
70 return Ok(alias);
71 }
72
73 let (proto_raw, rest) = input
74 .split_once('/')
75 .ok_or_else(|| "service must be formatted as protocol/value".to_string())?;
76
77 let proto = proto_raw.trim().to_ascii_lowercase();
78 let value = rest.trim();
79
80 if value.is_empty() {
81 return Err("service value cannot be empty".into());
82 }
83
84 match proto.as_str() {
85 "tcp" => Ok(TransportService::Tcp(parse_port(value)?)),
86 "udp" => Ok(TransportService::Udp(parse_port(value)?)),
87 "icmp" | "icmpv4" => Ok(TransportService::Icmp(parse_descriptor(
88 value,
89 IcmpVersion::V4,
90 )?)),
91 "icmpv6" => Ok(TransportService::Icmp(parse_descriptor(
92 value,
93 IcmpVersion::V6,
94 )?)),
95 "ip" | "ipproto" | "proto" => {
96 Ok(TransportService::IpProtocol(parse_ip_protocol(value)?))
97 }
98 "any" => Ok(TransportService::Any),
99 _ => Err(format!("unknown transport protocol '{proto}'")),
100 }
101 }
102}
103
104fn parse_port(value: &str) -> Result<u16, String> {
105 value
106 .parse::<u16>()
107 .map_err(|_| "port must be in the range 0-65535".to_string())
108}
109
110fn parse_ip_protocol(value: &str) -> Result<u8, String> {
111 value
112 .parse::<u8>()
113 .map_err(|_| "IP protocol number must be in the range 0-255".to_string())
114}
115
116impl fmt::Display for ServiceObj {
117 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118 write!(f, "{}={}", self.name, self.value)
119 }
120}
121
122#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[derive(Debug, Clone)]
125pub struct ServiceObj {
126 pub name: String,
127 pub value: TransportService,
128}
129
130impl PartialEq for ServiceObj {
131 fn eq(&self, other: &Self) -> bool {
132 self.name.eq(&other.name)
133 }
134}
135
136impl Eq for ServiceObj {}
137
138impl PartialOrd for ServiceObj {
139 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140 Some(self.cmp(other))
141 }
142}
143
144impl Ord for ServiceObj {
145 fn cmp(&self, other: &Self) -> Ordering {
146 self.name.cmp(&other.name)
147 }
148}
149
150impl Hash for ServiceObj {
151 fn hash<H: Hasher>(&self, state: &mut H) {
152 self.name.hash(state);
153 }
154}
155
156impl ServiceObj {
157 pub fn new(name: String, value: TransportService) -> Self {
158 Self { name, value }
159 }
160}
161
162impl TryFrom<(&str, &str)> for ServiceObj {
163 type Error = String;
164
165 fn try_from(v: (&str, &str)) -> Result<Self, Self::Error> {
166 let (name, value) = v;
167 Ok(Self::new(
168 name.to_string(),
169 TransportService::from_str(value)?,
170 ))
171 }
172}
173
174#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
176#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
177pub struct ServiceObjGroup {
178 pub name: String,
179 pub value: BTreeSet<ServiceObj>,
180}
181
182impl ServiceObjGroup {
183 pub fn new(name: &str, value: BTreeSet<ServiceObj>) -> Result<Self, String> {
184 let name = name.trim();
185
186 if name.is_empty() {
187 return Err("service group name cannot be empty".into());
188 }
189
190 Ok(Self {
191 name: name.to_string(),
192 value,
193 })
194 }
195
196 pub fn add(&mut self, obj: ServiceObj) -> Result<(), String> {
197 if self.value.iter().any(|existing| existing.name == obj.name) {
198 return Err(format!(
199 "service object name '{}' already exists in group '{}'",
200 obj.name, self.name
201 ));
202 }
203
204 self.value.insert(obj);
205 Ok(())
206 }
207
208 pub fn remove(&mut self, obj: &ServiceObj) -> bool {
209 self.value.remove(obj)
210 }
211
212 pub fn len(&self) -> usize {
213 self.value.len()
214 }
215
216 pub fn is_empty(&self) -> bool {
217 self.value.is_empty()
218 }
219
220 pub fn iter(&self) -> impl Iterator<Item = &ServiceObj> {
221 self.value.iter()
222 }
223
224 pub fn services(&self) -> impl Iterator<Item = &TransportService> {
226 self.value.iter().map(|obj| &obj.value)
227 }
228
229 pub fn get(&self, name: &str) -> Option<&ServiceObj> {
231 let needle = name.trim();
232 if needle.is_empty() {
233 return None;
234 }
235 self.value.iter().find(|obj| obj.name == needle)
236 }
237
238 pub fn builder(name: &str) -> Result<ServiceObjGroupBuilder, String> {
240 ServiceObjGroupBuilder::new(name)
241 }
242}
243
244pub struct ServiceObjGroupBuilder {
246 name: String,
247 members: BTreeSet<ServiceObj>,
248}
249
250impl ServiceObjGroupBuilder {
251 pub fn new(name: &str) -> Result<Self, String> {
252 let name = name.trim();
253
254 if name.is_empty() {
255 return Err("service group name cannot be empty".into());
256 }
257
258 Ok(Self {
259 name: name.to_string(),
260 members: BTreeSet::new(),
261 })
262 }
263
264 fn insert(&mut self, obj: ServiceObj) -> Result<(), String> {
265 if self
266 .members
267 .iter()
268 .any(|existing| existing.name == obj.name)
269 {
270 return Err(format!(
271 "service object name '{}' already exists in group '{}'",
272 obj.name, self.name
273 ));
274 }
275 self.members.insert(obj);
276 Ok(())
277 }
278
279 pub fn with_obj(mut self, obj: ServiceObj) -> Result<Self, String> {
281 self.insert(obj)?;
282 Ok(self)
283 }
284
285 pub fn with_service(mut self, name: &str, value: TransportService) -> Result<Self, String> {
287 self.insert(ServiceObj::new(name.to_string(), value))?;
288 Ok(self)
289 }
290
291 pub fn with_parsed_service(mut self, name: &str, value: &str) -> Result<Self, String> {
293 self.insert(ServiceObj::try_from((name, value))?)?;
294 Ok(self)
295 }
296
297 pub fn build(self) -> ServiceObjGroup {
299 ServiceObjGroup {
300 name: self.name,
301 value: self.members,
302 }
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn parses_tcp_udp() {
312 assert_eq!(
313 TransportService::from_str("tcp/443").unwrap(),
314 TransportService::tcp(443)
315 );
316 assert_eq!(
317 TransportService::from_str("udp/53").unwrap(),
318 TransportService::udp(53)
319 );
320 }
321
322 #[test]
323 fn parses_icmp_variants() {
324 let icmp = TransportService::from_str("icmp/echo-request").unwrap();
325 assert_eq!(icmp, TransportService::icmp(IcmpVersion::V4, 8, None));
326
327 let icmpv6 = TransportService::from_str("icmpv6/128:0").unwrap();
328 assert_eq!(
329 icmpv6,
330 TransportService::icmp(IcmpVersion::V6, 128, Some(0))
331 );
332 }
333
334 #[test]
335 fn parses_aliases() {
336 let https = TransportService::from_str("https").unwrap();
337 assert_eq!(https, TransportService::tcp(443));
338 }
339
340 #[test]
341 fn group_rejects_duplicate_names() {
342 let mut group = ServiceObjGroup::new("web", BTreeSet::new()).unwrap();
343 group
344 .add(ServiceObj::new("https".into(), TransportService::tcp(443)))
345 .unwrap();
346
347 let err = group
348 .add(ServiceObj::new("https".into(), TransportService::udp(443)))
349 .unwrap_err();
350
351 assert!(err.contains("https"));
352 }
353
354 #[test]
355 fn builder_and_iter_helpers_work() {
356 let group = ServiceObjGroup::builder("databases")
357 .unwrap()
358 .with_parsed_service("postgres", "tcp/5432")
359 .unwrap()
360 .with_service("redis", TransportService::tcp(6379))
361 .unwrap()
362 .build();
363
364 assert_eq!(group.len(), 2);
365 assert!(group.get("postgres").is_some());
366 let rendered: Vec<_> = group.services().map(|svc| svc.to_string()).collect();
367 assert_eq!(rendered, vec!["tcp/5432", "tcp/6379"]);
368 }
369}