Skip to main content

packet_strata/packet/
udp.rs

1//! UDP (User Datagram Protocol) packet parser
2//!
3//! This module implements parsing for UDP datagrams as defined in RFC 768.
4//! UDP provides a simple, connectionless transport service with no guarantees
5//! of delivery, ordering, or duplicate protection.
6//!
7//! # UDP Header Format
8//!
9//! ```text
10//!  0                   1                   2                   3
11//!  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
12//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
13//! |          Source Port          |       Destination Port        |
14//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
15//! |            Length             |           Checksum            |
16//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
17//! ```
18//!
19//! # Key characteristics
20//!
21//! - Header size: 8 bytes (fixed)
22//! - Length field: includes header + payload
23//! - Checksum: optional in IPv4, mandatory in IPv6
24//!
25//! # Examples
26//!
27//! ## Basic UDP parsing
28//!
29//! ```
30//! use packet_strata::packet::udp::UdpHeader;
31//! use packet_strata::packet::HeaderParser;
32//!
33//! // UDP packet with DNS query
34//! let packet = vec![
35//!     0xC0, 0x00,        // Source port: 49152
36//!     0x00, 0x35,        // Destination port: 53 (DNS)
37//!     0x00, 0x10,        // Length: 16 bytes (8 header + 8 payload)
38//!     0x00, 0x00,        // Checksum: 0 (not computed)
39//!     // DNS payload follows...
40//!     0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
41//! ];
42//!
43//! let (header, payload) = UdpHeader::from_bytes(&packet).unwrap();
44//! assert_eq!(header.src_port(), 49152);
45//! assert_eq!(header.dst_port(), 53);
46//! assert_eq!(header.length(), 16);
47//! assert_eq!(payload.len(), 8);
48//! ```
49//!
50//! ## UDP with well-known ports
51//!
52//! ```
53//! use packet_strata::packet::udp::UdpHeader;
54//! use packet_strata::packet::HeaderParser;
55//!
56//! // UDP packet for DHCP
57//! let packet = vec![
58//!     0x00, 0x44,        // Source port: 68 (DHCP client)
59//!     0x00, 0x43,        // Destination port: 67 (DHCP server)
60//!     0x00, 0x08,        // Length: 8 bytes (header only)
61//!     0x12, 0x34,        // Checksum
62//! ];
63//!
64//! let (header, _) = UdpHeader::from_bytes(&packet).unwrap();
65//! assert_eq!(header.src_port(), 68);
66//! assert_eq!(header.dst_port(), 67);
67//! assert_eq!(header.checksum(), 0x1234);
68//! ```
69
70use std::fmt::{self, Formatter};
71use std::mem;
72
73use zerocopy::byteorder::{BigEndian, U16};
74use zerocopy::{FromBytes, IntoBytes, Unaligned};
75
76use crate::packet::{HeaderParser, PacketHeader};
77
78/// UDP Header structure as defined in RFC 768
79///
80/// The UDP header is always 8 bytes and contains source port, destination port,
81/// length, and checksum fields.
82#[repr(C, packed)]
83#[derive(
84    FromBytes, IntoBytes, Unaligned, Debug, Clone, Copy, zerocopy::KnownLayout, zerocopy::Immutable,
85)]
86pub struct UdpHeader {
87    src_port: U16<BigEndian>,
88    dst_port: U16<BigEndian>,
89    length: U16<BigEndian>,
90    checksum: U16<BigEndian>,
91}
92
93impl UdpHeader {
94    /// Returns the source port number
95    #[inline]
96    pub fn src_port(&self) -> u16 {
97        self.src_port.get()
98    }
99
100    /// Returns the destination port number
101    #[inline]
102    pub fn dst_port(&self) -> u16 {
103        self.dst_port.get()
104    }
105
106    /// Returns the total length of the UDP datagram (header + data)
107    #[inline]
108    pub fn length(&self) -> u16 {
109        self.length.get()
110    }
111
112    /// Returns the checksum value
113    #[inline]
114    pub fn checksum(&self) -> u16 {
115        self.checksum.get()
116    }
117
118    /// Returns the length of the UDP header (always 8 bytes)
119    #[inline]
120    pub fn header_len(&self) -> usize {
121        mem::size_of::<UdpHeader>()
122    }
123
124    /// Returns the length of the payload data
125    #[inline]
126    pub fn payload_len(&self) -> usize {
127        let total = self.length() as usize;
128        total.saturating_sub(Self::FIXED_LEN)
129    }
130
131    /// Validates the UDP header
132    #[inline]
133    pub fn is_valid(&self) -> bool {
134        // UDP length must be at least 8 bytes (header size)
135        self.length() >= Self::FIXED_LEN as u16
136    }
137
138    /// Verify UDP checksum (requires pseudo-header from IP layer)
139    ///
140    /// Note: For IPv4, the checksum is optional (can be 0)
141    /// For IPv6, the checksum is mandatory
142    pub fn verify_checksum(&self, src_ip: u32, dst_ip: u32, udp_data: &[u8]) -> bool {
143        let checksum = self.checksum();
144
145        // Checksum of 0 means no checksum was computed (valid for IPv4)
146        if checksum == 0 {
147            return true;
148        }
149
150        let computed = Self::compute_checksum(src_ip, dst_ip, udp_data);
151        computed == checksum
152    }
153
154    /// Compute UDP checksum including pseudo-header
155    pub fn compute_checksum(src_ip: u32, dst_ip: u32, udp_data: &[u8]) -> u16 {
156        let mut sum: u32 = 0;
157
158        // Pseudo-header: source IP
159        sum += (src_ip >> 16) & 0xFFFF;
160        sum += src_ip & 0xFFFF;
161
162        // Pseudo-header: destination IP
163        sum += (dst_ip >> 16) & 0xFFFF;
164        sum += dst_ip & 0xFFFF;
165
166        // Pseudo-header: protocol (17 for UDP)
167        sum += 17;
168
169        // Pseudo-header: UDP length
170        sum += udp_data.len() as u32;
171
172        // UDP header and data
173        let mut i = 0;
174        while i < udp_data.len() {
175            if i + 1 < udp_data.len() {
176                let word = u16::from_be_bytes([udp_data[i], udp_data[i + 1]]);
177                sum += word as u32;
178                i += 2;
179            } else {
180                // Odd length: pad with zero
181                let word = u16::from_be_bytes([udp_data[i], 0]);
182                sum += word as u32;
183                i += 1;
184            }
185        }
186
187        // Fold 32-bit sum to 16 bits
188        while sum >> 16 != 0 {
189            sum = (sum & 0xFFFF) + (sum >> 16);
190        }
191
192        // One's complement
193        !sum as u16
194    }
195}
196
197impl PacketHeader for UdpHeader {
198    const NAME: &'static str = "UdpHeader";
199
200    #[inline]
201    fn is_valid(&self) -> bool {
202        self.is_valid()
203    }
204
205    type InnerType = ();
206
207    #[inline]
208    fn inner_type(&self) -> Self::InnerType {}
209}
210
211impl HeaderParser for UdpHeader {
212    type Output<'a> = &'a UdpHeader;
213
214    #[inline]
215    fn into_view<'a>(header: &'a Self, _: &'a [u8]) -> Self::Output<'a> {
216        header
217    }
218}
219
220impl fmt::Display for UdpHeader {
221    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
222        write!(
223            f,
224            "UDP {} -> {} len={}",
225            self.src_port(),
226            self.dst_port(),
227            self.length()
228        )
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_udp_header_basic() {
238        let header = UdpHeader {
239            src_port: U16::new(53),
240            dst_port: U16::new(12345),
241            length: U16::new(16), // 8 bytes header + 8 bytes payload
242            checksum: U16::new(0),
243        };
244
245        assert_eq!(header.src_port(), 53);
246        assert_eq!(header.dst_port(), 12345);
247        assert_eq!(header.length(), 16);
248        assert_eq!(header.header_len(), 8);
249        assert_eq!(header.payload_len(), 8);
250        assert!(header.is_valid());
251    }
252
253    #[test]
254    fn test_udp_header_validation() {
255        let invalid_header = UdpHeader {
256            src_port: U16::new(53),
257            dst_port: U16::new(12345),
258            length: U16::new(7), // Too small
259            checksum: U16::new(0),
260        };
261
262        assert!(!invalid_header.is_valid());
263
264        let valid_header = UdpHeader {
265            src_port: U16::new(53),
266            dst_port: U16::new(12345),
267            length: U16::new(8), // Minimum valid size
268            checksum: U16::new(0),
269        };
270
271        assert!(valid_header.is_valid());
272    }
273
274    #[test]
275    fn test_udp_checksum_zero() {
276        let header = UdpHeader {
277            src_port: U16::new(53),
278            dst_port: U16::new(12345),
279            length: U16::new(8),
280            checksum: U16::new(0),
281        };
282
283        // Checksum of 0 should always be valid for IPv4
284        assert!(header.verify_checksum(0x7f000001, 0x7f000001, &[]));
285    }
286
287    #[test]
288    fn test_udp_header_size() {
289        assert_eq!(mem::size_of::<UdpHeader>(), 8);
290        assert_eq!(UdpHeader::FIXED_LEN, 8);
291    }
292
293    #[test]
294    fn test_udp_parsing_basic() {
295        let packet = create_test_packet();
296
297        let result = UdpHeader::from_bytes(&packet);
298        assert!(result.is_ok());
299
300        let (header, payload) = result.unwrap();
301        assert_eq!(header.src_port(), 12345);
302        assert_eq!(header.dst_port(), 53);
303        assert_eq!(header.length(), 16);
304        assert_eq!(payload.len(), 8); // "DNS data" payload
305        assert!(header.is_valid());
306    }
307
308    #[test]
309    fn test_udp_parsing_too_small() {
310        let packet = vec![0u8; 7]; // Only 7 bytes, need 8
311
312        let result = UdpHeader::from_bytes(&packet);
313        assert!(result.is_err());
314    }
315
316    #[test]
317    fn test_udp_total_len() {
318        let packet = create_test_packet();
319        let (header, _) = UdpHeader::from_bytes(&packet).unwrap();
320
321        // UDP header is always 8 bytes (no options like TCP)
322        assert_eq!(header.total_len(&packet), 8);
323    }
324
325    #[test]
326    fn test_udp_from_bytes_with_payload() {
327        let mut packet = Vec::new();
328
329        // UDP Header
330        packet.extend_from_slice(&5000u16.to_be_bytes()); // Source port
331        packet.extend_from_slice(&8080u16.to_be_bytes()); // Destination port
332
333        let payload_data = b"Hello, UDP!";
334        let total_length = 8 + payload_data.len();
335
336        packet.extend_from_slice(&(total_length as u16).to_be_bytes()); // Length
337        packet.extend_from_slice(&0u16.to_be_bytes()); // Checksum (not computed)
338
339        // Add payload
340        packet.extend_from_slice(payload_data);
341
342        let result = UdpHeader::from_bytes(&packet);
343        assert!(result.is_ok());
344
345        let (header, payload) = result.unwrap();
346
347        // Verify header fields
348        assert_eq!(header.src_port(), 5000);
349        assert_eq!(header.dst_port(), 8080);
350        assert_eq!(header.length(), total_length as u16);
351        assert_eq!(header.payload_len(), payload_data.len());
352
353        // Verify payload separation
354        assert_eq!(payload.len(), payload_data.len());
355        assert_eq!(payload, payload_data);
356    }
357
358    #[test]
359    fn test_udp_payload_length_calculation() {
360        let header1 = UdpHeader {
361            src_port: U16::new(1234),
362            dst_port: U16::new(5678),
363            length: U16::new(8), // Header only, no payload
364            checksum: U16::new(0),
365        };
366        assert_eq!(header1.payload_len(), 0);
367
368        let header2 = UdpHeader {
369            src_port: U16::new(1234),
370            dst_port: U16::new(5678),
371            length: U16::new(100), // 8 bytes header + 92 bytes payload
372            checksum: U16::new(0),
373        };
374        assert_eq!(header2.payload_len(), 92);
375
376        // Invalid length (less than header size)
377        let header3 = UdpHeader {
378            src_port: U16::new(1234),
379            dst_port: U16::new(5678),
380            length: U16::new(5), // Invalid
381            checksum: U16::new(0),
382        };
383        assert_eq!(header3.payload_len(), 0);
384    }
385
386    #[test]
387    fn test_udp_dns_packet() {
388        let mut packet = Vec::new();
389
390        // UDP Header for DNS query
391        packet.extend_from_slice(&54321u16.to_be_bytes()); // Source port (ephemeral)
392        packet.extend_from_slice(&53u16.to_be_bytes()); // Destination port (DNS)
393
394        // Simplified DNS query payload
395        let dns_payload = vec![
396            0xab, 0xcd, // Transaction ID
397            0x01, 0x00, // Flags (standard query)
398            0x00, 0x01, // Questions: 1
399            0x00, 0x00, // Answer RRs: 0
400            0x00, 0x00, // Authority RRs: 0
401            0x00, 0x00, // Additional RRs: 0
402        ];
403
404        let total_length = 8 + dns_payload.len();
405        packet.extend_from_slice(&(total_length as u16).to_be_bytes()); // Length
406        packet.extend_from_slice(&0u16.to_be_bytes()); // Checksum
407
408        // Add DNS payload
409        packet.extend_from_slice(&dns_payload);
410
411        let (header, payload) = UdpHeader::from_bytes(&packet).unwrap();
412
413        assert_eq!(header.src_port(), 54321);
414        assert_eq!(header.dst_port(), 53);
415        assert_eq!(header.length(), total_length as u16);
416        assert_eq!(payload.len(), dns_payload.len());
417        assert_eq!(payload, dns_payload.as_slice());
418    }
419
420    #[test]
421    fn test_udp_checksum_computation() {
422        // Simple test with known values
423        let src_ip = 0xC0A80101; // 192.168.1.1
424        let dst_ip = 0xC0A80102; // 192.168.1.2
425
426        let mut udp_packet = Vec::new();
427        udp_packet.extend_from_slice(&12345u16.to_be_bytes()); // Source port
428        udp_packet.extend_from_slice(&80u16.to_be_bytes()); // Destination port
429        udp_packet.extend_from_slice(&12u16.to_be_bytes()); // Length
430        udp_packet.extend_from_slice(&0u16.to_be_bytes()); // Checksum (zero for computation)
431
432        // Add 4 bytes of payload
433        udp_packet.extend_from_slice(b"test");
434
435        let checksum = UdpHeader::compute_checksum(src_ip, dst_ip, &udp_packet);
436
437        // Checksum should be non-zero for non-empty data
438        assert_ne!(checksum, 0);
439    }
440
441    #[test]
442    fn test_udp_multiple_packets() {
443        // Test parsing multiple different UDP packets
444        let packets: Vec<(u16, u16, Vec<u8>)> = vec![
445            (1234, 5678, b"payload1".to_vec()),
446            (80, 54321, b"HTTP response".to_vec()),
447            (53, 12345, b"DNS".to_vec()),
448        ];
449
450        for (src, dst, payload_data) in packets {
451            let mut packet = Vec::new();
452            packet.extend_from_slice(&src.to_be_bytes());
453            packet.extend_from_slice(&dst.to_be_bytes());
454            packet.extend_from_slice(&((8 + payload_data.len()) as u16).to_be_bytes());
455            packet.extend_from_slice(&0u16.to_be_bytes());
456            packet.extend_from_slice(&payload_data);
457
458            let (header, payload) = UdpHeader::from_bytes(&packet).unwrap();
459            assert_eq!(header.src_port(), src);
460            assert_eq!(header.dst_port(), dst);
461            assert_eq!(payload, payload_data.as_slice());
462        }
463    }
464
465    // Helper function to create a test UDP packet
466    fn create_test_packet() -> Vec<u8> {
467        let mut packet = Vec::new();
468
469        // Source port: 12345
470        packet.extend_from_slice(&12345u16.to_be_bytes());
471
472        // Destination port: 53 (DNS)
473        packet.extend_from_slice(&53u16.to_be_bytes());
474
475        // Length: 16 (8 bytes header + 8 bytes payload)
476        packet.extend_from_slice(&16u16.to_be_bytes());
477
478        // Checksum: 0
479        packet.extend_from_slice(&0u16.to_be_bytes());
480
481        // Payload: "DNS data"
482        packet.extend_from_slice(b"DNS data");
483
484        packet
485    }
486}