1use std::fmt::{self, Formatter};
71use std::mem;
72
73use zerocopy::byteorder::{BigEndian, U16};
74use zerocopy::{FromBytes, IntoBytes, Unaligned};
75
76use crate::packet::{HeaderParser, PacketHeader};
77
78#[repr(C, packed)]
83#[derive(
84 FromBytes, IntoBytes, Unaligned, Debug, Clone, Copy, zerocopy::KnownLayout, zerocopy::Immutable,
85)]
86pub struct UdpHeader {
87 src_port: U16<BigEndian>,
88 dst_port: U16<BigEndian>,
89 length: U16<BigEndian>,
90 checksum: U16<BigEndian>,
91}
92
93impl UdpHeader {
94 #[inline]
96 pub fn src_port(&self) -> u16 {
97 self.src_port.get()
98 }
99
100 #[inline]
102 pub fn dst_port(&self) -> u16 {
103 self.dst_port.get()
104 }
105
106 #[inline]
108 pub fn length(&self) -> u16 {
109 self.length.get()
110 }
111
112 #[inline]
114 pub fn checksum(&self) -> u16 {
115 self.checksum.get()
116 }
117
118 #[inline]
120 pub fn header_len(&self) -> usize {
121 mem::size_of::<UdpHeader>()
122 }
123
124 #[inline]
126 pub fn payload_len(&self) -> usize {
127 let total = self.length() as usize;
128 total.saturating_sub(Self::FIXED_LEN)
129 }
130
131 #[inline]
133 pub fn is_valid(&self) -> bool {
134 self.length() >= Self::FIXED_LEN as u16
136 }
137
138 pub fn verify_checksum(&self, src_ip: u32, dst_ip: u32, udp_data: &[u8]) -> bool {
143 let checksum = self.checksum();
144
145 if checksum == 0 {
147 return true;
148 }
149
150 let computed = Self::compute_checksum(src_ip, dst_ip, udp_data);
151 computed == checksum
152 }
153
154 pub fn compute_checksum(src_ip: u32, dst_ip: u32, udp_data: &[u8]) -> u16 {
156 let mut sum: u32 = 0;
157
158 sum += (src_ip >> 16) & 0xFFFF;
160 sum += src_ip & 0xFFFF;
161
162 sum += (dst_ip >> 16) & 0xFFFF;
164 sum += dst_ip & 0xFFFF;
165
166 sum += 17;
168
169 sum += udp_data.len() as u32;
171
172 let mut i = 0;
174 while i < udp_data.len() {
175 if i + 1 < udp_data.len() {
176 let word = u16::from_be_bytes([udp_data[i], udp_data[i + 1]]);
177 sum += word as u32;
178 i += 2;
179 } else {
180 let word = u16::from_be_bytes([udp_data[i], 0]);
182 sum += word as u32;
183 i += 1;
184 }
185 }
186
187 while sum >> 16 != 0 {
189 sum = (sum & 0xFFFF) + (sum >> 16);
190 }
191
192 !sum as u16
194 }
195}
196
197impl PacketHeader for UdpHeader {
198 const NAME: &'static str = "UdpHeader";
199
200 #[inline]
201 fn is_valid(&self) -> bool {
202 self.is_valid()
203 }
204
205 type InnerType = ();
206
207 #[inline]
208 fn inner_type(&self) -> Self::InnerType {}
209}
210
211impl HeaderParser for UdpHeader {
212 type Output<'a> = &'a UdpHeader;
213
214 #[inline]
215 fn into_view<'a>(header: &'a Self, _: &'a [u8]) -> Self::Output<'a> {
216 header
217 }
218}
219
220impl fmt::Display for UdpHeader {
221 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
222 write!(
223 f,
224 "UDP {} -> {} len={}",
225 self.src_port(),
226 self.dst_port(),
227 self.length()
228 )
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 #[test]
237 fn test_udp_header_basic() {
238 let header = UdpHeader {
239 src_port: U16::new(53),
240 dst_port: U16::new(12345),
241 length: U16::new(16), checksum: U16::new(0),
243 };
244
245 assert_eq!(header.src_port(), 53);
246 assert_eq!(header.dst_port(), 12345);
247 assert_eq!(header.length(), 16);
248 assert_eq!(header.header_len(), 8);
249 assert_eq!(header.payload_len(), 8);
250 assert!(header.is_valid());
251 }
252
253 #[test]
254 fn test_udp_header_validation() {
255 let invalid_header = UdpHeader {
256 src_port: U16::new(53),
257 dst_port: U16::new(12345),
258 length: U16::new(7), checksum: U16::new(0),
260 };
261
262 assert!(!invalid_header.is_valid());
263
264 let valid_header = UdpHeader {
265 src_port: U16::new(53),
266 dst_port: U16::new(12345),
267 length: U16::new(8), checksum: U16::new(0),
269 };
270
271 assert!(valid_header.is_valid());
272 }
273
274 #[test]
275 fn test_udp_checksum_zero() {
276 let header = UdpHeader {
277 src_port: U16::new(53),
278 dst_port: U16::new(12345),
279 length: U16::new(8),
280 checksum: U16::new(0),
281 };
282
283 assert!(header.verify_checksum(0x7f000001, 0x7f000001, &[]));
285 }
286
287 #[test]
288 fn test_udp_header_size() {
289 assert_eq!(mem::size_of::<UdpHeader>(), 8);
290 assert_eq!(UdpHeader::FIXED_LEN, 8);
291 }
292
293 #[test]
294 fn test_udp_parsing_basic() {
295 let packet = create_test_packet();
296
297 let result = UdpHeader::from_bytes(&packet);
298 assert!(result.is_ok());
299
300 let (header, payload) = result.unwrap();
301 assert_eq!(header.src_port(), 12345);
302 assert_eq!(header.dst_port(), 53);
303 assert_eq!(header.length(), 16);
304 assert_eq!(payload.len(), 8); assert!(header.is_valid());
306 }
307
308 #[test]
309 fn test_udp_parsing_too_small() {
310 let packet = vec![0u8; 7]; let result = UdpHeader::from_bytes(&packet);
313 assert!(result.is_err());
314 }
315
316 #[test]
317 fn test_udp_total_len() {
318 let packet = create_test_packet();
319 let (header, _) = UdpHeader::from_bytes(&packet).unwrap();
320
321 assert_eq!(header.total_len(&packet), 8);
323 }
324
325 #[test]
326 fn test_udp_from_bytes_with_payload() {
327 let mut packet = Vec::new();
328
329 packet.extend_from_slice(&5000u16.to_be_bytes()); packet.extend_from_slice(&8080u16.to_be_bytes()); let payload_data = b"Hello, UDP!";
334 let total_length = 8 + payload_data.len();
335
336 packet.extend_from_slice(&(total_length as u16).to_be_bytes()); packet.extend_from_slice(&0u16.to_be_bytes()); packet.extend_from_slice(payload_data);
341
342 let result = UdpHeader::from_bytes(&packet);
343 assert!(result.is_ok());
344
345 let (header, payload) = result.unwrap();
346
347 assert_eq!(header.src_port(), 5000);
349 assert_eq!(header.dst_port(), 8080);
350 assert_eq!(header.length(), total_length as u16);
351 assert_eq!(header.payload_len(), payload_data.len());
352
353 assert_eq!(payload.len(), payload_data.len());
355 assert_eq!(payload, payload_data);
356 }
357
358 #[test]
359 fn test_udp_payload_length_calculation() {
360 let header1 = UdpHeader {
361 src_port: U16::new(1234),
362 dst_port: U16::new(5678),
363 length: U16::new(8), checksum: U16::new(0),
365 };
366 assert_eq!(header1.payload_len(), 0);
367
368 let header2 = UdpHeader {
369 src_port: U16::new(1234),
370 dst_port: U16::new(5678),
371 length: U16::new(100), checksum: U16::new(0),
373 };
374 assert_eq!(header2.payload_len(), 92);
375
376 let header3 = UdpHeader {
378 src_port: U16::new(1234),
379 dst_port: U16::new(5678),
380 length: U16::new(5), checksum: U16::new(0),
382 };
383 assert_eq!(header3.payload_len(), 0);
384 }
385
386 #[test]
387 fn test_udp_dns_packet() {
388 let mut packet = Vec::new();
389
390 packet.extend_from_slice(&54321u16.to_be_bytes()); packet.extend_from_slice(&53u16.to_be_bytes()); let dns_payload = vec![
396 0xab, 0xcd, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ];
403
404 let total_length = 8 + dns_payload.len();
405 packet.extend_from_slice(&(total_length as u16).to_be_bytes()); packet.extend_from_slice(&0u16.to_be_bytes()); packet.extend_from_slice(&dns_payload);
410
411 let (header, payload) = UdpHeader::from_bytes(&packet).unwrap();
412
413 assert_eq!(header.src_port(), 54321);
414 assert_eq!(header.dst_port(), 53);
415 assert_eq!(header.length(), total_length as u16);
416 assert_eq!(payload.len(), dns_payload.len());
417 assert_eq!(payload, dns_payload.as_slice());
418 }
419
420 #[test]
421 fn test_udp_checksum_computation() {
422 let src_ip = 0xC0A80101; let dst_ip = 0xC0A80102; let mut udp_packet = Vec::new();
427 udp_packet.extend_from_slice(&12345u16.to_be_bytes()); udp_packet.extend_from_slice(&80u16.to_be_bytes()); udp_packet.extend_from_slice(&12u16.to_be_bytes()); udp_packet.extend_from_slice(&0u16.to_be_bytes()); udp_packet.extend_from_slice(b"test");
434
435 let checksum = UdpHeader::compute_checksum(src_ip, dst_ip, &udp_packet);
436
437 assert_ne!(checksum, 0);
439 }
440
441 #[test]
442 fn test_udp_multiple_packets() {
443 let packets: Vec<(u16, u16, Vec<u8>)> = vec![
445 (1234, 5678, b"payload1".to_vec()),
446 (80, 54321, b"HTTP response".to_vec()),
447 (53, 12345, b"DNS".to_vec()),
448 ];
449
450 for (src, dst, payload_data) in packets {
451 let mut packet = Vec::new();
452 packet.extend_from_slice(&src.to_be_bytes());
453 packet.extend_from_slice(&dst.to_be_bytes());
454 packet.extend_from_slice(&((8 + payload_data.len()) as u16).to_be_bytes());
455 packet.extend_from_slice(&0u16.to_be_bytes());
456 packet.extend_from_slice(&payload_data);
457
458 let (header, payload) = UdpHeader::from_bytes(&packet).unwrap();
459 assert_eq!(header.src_port(), src);
460 assert_eq!(header.dst_port(), dst);
461 assert_eq!(payload, payload_data.as_slice());
462 }
463 }
464
465 fn create_test_packet() -> Vec<u8> {
467 let mut packet = Vec::new();
468
469 packet.extend_from_slice(&12345u16.to_be_bytes());
471
472 packet.extend_from_slice(&53u16.to_be_bytes());
474
475 packet.extend_from_slice(&16u16.to_be_bytes());
477
478 packet.extend_from_slice(&0u16.to_be_bytes());
480
481 packet.extend_from_slice(b"DNS data");
483
484 packet
485 }
486}