nex_packet/
udp.rs

1//! A UDP packet abstraction.
2
3use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext};
4use crate::ip::IpNextProtocol;
5use crate::packet::{MutablePacket, Packet};
6
7use crate::util;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use nex_core::bitfield::u16be;
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15/// UDP Header Length
16pub const UDP_HEADER_LEN: usize = 8;
17
18/// Represents the UDP header.
19#[derive(Clone, Debug, PartialEq, Eq)]
20#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21pub struct UdpHeader {
22    pub source: u16be,
23    pub destination: u16be,
24    pub length: u16be,
25    pub checksum: u16be,
26}
27
28/// Represents a UDP Packet.
29#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct UdpPacket {
31    pub header: UdpHeader,
32    pub payload: Bytes,
33}
34
35impl Packet for UdpPacket {
36    type Header = UdpHeader;
37    fn from_buf(mut bytes: &[u8]) -> Option<Self> {
38        if bytes.len() < UDP_HEADER_LEN {
39            return None;
40        }
41
42        let source = bytes.get_u16();
43        let destination = bytes.get_u16();
44        let length = bytes.get_u16();
45        let checksum = bytes.get_u16();
46
47        if length < UDP_HEADER_LEN as u16 {
48            return None;
49        }
50
51        let payload_len = length as usize - UDP_HEADER_LEN;
52        if bytes.len() < payload_len {
53            return None;
54        }
55
56        let (payload_slice, _) = bytes.split_at(payload_len);
57
58        Some(UdpPacket {
59            header: UdpHeader {
60                source,
61                destination,
62                length,
63                checksum,
64            },
65            payload: Bytes::copy_from_slice(payload_slice),
66        })
67    }
68    fn from_bytes(mut bytes: Bytes) -> Option<Self> {
69        Self::from_buf(&mut bytes)
70    }
71    fn to_bytes(&self) -> Bytes {
72        let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN + self.payload.len());
73        buf.put_u16(self.header.source);
74        buf.put_u16(self.header.destination);
75        buf.put_u16((UDP_HEADER_LEN + self.payload.len()) as u16);
76        buf.put_u16(self.header.checksum);
77        buf.extend_from_slice(&self.payload);
78        buf.freeze()
79    }
80    fn header(&self) -> Bytes {
81        let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN);
82        buf.put_u16(self.header.source);
83        buf.put_u16(self.header.destination);
84        buf.put_u16(self.header.length);
85        buf.put_u16(self.header.checksum);
86        buf.freeze()
87    }
88
89    fn payload(&self) -> Bytes {
90        self.payload.clone()
91    }
92
93    fn header_len(&self) -> usize {
94        UDP_HEADER_LEN
95    }
96
97    fn payload_len(&self) -> usize {
98        self.payload.len()
99    }
100
101    fn total_len(&self) -> usize {
102        self.header_len() + self.payload_len()
103    }
104
105    fn into_parts(self) -> (Self::Header, Bytes) {
106        (self.header, self.payload)
107    }
108}
109
110/// Represents a mutable UDP packet.
111pub struct MutableUdpPacket<'a> {
112    buffer: &'a mut [u8],
113    checksum: ChecksumState,
114    checksum_context: Option<TransportChecksumContext>,
115}
116
117impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> {
118    type Packet = UdpPacket;
119
120    fn new(buffer: &'a mut [u8]) -> Option<Self> {
121        if buffer.len() < UDP_HEADER_LEN {
122            return None;
123        }
124
125        let length = u16::from_be_bytes([buffer[4], buffer[5]]);
126        if length != 0 {
127            if length < UDP_HEADER_LEN as u16 {
128                return None;
129            }
130
131            if length as usize > buffer.len() {
132                return None;
133            }
134        }
135
136        Some(Self {
137            buffer,
138            checksum: ChecksumState::new(),
139            checksum_context: None,
140        })
141    }
142
143    fn packet(&self) -> &[u8] {
144        &*self.buffer
145    }
146
147    fn packet_mut(&mut self) -> &mut [u8] {
148        &mut *self.buffer
149    }
150
151    fn header(&self) -> &[u8] {
152        &self.packet()[..UDP_HEADER_LEN]
153    }
154
155    fn header_mut(&mut self) -> &mut [u8] {
156        let (header, _) = (&mut *self.buffer).split_at_mut(UDP_HEADER_LEN);
157        header
158    }
159
160    fn payload(&self) -> &[u8] {
161        let length = self.total_len();
162        &self.packet()[UDP_HEADER_LEN..length]
163    }
164
165    fn payload_mut(&mut self) -> &mut [u8] {
166        let total_len = self.total_len();
167        let (_, payload) = (&mut *self.buffer).split_at_mut(UDP_HEADER_LEN);
168        &mut payload[..total_len.saturating_sub(UDP_HEADER_LEN)]
169    }
170}
171
172impl<'a> MutableUdpPacket<'a> {
173    /// Create a new packet without validating length fields.
174    pub fn new_unchecked(buffer: &'a mut [u8]) -> Self {
175        Self {
176            buffer,
177            checksum: ChecksumState::new(),
178            checksum_context: None,
179        }
180    }
181
182    fn raw(&self) -> &[u8] {
183        &*self.buffer
184    }
185
186    fn raw_mut(&mut self) -> &mut [u8] {
187        &mut *self.buffer
188    }
189
190    fn after_field_mutation(&mut self) {
191        self.checksum.mark_dirty();
192        if self.checksum.automatic() {
193            let _ = self.recompute_checksum();
194        }
195    }
196
197    fn write_checksum(&mut self, checksum: u16) {
198        self.raw_mut()[6..8].copy_from_slice(&checksum.to_be_bytes());
199    }
200
201    /// Returns the checksum recalculation mode.
202    pub fn checksum_mode(&self) -> ChecksumMode {
203        self.checksum.mode()
204    }
205
206    /// Sets the checksum recalculation mode.
207    pub fn set_checksum_mode(&mut self, mode: ChecksumMode) {
208        self.checksum.set_mode(mode);
209        if self.checksum.automatic() && self.checksum.is_dirty() {
210            let _ = self.recompute_checksum();
211        }
212    }
213
214    /// Enables automatic checksum recalculation when tracked fields change.
215    pub fn enable_auto_checksum(&mut self) {
216        self.set_checksum_mode(ChecksumMode::Automatic);
217    }
218
219    /// Disables automatic checksum recalculation.
220    pub fn disable_auto_checksum(&mut self) {
221        self.set_checksum_mode(ChecksumMode::Manual);
222    }
223
224    /// Returns true if the checksum needs to be recomputed.
225    pub fn is_checksum_dirty(&self) -> bool {
226        self.checksum.is_dirty()
227    }
228
229    /// Marks the checksum as stale and recomputes it when automatic mode is enabled.
230    pub fn mark_checksum_dirty(&mut self) {
231        self.checksum.mark_dirty();
232        if self.checksum.automatic() {
233            let _ = self.recompute_checksum();
234        }
235    }
236
237    /// Defines the pseudo-header context used when recomputing the checksum.
238    pub fn set_checksum_context(&mut self, context: TransportChecksumContext) {
239        self.checksum_context = Some(context);
240        if self.checksum.automatic() && self.checksum.is_dirty() {
241            let _ = self.recompute_checksum();
242        }
243    }
244
245    /// Sets an IPv4 pseudo-header context used for checksum recomputation.
246    pub fn set_ipv4_checksum_context(&mut self, source: Ipv4Addr, destination: Ipv4Addr) {
247        self.set_checksum_context(TransportChecksumContext::ipv4(source, destination));
248    }
249
250    /// Sets an IPv6 pseudo-header context used for checksum recomputation.
251    pub fn set_ipv6_checksum_context(&mut self, source: Ipv6Addr, destination: Ipv6Addr) {
252        self.set_checksum_context(TransportChecksumContext::ipv6(source, destination));
253    }
254
255    /// Clears the configured checksum pseudo-header context.
256    pub fn clear_checksum_context(&mut self) {
257        self.checksum_context = None;
258    }
259
260    /// Provides access to the configured checksum pseudo-header context.
261    pub fn checksum_context(&self) -> Option<TransportChecksumContext> {
262        self.checksum_context
263    }
264
265    /// Recomputes the UDP checksum if a pseudo-header context is available.
266    pub fn recompute_checksum(&mut self) -> Option<u16> {
267        let context = self.checksum_context?;
268
269        let checksum = match context {
270            TransportChecksumContext::Ipv4 {
271                source,
272                destination,
273            } => util::ipv4_checksum(
274                self.raw(),
275                3,
276                &[],
277                &source,
278                &destination,
279                IpNextProtocol::Udp,
280            ) as u16,
281            TransportChecksumContext::Ipv6 {
282                source,
283                destination,
284            } => util::ipv6_checksum(
285                self.raw(),
286                3,
287                &[],
288                &source,
289                &destination,
290                IpNextProtocol::Udp,
291            ) as u16,
292        };
293
294        self.write_checksum(checksum);
295        self.checksum.clear_dirty();
296        Some(checksum)
297    }
298
299    /// Returns the total length derived from the UDP length field.
300    pub fn total_len(&self) -> usize {
301        let field = u16::from_be_bytes([self.raw()[4], self.raw()[5]]);
302        if field == 0 {
303            self.raw().len()
304        } else {
305            field as usize
306        }
307    }
308
309    /// Returns the payload length.
310    pub fn payload_len(&self) -> usize {
311        self.total_len().saturating_sub(UDP_HEADER_LEN)
312    }
313
314    pub fn get_source(&self) -> u16 {
315        u16::from_be_bytes([self.raw()[0], self.raw()[1]])
316    }
317
318    pub fn set_source(&mut self, port: u16) {
319        self.raw_mut()[0..2].copy_from_slice(&port.to_be_bytes());
320        self.after_field_mutation();
321    }
322
323    pub fn get_destination(&self) -> u16 {
324        u16::from_be_bytes([self.raw()[2], self.raw()[3]])
325    }
326
327    pub fn set_destination(&mut self, port: u16) {
328        self.raw_mut()[2..4].copy_from_slice(&port.to_be_bytes());
329        self.after_field_mutation();
330    }
331
332    pub fn get_length(&self) -> u16 {
333        u16::from_be_bytes([self.raw()[4], self.raw()[5]])
334    }
335
336    pub fn set_length(&mut self, length: u16) {
337        self.raw_mut()[4..6].copy_from_slice(&length.to_be_bytes());
338        self.after_field_mutation();
339    }
340
341    pub fn get_checksum(&self) -> u16 {
342        u16::from_be_bytes([self.raw()[6], self.raw()[7]])
343    }
344
345    pub fn set_checksum(&mut self, checksum: u16) {
346        self.write_checksum(checksum);
347        self.checksum.clear_dirty();
348    }
349}
350
351pub fn checksum(packet: &UdpPacket, source: &IpAddr, destination: &IpAddr) -> u16 {
352    match (source, destination) {
353        (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst),
354        (IpAddr::V6(src), IpAddr::V6(dst)) => ipv6_checksum(packet, src, dst),
355        _ => 0, // Unsupported IP version
356    }
357}
358
359/// Calculate a checksum for a packet built on IPv4.
360pub fn ipv4_checksum(packet: &UdpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16be {
361    ipv4_checksum_adv(packet, &[], source, destination)
362}
363
364/// Calculate a checksum for a packet built on IPv4. Advanced version which
365/// accepts an extra slice of data that will be included in the checksum
366/// as being part of the data portion of the packet.
367///
368/// If `packet` contains an odd number of bytes the last byte will not be
369/// counted as the first byte of a word together with the first byte of
370/// `extra_data`.
371pub fn ipv4_checksum_adv(
372    packet: &UdpPacket,
373    extra_data: &[u8],
374    source: &Ipv4Addr,
375    destination: &Ipv4Addr,
376) -> u16be {
377    util::ipv4_checksum(
378        packet.to_bytes().as_ref(),
379        3,
380        extra_data,
381        source,
382        destination,
383        IpNextProtocol::Udp,
384    )
385}
386
387/// Calculate a checksum for a packet built on IPv6.
388pub fn ipv6_checksum(packet: &UdpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be {
389    ipv6_checksum_adv(packet, &[], source, destination)
390}
391
392/// Calculate the checksum for a packet built on IPv6. Advanced version which
393/// accepts an extra slice of data that will be included in the checksum
394/// as being part of the data portion of the packet.
395///
396/// If `packet` contains an odd number of bytes the last byte will not be
397/// counted as the first byte of a word together with the first byte of
398/// `extra_data`.
399pub fn ipv6_checksum_adv(
400    packet: &UdpPacket,
401    extra_data: &[u8],
402    source: &Ipv6Addr,
403    destination: &Ipv6Addr,
404) -> u16be {
405    util::ipv6_checksum(
406        packet.to_bytes().as_ref(),
407        3,
408        extra_data,
409        source,
410        destination,
411        IpNextProtocol::Udp,
412    )
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    #[test]
419    fn test_basic_udp_parse() {
420        let raw = Bytes::from_static(&[
421            0x12, 0x34, // source
422            0xab, 0xcd, // destination
423            0x00, 0x0c, // length = 12 bytes (8 header + 4 payload)
424            0x55, 0xaa, // checksum
425            b'd', b'a', b't', b'a', // payload
426        ]);
427        let packet = UdpPacket::from_bytes(raw.clone()).expect("Failed to parse UDP packet");
428
429        assert_eq!(packet.header.source, 0x1234);
430        assert_eq!(packet.header.destination, 0xabcd);
431        assert_eq!(packet.header.length, 12);
432        assert_eq!(packet.header.checksum, 0x55aa);
433        assert_eq!(packet.payload, Bytes::from_static(b"data"));
434        assert_eq!(packet.to_bytes(), raw);
435    }
436    #[test]
437    fn test_basic_udp_create() {
438        let payload = Bytes::from_static(b"data");
439        let packet = UdpPacket {
440            header: UdpHeader {
441                source: 0x1234,
442                destination: 0xabcd,
443                length: (UDP_HEADER_LEN + payload.len()) as u16,
444                checksum: 0x55aa,
445            },
446            payload: payload.clone(),
447        };
448
449        let expected = Bytes::from_static(&[
450            0x12, 0x34, // source
451            0xab, 0xcd, // destination
452            0x00, 0x0c, // length
453            0x55, 0xaa, // checksum
454            b'd', b'a', b't', b'a', // payload
455        ]);
456
457        assert_eq!(packet.to_bytes(), expected);
458        assert_eq!(packet.payload(), payload);
459        assert_eq!(packet.header_len(), UDP_HEADER_LEN);
460    }
461    #[test]
462    fn test_mutable_udp_packet_updates_in_place() {
463        let mut raw = [
464            0x12, 0x34, // source
465            0xab, 0xcd, // destination
466            0x00, 0x0c, // length
467            0x55, 0xaa, // checksum
468            b'd', b'a', b't', b'a', // payload
469            0, 0, // trailing capacity
470        ];
471
472        let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp");
473        assert_eq!(packet.get_source(), 0x1234);
474        packet.set_source(0x4321);
475        packet.set_destination(0x0102);
476        packet.payload_mut()[0] = b'x';
477        packet.set_checksum(0xffff);
478
479        let frozen = packet.freeze().expect("freeze");
480        assert_eq!(frozen.header.source, 0x4321);
481        assert_eq!(frozen.header.destination, 0x0102);
482        assert_eq!(frozen.header.checksum, 0xffff);
483        assert_eq!(&raw[UDP_HEADER_LEN], &b'x');
484    }
485
486    #[test]
487    fn test_udp_auto_checksum_with_context() {
488        let mut raw = [
489            0x12, 0x34, // source
490            0xab, 0xcd, // destination
491            0x00, 0x0c, // length
492            0x00, 0x00, // checksum placeholder
493            b'd', b'a', b't', b'a', // payload
494        ];
495
496        let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp");
497        let src = Ipv4Addr::new(192, 168, 0, 1);
498        let dst = Ipv4Addr::new(192, 168, 0, 2);
499        packet.set_ipv4_checksum_context(src, dst);
500        packet.enable_auto_checksum();
501
502        let baseline = packet.recompute_checksum().expect("checksum");
503        assert_eq!(baseline, packet.get_checksum());
504
505        packet.set_destination(0xabce);
506        let updated = packet.get_checksum();
507        assert_ne!(baseline, updated);
508        assert!(!packet.is_checksum_dirty());
509
510        let frozen = packet.freeze().expect("freeze");
511        let expected = ipv4_checksum(&frozen, &src, &dst);
512        assert_eq!(updated, expected as u16);
513    }
514
515    #[test]
516    fn test_udp_manual_checksum_tracking() {
517        let mut raw = [
518            0x12, 0x34, // source
519            0xab, 0xcd, // destination
520            0x00, 0x0c, // length
521            0x00, 0x00, // checksum placeholder
522            b'd', b'a', b't', b'a', // payload
523        ];
524
525        let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp");
526        let src = Ipv4Addr::new(10, 0, 0, 1);
527        let dst = Ipv4Addr::new(10, 0, 0, 2);
528        packet.set_ipv4_checksum_context(src, dst);
529
530        packet.set_source(0x2222);
531        assert!(packet.is_checksum_dirty());
532
533        let recomputed = packet.recompute_checksum().expect("checksum");
534        assert_eq!(recomputed, packet.get_checksum());
535        assert!(!packet.is_checksum_dirty());
536    }
537}