1use core::net::{Ipv4Addr, SocketAddrV4};
2
3use super::bytes::{BytesIn, BytesOut};
4
5use super::{checksum_accumulate, checksum_finish, Error};
6
7#[allow(clippy::type_complexity)]
8pub fn decode(
9 src: Ipv4Addr,
10 dst: Ipv4Addr,
11 packet: &[u8],
12 filter_src: Option<u16>,
13 filter_dst: Option<u16>,
14) -> Result<Option<(SocketAddrV4, SocketAddrV4, &[u8])>, Error> {
15 let data = UdpPacketHeader::decode_with_payload(packet, src, dst, filter_src, filter_dst)?.map(
16 |(hdr, payload)| {
17 (
18 SocketAddrV4::new(src, hdr.src),
19 SocketAddrV4::new(dst, hdr.dst),
20 payload,
21 )
22 },
23 );
24
25 Ok(data)
26}
27
28pub fn encode<F>(
29 buf: &mut [u8],
30 src: SocketAddrV4,
31 dst: SocketAddrV4,
32 payload: F,
33) -> Result<&[u8], Error>
34where
35 F: FnOnce(&mut [u8]) -> Result<usize, Error>,
36{
37 let mut hdr = UdpPacketHeader::new(src.port(), dst.port());
38
39 hdr.encode_with_payload(buf, *src.ip(), *dst.ip(), |buf| payload(buf))
40}
41
42#[derive(Clone, Debug)]
44#[cfg_attr(feature = "defmt", derive(defmt::Format))]
45pub struct UdpPacketHeader {
46 pub src: u16,
48 pub dst: u16,
50 pub len: u16,
52 pub sum: u16,
54}
55
56impl UdpPacketHeader {
57 pub const PROTO: u8 = 17;
58
59 pub const SIZE: usize = 8;
60 pub const CHECKSUM_WORD: usize = 3;
61
62 pub fn new(src: u16, dst: u16) -> Self {
64 Self {
65 src,
66 dst,
67 len: 0,
68 sum: 0,
69 }
70 }
71
72 pub fn decode(data: &[u8]) -> Result<Self, Error> {
74 let mut bytes = BytesIn::new(data);
75
76 Ok(Self {
77 src: u16::from_be_bytes(bytes.arr()?),
78 dst: u16::from_be_bytes(bytes.arr()?),
79 len: u16::from_be_bytes(bytes.arr()?),
80 sum: u16::from_be_bytes(bytes.arr()?),
81 })
82 }
83
84 pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> {
86 let mut bytes = BytesOut::new(buf);
87
88 bytes
89 .push(&u16::to_be_bytes(self.src))?
90 .push(&u16::to_be_bytes(self.dst))?
91 .push(&u16::to_be_bytes(self.len))?
92 .push(&u16::to_be_bytes(self.sum))?;
93
94 let len = bytes.len();
95
96 Ok(&buf[..len])
97 }
98
99 pub fn encode_with_payload<'o, F>(
101 &mut self,
102 buf: &'o mut [u8],
103 src: Ipv4Addr,
104 dst: Ipv4Addr,
105 encoder: F,
106 ) -> Result<&'o [u8], Error>
107 where
108 F: FnOnce(&mut [u8]) -> Result<usize, Error>,
109 {
110 if buf.len() < Self::SIZE {
111 Err(Error::BufferOverflow)?;
112 }
113
114 let (hdr_buf, payload_buf) = buf.split_at_mut(Self::SIZE);
115
116 let payload_len = encoder(payload_buf)?;
117
118 let len = Self::SIZE + payload_len;
119 self.len = len as _;
120
121 let hdr_len = self.encode(hdr_buf)?.len();
122 assert_eq!(Self::SIZE, hdr_len);
123
124 let packet = &mut buf[..len];
125
126 let checksum = Self::checksum(packet, src, dst);
127 self.sum = checksum;
128
129 Self::inject_checksum(packet, checksum);
130
131 Ok(packet)
132 }
133
134 pub fn decode_with_payload(
136 packet: &[u8],
137 src: Ipv4Addr,
138 dst: Ipv4Addr,
139 filter_src: Option<u16>,
140 filter_dst: Option<u16>,
141 ) -> Result<Option<(Self, &[u8])>, Error> {
142 let hdr = Self::decode(packet)?;
143
144 if let Some(filter_src) = filter_src {
145 if filter_src != hdr.src {
146 return Ok(None);
147 }
148 }
149
150 if let Some(filter_dst) = filter_dst {
151 if filter_dst != hdr.dst {
152 return Ok(None);
153 }
154 }
155
156 let len = hdr.len as usize;
157 if packet.len() < len {
158 Err(Error::DataUnderflow)?;
159 }
160
161 let checksum = Self::checksum(&packet[..len], src, dst);
162
163 trace!(
164 "UDP header decoded, src={}, dst={}, size={}, checksum={}, ours={}",
165 hdr.src,
166 hdr.dst,
167 hdr.len,
168 hdr.sum,
169 checksum
170 );
171
172 if checksum != hdr.sum {
173 Err(Error::InvalidChecksum)?;
174 }
175
176 let packet = &packet[..len];
177
178 let payload_data = &packet[Self::SIZE..];
179
180 Ok(Some((hdr, payload_data)))
181 }
182
183 pub fn inject_checksum(packet: &mut [u8], checksum: u16) {
185 let checksum = checksum.to_be_bytes();
186
187 let offset = Self::CHECKSUM_WORD << 1;
188 packet[offset] = checksum[0];
189 packet[offset + 1] = checksum[1];
190 }
191
192 pub fn checksum(packet: &[u8], src: Ipv4Addr, dst: Ipv4Addr) -> u16 {
194 let mut buf = [0; 12];
195
196 let len = unwrap!(
198 unwrap!(
199 unwrap!(
200 unwrap!(
201 unwrap!(
202 BytesOut::new(&mut buf).push(&u32::to_be_bytes(src.into())),
203 "Unreachable"
204 )
205 .push(&u32::to_be_bytes(dst.into())),
206 "Unreachable"
207 )
208 .byte(0),
209 "Unreachable"
210 )
211 .byte(UdpPacketHeader::PROTO),
212 "Unreachable"
213 )
214 .push(&u16::to_be_bytes(packet.len() as u16)),
215 "Unreachable"
216 )
217 .len();
218
219 let sum = checksum_accumulate(&buf[..len], usize::MAX)
220 + checksum_accumulate(packet, Self::CHECKSUM_WORD);
221
222 checksum_finish(sum)
223 }
224}