firewall_objects/service/
transport.rs

1use crate::service::icmp::{IcmpSpec, IcmpVersion, parse_descriptor};
2use crate::service::registry;
3use std::cmp::Ordering;
4use std::collections::BTreeSet;
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use std::str::FromStr;
8
9/// Transport-level representation of a firewall service definition.
10///
11/// Feature flags:
12/// - `serde` – Enables serialization for transport types and containers.
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
15pub enum TransportService {
16    Tcp(u16),
17    Udp(u16),
18    Icmp(IcmpSpec),
19    IpProtocol(u8),
20    Any,
21}
22
23impl TransportService {
24    pub const fn tcp(port: u16) -> Self {
25        TransportService::Tcp(port)
26    }
27
28    pub const fn udp(port: u16) -> Self {
29        TransportService::Udp(port)
30    }
31
32    pub const fn icmp(version: IcmpVersion, ty: u8, code: Option<u8>) -> Self {
33        TransportService::Icmp(IcmpSpec::new(version, ty, code))
34    }
35
36    pub const fn ip_protocol(number: u8) -> Self {
37        TransportService::IpProtocol(number)
38    }
39}
40
41impl fmt::Display for TransportService {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            TransportService::Tcp(port) => write!(f, "tcp/{port}"),
45            TransportService::Udp(port) => write!(f, "udp/{port}"),
46            TransportService::Icmp(spec) => write!(f, "{spec}"),
47            TransportService::IpProtocol(proto) => write!(f, "ip/{proto}"),
48            TransportService::Any => write!(f, "any"),
49        }
50    }
51}
52
53impl FromStr for TransportService {
54    type Err = String;
55
56    /// Parse a transport definition from common strings:
57    /// - `"tcp/443"`
58    /// - `"udp/53"`
59    /// - `"icmp/echo-request"`, `"icmpv6/128:0"`
60    /// - `"ip/47"`
61    /// - Case-insensitive aliases such as `"https"` or `"ping"`
62    fn from_str(s: &str) -> Result<Self, Self::Err> {
63        let input = s.trim();
64
65        if input.is_empty() {
66            return Err("service definition cannot be empty".into());
67        }
68
69        if let Some(alias) = registry::lookup(input) {
70            return Ok(alias);
71        }
72
73        let (proto_raw, rest) = input
74            .split_once('/')
75            .ok_or_else(|| "service must be formatted as protocol/value".to_string())?;
76
77        let proto = proto_raw.trim().to_ascii_lowercase();
78        let value = rest.trim();
79
80        if value.is_empty() {
81            return Err("service value cannot be empty".into());
82        }
83
84        match proto.as_str() {
85            "tcp" => Ok(TransportService::Tcp(parse_port(value)?)),
86            "udp" => Ok(TransportService::Udp(parse_port(value)?)),
87            "icmp" | "icmpv4" => Ok(TransportService::Icmp(parse_descriptor(
88                value,
89                IcmpVersion::V4,
90            )?)),
91            "icmpv6" => Ok(TransportService::Icmp(parse_descriptor(
92                value,
93                IcmpVersion::V6,
94            )?)),
95            "ip" | "ipproto" | "proto" => {
96                Ok(TransportService::IpProtocol(parse_ip_protocol(value)?))
97            }
98            "any" => Ok(TransportService::Any),
99            _ => Err(format!("unknown transport protocol '{proto}'")),
100        }
101    }
102}
103
104fn parse_port(value: &str) -> Result<u16, String> {
105    value
106        .parse::<u16>()
107        .map_err(|_| "port must be in the range 0-65535".to_string())
108}
109
110fn parse_ip_protocol(value: &str) -> Result<u8, String> {
111    value
112        .parse::<u8>()
113        .map_err(|_| "IP protocol number must be in the range 0-255".to_string())
114}
115
116impl fmt::Display for ServiceObj {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        write!(f, "{}={}", self.name, self.value)
119    }
120}
121
122/// Transport service definition with a unique identifier.
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[derive(Debug, Clone)]
125pub struct ServiceObj {
126    pub name: String,
127    pub value: TransportService,
128}
129
130impl PartialEq for ServiceObj {
131    fn eq(&self, other: &Self) -> bool {
132        self.name.eq(&other.name)
133    }
134}
135
136impl Eq for ServiceObj {}
137
138impl PartialOrd for ServiceObj {
139    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
140        Some(self.cmp(other))
141    }
142}
143
144impl Ord for ServiceObj {
145    fn cmp(&self, other: &Self) -> Ordering {
146        self.name.cmp(&other.name)
147    }
148}
149
150impl Hash for ServiceObj {
151    fn hash<H: Hasher>(&self, state: &mut H) {
152        self.name.hash(state);
153    }
154}
155
156impl ServiceObj {
157    pub fn new(name: String, value: TransportService) -> Self {
158        Self { name, value }
159    }
160}
161
162impl TryFrom<(&str, &str)> for ServiceObj {
163    type Error = String;
164
165    fn try_from(v: (&str, &str)) -> Result<Self, Self::Error> {
166        let (name, value) = v;
167        Ok(Self::new(
168            name.to_string(),
169            TransportService::from_str(value)?,
170        ))
171    }
172}
173
174/// Ordered set of service objects sharing a human-friendly name.
175#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
176#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
177pub struct ServiceObjGroup {
178    pub name: String,
179    pub value: BTreeSet<ServiceObj>,
180}
181
182impl ServiceObjGroup {
183    pub fn new(name: &str, value: BTreeSet<ServiceObj>) -> Result<Self, String> {
184        let name = name.trim();
185
186        if name.is_empty() {
187            return Err("service group name cannot be empty".into());
188        }
189
190        Ok(Self {
191            name: name.to_string(),
192            value,
193        })
194    }
195
196    pub fn add(&mut self, obj: ServiceObj) -> Result<(), String> {
197        if self.value.iter().any(|existing| existing.name == obj.name) {
198            return Err(format!(
199                "service object name '{}' already exists in group '{}'",
200                obj.name, self.name
201            ));
202        }
203
204        self.value.insert(obj);
205        Ok(())
206    }
207
208    pub fn remove(&mut self, obj: &ServiceObj) -> bool {
209        self.value.remove(obj)
210    }
211
212    pub fn len(&self) -> usize {
213        self.value.len()
214    }
215
216    pub fn is_empty(&self) -> bool {
217        self.value.is_empty()
218    }
219
220    pub fn iter(&self) -> impl Iterator<Item = &ServiceObj> {
221        self.value.iter()
222    }
223
224    /// Iterate over transport definitions directly.
225    pub fn services(&self) -> impl Iterator<Item = &TransportService> {
226        self.value.iter().map(|obj| &obj.value)
227    }
228
229    /// Fetch a specific service object by name.
230    pub fn get(&self, name: &str) -> Option<&ServiceObj> {
231        let needle = name.trim();
232        if needle.is_empty() {
233            return None;
234        }
235        self.value.iter().find(|obj| obj.name == needle)
236    }
237
238    /// Create a builder to assemble the group incrementally.
239    pub fn builder(name: &str) -> Result<ServiceObjGroupBuilder, String> {
240        ServiceObjGroupBuilder::new(name)
241    }
242}
243
244/// Builder that helps construct `ServiceObjGroup` instances.
245pub struct ServiceObjGroupBuilder {
246    name: String,
247    members: BTreeSet<ServiceObj>,
248}
249
250impl ServiceObjGroupBuilder {
251    pub fn new(name: &str) -> Result<Self, String> {
252        let name = name.trim();
253
254        if name.is_empty() {
255            return Err("service group name cannot be empty".into());
256        }
257
258        Ok(Self {
259            name: name.to_string(),
260            members: BTreeSet::new(),
261        })
262    }
263
264    fn insert(&mut self, obj: ServiceObj) -> Result<(), String> {
265        if self
266            .members
267            .iter()
268            .any(|existing| existing.name == obj.name)
269        {
270            return Err(format!(
271                "service object name '{}' already exists in group '{}'",
272                obj.name, self.name
273            ));
274        }
275        self.members.insert(obj);
276        Ok(())
277    }
278
279    /// Add a pre-built service object. Returned builder allows chaining.
280    pub fn with_obj(mut self, obj: ServiceObj) -> Result<Self, String> {
281        self.insert(obj)?;
282        Ok(self)
283    }
284
285    /// Add a named transport definition.
286    pub fn with_service(mut self, name: &str, value: TransportService) -> Result<Self, String> {
287        self.insert(ServiceObj::new(name.to_string(), value))?;
288        Ok(self)
289    }
290
291    /// Parse a service string such as `tcp/443` and add it under `name`.
292    pub fn with_parsed_service(mut self, name: &str, value: &str) -> Result<Self, String> {
293        self.insert(ServiceObj::try_from((name, value))?)?;
294        Ok(self)
295    }
296
297    /// Finish building the group.
298    pub fn build(self) -> ServiceObjGroup {
299        ServiceObjGroup {
300            name: self.name,
301            value: self.members,
302        }
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn parses_tcp_udp() {
312        assert_eq!(
313            TransportService::from_str("tcp/443").unwrap(),
314            TransportService::tcp(443)
315        );
316        assert_eq!(
317            TransportService::from_str("udp/53").unwrap(),
318            TransportService::udp(53)
319        );
320    }
321
322    #[test]
323    fn parses_icmp_variants() {
324        let icmp = TransportService::from_str("icmp/echo-request").unwrap();
325        assert_eq!(icmp, TransportService::icmp(IcmpVersion::V4, 8, None));
326
327        let icmpv6 = TransportService::from_str("icmpv6/128:0").unwrap();
328        assert_eq!(
329            icmpv6,
330            TransportService::icmp(IcmpVersion::V6, 128, Some(0))
331        );
332    }
333
334    #[test]
335    fn parses_aliases() {
336        let https = TransportService::from_str("https").unwrap();
337        assert_eq!(https, TransportService::tcp(443));
338    }
339
340    #[test]
341    fn group_rejects_duplicate_names() {
342        let mut group = ServiceObjGroup::new("web", BTreeSet::new()).unwrap();
343        group
344            .add(ServiceObj::new("https".into(), TransportService::tcp(443)))
345            .unwrap();
346
347        let err = group
348            .add(ServiceObj::new("https".into(), TransportService::udp(443)))
349            .unwrap_err();
350
351        assert!(err.contains("https"));
352    }
353
354    #[test]
355    fn builder_and_iter_helpers_work() {
356        let group = ServiceObjGroup::builder("databases")
357            .unwrap()
358            .with_parsed_service("postgres", "tcp/5432")
359            .unwrap()
360            .with_service("redis", TransportService::tcp(6379))
361            .unwrap()
362            .build();
363
364        assert_eq!(group.len(), 2);
365        assert!(group.get("postgres").is_some());
366        let rendered: Vec<_> = group.services().map(|svc| svc.to_string()).collect();
367        assert_eq!(rendered, vec!["tcp/5432", "tcp/6379"]);
368    }
369}