edge_raw/
udp.rs

1use core::net::{Ipv4Addr, SocketAddrV4};
2
3use super::bytes::{BytesIn, BytesOut};
4
5use super::{checksum_accumulate, checksum_finish, Error};
6
7#[allow(clippy::type_complexity)]
8pub fn decode(
9    src: Ipv4Addr,
10    dst: Ipv4Addr,
11    packet: &[u8],
12    filter_src: Option<u16>,
13    filter_dst: Option<u16>,
14) -> Result<Option<(SocketAddrV4, SocketAddrV4, &[u8])>, Error> {
15    let data = UdpPacketHeader::decode_with_payload(packet, src, dst, filter_src, filter_dst)?.map(
16        |(hdr, payload)| {
17            (
18                SocketAddrV4::new(src, hdr.src),
19                SocketAddrV4::new(dst, hdr.dst),
20                payload,
21            )
22        },
23    );
24
25    Ok(data)
26}
27
28pub fn encode<F>(
29    buf: &mut [u8],
30    src: SocketAddrV4,
31    dst: SocketAddrV4,
32    payload: F,
33) -> Result<&[u8], Error>
34where
35    F: FnOnce(&mut [u8]) -> Result<usize, Error>,
36{
37    let mut hdr = UdpPacketHeader::new(src.port(), dst.port());
38
39    hdr.encode_with_payload(buf, *src.ip(), *dst.ip(), |buf| payload(buf))
40}
41
42/// Represents a parsed UDP header
43#[derive(Clone, Debug)]
44#[cfg_attr(feature = "defmt", derive(defmt::Format))]
45pub struct UdpPacketHeader {
46    /// Source port
47    pub src: u16,
48    /// Destination port
49    pub dst: u16,
50    /// UDP length
51    pub len: u16,
52    /// UDP checksum
53    pub sum: u16,
54}
55
56impl UdpPacketHeader {
57    pub const PROTO: u8 = 17;
58
59    pub const SIZE: usize = 8;
60    pub const CHECKSUM_WORD: usize = 3;
61
62    /// Create a new header instance
63    pub fn new(src: u16, dst: u16) -> Self {
64        Self {
65            src,
66            dst,
67            len: 0,
68            sum: 0,
69        }
70    }
71
72    /// Decodes the header from a byte slice
73    pub fn decode(data: &[u8]) -> Result<Self, Error> {
74        let mut bytes = BytesIn::new(data);
75
76        Ok(Self {
77            src: u16::from_be_bytes(bytes.arr()?),
78            dst: u16::from_be_bytes(bytes.arr()?),
79            len: u16::from_be_bytes(bytes.arr()?),
80            sum: u16::from_be_bytes(bytes.arr()?),
81        })
82    }
83
84    /// Encodes the header into the provided buf slice
85    pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> {
86        let mut bytes = BytesOut::new(buf);
87
88        bytes
89            .push(&u16::to_be_bytes(self.src))?
90            .push(&u16::to_be_bytes(self.dst))?
91            .push(&u16::to_be_bytes(self.len))?
92            .push(&u16::to_be_bytes(self.sum))?;
93
94        let len = bytes.len();
95
96        Ok(&buf[..len])
97    }
98
99    /// Encodes the header and the provided payload into the provided buf slice
100    pub fn encode_with_payload<'o, F>(
101        &mut self,
102        buf: &'o mut [u8],
103        src: Ipv4Addr,
104        dst: Ipv4Addr,
105        encoder: F,
106    ) -> Result<&'o [u8], Error>
107    where
108        F: FnOnce(&mut [u8]) -> Result<usize, Error>,
109    {
110        if buf.len() < Self::SIZE {
111            Err(Error::BufferOverflow)?;
112        }
113
114        let (hdr_buf, payload_buf) = buf.split_at_mut(Self::SIZE);
115
116        let payload_len = encoder(payload_buf)?;
117
118        let len = Self::SIZE + payload_len;
119        self.len = len as _;
120
121        let hdr_len = self.encode(hdr_buf)?.len();
122        assert_eq!(Self::SIZE, hdr_len);
123
124        let packet = &mut buf[..len];
125
126        let checksum = Self::checksum(packet, src, dst);
127        self.sum = checksum;
128
129        Self::inject_checksum(packet, checksum);
130
131        Ok(packet)
132    }
133
134    /// Decodes the provided packet into a header and a payload slice
135    pub fn decode_with_payload(
136        packet: &[u8],
137        src: Ipv4Addr,
138        dst: Ipv4Addr,
139        filter_src: Option<u16>,
140        filter_dst: Option<u16>,
141    ) -> Result<Option<(Self, &[u8])>, Error> {
142        let hdr = Self::decode(packet)?;
143
144        if let Some(filter_src) = filter_src {
145            if filter_src != hdr.src {
146                return Ok(None);
147            }
148        }
149
150        if let Some(filter_dst) = filter_dst {
151            if filter_dst != hdr.dst {
152                return Ok(None);
153            }
154        }
155
156        let len = hdr.len as usize;
157        if packet.len() < len {
158            Err(Error::DataUnderflow)?;
159        }
160
161        let checksum = Self::checksum(&packet[..len], src, dst);
162
163        trace!(
164            "UDP header decoded, src={}, dst={}, size={}, checksum={}, ours={}",
165            hdr.src,
166            hdr.dst,
167            hdr.len,
168            hdr.sum,
169            checksum
170        );
171
172        if checksum != hdr.sum {
173            Err(Error::InvalidChecksum)?;
174        }
175
176        let packet = &packet[..len];
177
178        let payload_data = &packet[Self::SIZE..];
179
180        Ok(Some((hdr, payload_data)))
181    }
182
183    /// Injects the checksum into the provided packet
184    pub fn inject_checksum(packet: &mut [u8], checksum: u16) {
185        let checksum = checksum.to_be_bytes();
186
187        let offset = Self::CHECKSUM_WORD << 1;
188        packet[offset] = checksum[0];
189        packet[offset + 1] = checksum[1];
190    }
191
192    /// Computes the checksum for an already encoded packet
193    pub fn checksum(packet: &[u8], src: Ipv4Addr, dst: Ipv4Addr) -> u16 {
194        let mut buf = [0; 12];
195
196        // Pseudo IP-header for UDP checksum calculation
197        let len = unwrap!(
198            unwrap!(
199                unwrap!(
200                    unwrap!(
201                        unwrap!(
202                            BytesOut::new(&mut buf).push(&u32::to_be_bytes(src.into())),
203                            "Unreachable"
204                        )
205                        .push(&u32::to_be_bytes(dst.into())),
206                        "Unreachable"
207                    )
208                    .byte(0),
209                    "Unreachable"
210                )
211                .byte(UdpPacketHeader::PROTO),
212                "Unreachable"
213            )
214            .push(&u16::to_be_bytes(packet.len() as u16)),
215            "Unreachable"
216        )
217        .len();
218
219        let sum = checksum_accumulate(&buf[..len], usize::MAX)
220            + checksum_accumulate(packet, Self::CHECKSUM_WORD);
221
222        checksum_finish(sum)
223    }
224}