pcapsql_core/protocol/
projection.rs

1//! Field projection configuration for parsing optimization.
2//!
3//! When executing SQL queries that only need certain columns,
4//! we can skip extracting fields that aren't needed. This module provides
5//! configuration structures to specify which fields to extract.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use std::collections::HashSet;
11//! use pcapsql_core::protocol::{ProjectionConfig, default_registry};
12//!
13//! // Create projection requesting only ports from TCP
14//! let mut config = ProjectionConfig::new();
15//! config.add_protocol_fields("tcp", &["src_port", "dst_port"]);
16//!
17//! // Parse with projection
18//! let fields = config.get("tcp");
19//! let result = parser.parse_projected(&data, &context, fields);
20//! ```
21
22use std::collections::{HashMap, HashSet};
23
24/// Configuration for field projection during parsing.
25///
26/// Stores per-protocol field sets that control which fields are extracted.
27/// Fields not in the set are skipped during parsing, reducing CPU usage.
28#[derive(Debug, Clone, Default)]
29pub struct ProjectionConfig {
30    /// Per-protocol field projections.
31    /// Key is protocol name, value is set of required field names.
32    protocol_fields: HashMap<String, HashSet<String>>,
33
34    /// If true, always include fields needed for protocol chaining.
35    /// These are fields required to detect and parse child protocols.
36    include_chain_fields: bool,
37}
38
39impl ProjectionConfig {
40    /// Create a new empty projection configuration.
41    pub fn new() -> Self {
42        Self {
43            protocol_fields: HashMap::new(),
44            include_chain_fields: true,
45        }
46    }
47
48    /// Create configuration that includes chain fields.
49    ///
50    /// Chain fields are those needed to detect and parse child protocols,
51    /// such as IP protocol number or TCP/UDP ports.
52    pub fn with_chain_fields(mut self) -> Self {
53        self.include_chain_fields = true;
54        self
55    }
56
57    /// Disable automatic inclusion of chain fields.
58    pub fn without_chain_fields(mut self) -> Self {
59        self.include_chain_fields = false;
60        self
61    }
62
63    /// Add required fields for a protocol.
64    ///
65    /// # Arguments
66    ///
67    /// * `protocol` - Protocol name (e.g., "tcp", "dns")
68    /// * `fields` - Field names to extract (e.g., "src_port", "dst_port")
69    pub fn add_protocol_fields<I, S>(&mut self, protocol: &str, fields: I)
70    where
71        I: IntoIterator<Item = S>,
72        S: AsRef<str>,
73    {
74        let field_set = self
75            .protocol_fields
76            .entry(protocol.to_string())
77            .or_default();
78
79        for field in fields {
80            field_set.insert(field.as_ref().to_string());
81        }
82    }
83
84    /// Builder-style method to add protocol fields.
85    pub fn with_protocol_fields<I, S>(mut self, protocol: &str, fields: I) -> Self
86    where
87        I: IntoIterator<Item = S>,
88        S: AsRef<str>,
89    {
90        self.add_protocol_fields(protocol, fields);
91        self
92    }
93
94    /// Get the projection for a specific protocol.
95    ///
96    /// Returns None if no projection is configured for this protocol,
97    /// meaning all fields should be extracted.
98    pub fn get(&self, protocol: &str) -> Option<&HashSet<String>> {
99        self.protocol_fields.get(protocol)
100    }
101
102    /// Check if any projection is configured.
103    pub fn is_empty(&self) -> bool {
104        self.protocol_fields.is_empty()
105    }
106
107    /// Check if chain fields should be included.
108    pub fn include_chain_fields(&self) -> bool {
109        self.include_chain_fields
110    }
111
112    /// Get all protocol names with configured projections.
113    pub fn protocols(&self) -> impl Iterator<Item = &str> {
114        self.protocol_fields.keys().map(|s| s.as_str())
115    }
116
117    /// Create a projection config from DataFusion projection indices.
118    ///
119    /// Converts projection indices to field names based on schema.
120    ///
121    /// # Arguments
122    ///
123    /// * `protocol` - Protocol name
124    /// * `field_names` - Iterator of field names to include
125    pub fn from_field_names<I, S>(protocol: &str, field_names: I) -> Self
126    where
127        I: IntoIterator<Item = S>,
128        S: AsRef<str>,
129    {
130        let mut config = Self::new();
131        config.add_protocol_fields(protocol, field_names);
132        config
133    }
134}
135
136/// Fields required for protocol chaining.
137///
138/// These fields must be extracted even if not in the projection,
139/// because they're needed to detect and parse child protocols.
140pub fn chain_fields_for_protocol(protocol: &str) -> &'static [&'static str] {
141    match protocol {
142        "ethernet" => &["ethertype"],
143        "vlan" => &["ethertype"],
144        "ipv4" => &["protocol", "src_ip", "dst_ip"],
145        "ipv6" => &["next_header", "src_ip", "dst_ip"],
146        "tcp" => &["src_port", "dst_port"],
147        "udp" => &["src_port", "dst_port"],
148        "gre" => &["protocol_type"],
149        "mpls" => &["bottom_of_stack"],
150        "vxlan" => &["vni"],
151        "gtp" => &["teid"],
152        _ => &[],
153    }
154}
155
156/// Merge projection with chain fields if needed.
157///
158/// Returns a new set containing the projection fields plus any
159/// chain fields needed for protocol detection.
160pub fn merge_with_chain_fields(
161    protocol: &str,
162    projection: Option<&HashSet<String>>,
163    include_chain: bool,
164) -> Option<HashSet<String>> {
165    let projection = projection?;
166
167    if !include_chain {
168        return Some(projection.clone());
169    }
170
171    let chain_fields = chain_fields_for_protocol(protocol);
172    if chain_fields.is_empty() {
173        return Some(projection.clone());
174    }
175
176    let mut merged = projection.clone();
177    for field in chain_fields {
178        merged.insert((*field).to_string());
179    }
180
181    Some(merged)
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_projection_config_new() {
190        let config = ProjectionConfig::new();
191        assert!(config.is_empty());
192        assert!(config.include_chain_fields());
193    }
194
195    #[test]
196    fn test_add_protocol_fields() {
197        let mut config = ProjectionConfig::new();
198        config.add_protocol_fields("tcp", &["src_port", "dst_port"]);
199
200        let fields = config.get("tcp");
201        assert!(fields.is_some());
202        let fields = fields.unwrap();
203        assert!(fields.contains("src_port"));
204        assert!(fields.contains("dst_port"));
205        assert!(!fields.contains("seq"));
206    }
207
208    #[test]
209    fn test_builder_pattern() {
210        let config = ProjectionConfig::new()
211            .with_protocol_fields("tcp", &["src_port", "dst_port"])
212            .with_protocol_fields("udp", &["src_port", "dst_port", "length"]);
213
214        assert!(!config.is_empty());
215        assert!(config.get("tcp").is_some());
216        assert!(config.get("udp").is_some());
217        assert!(config.get("dns").is_none());
218    }
219
220    #[test]
221    fn test_chain_fields() {
222        assert!(!chain_fields_for_protocol("ethernet").is_empty());
223        assert!(!chain_fields_for_protocol("ipv4").is_empty());
224        assert!(!chain_fields_for_protocol("tcp").is_empty());
225        assert!(chain_fields_for_protocol("dns").is_empty());
226    }
227
228    #[test]
229    fn test_merge_with_chain_fields() {
230        let projection: HashSet<String> = ["src_port"].iter().map(|s| s.to_string()).collect();
231
232        // With chain fields enabled
233        let merged = merge_with_chain_fields("tcp", Some(&projection), true);
234        assert!(merged.is_some());
235        let merged = merged.unwrap();
236        assert!(merged.contains("src_port"));
237        assert!(merged.contains("dst_port")); // Chain field added
238
239        // Without chain fields
240        let merged = merge_with_chain_fields("tcp", Some(&projection), false);
241        assert!(merged.is_some());
242        let merged = merged.unwrap();
243        assert!(merged.contains("src_port"));
244        assert!(!merged.contains("dst_port")); // Chain field NOT added
245    }
246
247    #[test]
248    fn test_from_field_names() {
249        let config = ProjectionConfig::from_field_names("dns", &["query_name", "query_type"]);
250
251        let fields = config.get("dns").unwrap();
252        assert_eq!(fields.len(), 2);
253        assert!(fields.contains("query_name"));
254        assert!(fields.contains("query_type"));
255    }
256}