1use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext};
4use crate::ip::IpNextProtocol;
5use crate::packet::{MutablePacket, Packet};
6
7use crate::util;
8use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use nex_core::bitfield::u16be;
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15pub const UDP_HEADER_LEN: usize = 8;
17
18#[derive(Clone, Debug, PartialEq, Eq)]
20#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
21pub struct UdpHeader {
22 pub source: u16be,
23 pub destination: u16be,
24 pub length: u16be,
25 pub checksum: u16be,
26}
27
28#[derive(Clone, Debug, PartialEq, Eq)]
30pub struct UdpPacket {
31 pub header: UdpHeader,
32 pub payload: Bytes,
33}
34
35impl Packet for UdpPacket {
36 type Header = UdpHeader;
37 fn from_buf(mut bytes: &[u8]) -> Option<Self> {
38 if bytes.len() < UDP_HEADER_LEN {
39 return None;
40 }
41
42 let source = bytes.get_u16();
43 let destination = bytes.get_u16();
44 let length = bytes.get_u16();
45 let checksum = bytes.get_u16();
46
47 if length < UDP_HEADER_LEN as u16 {
48 return None;
49 }
50
51 let payload_len = length as usize - UDP_HEADER_LEN;
52 if bytes.len() < payload_len {
53 return None;
54 }
55
56 let (payload_slice, _) = bytes.split_at(payload_len);
57
58 Some(UdpPacket {
59 header: UdpHeader {
60 source,
61 destination,
62 length,
63 checksum,
64 },
65 payload: Bytes::copy_from_slice(payload_slice),
66 })
67 }
68 fn from_bytes(mut bytes: Bytes) -> Option<Self> {
69 Self::from_buf(&mut bytes)
70 }
71 fn to_bytes(&self) -> Bytes {
72 let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN + self.payload.len());
73 buf.put_u16(self.header.source);
74 buf.put_u16(self.header.destination);
75 buf.put_u16((UDP_HEADER_LEN + self.payload.len()) as u16);
76 buf.put_u16(self.header.checksum);
77 buf.extend_from_slice(&self.payload);
78 buf.freeze()
79 }
80 fn header(&self) -> Bytes {
81 let mut buf = BytesMut::with_capacity(UDP_HEADER_LEN);
82 buf.put_u16(self.header.source);
83 buf.put_u16(self.header.destination);
84 buf.put_u16(self.header.length);
85 buf.put_u16(self.header.checksum);
86 buf.freeze()
87 }
88
89 fn payload(&self) -> Bytes {
90 self.payload.clone()
91 }
92
93 fn header_len(&self) -> usize {
94 UDP_HEADER_LEN
95 }
96
97 fn payload_len(&self) -> usize {
98 self.payload.len()
99 }
100
101 fn total_len(&self) -> usize {
102 self.header_len() + self.payload_len()
103 }
104
105 fn into_parts(self) -> (Self::Header, Bytes) {
106 (self.header, self.payload)
107 }
108}
109
110pub struct MutableUdpPacket<'a> {
112 buffer: &'a mut [u8],
113 checksum: ChecksumState,
114 checksum_context: Option<TransportChecksumContext>,
115}
116
117impl<'a> MutablePacket<'a> for MutableUdpPacket<'a> {
118 type Packet = UdpPacket;
119
120 fn new(buffer: &'a mut [u8]) -> Option<Self> {
121 if buffer.len() < UDP_HEADER_LEN {
122 return None;
123 }
124
125 let length = u16::from_be_bytes([buffer[4], buffer[5]]);
126 if length != 0 {
127 if length < UDP_HEADER_LEN as u16 {
128 return None;
129 }
130
131 if length as usize > buffer.len() {
132 return None;
133 }
134 }
135
136 Some(Self {
137 buffer,
138 checksum: ChecksumState::new(),
139 checksum_context: None,
140 })
141 }
142
143 fn packet(&self) -> &[u8] {
144 &*self.buffer
145 }
146
147 fn packet_mut(&mut self) -> &mut [u8] {
148 &mut *self.buffer
149 }
150
151 fn header(&self) -> &[u8] {
152 &self.packet()[..UDP_HEADER_LEN]
153 }
154
155 fn header_mut(&mut self) -> &mut [u8] {
156 let (header, _) = (&mut *self.buffer).split_at_mut(UDP_HEADER_LEN);
157 header
158 }
159
160 fn payload(&self) -> &[u8] {
161 let length = self.total_len();
162 &self.packet()[UDP_HEADER_LEN..length]
163 }
164
165 fn payload_mut(&mut self) -> &mut [u8] {
166 let total_len = self.total_len();
167 let (_, payload) = (&mut *self.buffer).split_at_mut(UDP_HEADER_LEN);
168 &mut payload[..total_len.saturating_sub(UDP_HEADER_LEN)]
169 }
170}
171
172impl<'a> MutableUdpPacket<'a> {
173 pub fn new_unchecked(buffer: &'a mut [u8]) -> Self {
175 Self {
176 buffer,
177 checksum: ChecksumState::new(),
178 checksum_context: None,
179 }
180 }
181
182 fn raw(&self) -> &[u8] {
183 &*self.buffer
184 }
185
186 fn raw_mut(&mut self) -> &mut [u8] {
187 &mut *self.buffer
188 }
189
190 fn after_field_mutation(&mut self) {
191 self.checksum.mark_dirty();
192 if self.checksum.automatic() {
193 let _ = self.recompute_checksum();
194 }
195 }
196
197 fn write_checksum(&mut self, checksum: u16) {
198 self.raw_mut()[6..8].copy_from_slice(&checksum.to_be_bytes());
199 }
200
201 pub fn checksum_mode(&self) -> ChecksumMode {
203 self.checksum.mode()
204 }
205
206 pub fn set_checksum_mode(&mut self, mode: ChecksumMode) {
208 self.checksum.set_mode(mode);
209 if self.checksum.automatic() && self.checksum.is_dirty() {
210 let _ = self.recompute_checksum();
211 }
212 }
213
214 pub fn enable_auto_checksum(&mut self) {
216 self.set_checksum_mode(ChecksumMode::Automatic);
217 }
218
219 pub fn disable_auto_checksum(&mut self) {
221 self.set_checksum_mode(ChecksumMode::Manual);
222 }
223
224 pub fn is_checksum_dirty(&self) -> bool {
226 self.checksum.is_dirty()
227 }
228
229 pub fn mark_checksum_dirty(&mut self) {
231 self.checksum.mark_dirty();
232 if self.checksum.automatic() {
233 let _ = self.recompute_checksum();
234 }
235 }
236
237 pub fn set_checksum_context(&mut self, context: TransportChecksumContext) {
239 self.checksum_context = Some(context);
240 if self.checksum.automatic() && self.checksum.is_dirty() {
241 let _ = self.recompute_checksum();
242 }
243 }
244
245 pub fn set_ipv4_checksum_context(&mut self, source: Ipv4Addr, destination: Ipv4Addr) {
247 self.set_checksum_context(TransportChecksumContext::ipv4(source, destination));
248 }
249
250 pub fn set_ipv6_checksum_context(&mut self, source: Ipv6Addr, destination: Ipv6Addr) {
252 self.set_checksum_context(TransportChecksumContext::ipv6(source, destination));
253 }
254
255 pub fn clear_checksum_context(&mut self) {
257 self.checksum_context = None;
258 }
259
260 pub fn checksum_context(&self) -> Option<TransportChecksumContext> {
262 self.checksum_context
263 }
264
265 pub fn recompute_checksum(&mut self) -> Option<u16> {
267 let context = self.checksum_context?;
268
269 let checksum = match context {
270 TransportChecksumContext::Ipv4 {
271 source,
272 destination,
273 } => util::ipv4_checksum(
274 self.raw(),
275 3,
276 &[],
277 &source,
278 &destination,
279 IpNextProtocol::Udp,
280 ) as u16,
281 TransportChecksumContext::Ipv6 {
282 source,
283 destination,
284 } => util::ipv6_checksum(
285 self.raw(),
286 3,
287 &[],
288 &source,
289 &destination,
290 IpNextProtocol::Udp,
291 ) as u16,
292 };
293
294 self.write_checksum(checksum);
295 self.checksum.clear_dirty();
296 Some(checksum)
297 }
298
299 pub fn total_len(&self) -> usize {
301 let field = u16::from_be_bytes([self.raw()[4], self.raw()[5]]);
302 if field == 0 {
303 self.raw().len()
304 } else {
305 field as usize
306 }
307 }
308
309 pub fn payload_len(&self) -> usize {
311 self.total_len().saturating_sub(UDP_HEADER_LEN)
312 }
313
314 pub fn get_source(&self) -> u16 {
315 u16::from_be_bytes([self.raw()[0], self.raw()[1]])
316 }
317
318 pub fn set_source(&mut self, port: u16) {
319 self.raw_mut()[0..2].copy_from_slice(&port.to_be_bytes());
320 self.after_field_mutation();
321 }
322
323 pub fn get_destination(&self) -> u16 {
324 u16::from_be_bytes([self.raw()[2], self.raw()[3]])
325 }
326
327 pub fn set_destination(&mut self, port: u16) {
328 self.raw_mut()[2..4].copy_from_slice(&port.to_be_bytes());
329 self.after_field_mutation();
330 }
331
332 pub fn get_length(&self) -> u16 {
333 u16::from_be_bytes([self.raw()[4], self.raw()[5]])
334 }
335
336 pub fn set_length(&mut self, length: u16) {
337 self.raw_mut()[4..6].copy_from_slice(&length.to_be_bytes());
338 self.after_field_mutation();
339 }
340
341 pub fn get_checksum(&self) -> u16 {
342 u16::from_be_bytes([self.raw()[6], self.raw()[7]])
343 }
344
345 pub fn set_checksum(&mut self, checksum: u16) {
346 self.write_checksum(checksum);
347 self.checksum.clear_dirty();
348 }
349}
350
351pub fn checksum(packet: &UdpPacket, source: &IpAddr, destination: &IpAddr) -> u16 {
352 match (source, destination) {
353 (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst),
354 (IpAddr::V6(src), IpAddr::V6(dst)) => ipv6_checksum(packet, src, dst),
355 _ => 0, }
357}
358
359pub fn ipv4_checksum(packet: &UdpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16be {
361 ipv4_checksum_adv(packet, &[], source, destination)
362}
363
364pub fn ipv4_checksum_adv(
372 packet: &UdpPacket,
373 extra_data: &[u8],
374 source: &Ipv4Addr,
375 destination: &Ipv4Addr,
376) -> u16be {
377 util::ipv4_checksum(
378 packet.to_bytes().as_ref(),
379 3,
380 extra_data,
381 source,
382 destination,
383 IpNextProtocol::Udp,
384 )
385}
386
387pub fn ipv6_checksum(packet: &UdpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16be {
389 ipv6_checksum_adv(packet, &[], source, destination)
390}
391
392pub fn ipv6_checksum_adv(
400 packet: &UdpPacket,
401 extra_data: &[u8],
402 source: &Ipv6Addr,
403 destination: &Ipv6Addr,
404) -> u16be {
405 util::ipv6_checksum(
406 packet.to_bytes().as_ref(),
407 3,
408 extra_data,
409 source,
410 destination,
411 IpNextProtocol::Udp,
412 )
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 #[test]
419 fn test_basic_udp_parse() {
420 let raw = Bytes::from_static(&[
421 0x12, 0x34, 0xab, 0xcd, 0x00, 0x0c, 0x55, 0xaa, b'd', b'a', b't', b'a', ]);
427 let packet = UdpPacket::from_bytes(raw.clone()).expect("Failed to parse UDP packet");
428
429 assert_eq!(packet.header.source, 0x1234);
430 assert_eq!(packet.header.destination, 0xabcd);
431 assert_eq!(packet.header.length, 12);
432 assert_eq!(packet.header.checksum, 0x55aa);
433 assert_eq!(packet.payload, Bytes::from_static(b"data"));
434 assert_eq!(packet.to_bytes(), raw);
435 }
436 #[test]
437 fn test_basic_udp_create() {
438 let payload = Bytes::from_static(b"data");
439 let packet = UdpPacket {
440 header: UdpHeader {
441 source: 0x1234,
442 destination: 0xabcd,
443 length: (UDP_HEADER_LEN + payload.len()) as u16,
444 checksum: 0x55aa,
445 },
446 payload: payload.clone(),
447 };
448
449 let expected = Bytes::from_static(&[
450 0x12, 0x34, 0xab, 0xcd, 0x00, 0x0c, 0x55, 0xaa, b'd', b'a', b't', b'a', ]);
456
457 assert_eq!(packet.to_bytes(), expected);
458 assert_eq!(packet.payload(), payload);
459 assert_eq!(packet.header_len(), UDP_HEADER_LEN);
460 }
461 #[test]
462 fn test_mutable_udp_packet_updates_in_place() {
463 let mut raw = [
464 0x12, 0x34, 0xab, 0xcd, 0x00, 0x0c, 0x55, 0xaa, b'd', b'a', b't', b'a', 0, 0, ];
471
472 let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp");
473 assert_eq!(packet.get_source(), 0x1234);
474 packet.set_source(0x4321);
475 packet.set_destination(0x0102);
476 packet.payload_mut()[0] = b'x';
477 packet.set_checksum(0xffff);
478
479 let frozen = packet.freeze().expect("freeze");
480 assert_eq!(frozen.header.source, 0x4321);
481 assert_eq!(frozen.header.destination, 0x0102);
482 assert_eq!(frozen.header.checksum, 0xffff);
483 assert_eq!(&raw[UDP_HEADER_LEN], &b'x');
484 }
485
486 #[test]
487 fn test_udp_auto_checksum_with_context() {
488 let mut raw = [
489 0x12, 0x34, 0xab, 0xcd, 0x00, 0x0c, 0x00, 0x00, b'd', b'a', b't', b'a', ];
495
496 let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp");
497 let src = Ipv4Addr::new(192, 168, 0, 1);
498 let dst = Ipv4Addr::new(192, 168, 0, 2);
499 packet.set_ipv4_checksum_context(src, dst);
500 packet.enable_auto_checksum();
501
502 let baseline = packet.recompute_checksum().expect("checksum");
503 assert_eq!(baseline, packet.get_checksum());
504
505 packet.set_destination(0xabce);
506 let updated = packet.get_checksum();
507 assert_ne!(baseline, updated);
508 assert!(!packet.is_checksum_dirty());
509
510 let frozen = packet.freeze().expect("freeze");
511 let expected = ipv4_checksum(&frozen, &src, &dst);
512 assert_eq!(updated, expected as u16);
513 }
514
515 #[test]
516 fn test_udp_manual_checksum_tracking() {
517 let mut raw = [
518 0x12, 0x34, 0xab, 0xcd, 0x00, 0x0c, 0x00, 0x00, b'd', b'a', b't', b'a', ];
524
525 let mut packet = MutableUdpPacket::new(&mut raw).expect("mutable udp");
526 let src = Ipv4Addr::new(10, 0, 0, 1);
527 let dst = Ipv4Addr::new(10, 0, 0, 2);
528 packet.set_ipv4_checksum_context(src, dst);
529
530 packet.set_source(0x2222);
531 assert!(packet.is_checksum_dirty());
532
533 let recomputed = packet.recompute_checksum().expect("checksum");
534 assert_eq!(recomputed, packet.get_checksum());
535 assert!(!packet.is_checksum_dirty());
536 }
537}