1use core::{fmt, mem};
2
3use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned, network_endian};
4
5use super::{ChecksumWords, DataDebug};
6
7#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
8#[repr(C, packed)]
9pub struct Ipv4Pdu {
10    pub fields: Ipv4HeaderFields,
11    pub options_payload: [u8],
12}
13
14impl Ipv4Pdu {
15    #[inline]
16    pub fn from_bytes(buf: &[u8]) -> Result<&Self, Ipv4PduError> {
17        Ipv4Pdu::ref_from_bytes(buf)
18            .map_err(zerocopy::SizeError::from)
19            .map_err(Into::into)
20    }
21
22    #[inline]
23    pub fn from_bytes_mut(buf: &mut [u8]) -> Result<&mut Self, Ipv4PduError> {
24        Ipv4Pdu::mut_from_bytes(buf)
25            .map_err(zerocopy::SizeError::from)
26            .map_err(Into::into)
27    }
28
29    #[inline]
30    pub fn as_parts(&self) -> Result<(&Ipv4Header, &[u8]), Ipv4PduError> {
31        let len = self.fields.header_length();
32        let buf = self.as_bytes();
33        if len < Ipv4HeaderFields::SIZE {
34            return Err(Ipv4PduError::InvalidHeaderLength);
35        }
36        let (header, payload) = buf
37            .split_at_checked(len)
38            .ok_or(Ipv4PduError::InvalidHeaderLength)?;
39
40        Ok((
41            Ipv4Header::ref_from_bytes(header).map_err(zerocopy::SizeError::from)?,
42            payload,
43        ))
44    }
45
46    #[inline]
48    pub fn as_mut_parts(
49        &mut self,
50        options: usize,
51    ) -> Result<(&mut Ipv4Header, &mut [u8]), Ipv4PduError> {
52        let buf = self.as_mut_bytes();
53        let (header, payload) = buf
54            .split_at_mut_checked(Ipv4HeaderFields::SIZE.saturating_add(options))
55            .ok_or(Ipv4PduError::InvalidHeaderLength)?;
56        Ok((
57            Ipv4Header::mut_from_bytes(header).map_err(zerocopy::SizeError::from)?,
58            payload,
59        ))
60    }
61}
62
63impl fmt::Debug for Ipv4Pdu {
64    #[inline]
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("Ipv4Pdu")
67            .field("fields", &self.fields)
68            .field("options_payload", &DataDebug(&self.options_payload))
70            .finish()
71    }
72}
73
74#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
75#[repr(C, packed)]
76pub struct Ipv4Header {
77    pub fields: Ipv4HeaderFields,
78    pub options: [u8],
79}
80
81impl Ipv4Header {
82    #[inline]
83    #[must_use]
84    pub const fn length(&self) -> usize {
85        mem::size_of_val(&self.fields).wrapping_add(self.options.len())
86    }
87
88    #[inline]
89    pub fn update_checksum(&mut self) -> Result<(), Ipv4PduError> {
90        self.as_mut_words()?.update_checksum(0);
91        Ok(())
92    }
93
94    #[inline]
95    pub fn verify_checksum(&self) -> Result<(), Ipv4PduError> {
96        self.as_words()?
97            .verify_checksum(0)
98            .map_err(|()| Ipv4PduError::InvalidChecksum)
99    }
100
101    #[inline]
102    pub fn pseudo_header(&self) -> Result<&Ipv4PseudoHeader, Ipv4PduError> {
103        Ipv4PseudoHeader::ref_from_bytes(self.as_bytes())
105            .map_err(zerocopy::SizeError::from)
106            .map_err(Into::into)
107    }
108
109    fn as_words(&self) -> Result<&Ipv4HeaderWords, Ipv4PduError> {
110        Ipv4HeaderWords::ref_from_bytes(self.as_bytes())
112            .map_err(zerocopy::SizeError::from)
113            .map_err(Into::into)
114    }
115
116    fn as_mut_words(&mut self) -> Result<&mut Ipv4HeaderWords, Ipv4PduError> {
117        Ipv4HeaderWords::mut_from_bytes(self.as_mut_bytes())
119            .map_err(zerocopy::SizeError::from)
120            .map_err(Into::into)
121    }
122}
123
124impl fmt::Debug for Ipv4Header {
125    #[inline]
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        f.debug_struct("Ipv4Header")
128            .field("fields", &self.fields)
129            .field("options", &format_args!("{:x?}", &self.options))
131            .finish()
132    }
133}
134
135#[derive(Copy, Clone, Debug, FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
136#[repr(C, packed)]
137pub struct Ipv4HeaderFields {
138    pub version_ihl: u8,
139    pub dscp_ecn: u8,
140    pub total_length: network_endian::U16,
141    pub identification: network_endian::U16,
142    pub fragmentation: Fragmentation,
143    pub ttl: u8,
144    pub protocol: InetProtocol,
145    pub checksum: network_endian::U16,
146    pub saddr: Ipv4Address,
147    pub daddr: Ipv4Address,
148}
149
150impl Default for Ipv4HeaderFields {
151    #[inline]
152    fn default() -> Self {
153        Ipv4HeaderFields {
154            version_ihl: 0x45,
155            dscp_ecn: 0,
156            total_length: 20.into(),
157            identification: 0.into(),
158            fragmentation: Fragmentation::default(),
159            ttl: 255,
160            protocol: InetProtocol::TCP,
161            checksum: 0.into(),
162            saddr: Ipv4Address::UNSPECIFIED,
163            daddr: Ipv4Address::UNSPECIFIED,
164        }
165    }
166}
167
168impl fmt::Display for Ipv4HeaderFields {
169    #[inline]
170    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171        write!(
172            f,
173            "IPv4: len={} id={} frag={}/{} ttl={} protocol={} checksum={:04x} saddr={} daddr={}",
174            self.total_length,
175            self.identification,
176            self.fragmentation.flags(),
177            self.fragmentation.offset(),
178            self.ttl,
179            self.protocol,
180            self.checksum,
181            self.saddr,
182            self.daddr,
183        )
184    }
185}
186
187impl Ipv4HeaderFields {
188    pub const SIZE: usize = mem::size_of::<Self>();
189    pub const WORDS: usize = Self::SIZE >> 1;
190
191    #[inline]
192    pub const fn set_version(&mut self, version: u8) {
193        self.version_ihl = (self.version_ihl & 0b0000_1111) | ((version & 0b1111) << 4);
194    }
195
196    #[inline]
197    pub const fn set_ihl(&mut self, ihl: u8) {
198        self.version_ihl = (self.version_ihl & 0b1111_0000) | (ihl & 0b1111);
199    }
200
201    #[inline]
202    #[must_use]
203    pub const fn version(&self) -> u8 {
204        (self.version_ihl >> 4) & 0b1111
205    }
206
207    #[inline]
208    #[must_use]
209    pub const fn ihl(&self) -> u8 {
210        self.version_ihl & 0b1111
211    }
212
213    #[inline]
214    #[must_use]
215    #[allow(clippy::as_conversions, reason = "unsigned to usize")]
216    pub const fn header_length(&self) -> usize {
217        (self.ihl() as usize).wrapping_mul(4)
218    }
219
220    #[inline]
221    #[must_use]
222    #[allow(clippy::as_conversions, reason = "unsigned to usize")]
223    pub const fn packet_length(&self) -> usize {
224        self.total_length.get() as usize
225    }
226}
227
228#[derive(
229    Copy,
230    Clone,
231    Debug,
232    FromBytes,
233    IntoBytes,
234    KnownLayout,
235    Immutable,
236    Unaligned,
237    PartialEq,
238    Eq,
239    PartialOrd,
240    Ord,
241    Hash,
242)]
243#[repr(transparent)]
244pub struct Ipv4Address(pub [u8; 4]);
245
246impl Ipv4Address {
247    pub const UNSPECIFIED: Self = Self([0, 0, 0, 0]);
248
249    #[must_use]
250    #[inline]
251    pub const fn into_std(self) -> core::net::Ipv4Addr {
252        let [a, b, c, d] = self.0;
253        core::net::Ipv4Addr::new(a, b, c, d)
254    }
255
256    #[must_use]
257    #[inline]
258    pub const fn from_std(addr: core::net::Ipv4Addr) -> Self {
259        Ipv4Address(addr.octets())
260    }
261}
262
263impl From<core::net::Ipv4Addr> for Ipv4Address {
264    #[inline]
265    fn from(addr: core::net::Ipv4Addr) -> Self {
266        Ipv4Address::from_std(addr)
267    }
268}
269
270impl From<&core::net::Ipv4Addr> for Ipv4Address {
271    #[inline]
272    fn from(addr: &core::net::Ipv4Addr) -> Self {
273        Ipv4Address::from_std(*addr)
274    }
275}
276
277impl From<Ipv4Address> for core::net::Ipv4Addr {
278    #[inline]
279    fn from(addr: Ipv4Address) -> Self {
280        addr.into_std()
281    }
282}
283
284impl From<&Ipv4Address> for core::net::Ipv4Addr {
285    #[inline]
286    fn from(addr: &Ipv4Address) -> Self {
287        addr.into_std()
288    }
289}
290
291impl fmt::Display for Ipv4Address {
292    #[inline]
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        let [a, b, c, d] = self.0;
295        write!(f, "{a}.{b}.{c}.{d}")
296    }
297}
298
299#[derive(
300    Copy,
301    Clone,
302    Debug,
303    FromBytes,
304    IntoBytes,
305    KnownLayout,
306    Immutable,
307    Unaligned,
308    PartialEq,
309    Eq,
310    PartialOrd,
311    Ord,
312    Hash,
313)]
314#[repr(transparent)]
315pub struct InetProtocol(pub u8);
316
317impl InetProtocol {
318    pub const ICMP: Self = Self(1);
319    pub const IGMP: Self = Self(2);
320    pub const TCP: Self = Self(6);
321    pub const UDP: Self = Self(17);
322    pub const ENCAP: Self = Self(41);
323    pub const OSPF: Self = Self(89);
324    pub const SCTP: Self = Self(132);
325
326    #[inline]
327    #[must_use]
328    pub const fn name(self) -> Option<&'static str> {
329        Some(match self {
330            Self::ICMP => "ICMP",
331            Self::IGMP => "IGMP",
332            Self::TCP => "TCP",
333            Self::UDP => "UDP",
334            Self::ENCAP => "ENCAP",
335            Self::OSPF => "OSPF",
336            Self::SCTP => "SCTP",
337            _ => return None,
338        })
339    }
340}
341
342impl fmt::Display for InetProtocol {
343    #[inline]
344    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345        if let Some(name) = self.name() {
346            f.write_str(name)
347        } else {
348            write!(f, "{}", self.0)
349        }
350    }
351}
352
353#[derive(
354    Copy,
355    Clone,
356    Debug,
357    Default,
358    FromBytes,
359    IntoBytes,
360    KnownLayout,
361    Immutable,
362    Unaligned,
363    PartialEq,
364    Eq,
365    PartialOrd,
366    Ord,
367    Hash,
368)]
369#[repr(transparent)]
370pub struct Fragmentation(pub network_endian::U16);
371
372impl Fragmentation {
373    #[inline]
374    #[must_use]
375    pub const fn flags(self) -> u16 {
376        self.0.get() >> 13
377    }
378
379    #[inline]
380    #[must_use]
381    pub const fn dont_fragment(self) -> bool {
382        self.flags() & 0b010 != 0
383    }
384
385    #[inline]
386    #[must_use]
387    pub const fn more_fragments(self) -> bool {
388        self.flags() & 0b001 != 0
389    }
390
391    #[inline]
392    #[must_use]
393    #[allow(clippy::as_conversions, reason = "unsigned to usize")]
394    pub const fn offset(self) -> usize {
395        (self.0.get() & 0b1_1111_1111_1111) as usize
396    }
397}
398
399type Ipv4HeaderWords = ChecksumWords<{ Ipv4HeaderFields::WORDS }, 5>;
400
401#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
403#[repr(C, packed)]
404pub struct Ipv4PseudoHeader {
405    version_ihl_dscp_ecn: network_endian::U16,
406    total_length: network_endian::U16,
407    _identification: network_endian::U16,
408    _fragmentation: Fragmentation,
409    ttl_protocol: network_endian::U16,
410    _checksum: network_endian::U16,
411    saddr: [network_endian::U16; 2],
412    daddr: [network_endian::U16; 2],
413    _options: [u8],
414}
415
416impl Ipv4PseudoHeader {
417    #[allow(clippy::as_conversions, reason = "u16 to u32")]
418    #[inline]
419    #[must_use]
420    pub const fn checksum(&self) -> u32 {
421        let mut cs = 0u32;
422        cs = cs.wrapping_add(self.saddr[0].get() as u32);
423        cs = cs.wrapping_add(self.saddr[1].get() as u32);
424        cs = cs.wrapping_add(self.daddr[0].get() as u32);
425        cs = cs.wrapping_add(self.daddr[1].get() as u32);
426        cs = cs.wrapping_add(self.protocol() as u32);
427        cs.wrapping_add(self.payload_length() as u32)
428    }
429
430    const fn protocol(&self) -> u16 {
431        self.ttl_protocol.get() & 0xff
432    }
433
434    const fn header_length(&self) -> u16 {
435        ((self.version_ihl_dscp_ecn.get() & 0x0F00) >> 8).wrapping_mul(4)
436    }
437
438    const fn total_length(&self) -> u16 {
439        self.total_length.get()
440    }
441
442    const fn payload_length(&self) -> u16 {
443        self.total_length().wrapping_sub(self.header_length())
444    }
445}
446
447impl fmt::Debug for Ipv4PseudoHeader {
448    #[inline]
449    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
450        f.debug_struct("Ipv4PseudoHeader")
451            .field("saddr", &self.saddr)
452            .field("daddr", &self.daddr)
453            .field("proto", &self.protocol())
454            .field("len", &self.payload_length())
455            .finish()
456    }
457}
458
459#[derive(Copy, Clone, Debug, PartialEq, Eq)]
460pub enum Ipv4PduError {
461    InvalidHeaderLength,
462    InvalidChecksum,
463    BufferTooShort,
464}
465
466impl<T: ?Sized> From<zerocopy::SizeError<&mut [u8], T>> for Ipv4PduError {
468    #[inline]
469    fn from(_err: zerocopy::SizeError<&mut [u8], T>) -> Self {
470        Ipv4PduError::BufferTooShort
471    }
472}
473
474impl<T: ?Sized> From<zerocopy::SizeError<&[u8], T>> for Ipv4PduError {
475    #[inline]
476    fn from(_err: zerocopy::SizeError<&[u8], T>) -> Self {
477        Ipv4PduError::BufferTooShort
478    }
479}