pcapsql_core/protocol/
udp.rs

1//! UDP protocol parser.
2
3use std::collections::HashSet;
4
5use smallvec::SmallVec;
6
7use etherparse::UdpHeaderSlice;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12/// IP protocol number for UDP.
13pub const IP_PROTO_UDP: u8 = 17;
14
15/// UDP protocol parser.
16#[derive(Debug, Clone, Copy)]
17pub struct UdpProtocol;
18
19impl Protocol for UdpProtocol {
20    fn name(&self) -> &'static str {
21        "udp"
22    }
23
24    fn display_name(&self) -> &'static str {
25        "UDP"
26    }
27
28    fn can_parse(&self, context: &ParseContext) -> Option<u32> {
29        match context.hint("ip_protocol") {
30            Some(proto) if proto == IP_PROTO_UDP as u64 => Some(100),
31            _ => None,
32        }
33    }
34
35    fn parse<'a>(&self, data: &'a [u8], _context: &ParseContext) -> ParseResult<'a> {
36        match UdpHeaderSlice::from_slice(data) {
37            Ok(udp) => {
38                let mut fields = SmallVec::new();
39
40                fields.push(("src_port", FieldValue::UInt16(udp.source_port())));
41                fields.push(("dst_port", FieldValue::UInt16(udp.destination_port())));
42                fields.push(("length", FieldValue::UInt16(udp.length())));
43                fields.push(("checksum", FieldValue::UInt16(udp.checksum())));
44
45                let mut child_hints = SmallVec::new();
46                child_hints.push(("src_port", udp.source_port() as u64));
47                child_hints.push(("dst_port", udp.destination_port() as u64));
48                child_hints.push(("transport", 17)); // UDP
49
50                // UDP header is always 8 bytes
51                ParseResult::success(fields, &data[8..], child_hints)
52            }
53            Err(e) => ParseResult::error(format!("UDP parse error: {e}"), data),
54        }
55    }
56
57    fn schema_fields(&self) -> Vec<FieldDescriptor> {
58        vec![
59            FieldDescriptor::new("udp.src_port", DataKind::UInt16).set_nullable(true),
60            FieldDescriptor::new("udp.dst_port", DataKind::UInt16).set_nullable(true),
61            FieldDescriptor::new("udp.length", DataKind::UInt16).set_nullable(true),
62            FieldDescriptor::new("udp.checksum", DataKind::UInt16).set_nullable(true),
63        ]
64    }
65
66    fn child_protocols(&self) -> &[&'static str] {
67        &["dns", "dhcp", "ntp"]
68    }
69
70    fn dependencies(&self) -> &'static [&'static str] {
71        &["ipv4", "ipv6"]
72    }
73
74    fn parse_projected<'a>(
75        &self,
76        data: &'a [u8],
77        _context: &ParseContext,
78        fields: Option<&HashSet<String>>,
79    ) -> ParseResult<'a> {
80        // If no projection, use full parse
81        let fields = match fields {
82            None => return self.parse(data, _context),
83            Some(f) if f.is_empty() => return self.parse(data, _context),
84            Some(f) => f,
85        };
86
87        match UdpHeaderSlice::from_slice(data) {
88            Ok(udp) => {
89                let mut result_fields = SmallVec::new();
90
91                // Always extract ports for child hints
92                let src_port = udp.source_port();
93                let dst_port = udp.destination_port();
94
95                // Only insert requested fields
96                if fields.contains("src_port") {
97                    result_fields.push(("src_port", FieldValue::UInt16(src_port)));
98                }
99                if fields.contains("dst_port") {
100                    result_fields.push(("dst_port", FieldValue::UInt16(dst_port)));
101                }
102                if fields.contains("length") {
103                    result_fields.push(("length", FieldValue::UInt16(udp.length())));
104                }
105                if fields.contains("checksum") {
106                    result_fields.push(("checksum", FieldValue::UInt16(udp.checksum())));
107                }
108
109                let mut child_hints = SmallVec::new();
110                child_hints.push(("src_port", src_port as u64));
111                child_hints.push(("dst_port", dst_port as u64));
112                child_hints.push(("transport", 17)); // UDP
113
114                // UDP header is always 8 bytes
115                ParseResult::success(result_fields, &data[8..], child_hints)
116            }
117            Err(e) => ParseResult::error(format!("UDP parse error: {e}"), data),
118        }
119    }
120
121    fn cheap_fields(&self) -> &'static [&'static str] {
122        // All UDP fields come from the fixed 8-byte header
123        &["src_port", "dst_port", "length", "checksum"]
124    }
125
126    fn expensive_fields(&self) -> &'static [&'static str] {
127        // UDP has no expensive fields
128        &[]
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_parse_udp() {
138        // UDP header (8 bytes)
139        let header = [
140            0x00, 0x35, // Src port: 53 (DNS)
141            0xc0, 0x00, // Dst port: 49152
142            0x00, 0x20, // Length: 32
143            0x00, 0x00, // Checksum
144            // Payload would follow
145            0xde, 0xad, 0xbe, 0xef,
146        ];
147
148        let parser = UdpProtocol;
149        let mut context = ParseContext::new(1);
150        context.insert_hint("ip_protocol", 17);
151
152        let result = parser.parse(&header, &context);
153
154        assert!(result.is_ok());
155        assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(53)));
156        assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(49152)));
157        assert_eq!(result.get("length"), Some(&FieldValue::UInt16(32)));
158        assert_eq!(result.remaining.len(), 4); // Payload bytes
159    }
160
161    #[test]
162    fn test_parse_udp_dns_query() {
163        let header = [
164            0xc3, 0x50, // Src port: 50000
165            0x00, 0x35, // Dst port: 53 (DNS)
166            0x00, 0x1c, // Length: 28
167            0xab, 0xcd, // Checksum
168            // DNS query payload (simplified)
169            0x12, 0x34, 0x01, 0x00,
170        ];
171
172        let parser = UdpProtocol;
173        let mut context = ParseContext::new(1);
174        context.insert_hint("ip_protocol", 17);
175
176        let result = parser.parse(&header, &context);
177
178        assert!(result.is_ok());
179        assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(50000)));
180        assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(53)));
181        assert_eq!(result.hint("dst_port"), Some(53u64));
182    }
183
184    #[test]
185    fn test_parse_udp_dhcp() {
186        let header = [
187            0x00, 0x44, // Src port: 68 (DHCP client)
188            0x00, 0x43, // Dst port: 67 (DHCP server)
189            0x01, 0x00, // Length: 256
190            0x00, 0x00, // Checksum
191        ];
192
193        let parser = UdpProtocol;
194        let mut context = ParseContext::new(1);
195        context.insert_hint("ip_protocol", 17);
196
197        let result = parser.parse(&header, &context);
198
199        assert!(result.is_ok());
200        assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(68)));
201        assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(67)));
202    }
203
204    #[test]
205    fn test_can_parse_udp() {
206        let parser = UdpProtocol;
207
208        // Without hint
209        let ctx1 = ParseContext::new(1);
210        assert!(parser.can_parse(&ctx1).is_none());
211
212        // With TCP protocol
213        let mut ctx2 = ParseContext::new(1);
214        ctx2.insert_hint("ip_protocol", 6);
215        assert!(parser.can_parse(&ctx2).is_none());
216
217        // With UDP protocol
218        let mut ctx3 = ParseContext::new(1);
219        ctx3.insert_hint("ip_protocol", 17);
220        assert!(parser.can_parse(&ctx3).is_some());
221    }
222
223    #[test]
224    fn test_parse_udp_too_short() {
225        let short_header = [0x00, 0x35, 0xc0, 0x00]; // Only 4 bytes
226
227        let parser = UdpProtocol;
228        let mut context = ParseContext::new(1);
229        context.insert_hint("ip_protocol", 17);
230
231        let result = parser.parse(&short_header, &context);
232
233        assert!(!result.is_ok());
234        assert!(result.error.is_some());
235    }
236
237    #[test]
238    fn test_udp_child_hints() {
239        let header = [
240            0x12, 0x34, // Src port: 4660
241            0x56, 0x78, // Dst port: 22136
242            0x00, 0x10, // Length: 16
243            0x00, 0x00, // Checksum
244        ];
245
246        let parser = UdpProtocol;
247        let mut context = ParseContext::new(1);
248        context.insert_hint("ip_protocol", 17);
249
250        let result = parser.parse(&header, &context);
251
252        assert!(result.is_ok());
253        assert_eq!(result.hint("src_port"), Some(4660u64));
254        assert_eq!(result.hint("dst_port"), Some(22136u64));
255        assert_eq!(result.hint("transport"), Some(17u64));
256    }
257
258    #[test]
259    fn test_udp_minimal_header() {
260        // Exactly 8 bytes (minimum valid UDP)
261        let header = [
262            0x00, 0x50, // Src port: 80
263            0x00, 0x51, // Dst port: 81
264            0x00, 0x08, // Length: 8 (header only)
265            0x00, 0x00, // Checksum
266        ];
267
268        let parser = UdpProtocol;
269        let mut context = ParseContext::new(1);
270        context.insert_hint("ip_protocol", 17);
271
272        let result = parser.parse(&header, &context);
273
274        assert!(result.is_ok());
275        assert_eq!(result.get("length"), Some(&FieldValue::UInt16(8)));
276        assert!(result.remaining.is_empty());
277    }
278
279    #[test]
280    fn test_udp_projected_parsing_ports_only() {
281        let header = [
282            0x00, 0x35, // Src port: 53 (DNS)
283            0xc0, 0x00, // Dst port: 49152
284            0x00, 0x20, // Length: 32
285            0xab, 0xcd, // Checksum
286        ];
287
288        let parser = UdpProtocol;
289        let mut context = ParseContext::new(1);
290        context.insert_hint("ip_protocol", 17);
291
292        // Project to only ports
293        let fields: HashSet<String> = ["src_port", "dst_port"]
294            .iter()
295            .map(|s| s.to_string())
296            .collect();
297        let result = parser.parse_projected(&header, &context, Some(&fields));
298
299        assert!(result.is_ok());
300        // Requested fields are present
301        assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(53)));
302        assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(49152)));
303        // Unrequested fields are NOT present
304        assert!(result.get("length").is_none());
305        assert!(result.get("checksum").is_none());
306        // Child hints are still populated
307        assert_eq!(result.hint("src_port"), Some(53u64));
308        assert_eq!(result.hint("dst_port"), Some(49152u64));
309    }
310}