pcapsql_core/protocol/
pruning.rs

1//! Protocol pruning for query optimization.
2//!
3//! When executing SQL queries that only reference certain protocol tables,
4//! we can skip parsing protocols that aren't needed. This module provides
5//! utilities to compute the required set of protocols and check whether
6//! to continue parsing.
7//!
8//! # Example
9//!
10//! ```rust,ignore
11//! use std::collections::HashSet;
12//! use pcapsql_core::protocol::{default_registry, compute_required_protocols};
13//!
14//! let registry = default_registry();
15//!
16//! // Query only touches TCP table
17//! let required = compute_required_protocols(&["tcp"], &registry);
18//!
19//! // Required set includes TCP and its dependencies
20//! assert!(required.contains("tcp"));
21//! assert!(required.contains("ipv4"));
22//! assert!(required.contains("ethernet"));
23//!
24//! // But not unrelated protocols
25//! assert!(!required.contains("dns"));
26//! assert!(!required.contains("http"));
27//! ```
28
29use std::collections::HashSet;
30
31use super::registry::{Protocol, ProtocolRegistry};
32
33/// Compute the set of protocols required to satisfy a query.
34///
35/// Given a set of queried table names (e.g., `["tcp", "ipv4"]`), returns
36/// all protocols that must be parsed, including transitive dependencies.
37///
38/// # Arguments
39///
40/// * `queried_tables` - Names of protocol tables referenced in the query
41/// * `registry` - Protocol registry containing parser definitions
42///
43/// # Returns
44///
45/// A set of protocol names that must be parsed. Always includes "frames"
46/// (the base frame data) plus all queried protocols and their dependencies.
47pub fn compute_required_protocols(
48    queried_tables: &[&str],
49    registry: &ProtocolRegistry,
50) -> HashSet<String> {
51    let mut required = HashSet::new();
52
53    // Always need the ability to read frames
54    required.insert("frames".to_string());
55
56    // For each queried table, add it and its dependencies
57    for table in queried_tables {
58        add_with_dependencies(table, registry, &mut required);
59    }
60
61    required
62}
63
64/// Recursively add a protocol and all its dependencies to the required set.
65fn add_with_dependencies(
66    protocol: &str,
67    registry: &ProtocolRegistry,
68    required: &mut HashSet<String>,
69) {
70    // Get parser from registry
71    if let Some(parser) = registry.get_parser(protocol) {
72        let name = parser.name().to_string();
73        if required.insert(name) {
74            // Newly added, also add dependencies
75            for dep in parser.dependencies() {
76                add_with_dependencies(dep, registry, required);
77            }
78        }
79    } else {
80        // Protocol not in registry, just add the name
81        // (e.g., "frames" is not a parser but a pseudo-table)
82        required.insert(protocol.to_string());
83    }
84}
85
86/// Check if parsing should continue given the current parse results and required set.
87///
88/// Returns `true` if there are required protocols that haven't been parsed yet,
89/// meaning parsing should continue.
90///
91/// # Arguments
92///
93/// * `parsed_so_far` - Names of protocols already parsed from the current packet
94/// * `required` - Set of protocols needed for the query
95///
96/// # Returns
97///
98/// `true` if parsing should continue, `false` if all required protocols have been found.
99pub fn should_continue_parsing(parsed_so_far: &[&str], required: &HashSet<String>) -> bool {
100    // Continue if there are required protocols we haven't parsed yet
101    for req in required {
102        // Skip "frames" as it's always available without parsing
103        if req == "frames" {
104            continue;
105        }
106        if !parsed_so_far.contains(&req.as_str()) {
107            return true;
108        }
109    }
110    false
111}
112
113/// Check if a specific parser should be run.
114///
115/// Returns `true` if:
116/// 1. The parser's output is directly needed by the query, OR
117/// 2. The parser is on the path to a needed protocol (i.e., some required
118///    protocol depends on this one)
119///
120/// # Arguments
121///
122/// * `parser_name` - Name of the parser to check
123/// * `required` - Set of protocols needed for the query
124/// * `registry` - Protocol registry containing parser definitions
125pub fn should_run_parser(
126    parser_name: &str,
127    required: &HashSet<String>,
128    registry: &ProtocolRegistry,
129) -> bool {
130    // Directly required
131    if required.contains(parser_name) {
132        return true;
133    }
134
135    // Check if any required protocol depends on this one
136    for req in required {
137        if let Some(parser) = registry.get_parser(req) {
138            if parser.dependencies().contains(&parser_name) {
139                return true;
140            }
141        }
142    }
143
144    false
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::protocol::default_registry;
151
152    #[test]
153    fn test_compute_required_protocols_tcp() {
154        let registry = default_registry();
155
156        // TCP requires ethernet -> ipv4/ipv6 -> tcp
157        let required = compute_required_protocols(&["tcp"], &registry);
158
159        assert!(required.contains("frames"));
160        assert!(required.contains("tcp"));
161        assert!(required.contains("ipv4"));
162        assert!(required.contains("ipv6"));
163        assert!(required.contains("ethernet"));
164
165        // Should NOT include unrelated protocols
166        assert!(!required.contains("dns"));
167        assert!(!required.contains("dhcp"));
168        assert!(!required.contains("tls"));
169    }
170
171    #[test]
172    fn test_compute_required_protocols_dns() {
173        let registry = default_registry();
174
175        // DNS requires ethernet -> ipv4/ipv6 -> udp/tcp -> dns
176        let required = compute_required_protocols(&["dns"], &registry);
177
178        assert!(required.contains("frames"));
179        assert!(required.contains("dns"));
180        assert!(required.contains("udp"));
181        assert!(required.contains("tcp")); // DNS can run over TCP
182        assert!(required.contains("ipv4"));
183        assert!(required.contains("ipv6"));
184        assert!(required.contains("ethernet"));
185
186        // Should NOT include TLS, SSH, etc.
187        assert!(!required.contains("tls"));
188        assert!(!required.contains("ssh"));
189    }
190
191    #[test]
192    fn test_compute_required_protocols_ethernet_only() {
193        let registry = default_registry();
194
195        // Ethernet only needs ethernet layer
196        let required = compute_required_protocols(&["ethernet"], &registry);
197
198        assert!(required.contains("frames"));
199        assert!(required.contains("ethernet"));
200
201        // Should NOT include any L3+ protocols
202        assert!(!required.contains("ipv4"));
203        assert!(!required.contains("tcp"));
204        assert!(!required.contains("dns"));
205    }
206
207    #[test]
208    fn test_compute_required_protocols_multiple_tables() {
209        let registry = default_registry();
210
211        // Join between TCP and DNS requires both paths
212        let required = compute_required_protocols(&["tcp", "dns"], &registry);
213
214        assert!(required.contains("tcp"));
215        assert!(required.contains("dns"));
216        assert!(required.contains("udp")); // For DNS
217        assert!(required.contains("ipv4"));
218        assert!(required.contains("ethernet"));
219    }
220
221    #[test]
222    fn test_should_continue_parsing() {
223        let required: HashSet<String> = ["frames", "ethernet", "ipv4", "tcp"]
224            .iter()
225            .map(|s| s.to_string())
226            .collect();
227
228        // Nothing parsed yet - should continue
229        assert!(should_continue_parsing(&[], &required));
230
231        // Ethernet parsed - should continue
232        assert!(should_continue_parsing(&["ethernet"], &required));
233
234        // Ethernet + IPv4 parsed - should continue (need TCP)
235        assert!(should_continue_parsing(&["ethernet", "ipv4"], &required));
236
237        // All required parsed - should stop
238        assert!(!should_continue_parsing(
239            &["ethernet", "ipv4", "tcp"],
240            &required
241        ));
242    }
243
244    #[test]
245    fn test_should_run_parser() {
246        let registry = default_registry();
247        let required: HashSet<String> = ["tcp"].iter().map(|s| s.to_string()).collect();
248
249        // TCP is directly required
250        assert!(should_run_parser("tcp", &required, &registry));
251
252        // IPv4/IPv6 are dependencies of TCP
253        assert!(should_run_parser("ipv4", &required, &registry));
254        assert!(should_run_parser("ipv6", &required, &registry));
255
256        // DNS is not required
257        assert!(!should_run_parser("dns", &required, &registry));
258
259        // UDP is not required when only TCP is needed
260        assert!(!should_run_parser("udp", &required, &registry));
261    }
262
263    #[test]
264    fn test_vxlan_dependencies() {
265        let registry = default_registry();
266
267        // VXLAN requires the full encapsulation path
268        let required = compute_required_protocols(&["vxlan"], &registry);
269
270        assert!(required.contains("vxlan"));
271        assert!(required.contains("udp"));
272        assert!(required.contains("ipv4"));
273        assert!(required.contains("ipv6"));
274        assert!(required.contains("ethernet"));
275    }
276}