1use std::collections::HashSet;
4
5use smallvec::SmallVec;
6
7use etherparse::UdpHeaderSlice;
8
9use super::{FieldValue, ParseContext, ParseResult, Protocol};
10use crate::schema::{DataKind, FieldDescriptor};
11
12pub const IP_PROTO_UDP: u8 = 17;
14
15#[derive(Debug, Clone, Copy)]
17pub struct UdpProtocol;
18
19impl Protocol for UdpProtocol {
20 fn name(&self) -> &'static str {
21 "udp"
22 }
23
24 fn display_name(&self) -> &'static str {
25 "UDP"
26 }
27
28 fn can_parse(&self, context: &ParseContext) -> Option<u32> {
29 match context.hint("ip_protocol") {
30 Some(proto) if proto == IP_PROTO_UDP as u64 => Some(100),
31 _ => None,
32 }
33 }
34
35 fn parse<'a>(&self, data: &'a [u8], _context: &ParseContext) -> ParseResult<'a> {
36 match UdpHeaderSlice::from_slice(data) {
37 Ok(udp) => {
38 let mut fields = SmallVec::new();
39
40 fields.push(("src_port", FieldValue::UInt16(udp.source_port())));
41 fields.push(("dst_port", FieldValue::UInt16(udp.destination_port())));
42 fields.push(("length", FieldValue::UInt16(udp.length())));
43 fields.push(("checksum", FieldValue::UInt16(udp.checksum())));
44
45 let mut child_hints = SmallVec::new();
46 child_hints.push(("src_port", udp.source_port() as u64));
47 child_hints.push(("dst_port", udp.destination_port() as u64));
48 child_hints.push(("transport", 17)); ParseResult::success(fields, &data[8..], child_hints)
52 }
53 Err(e) => ParseResult::error(format!("UDP parse error: {e}"), data),
54 }
55 }
56
57 fn schema_fields(&self) -> Vec<FieldDescriptor> {
58 vec![
59 FieldDescriptor::new("udp.src_port", DataKind::UInt16).set_nullable(true),
60 FieldDescriptor::new("udp.dst_port", DataKind::UInt16).set_nullable(true),
61 FieldDescriptor::new("udp.length", DataKind::UInt16).set_nullable(true),
62 FieldDescriptor::new("udp.checksum", DataKind::UInt16).set_nullable(true),
63 ]
64 }
65
66 fn child_protocols(&self) -> &[&'static str] {
67 &["dns", "dhcp", "ntp"]
68 }
69
70 fn dependencies(&self) -> &'static [&'static str] {
71 &["ipv4", "ipv6"]
72 }
73
74 fn parse_projected<'a>(
75 &self,
76 data: &'a [u8],
77 _context: &ParseContext,
78 fields: Option<&HashSet<String>>,
79 ) -> ParseResult<'a> {
80 let fields = match fields {
82 None => return self.parse(data, _context),
83 Some(f) if f.is_empty() => return self.parse(data, _context),
84 Some(f) => f,
85 };
86
87 match UdpHeaderSlice::from_slice(data) {
88 Ok(udp) => {
89 let mut result_fields = SmallVec::new();
90
91 let src_port = udp.source_port();
93 let dst_port = udp.destination_port();
94
95 if fields.contains("src_port") {
97 result_fields.push(("src_port", FieldValue::UInt16(src_port)));
98 }
99 if fields.contains("dst_port") {
100 result_fields.push(("dst_port", FieldValue::UInt16(dst_port)));
101 }
102 if fields.contains("length") {
103 result_fields.push(("length", FieldValue::UInt16(udp.length())));
104 }
105 if fields.contains("checksum") {
106 result_fields.push(("checksum", FieldValue::UInt16(udp.checksum())));
107 }
108
109 let mut child_hints = SmallVec::new();
110 child_hints.push(("src_port", src_port as u64));
111 child_hints.push(("dst_port", dst_port as u64));
112 child_hints.push(("transport", 17)); ParseResult::success(result_fields, &data[8..], child_hints)
116 }
117 Err(e) => ParseResult::error(format!("UDP parse error: {e}"), data),
118 }
119 }
120
121 fn cheap_fields(&self) -> &'static [&'static str] {
122 &["src_port", "dst_port", "length", "checksum"]
124 }
125
126 fn expensive_fields(&self) -> &'static [&'static str] {
127 &[]
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135
136 #[test]
137 fn test_parse_udp() {
138 let header = [
140 0x00, 0x35, 0xc0, 0x00, 0x00, 0x20, 0x00, 0x00, 0xde, 0xad, 0xbe, 0xef,
146 ];
147
148 let parser = UdpProtocol;
149 let mut context = ParseContext::new(1);
150 context.insert_hint("ip_protocol", 17);
151
152 let result = parser.parse(&header, &context);
153
154 assert!(result.is_ok());
155 assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(53)));
156 assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(49152)));
157 assert_eq!(result.get("length"), Some(&FieldValue::UInt16(32)));
158 assert_eq!(result.remaining.len(), 4); }
160
161 #[test]
162 fn test_parse_udp_dns_query() {
163 let header = [
164 0xc3, 0x50, 0x00, 0x35, 0x00, 0x1c, 0xab, 0xcd, 0x12, 0x34, 0x01, 0x00,
170 ];
171
172 let parser = UdpProtocol;
173 let mut context = ParseContext::new(1);
174 context.insert_hint("ip_protocol", 17);
175
176 let result = parser.parse(&header, &context);
177
178 assert!(result.is_ok());
179 assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(50000)));
180 assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(53)));
181 assert_eq!(result.hint("dst_port"), Some(53u64));
182 }
183
184 #[test]
185 fn test_parse_udp_dhcp() {
186 let header = [
187 0x00, 0x44, 0x00, 0x43, 0x01, 0x00, 0x00, 0x00, ];
192
193 let parser = UdpProtocol;
194 let mut context = ParseContext::new(1);
195 context.insert_hint("ip_protocol", 17);
196
197 let result = parser.parse(&header, &context);
198
199 assert!(result.is_ok());
200 assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(68)));
201 assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(67)));
202 }
203
204 #[test]
205 fn test_can_parse_udp() {
206 let parser = UdpProtocol;
207
208 let ctx1 = ParseContext::new(1);
210 assert!(parser.can_parse(&ctx1).is_none());
211
212 let mut ctx2 = ParseContext::new(1);
214 ctx2.insert_hint("ip_protocol", 6);
215 assert!(parser.can_parse(&ctx2).is_none());
216
217 let mut ctx3 = ParseContext::new(1);
219 ctx3.insert_hint("ip_protocol", 17);
220 assert!(parser.can_parse(&ctx3).is_some());
221 }
222
223 #[test]
224 fn test_parse_udp_too_short() {
225 let short_header = [0x00, 0x35, 0xc0, 0x00]; let parser = UdpProtocol;
228 let mut context = ParseContext::new(1);
229 context.insert_hint("ip_protocol", 17);
230
231 let result = parser.parse(&short_header, &context);
232
233 assert!(!result.is_ok());
234 assert!(result.error.is_some());
235 }
236
237 #[test]
238 fn test_udp_child_hints() {
239 let header = [
240 0x12, 0x34, 0x56, 0x78, 0x00, 0x10, 0x00, 0x00, ];
245
246 let parser = UdpProtocol;
247 let mut context = ParseContext::new(1);
248 context.insert_hint("ip_protocol", 17);
249
250 let result = parser.parse(&header, &context);
251
252 assert!(result.is_ok());
253 assert_eq!(result.hint("src_port"), Some(4660u64));
254 assert_eq!(result.hint("dst_port"), Some(22136u64));
255 assert_eq!(result.hint("transport"), Some(17u64));
256 }
257
258 #[test]
259 fn test_udp_minimal_header() {
260 let header = [
262 0x00, 0x50, 0x00, 0x51, 0x00, 0x08, 0x00, 0x00, ];
267
268 let parser = UdpProtocol;
269 let mut context = ParseContext::new(1);
270 context.insert_hint("ip_protocol", 17);
271
272 let result = parser.parse(&header, &context);
273
274 assert!(result.is_ok());
275 assert_eq!(result.get("length"), Some(&FieldValue::UInt16(8)));
276 assert!(result.remaining.is_empty());
277 }
278
279 #[test]
280 fn test_udp_projected_parsing_ports_only() {
281 let header = [
282 0x00, 0x35, 0xc0, 0x00, 0x00, 0x20, 0xab, 0xcd, ];
287
288 let parser = UdpProtocol;
289 let mut context = ParseContext::new(1);
290 context.insert_hint("ip_protocol", 17);
291
292 let fields: HashSet<String> = ["src_port", "dst_port"]
294 .iter()
295 .map(|s| s.to_string())
296 .collect();
297 let result = parser.parse_projected(&header, &context, Some(&fields));
298
299 assert!(result.is_ok());
300 assert_eq!(result.get("src_port"), Some(&FieldValue::UInt16(53)));
302 assert_eq!(result.get("dst_port"), Some(&FieldValue::UInt16(49152)));
303 assert!(result.get("length").is_none());
305 assert!(result.get("checksum").is_none());
306 assert_eq!(result.hint("src_port"), Some(53u64));
308 assert_eq!(result.hint("dst_port"), Some(49152u64));
309 }
310}