edge_raw/
ip.rs

1use core::net::Ipv4Addr;
2
3use super::bytes::{BytesIn, BytesOut};
4
5use super::{checksum_accumulate, checksum_finish, Error};
6
7#[allow(clippy::type_complexity)]
8pub fn decode(
9    packet: &[u8],
10    filter_src: Ipv4Addr,
11    filter_dst: Ipv4Addr,
12    filter_proto: Option<u8>,
13) -> Result<Option<(Ipv4Addr, Ipv4Addr, u8, &[u8])>, Error> {
14    let data = Ipv4PacketHeader::decode_with_payload(packet, filter_src, filter_dst, filter_proto)?
15        .map(|(hdr, payload)| (hdr.src, hdr.dst, hdr.p, payload));
16
17    Ok(data)
18}
19
20pub fn encode<F>(
21    buf: &mut [u8],
22    src: Ipv4Addr,
23    dst: Ipv4Addr,
24    proto: u8,
25    encoder: F,
26) -> Result<&[u8], Error>
27where
28    F: FnOnce(&mut [u8]) -> Result<usize, Error>,
29{
30    let mut hdr = Ipv4PacketHeader::new(src, dst, proto);
31
32    hdr.encode_with_payload(buf, encoder)
33}
34
35/// Represents a parsed IP header
36#[derive(Clone, Debug)]
37#[cfg_attr(feature = "defmt", derive(defmt::Format))]
38pub struct Ipv4PacketHeader {
39    /// Version
40    pub version: u8,
41    /// Header length
42    pub hlen: u8,
43    /// Type of service
44    pub tos: u8,
45    /// Total length
46    pub len: u16,
47    /// Identification
48    pub id: u16,
49    /// Fragment offset field
50    pub off: u16,
51    /// Time to live
52    pub ttl: u8,
53    /// Protocol
54    pub p: u8,
55    /// Checksum
56    pub sum: u16,
57    /// Source address
58    pub src: Ipv4Addr,
59    /// Dest address
60    pub dst: Ipv4Addr,
61}
62
63impl Ipv4PacketHeader {
64    pub const MIN_SIZE: usize = 20;
65    pub const CHECKSUM_WORD: usize = 5;
66
67    pub const IP_DF: u16 = 0x4000; // Don't fragment flag
68    pub const IP_MF: u16 = 0x2000; // More fragments flag
69
70    /// Create a new header instance
71    pub fn new(src: Ipv4Addr, dst: Ipv4Addr, proto: u8) -> Self {
72        Self {
73            version: 4,
74            hlen: Self::MIN_SIZE as _,
75            tos: 0,
76            len: Self::MIN_SIZE as _,
77            id: 0,
78            off: 0,
79            ttl: 64,
80            p: proto,
81            sum: 0,
82            src,
83            dst,
84        }
85    }
86
87    /// Decodes the header from a byte slice
88    pub fn decode(data: &[u8]) -> Result<Self, Error> {
89        let mut bytes = BytesIn::new(data);
90
91        let vhl = bytes.byte()?;
92
93        Ok(Self {
94            version: vhl >> 4,
95            hlen: (vhl & 0x0f) * 4,
96            tos: bytes.byte()?,
97            len: u16::from_be_bytes(bytes.arr()?),
98            id: u16::from_be_bytes(bytes.arr()?),
99            off: u16::from_be_bytes(bytes.arr()?),
100            ttl: bytes.byte()?,
101            p: bytes.byte()?,
102            sum: u16::from_be_bytes(bytes.arr()?),
103            src: u32::from_be_bytes(bytes.arr()?).into(),
104            dst: u32::from_be_bytes(bytes.arr()?).into(),
105        })
106    }
107
108    /// Encodes the header into the provided buf slice
109    pub fn encode<'o>(&self, buf: &'o mut [u8]) -> Result<&'o [u8], Error> {
110        let mut bytes = BytesOut::new(buf);
111
112        bytes
113            .byte(
114                (self.version << 4)
115                    | (self.hlen / 4 + (if !self.hlen.is_multiple_of(4) { 1 } else { 0 })),
116            )?
117            .byte(self.tos)?
118            .push(&u16::to_be_bytes(self.len))?
119            .push(&u16::to_be_bytes(self.id))?
120            .push(&u16::to_be_bytes(self.off))?
121            .byte(self.ttl)?
122            .byte(self.p)?
123            .push(&u16::to_be_bytes(self.sum))?
124            .push(&u32::to_be_bytes(self.src.into()))?
125            .push(&u32::to_be_bytes(self.dst.into()))?;
126
127        let len = bytes.len();
128
129        Ok(&buf[..len])
130    }
131
132    /// Encodes the header and the provided payload into the provided buf slice
133    pub fn encode_with_payload<'o, F>(
134        &mut self,
135        buf: &'o mut [u8],
136        encoder: F,
137    ) -> Result<&'o [u8], Error>
138    where
139        F: FnOnce(&mut [u8]) -> Result<usize, Error>,
140    {
141        let hdr_len = self.hlen as usize;
142        if hdr_len < Self::MIN_SIZE || buf.len() < hdr_len {
143            Err(Error::BufferOverflow)?;
144        }
145
146        let (hdr_buf, payload_buf) = buf.split_at_mut(hdr_len);
147
148        let payload_len = encoder(payload_buf)?;
149
150        let len = hdr_len + payload_len;
151        self.len = len as _;
152
153        let min_hdr_len = self.encode(hdr_buf)?.len();
154        assert_eq!(min_hdr_len, Self::MIN_SIZE);
155
156        hdr_buf[Self::MIN_SIZE..hdr_len].fill(0);
157
158        let checksum = Self::checksum(hdr_buf);
159        self.sum = checksum;
160
161        Self::inject_checksum(hdr_buf, checksum);
162
163        Ok(&buf[..len])
164    }
165
166    /// Decodes the provided packet into a header and a payload slice
167    pub fn decode_with_payload(
168        packet: &[u8],
169        filter_src: Ipv4Addr,
170        filter_dst: Ipv4Addr,
171        filter_proto: Option<u8>,
172    ) -> Result<Option<(Self, &[u8])>, Error> {
173        let hdr = Self::decode(packet)?;
174        if hdr.version == 4 {
175            // IPv4
176
177            if !filter_src.is_unspecified() && !hdr.src.is_broadcast() && filter_src != hdr.src {
178                return Ok(None);
179            }
180
181            if !filter_dst.is_unspecified() && !hdr.dst.is_broadcast() && filter_dst != hdr.dst {
182                return Ok(None);
183            }
184
185            if let Some(filter_proto) = filter_proto {
186                if filter_proto != hdr.p {
187                    return Ok(None);
188                }
189            }
190
191            let len = hdr.len as usize;
192            if packet.len() < len {
193                Err(Error::DataUnderflow)?;
194            }
195
196            let checksum = Self::checksum(&packet[..len]);
197
198            trace!("IP header decoded, total_size={}, src={}, dst={}, hlen={}, size={}, checksum={}, ours={}", packet.len(), hdr.src, hdr.dst, hdr.hlen, hdr.len, hdr.sum, checksum);
199
200            if checksum != hdr.sum {
201                Err(Error::InvalidChecksum)?;
202            }
203
204            let packet = &packet[..len];
205            let hdr_len = hdr.hlen as usize;
206            if packet.len() < hdr_len {
207                Err(Error::DataUnderflow)?;
208            }
209
210            Ok(Some((hdr, &packet[hdr_len..])))
211        } else {
212            Err(Error::InvalidFormat)
213        }
214    }
215
216    /// Injects the checksum into the provided packet
217    pub fn inject_checksum(packet: &mut [u8], checksum: u16) {
218        let checksum = checksum.to_be_bytes();
219
220        let offset = Self::CHECKSUM_WORD << 1;
221        packet[offset] = checksum[0];
222        packet[offset + 1] = checksum[1];
223    }
224
225    /// Computes the checksum for an already encoded packet
226    pub fn checksum(packet: &[u8]) -> u16 {
227        let hlen = (packet[0] & 0x0f) as usize * 4;
228
229        let sum = checksum_accumulate(&packet[..hlen], Self::CHECKSUM_WORD);
230
231        checksum_finish(sum)
232    }
233}