pcapsql_core/protocol/
projection.rs1use std::collections::{HashMap, HashSet};
23
24#[derive(Debug, Clone, Default)]
29pub struct ProjectionConfig {
30 protocol_fields: HashMap<String, HashSet<String>>,
33
34 include_chain_fields: bool,
37}
38
39impl ProjectionConfig {
40 pub fn new() -> Self {
42 Self {
43 protocol_fields: HashMap::new(),
44 include_chain_fields: true,
45 }
46 }
47
48 pub fn with_chain_fields(mut self) -> Self {
53 self.include_chain_fields = true;
54 self
55 }
56
57 pub fn without_chain_fields(mut self) -> Self {
59 self.include_chain_fields = false;
60 self
61 }
62
63 pub fn add_protocol_fields<I, S>(&mut self, protocol: &str, fields: I)
70 where
71 I: IntoIterator<Item = S>,
72 S: AsRef<str>,
73 {
74 let field_set = self
75 .protocol_fields
76 .entry(protocol.to_string())
77 .or_default();
78
79 for field in fields {
80 field_set.insert(field.as_ref().to_string());
81 }
82 }
83
84 pub fn with_protocol_fields<I, S>(mut self, protocol: &str, fields: I) -> Self
86 where
87 I: IntoIterator<Item = S>,
88 S: AsRef<str>,
89 {
90 self.add_protocol_fields(protocol, fields);
91 self
92 }
93
94 pub fn get(&self, protocol: &str) -> Option<&HashSet<String>> {
99 self.protocol_fields.get(protocol)
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.protocol_fields.is_empty()
105 }
106
107 pub fn include_chain_fields(&self) -> bool {
109 self.include_chain_fields
110 }
111
112 pub fn protocols(&self) -> impl Iterator<Item = &str> {
114 self.protocol_fields.keys().map(|s| s.as_str())
115 }
116
117 pub fn from_field_names<I, S>(protocol: &str, field_names: I) -> Self
126 where
127 I: IntoIterator<Item = S>,
128 S: AsRef<str>,
129 {
130 let mut config = Self::new();
131 config.add_protocol_fields(protocol, field_names);
132 config
133 }
134}
135
136pub fn chain_fields_for_protocol(protocol: &str) -> &'static [&'static str] {
141 match protocol {
142 "ethernet" => &["ethertype"],
143 "vlan" => &["ethertype"],
144 "ipv4" => &["protocol", "src_ip", "dst_ip"],
145 "ipv6" => &["next_header", "src_ip", "dst_ip"],
146 "tcp" => &["src_port", "dst_port"],
147 "udp" => &["src_port", "dst_port"],
148 "gre" => &["protocol_type"],
149 "mpls" => &["bottom_of_stack"],
150 "vxlan" => &["vni"],
151 "gtp" => &["teid"],
152 _ => &[],
153 }
154}
155
156pub fn merge_with_chain_fields(
161 protocol: &str,
162 projection: Option<&HashSet<String>>,
163 include_chain: bool,
164) -> Option<HashSet<String>> {
165 let projection = projection?;
166
167 if !include_chain {
168 return Some(projection.clone());
169 }
170
171 let chain_fields = chain_fields_for_protocol(protocol);
172 if chain_fields.is_empty() {
173 return Some(projection.clone());
174 }
175
176 let mut merged = projection.clone();
177 for field in chain_fields {
178 merged.insert((*field).to_string());
179 }
180
181 Some(merged)
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187
188 #[test]
189 fn test_projection_config_new() {
190 let config = ProjectionConfig::new();
191 assert!(config.is_empty());
192 assert!(config.include_chain_fields());
193 }
194
195 #[test]
196 fn test_add_protocol_fields() {
197 let mut config = ProjectionConfig::new();
198 config.add_protocol_fields("tcp", &["src_port", "dst_port"]);
199
200 let fields = config.get("tcp");
201 assert!(fields.is_some());
202 let fields = fields.unwrap();
203 assert!(fields.contains("src_port"));
204 assert!(fields.contains("dst_port"));
205 assert!(!fields.contains("seq"));
206 }
207
208 #[test]
209 fn test_builder_pattern() {
210 let config = ProjectionConfig::new()
211 .with_protocol_fields("tcp", &["src_port", "dst_port"])
212 .with_protocol_fields("udp", &["src_port", "dst_port", "length"]);
213
214 assert!(!config.is_empty());
215 assert!(config.get("tcp").is_some());
216 assert!(config.get("udp").is_some());
217 assert!(config.get("dns").is_none());
218 }
219
220 #[test]
221 fn test_chain_fields() {
222 assert!(!chain_fields_for_protocol("ethernet").is_empty());
223 assert!(!chain_fields_for_protocol("ipv4").is_empty());
224 assert!(!chain_fields_for_protocol("tcp").is_empty());
225 assert!(chain_fields_for_protocol("dns").is_empty());
226 }
227
228 #[test]
229 fn test_merge_with_chain_fields() {
230 let projection: HashSet<String> = ["src_port"].iter().map(|s| s.to_string()).collect();
231
232 let merged = merge_with_chain_fields("tcp", Some(&projection), true);
234 assert!(merged.is_some());
235 let merged = merged.unwrap();
236 assert!(merged.contains("src_port"));
237 assert!(merged.contains("dst_port")); let merged = merge_with_chain_fields("tcp", Some(&projection), false);
241 assert!(merged.is_some());
242 let merged = merged.unwrap();
243 assert!(merged.contains("src_port"));
244 assert!(!merged.contains("dst_port")); }
246
247 #[test]
248 fn test_from_field_names() {
249 let config = ProjectionConfig::from_field_names("dns", &["query_name", "query_type"]);
250
251 let fields = config.get("dns").unwrap();
252 assert_eq!(fields.len(), 2);
253 assert!(fields.contains("query_name"));
254 assert!(fields.contains("query_type"));
255 }
256}