firewall_objects/service/
definition.rs1use super::transport::TransportService;
4use std::fmt;
5use std::str::FromStr;
6
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub enum Layer4Protocol {
11 Tcp,
12 Udp,
13}
14
15impl fmt::Display for Layer4Protocol {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 Layer4Protocol::Tcp => write!(f, "tcp"),
19 Layer4Protocol::Udp => write!(f, "udp"),
20 }
21 }
22}
23
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum ServiceToken {
28 Single(TransportService),
29 Range {
30 protocol: Layer4Protocol,
31 start: u16,
32 end: u16,
33 },
34}
35
36impl ServiceToken {
37 pub fn expand(&self) -> Vec<TransportService> {
39 match self {
40 ServiceToken::Single(svc) => vec![svc.clone()],
41 ServiceToken::Range {
42 protocol,
43 start,
44 end,
45 } => match protocol {
46 Layer4Protocol::Tcp => (*start..=*end).map(TransportService::tcp).collect(),
47 Layer4Protocol::Udp => (*start..=*end).map(TransportService::udp).collect(),
48 },
49 }
50 }
51}
52
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct ServiceDefinition {
57 tokens: Vec<ServiceToken>,
58}
59
60impl ServiceDefinition {
61 pub fn parse(input: &str) -> Result<Self, String> {
63 let mut tokens = Vec::new();
64
65 for (idx, raw_token) in input.split(',').enumerate() {
66 let token = raw_token.trim();
67 if token.is_empty() {
68 return Err(format!("empty service token at position {}", idx + 1));
69 }
70
71 if let Some(range) = parse_range(token)? {
72 tokens.push(range);
73 continue;
74 }
75
76 let service = TransportService::from_str(token)?;
77 tokens.push(ServiceToken::Single(service));
78 }
79
80 if tokens.is_empty() {
81 return Err("service definition cannot be empty".into());
82 }
83
84 Ok(Self { tokens })
85 }
86
87 pub fn tokens(&self) -> &[ServiceToken] {
89 &self.tokens
90 }
91
92 pub fn iter(&self) -> impl Iterator<Item = &ServiceToken> {
94 self.tokens.iter()
95 }
96
97 pub fn expand(&self) -> Vec<TransportService> {
102 self.tokens.iter().flat_map(ServiceToken::expand).collect()
103 }
104}
105
106fn parse_range(token: &str) -> Result<Option<ServiceToken>, String> {
107 let (proto_raw, rest) = match token.split_once('/') {
108 Some(parts) => parts,
109 None => return Ok(None),
110 };
111
112 let proto = proto_raw.trim().to_ascii_lowercase();
113 let layer4 = match proto.as_str() {
114 "tcp" => Layer4Protocol::Tcp,
115 "udp" => Layer4Protocol::Udp,
116 _ => return Ok(None),
117 };
118
119 let value = rest.trim();
120 if let Some((start_raw, end_raw)) = value.split_once('-') {
121 let start = start_raw
122 .trim()
123 .parse::<u16>()
124 .map_err(|_| "invalid start port in range".to_string())?;
125 let end = end_raw
126 .trim()
127 .parse::<u16>()
128 .map_err(|_| "invalid end port in range".to_string())?;
129
130 if start > end {
131 return Err("range start must be <= end".into());
132 }
133
134 return Ok(Some(ServiceToken::Range {
135 protocol: layer4,
136 start,
137 end,
138 }));
139 }
140
141 Ok(None)
142}
143
144impl std::str::FromStr for ServiceDefinition {
145 type Err = String;
146
147 fn from_str(s: &str) -> Result<Self, Self::Err> {
148 ServiceDefinition::parse(s)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn parses_mixed_tokens() {
158 let def = ServiceDefinition::parse("tcp/80, udp/53, https").unwrap();
159 assert_eq!(def.tokens().len(), 3);
160 }
161
162 #[test]
163 fn parses_ranges() {
164 let def = ServiceDefinition::parse("tcp/1000-1002").unwrap();
165 match def.tokens()[0] {
166 ServiceToken::Range {
167 protocol,
168 start,
169 end,
170 } => {
171 assert_eq!(protocol, Layer4Protocol::Tcp);
172 assert_eq!(start, 1000);
173 assert_eq!(end, 1002);
174 }
175 _ => panic!("expected range"),
176 }
177
178 let expanded = def.expand();
179 assert_eq!(
180 expanded.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
181 vec!["tcp/1000", "tcp/1001", "tcp/1002"]
182 );
183 }
184
185 #[test]
186 fn rejects_bad_ranges() {
187 assert!(ServiceDefinition::parse("tcp/2000-1000").is_err());
188 assert!(ServiceDefinition::parse("tcp/-100").is_err());
189 }
190
191 #[test]
192 fn rejects_empty_tokens() {
193 assert!(ServiceDefinition::parse("tcp/80,,udp/53").is_err());
194 }
195}