pcapsql_core/protocol/
pruning.rs1use std::collections::HashSet;
30
31use super::registry::{Protocol, ProtocolRegistry};
32
33pub fn compute_required_protocols(
48 queried_tables: &[&str],
49 registry: &ProtocolRegistry,
50) -> HashSet<String> {
51 let mut required = HashSet::new();
52
53 required.insert("frames".to_string());
55
56 for table in queried_tables {
58 add_with_dependencies(table, registry, &mut required);
59 }
60
61 required
62}
63
64fn add_with_dependencies(
66 protocol: &str,
67 registry: &ProtocolRegistry,
68 required: &mut HashSet<String>,
69) {
70 if let Some(parser) = registry.get_parser(protocol) {
72 let name = parser.name().to_string();
73 if required.insert(name) {
74 for dep in parser.dependencies() {
76 add_with_dependencies(dep, registry, required);
77 }
78 }
79 } else {
80 required.insert(protocol.to_string());
83 }
84}
85
86pub fn should_continue_parsing(parsed_so_far: &[&str], required: &HashSet<String>) -> bool {
100 for req in required {
102 if req == "frames" {
104 continue;
105 }
106 if !parsed_so_far.contains(&req.as_str()) {
107 return true;
108 }
109 }
110 false
111}
112
113pub fn should_run_parser(
126 parser_name: &str,
127 required: &HashSet<String>,
128 registry: &ProtocolRegistry,
129) -> bool {
130 if required.contains(parser_name) {
132 return true;
133 }
134
135 for req in required {
137 if let Some(parser) = registry.get_parser(req) {
138 if parser.dependencies().contains(&parser_name) {
139 return true;
140 }
141 }
142 }
143
144 false
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use crate::protocol::default_registry;
151
152 #[test]
153 fn test_compute_required_protocols_tcp() {
154 let registry = default_registry();
155
156 let required = compute_required_protocols(&["tcp"], ®istry);
158
159 assert!(required.contains("frames"));
160 assert!(required.contains("tcp"));
161 assert!(required.contains("ipv4"));
162 assert!(required.contains("ipv6"));
163 assert!(required.contains("ethernet"));
164
165 assert!(!required.contains("dns"));
167 assert!(!required.contains("dhcp"));
168 assert!(!required.contains("tls"));
169 }
170
171 #[test]
172 fn test_compute_required_protocols_dns() {
173 let registry = default_registry();
174
175 let required = compute_required_protocols(&["dns"], ®istry);
177
178 assert!(required.contains("frames"));
179 assert!(required.contains("dns"));
180 assert!(required.contains("udp"));
181 assert!(required.contains("tcp")); assert!(required.contains("ipv4"));
183 assert!(required.contains("ipv6"));
184 assert!(required.contains("ethernet"));
185
186 assert!(!required.contains("tls"));
188 assert!(!required.contains("ssh"));
189 }
190
191 #[test]
192 fn test_compute_required_protocols_ethernet_only() {
193 let registry = default_registry();
194
195 let required = compute_required_protocols(&["ethernet"], ®istry);
197
198 assert!(required.contains("frames"));
199 assert!(required.contains("ethernet"));
200
201 assert!(!required.contains("ipv4"));
203 assert!(!required.contains("tcp"));
204 assert!(!required.contains("dns"));
205 }
206
207 #[test]
208 fn test_compute_required_protocols_multiple_tables() {
209 let registry = default_registry();
210
211 let required = compute_required_protocols(&["tcp", "dns"], ®istry);
213
214 assert!(required.contains("tcp"));
215 assert!(required.contains("dns"));
216 assert!(required.contains("udp")); assert!(required.contains("ipv4"));
218 assert!(required.contains("ethernet"));
219 }
220
221 #[test]
222 fn test_should_continue_parsing() {
223 let required: HashSet<String> = ["frames", "ethernet", "ipv4", "tcp"]
224 .iter()
225 .map(|s| s.to_string())
226 .collect();
227
228 assert!(should_continue_parsing(&[], &required));
230
231 assert!(should_continue_parsing(&["ethernet"], &required));
233
234 assert!(should_continue_parsing(&["ethernet", "ipv4"], &required));
236
237 assert!(!should_continue_parsing(
239 &["ethernet", "ipv4", "tcp"],
240 &required
241 ));
242 }
243
244 #[test]
245 fn test_should_run_parser() {
246 let registry = default_registry();
247 let required: HashSet<String> = ["tcp"].iter().map(|s| s.to_string()).collect();
248
249 assert!(should_run_parser("tcp", &required, ®istry));
251
252 assert!(should_run_parser("ipv4", &required, ®istry));
254 assert!(should_run_parser("ipv6", &required, ®istry));
255
256 assert!(!should_run_parser("dns", &required, ®istry));
258
259 assert!(!should_run_parser("udp", &required, ®istry));
261 }
262
263 #[test]
264 fn test_vxlan_dependencies() {
265 let registry = default_registry();
266
267 let required = compute_required_protocols(&["vxlan"], ®istry);
269
270 assert!(required.contains("vxlan"));
271 assert!(required.contains("udp"));
272 assert!(required.contains("ipv4"));
273 assert!(required.contains("ipv6"));
274 assert!(required.contains("ethernet"));
275 }
276}