1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use crate::phy::ChecksumCapabilities;
5use crate::wire::ip::checksum;
6use crate::wire::{IpAddress, IpProtocol};
7use crate::{Error, Result};
8
9#[derive(Debug, PartialEq, Clone)]
11#[cfg_attr(feature = "defmt", derive(defmt::Format))]
12pub struct Packet<T: AsRef<[u8]>> {
13 buffer: T,
14}
15
16mod field {
17 #![allow(non_snake_case)]
18
19 use crate::wire::field::*;
20
21 pub const SRC_PORT: Field = 0..2;
22 pub const DST_PORT: Field = 2..4;
23 pub const LENGTH: Field = 4..6;
24 pub const CHECKSUM: Field = 6..8;
25
26 pub fn PAYLOAD(length: u16) -> Field {
27 CHECKSUM.end..(length as usize)
28 }
29}
30
31pub const HEADER_LEN: usize = field::CHECKSUM.end;
32
33#[allow(clippy::len_without_is_empty)]
34impl<T: AsRef<[u8]>> Packet<T> {
35 pub fn new_unchecked(buffer: T) -> Packet<T> {
37 Packet { buffer }
38 }
39
40 pub fn new_checked(buffer: T) -> Result<Packet<T>> {
45 let packet = Self::new_unchecked(buffer);
46 packet.check_len()?;
47 Ok(packet)
48 }
49
50 pub fn check_len(&self) -> Result<()> {
59 let buffer_len = self.buffer.as_ref().len();
60 if buffer_len < HEADER_LEN {
61 Err(Error::Truncated)
62 } else {
63 let field_len = self.len() as usize;
64 if buffer_len < field_len {
65 Err(Error::Truncated)
66 } else if field_len < HEADER_LEN {
67 Err(Error::Malformed)
68 } else {
69 Ok(())
70 }
71 }
72 }
73
74 pub fn into_inner(self) -> T {
76 self.buffer
77 }
78
79 #[inline]
81 pub fn src_port(&self) -> u16 {
82 let data = self.buffer.as_ref();
83 NetworkEndian::read_u16(&data[field::SRC_PORT])
84 }
85
86 #[inline]
88 pub fn dst_port(&self) -> u16 {
89 let data = self.buffer.as_ref();
90 NetworkEndian::read_u16(&data[field::DST_PORT])
91 }
92
93 #[inline]
95 pub fn len(&self) -> u16 {
96 let data = self.buffer.as_ref();
97 NetworkEndian::read_u16(&data[field::LENGTH])
98 }
99
100 #[inline]
102 pub fn checksum(&self) -> u16 {
103 let data = self.buffer.as_ref();
104 NetworkEndian::read_u16(&data[field::CHECKSUM])
105 }
106
107 pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
116 if cfg!(fuzzing) {
117 return true;
118 }
119
120 let data = self.buffer.as_ref();
121 checksum::combine(&[
122 checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
123 checksum::data(&data[..self.len() as usize]),
124 ]) == !0
125 }
126}
127
128impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
129 #[inline]
131 pub fn payload(&self) -> &'a [u8] {
132 let length = self.len();
133 let data = self.buffer.as_ref();
134 &data[field::PAYLOAD(length)]
135 }
136}
137
138impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
139 #[inline]
141 pub fn set_src_port(&mut self, value: u16) {
142 let data = self.buffer.as_mut();
143 NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
144 }
145
146 #[inline]
148 pub fn set_dst_port(&mut self, value: u16) {
149 let data = self.buffer.as_mut();
150 NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
151 }
152
153 #[inline]
155 pub fn set_len(&mut self, value: u16) {
156 let data = self.buffer.as_mut();
157 NetworkEndian::write_u16(&mut data[field::LENGTH], value)
158 }
159
160 #[inline]
162 pub fn set_checksum(&mut self, value: u16) {
163 let data = self.buffer.as_mut();
164 NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
165 }
166
167 pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
173 self.set_checksum(0);
174 let checksum = {
175 let data = self.buffer.as_ref();
176 !checksum::combine(&[
177 checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
178 checksum::data(&data[..self.len() as usize]),
179 ])
180 };
181 self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
186 }
187
188 #[inline]
190 pub fn payload_mut(&mut self) -> &mut [u8] {
191 let length = self.len();
192 let data = self.buffer.as_mut();
193 &mut data[field::PAYLOAD(length)]
194 }
195}
196
197impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
198 fn as_ref(&self) -> &[u8] {
199 self.buffer.as_ref()
200 }
201}
202
203#[derive(Debug, PartialEq, Eq, Clone, Copy)]
205#[cfg_attr(feature = "defmt", derive(defmt::Format))]
206pub struct Repr {
207 pub src_port: u16,
208 pub dst_port: u16,
209}
210
211impl Repr {
212 pub fn parse<T>(
214 packet: &Packet<&T>,
215 src_addr: &IpAddress,
216 dst_addr: &IpAddress,
217 checksum_caps: &ChecksumCapabilities,
218 ) -> Result<Repr>
219 where
220 T: AsRef<[u8]> + ?Sized,
221 {
222 if packet.dst_port() == 0 {
224 return Err(Error::Malformed);
225 }
226 if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
228 match (src_addr, dst_addr) {
229 #[cfg(feature = "proto-ipv4")]
231 (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (),
232 _ => return Err(Error::Checksum),
233 }
234 }
235
236 Ok(Repr {
237 src_port: packet.src_port(),
238 dst_port: packet.dst_port(),
239 })
240 }
241
242 pub fn header_len(&self) -> usize {
244 HEADER_LEN
245 }
246
247 pub fn emit<T: ?Sized>(
249 &self,
250 packet: &mut Packet<&mut T>,
251 src_addr: &IpAddress,
252 dst_addr: &IpAddress,
253 payload_len: usize,
254 emit_payload: impl FnOnce(&mut [u8]),
255 checksum_caps: &ChecksumCapabilities,
256 ) where
257 T: AsRef<[u8]> + AsMut<[u8]>,
258 {
259 packet.set_src_port(self.src_port);
260 packet.set_dst_port(self.dst_port);
261 packet.set_len((HEADER_LEN + payload_len) as u16);
262 emit_payload(packet.payload_mut());
263
264 if checksum_caps.udp.tx() {
265 packet.fill_checksum(src_addr, dst_addr)
266 } else {
267 packet.set_checksum(0);
270 }
271 }
272}
273
274impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
275 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
276 write!(
278 f,
279 "UDP src={} dst={} len={}",
280 self.src_port(),
281 self.dst_port(),
282 self.payload().len()
283 )
284 }
285}
286
287impl fmt::Display for Repr {
288 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289 write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
290 }
291}
292
293use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
294
295impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
296 fn pretty_print(
297 buffer: &dyn AsRef<[u8]>,
298 f: &mut fmt::Formatter,
299 indent: &mut PrettyIndent,
300 ) -> fmt::Result {
301 match Packet::new_checked(buffer) {
302 Err(err) => write!(f, "{}({})", indent, err),
303 Ok(packet) => write!(f, "{}{}", indent, packet),
304 }
305 }
306}
307
308#[cfg(test)]
309mod test {
310 use super::*;
311 #[cfg(feature = "proto-ipv4")]
312 use crate::wire::Ipv4Address;
313
314 #[cfg(feature = "proto-ipv4")]
315 const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
316 #[cfg(feature = "proto-ipv4")]
317 const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
318
319 #[cfg(feature = "proto-ipv4")]
320 static PACKET_BYTES: [u8; 12] = [
321 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff,
322 ];
323
324 #[cfg(feature = "proto-ipv4")]
325 static NO_CHECKSUM_PACKET: [u8; 12] = [
326 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff,
327 ];
328
329 #[cfg(feature = "proto-ipv4")]
330 static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
331
332 #[test]
333 #[cfg(feature = "proto-ipv4")]
334 fn test_deconstruct() {
335 let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
336 assert_eq!(packet.src_port(), 48896);
337 assert_eq!(packet.dst_port(), 53);
338 assert_eq!(packet.len(), 12);
339 assert_eq!(packet.checksum(), 0x124d);
340 assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
341 assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
342 }
343
344 #[test]
345 #[cfg(feature = "proto-ipv4")]
346 fn test_construct() {
347 let mut bytes = vec![0xa5; 12];
348 let mut packet = Packet::new_unchecked(&mut bytes);
349 packet.set_src_port(48896);
350 packet.set_dst_port(53);
351 packet.set_len(12);
352 packet.set_checksum(0xffff);
353 packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
354 packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
355 assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
356 }
357
358 #[test]
359 fn test_impossible_len() {
360 let mut bytes = vec![0; 12];
361 let mut packet = Packet::new_unchecked(&mut bytes);
362 packet.set_len(4);
363 assert_eq!(packet.check_len(), Err(Error::Malformed));
364 }
365
366 #[test]
367 #[cfg(feature = "proto-ipv4")]
368 fn test_zero_checksum() {
369 let mut bytes = vec![0; 8];
370 let mut packet = Packet::new_unchecked(&mut bytes);
371 packet.set_src_port(1);
372 packet.set_dst_port(31881);
373 packet.set_len(8);
374 packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
375 assert_eq!(packet.checksum(), 0xffff);
376 }
377
378 #[cfg(feature = "proto-ipv4")]
379 fn packet_repr() -> Repr {
380 Repr {
381 src_port: 48896,
382 dst_port: 53,
383 }
384 }
385
386 #[test]
387 #[cfg(feature = "proto-ipv4")]
388 fn test_parse() {
389 let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
390 let repr = Repr::parse(
391 &packet,
392 &SRC_ADDR.into(),
393 &DST_ADDR.into(),
394 &ChecksumCapabilities::default(),
395 )
396 .unwrap();
397 assert_eq!(repr, packet_repr());
398 }
399
400 #[test]
401 #[cfg(feature = "proto-ipv4")]
402 fn test_emit() {
403 let repr = packet_repr();
404 let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
405 let mut packet = Packet::new_unchecked(&mut bytes);
406 repr.emit(
407 &mut packet,
408 &SRC_ADDR.into(),
409 &DST_ADDR.into(),
410 PAYLOAD_BYTES.len(),
411 |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
412 &ChecksumCapabilities::default(),
413 );
414 assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]);
415 }
416
417 #[test]
418 #[cfg(feature = "proto-ipv4")]
419 fn test_checksum_omitted() {
420 let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
421 let repr = Repr::parse(
422 &packet,
423 &SRC_ADDR.into(),
424 &DST_ADDR.into(),
425 &ChecksumCapabilities::default(),
426 )
427 .unwrap();
428 assert_eq!(repr, packet_repr());
429 }
430}