Skip to main content

firewall_objects/service/
definition.rs

1//! High-level service definition parser (comma-separated lists, ranges).
2
3use super::transport::TransportService;
4use std::fmt;
5use std::str::FromStr;
6
7/// Layer-4 protocol supported for range expressions.
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub enum Layer4Protocol {
11    Tcp,
12    Udp,
13}
14
15impl fmt::Display for Layer4Protocol {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            Layer4Protocol::Tcp => write!(f, "tcp"),
19            Layer4Protocol::Udp => write!(f, "udp"),
20        }
21    }
22}
23
24/// Parsed token within a service definition list.
25#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum ServiceToken {
28    Single(TransportService),
29    Range {
30        protocol: Layer4Protocol,
31        start: u16,
32        end: u16,
33    },
34}
35
36impl ServiceToken {
37    /// Expand a range token into individual transport services.
38    pub fn expand(&self) -> Vec<TransportService> {
39        match self {
40            ServiceToken::Single(svc) => vec![svc.clone()],
41            ServiceToken::Range {
42                protocol,
43                start,
44                end,
45            } => match protocol {
46                Layer4Protocol::Tcp => (*start..=*end).map(TransportService::tcp).collect(),
47                Layer4Protocol::Udp => (*start..=*end).map(TransportService::udp).collect(),
48            },
49        }
50    }
51}
52
53/// Representation of a comma-separated service definition string.
54#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ServiceDefinition {
57    tokens: Vec<ServiceToken>,
58}
59
60impl ServiceDefinition {
61    /// Parse a string like `"tcp/80, udp/53, tcp/1000-1005"` into tokens.
62    pub fn parse(input: &str) -> Result<Self, String> {
63        let mut tokens = Vec::new();
64
65        for (idx, raw_token) in input.split(',').enumerate() {
66            let token = raw_token.trim();
67            if token.is_empty() {
68                return Err(format!("empty service token at position {}", idx + 1));
69            }
70
71            if let Some(range) = parse_range(token)? {
72                tokens.push(range);
73                continue;
74            }
75
76            let service = TransportService::from_str(token)?;
77            tokens.push(ServiceToken::Single(service));
78        }
79
80        if tokens.is_empty() {
81            return Err("service definition cannot be empty".into());
82        }
83
84        Ok(Self { tokens })
85    }
86
87    /// Return the parsed tokens.
88    pub fn tokens(&self) -> &[ServiceToken] {
89        &self.tokens
90    }
91
92    /// Iterate over tokens.
93    pub fn iter(&self) -> impl Iterator<Item = &ServiceToken> {
94        self.tokens.iter()
95    }
96
97    /// Expand all tokens into explicit `TransportService` values.
98    ///
99    /// Ranges will produce one element per port in the span, while aliases
100    /// and single entries retain their normalized representation.
101    pub fn expand(&self) -> Vec<TransportService> {
102        self.tokens.iter().flat_map(ServiceToken::expand).collect()
103    }
104}
105
106fn parse_range(token: &str) -> Result<Option<ServiceToken>, String> {
107    let (proto_raw, rest) = match token.split_once('/') {
108        Some(parts) => parts,
109        None => return Ok(None),
110    };
111
112    let proto = proto_raw.trim().to_ascii_lowercase();
113    let layer4 = match proto.as_str() {
114        "tcp" => Layer4Protocol::Tcp,
115        "udp" => Layer4Protocol::Udp,
116        _ => return Ok(None),
117    };
118
119    let value = rest.trim();
120    if let Some((start_raw, end_raw)) = value.split_once('-') {
121        let start = start_raw
122            .trim()
123            .parse::<u16>()
124            .map_err(|_| "invalid start port in range".to_string())?;
125        let end = end_raw
126            .trim()
127            .parse::<u16>()
128            .map_err(|_| "invalid end port in range".to_string())?;
129
130        if start > end {
131            return Err("range start must be <= end".into());
132        }
133
134        return Ok(Some(ServiceToken::Range {
135            protocol: layer4,
136            start,
137            end,
138        }));
139    }
140
141    Ok(None)
142}
143
144impl std::str::FromStr for ServiceDefinition {
145    type Err = String;
146
147    fn from_str(s: &str) -> Result<Self, Self::Err> {
148        ServiceDefinition::parse(s)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn parses_mixed_tokens() {
158        let def = ServiceDefinition::parse("tcp/80, udp/53, https").unwrap();
159        assert_eq!(def.tokens().len(), 3);
160    }
161
162    #[test]
163    fn parses_ranges() {
164        let def = ServiceDefinition::parse("tcp/1000-1002").unwrap();
165        match def.tokens()[0] {
166            ServiceToken::Range {
167                protocol,
168                start,
169                end,
170            } => {
171                assert_eq!(protocol, Layer4Protocol::Tcp);
172                assert_eq!(start, 1000);
173                assert_eq!(end, 1002);
174            }
175            _ => panic!("expected range"),
176        }
177
178        let expanded = def.expand();
179        assert_eq!(
180            expanded.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
181            vec!["tcp/1000", "tcp/1001", "tcp/1002"]
182        );
183    }
184
185    #[test]
186    fn rejects_bad_ranges() {
187        assert!(ServiceDefinition::parse("tcp/2000-1000").is_err());
188        assert!(ServiceDefinition::parse("tcp/-100").is_err());
189    }
190
191    #[test]
192    fn rejects_empty_tokens() {
193        assert!(ServiceDefinition::parse("tcp/80,,udp/53").is_err());
194    }
195}