nex_packet/
ipv4.rs

1//! An IPv4 packet abstraction.
2
3use crate::{
4    checksum::{ChecksumMode, ChecksumState},
5    ip::IpNextProtocol,
6    packet::{MutablePacket, Packet},
7    util,
8};
9use bytes::{BufMut, Bytes, BytesMut};
10use nex_core::bitfield::*;
11use std::net::Ipv4Addr;
12
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16/// IPv4 Header Length
17pub const IPV4_HEADER_LEN: usize = 20;
18/// IPv4 Header Byte Unit (32 bits)
19pub const IPV4_HEADER_LENGTH_BYTE_UNITS: usize = 4;
20
21/// Represents the IPv4 header flags.
22#[allow(non_snake_case)]
23#[allow(non_upper_case_globals)]
24pub mod Ipv4Flags {
25    use nex_core::bitfield::*;
26    /// Don't Fragment flag.
27    pub const DontFragment: u3 = 0b010;
28    /// More Fragments flag.
29    pub const MoreFragments: u3 = 0b001;
30}
31
32/// Represents the IPv4 options.
33/// <http://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml>
34#[repr(u8)]
35#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
36#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
37pub enum Ipv4OptionType {
38    /// End of Options List
39    EOL = 0,
40    /// No Operation
41    NOP = 1,
42    /// Security
43    SEC = 2,
44    /// Loose Source Route
45    LSR = 3,
46    /// Time Stamp
47    TS = 4,
48    /// Extended Security
49    ESEC = 5,
50    /// Commercial Security
51    CIPSO = 6,
52    /// Record Route
53    RR = 7,
54    /// Stream ID
55    SID = 8,
56    /// Strict Source Route
57    SSR = 9,
58    /// Experimental Measurement
59    ZSU = 10,
60    /// MTU Probe
61    MTUP = 11,
62    /// MTU Reply
63    MTUR = 12,
64    /// Experimental Flow Control
65    FINN = 13,
66    /// Experimental Access Control
67    VISA = 14,
68    /// Encode
69    ENCODE = 15,
70    /// IMI Traffic Descriptor
71    IMITD = 16,
72    /// Extended Internet Protocol
73    EIP = 17,
74    /// Traceroute
75    TR = 18,
76    /// Address Extension
77    ADDEXT = 19,
78    /// Router Alert
79    RTRALT = 20,
80    /// Selective Directed Broadcast
81    SDB = 21,
82    /// Unassigned
83    Unassigned = 22,
84    /// Dynamic Packet State
85    DPS = 23,
86    /// Upstream Multicast Packet
87    UMP = 24,
88    /// Quick-Start
89    QS = 25,
90    /// RFC3692-style Experiment
91    EXP = 30,
92    /// Unknown
93    Unknown(u8),
94}
95
96impl Ipv4OptionType {
97    /// Constructs a new Ipv4OptionType from u8
98    pub fn new(n: u8) -> Ipv4OptionType {
99        match n {
100            0 => Ipv4OptionType::EOL,
101            1 => Ipv4OptionType::NOP,
102            2 => Ipv4OptionType::SEC,
103            3 => Ipv4OptionType::LSR,
104            4 => Ipv4OptionType::TS,
105            5 => Ipv4OptionType::ESEC,
106            6 => Ipv4OptionType::CIPSO,
107            7 => Ipv4OptionType::RR,
108            8 => Ipv4OptionType::SID,
109            9 => Ipv4OptionType::SSR,
110            10 => Ipv4OptionType::ZSU,
111            11 => Ipv4OptionType::MTUP,
112            12 => Ipv4OptionType::MTUR,
113            13 => Ipv4OptionType::FINN,
114            14 => Ipv4OptionType::VISA,
115            15 => Ipv4OptionType::ENCODE,
116            16 => Ipv4OptionType::IMITD,
117            17 => Ipv4OptionType::EIP,
118            18 => Ipv4OptionType::TR,
119            19 => Ipv4OptionType::ADDEXT,
120            20 => Ipv4OptionType::RTRALT,
121            21 => Ipv4OptionType::SDB,
122            22 => Ipv4OptionType::Unassigned,
123            23 => Ipv4OptionType::DPS,
124            24 => Ipv4OptionType::UMP,
125            25 => Ipv4OptionType::QS,
126            30 => Ipv4OptionType::EXP,
127            _ => Ipv4OptionType::Unknown(n),
128        }
129    }
130    pub fn value(&self) -> u8 {
131        match *self {
132            Ipv4OptionType::EOL => 0,
133            Ipv4OptionType::NOP => 1,
134            Ipv4OptionType::SEC => 2,
135            Ipv4OptionType::LSR => 3,
136            Ipv4OptionType::TS => 4,
137            Ipv4OptionType::ESEC => 5,
138            Ipv4OptionType::CIPSO => 6,
139            Ipv4OptionType::RR => 7,
140            Ipv4OptionType::SID => 8,
141            Ipv4OptionType::SSR => 9,
142            Ipv4OptionType::ZSU => 10,
143            Ipv4OptionType::MTUP => 11,
144            Ipv4OptionType::MTUR => 12,
145            Ipv4OptionType::FINN => 13,
146            Ipv4OptionType::VISA => 14,
147            Ipv4OptionType::ENCODE => 15,
148            Ipv4OptionType::IMITD => 16,
149            Ipv4OptionType::EIP => 17,
150            Ipv4OptionType::TR => 18,
151            Ipv4OptionType::ADDEXT => 19,
152            Ipv4OptionType::RTRALT => 20,
153            Ipv4OptionType::SDB => 21,
154            Ipv4OptionType::Unassigned => 22,
155            Ipv4OptionType::DPS => 23,
156            Ipv4OptionType::UMP => 24,
157            Ipv4OptionType::QS => 25,
158            Ipv4OptionType::EXP => 30,
159            Ipv4OptionType::Unknown(n) => n,
160        }
161    }
162}
163
164/// Represents the IPv4 option header.
165#[derive(Clone, Debug, PartialEq, Eq)]
166#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
167pub struct Ipv4OptionHeader {
168    pub copied: u1,
169    pub class: u2,
170    pub number: Ipv4OptionType,
171    pub length: Option<u8>,
172}
173
174/// Represents the IPv4 Option field.
175#[derive(Clone, Debug, PartialEq, Eq)]
176#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
177pub struct Ipv4OptionPacket {
178    pub header: Ipv4OptionHeader,
179    pub data: Bytes,
180}
181
182/// Represents the IPv4 header.
183#[derive(Clone, Debug, PartialEq, Eq)]
184#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
185pub struct Ipv4Header {
186    pub version: u4,
187    pub header_length: u4,
188    pub dscp: u6,
189    pub ecn: u2,
190    pub total_length: u16be,
191    pub identification: u16be,
192    pub flags: u3,
193    pub fragment_offset: u13be,
194    pub ttl: u8,
195    pub next_level_protocol: IpNextProtocol,
196    pub checksum: u16be,
197    pub source: Ipv4Addr,
198    pub destination: Ipv4Addr,
199    pub options: Vec<Ipv4OptionPacket>,
200}
201
202/// Represents an IPv4 Packet.
203#[derive(Clone, Debug, PartialEq, Eq)]
204#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
205pub struct Ipv4Packet {
206    pub header: Ipv4Header,
207    pub payload: Bytes,
208}
209
210impl Packet for Ipv4Packet {
211    type Header = Ipv4Header;
212
213    fn from_buf(bytes: &[u8]) -> Option<Self> {
214        if bytes.len() < IPV4_HEADER_LEN {
215            return None;
216        }
217
218        let version = (bytes[0] & 0xF0) >> 4;
219        let header_length = (bytes[0] & 0x0F) as usize;
220        let total_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
221        let total_length = if total_length > bytes.len() {
222            // fallback
223            bytes.len()
224        } else {
225            total_length
226        };
227
228        if header_length < 5 {
229            return None;
230        }
231
232        let ihl_bytes = header_length * 4;
233        if ihl_bytes < IPV4_HEADER_LEN || ihl_bytes > total_length {
234            return None;
235        }
236        let payload = Bytes::copy_from_slice(&bytes[ihl_bytes..total_length]);
237
238        let mut options = Vec::new();
239        let mut i = IPV4_HEADER_LEN;
240
241        while i < ihl_bytes {
242            let b = bytes[i];
243            let copied = (b >> 7) & 0x01;
244            let class = (b >> 5) & 0x03;
245            let number = Ipv4OptionType::new(b & 0b0001_1111);
246
247            match number {
248                Ipv4OptionType::EOL => {
249                    options.push(Ipv4OptionPacket {
250                        header: Ipv4OptionHeader {
251                            copied,
252                            class,
253                            number,
254                            length: None,
255                        },
256                        data: Bytes::new(),
257                    });
258                    break;
259                }
260                Ipv4OptionType::NOP => {
261                    options.push(Ipv4OptionPacket {
262                        header: Ipv4OptionHeader {
263                            copied,
264                            class,
265                            number,
266                            length: None,
267                        },
268                        data: Bytes::new(),
269                    });
270                    i += 1;
271                }
272                _ => {
273                    if i + 2 > ihl_bytes {
274                        break;
275                    }
276                    let len = bytes[i + 1] as usize;
277                    if len < 2 || i + len > ihl_bytes {
278                        break;
279                    }
280
281                    let data = Bytes::copy_from_slice(&bytes[i + 2..i + len]);
282
283                    options.push(Ipv4OptionPacket {
284                        header: Ipv4OptionHeader {
285                            copied,
286                            class,
287                            number,
288                            length: Some(len as u8),
289                        },
290                        data,
291                    });
292
293                    i += len;
294                }
295            }
296        }
297
298        Some(Self {
299            header: Ipv4Header {
300                version: version as u4,
301                header_length: header_length as u4,
302                dscp: (bytes[1] >> 2) as u6,
303                ecn: (bytes[1] & 0x03) as u2,
304                total_length: u16::from_be_bytes([bytes[2], bytes[3]]) as u16be,
305                identification: u16::from_be_bytes([bytes[4], bytes[5]]) as u16be,
306                flags: (bytes[6] >> 5) as u3,
307                fragment_offset: ((u16::from_be_bytes([bytes[6], bytes[7]])) & 0x1FFF) as u13be,
308                ttl: bytes[8],
309                next_level_protocol: IpNextProtocol::new(bytes[9]),
310                checksum: u16::from_be_bytes([bytes[10], bytes[11]]) as u16be,
311                source: Ipv4Addr::new(bytes[12], bytes[13], bytes[14], bytes[15]),
312                destination: Ipv4Addr::new(bytes[16], bytes[17], bytes[18], bytes[19]),
313                options,
314            },
315            payload,
316        })
317    }
318
319    fn from_bytes(bytes: Bytes) -> Option<Self> {
320        Self::from_buf(&bytes)
321    }
322
323    fn to_bytes(&self) -> Bytes {
324        // 1. Version/IHL + DSCP/ECN
325        let mut tmp_buf = BytesMut::with_capacity(60); // max header size
326        for option in &self.header.options {
327            let number = option.header.number.value();
328            let type_byte =
329                (option.header.copied << 7) | (option.header.class << 5) | (number & 0b0001_1111);
330            tmp_buf.put_u8(type_byte);
331
332            match option.header.number {
333                Ipv4OptionType::EOL | Ipv4OptionType::NOP => {}
334                _ => {
335                    let len = option
336                        .header
337                        .length
338                        .unwrap_or((option.data.len() + 2) as u8);
339                    tmp_buf.put_u8(len);
340                    tmp_buf.extend_from_slice(&option.data);
341                }
342            }
343        }
344
345        // padding
346        while tmp_buf.len() % 4 != 0 {
347            tmp_buf.put_u8(0);
348        }
349
350        let header_len = IPV4_HEADER_LEN + tmp_buf.len();
351
352        let total_len_expected = header_len + self.payload.len();
353        // Check if the total length exceeds the header's total_length field
354        if total_len_expected > self.header.total_length as usize {
355            panic!(
356                "Payload too long: header {} + payload {} = {} > total_length {}",
357                header_len,
358                self.payload.len(),
359                total_len_expected,
360                self.header.total_length
361            );
362        }
363
364        let header_len_words = (header_len / 4) as u8;
365
366        let mut buf = BytesMut::with_capacity(self.total_len());
367
368        buf.put_u8((self.header.version << 4 | header_len_words) as u8);
369        buf.put_u8((self.header.dscp << 2 | self.header.ecn) as u8);
370
371        // 2. Fixed header fields
372        buf.put_u16(self.header.total_length);
373        buf.put_u16(self.header.identification);
374        buf.put_u16(((self.header.flags as u16) << 13) | self.header.fragment_offset);
375        buf.put_u8(self.header.ttl);
376        buf.put_u8(self.header.next_level_protocol.value());
377        buf.put_u16(self.header.checksum);
378        buf.extend_from_slice(&self.header.source.octets());
379        buf.extend_from_slice(&self.header.destination.octets());
380
381        // 3. options
382        buf.extend_from_slice(&tmp_buf);
383
384        // 4. payload
385        buf.extend_from_slice(&self.payload);
386
387        buf.freeze()
388    }
389
390    fn header(&self) -> Bytes {
391        self.to_bytes().slice(..self.header_len())
392    }
393
394    fn payload(&self) -> Bytes {
395        self.payload.clone()
396    }
397
398    fn header_len(&self) -> usize {
399        self.header.header_length as usize * 4
400    }
401
402    fn payload_len(&self) -> usize {
403        self.payload.len()
404    }
405
406    fn total_len(&self) -> usize {
407        self.header_len() + self.payload_len()
408    }
409
410    fn into_parts(self) -> (Self::Header, Bytes) {
411        (self.header, self.payload)
412    }
413}
414
415impl Ipv4Packet {
416    pub fn with_computed_checksum(mut self) -> Self {
417        self.header.checksum = checksum(&self);
418        self
419    }
420}
421
422/// Represents a mutable IPv4 packet.
423pub struct MutableIpv4Packet<'a> {
424    buffer: &'a mut [u8],
425    checksum: ChecksumState,
426}
427
428impl<'a> MutablePacket<'a> for MutableIpv4Packet<'a> {
429    type Packet = Ipv4Packet;
430
431    fn new(buffer: &'a mut [u8]) -> Option<Self> {
432        if buffer.len() < IPV4_HEADER_LEN {
433            return None;
434        }
435
436        let ihl = (buffer[0] & 0x0F) as usize;
437        if ihl < 5 {
438            return None;
439        }
440
441        let header_len = ihl * IPV4_HEADER_LENGTH_BYTE_UNITS;
442        if header_len > buffer.len() {
443            return None;
444        }
445
446        let total_len = u16::from_be_bytes([buffer[2], buffer[3]]) as usize;
447        if total_len != 0 && total_len < header_len {
448            return None;
449        }
450
451        Some(Self {
452            buffer,
453            checksum: ChecksumState::new(),
454        })
455    }
456
457    fn packet(&self) -> &[u8] {
458        &*self.buffer
459    }
460
461    fn packet_mut(&mut self) -> &mut [u8] {
462        &mut *self.buffer
463    }
464
465    fn header(&self) -> &[u8] {
466        let header_len = self.header_len();
467        &self.packet()[..header_len]
468    }
469
470    fn header_mut(&mut self) -> &mut [u8] {
471        let header_len = self.header_len();
472        let (header, _) = (&mut *self.buffer).split_at_mut(header_len);
473        header
474    }
475
476    fn payload(&self) -> &[u8] {
477        let start = self.header_len();
478        let end = start + self.payload_len();
479        &self.packet()[start..end]
480    }
481
482    fn payload_mut(&mut self) -> &mut [u8] {
483        let header_len = self.header_len();
484        let payload_len = self.payload_len();
485        let (_, payload) = (&mut *self.buffer).split_at_mut(header_len);
486        &mut payload[..payload_len]
487    }
488}
489
490impl<'a> MutableIpv4Packet<'a> {
491    /// Create a mutable packet without validating the header fields.
492    pub fn new_unchecked(buffer: &'a mut [u8]) -> Self {
493        Self {
494            buffer,
495            checksum: ChecksumState::new(),
496        }
497    }
498
499    fn raw(&self) -> &[u8] {
500        &*self.buffer
501    }
502
503    fn raw_mut(&mut self) -> &mut [u8] {
504        &mut *self.buffer
505    }
506
507    fn after_field_mutation(&mut self) {
508        self.checksum.mark_dirty();
509        if self.checksum.automatic() {
510            let _ = self.recompute_checksum();
511        }
512    }
513
514    fn write_checksum(&mut self, checksum: u16) {
515        self.raw_mut()[10..12].copy_from_slice(&checksum.to_be_bytes());
516    }
517
518    /// Returns the current checksum recalculation mode.
519    pub fn checksum_mode(&self) -> ChecksumMode {
520        self.checksum.mode()
521    }
522
523    /// Updates the checksum recalculation mode.
524    pub fn set_checksum_mode(&mut self, mode: ChecksumMode) {
525        self.checksum.set_mode(mode);
526        if self.checksum.automatic() && self.checksum.is_dirty() {
527            let _ = self.recompute_checksum();
528        }
529    }
530
531    /// Enables automatic checksum recalculation.
532    pub fn enable_auto_checksum(&mut self) {
533        self.set_checksum_mode(ChecksumMode::Automatic);
534    }
535
536    /// Disables automatic checksum recalculation.
537    pub fn disable_auto_checksum(&mut self) {
538        self.set_checksum_mode(ChecksumMode::Manual);
539    }
540
541    /// Returns true when the checksum must be recomputed before serialization.
542    pub fn is_checksum_dirty(&self) -> bool {
543        self.checksum.is_dirty()
544    }
545
546    /// Marks the checksum as stale and triggers recomputation when automatic mode is enabled.
547    pub fn mark_checksum_dirty(&mut self) {
548        self.checksum.mark_dirty();
549        if self.checksum.automatic() {
550            let _ = self.recompute_checksum();
551        }
552    }
553
554    /// Recomputes the IPv4 header checksum using the current buffer contents.
555    pub fn recompute_checksum(&mut self) -> Option<u16> {
556        let header_len = self.header_len();
557        if header_len > self.raw().len() {
558            return None;
559        }
560
561        let checksum = util::checksum(&self.raw()[..header_len], 5) as u16;
562        self.write_checksum(checksum);
563        self.checksum.clear_dirty();
564        Some(checksum)
565    }
566
567    /// Returns the header length in bytes.
568    pub fn header_len(&self) -> usize {
569        let ihl = (self.raw()[0] & 0x0F) as usize;
570        let header_len = ihl * IPV4_HEADER_LENGTH_BYTE_UNITS;
571        header_len.max(IPV4_HEADER_LEN).min(self.raw().len())
572    }
573
574    /// Returns the payload length based on the total length field.
575    pub fn payload_len(&self) -> usize {
576        let total = self.total_len();
577        total.saturating_sub(self.header_len())
578    }
579
580    /// Returns the effective total length of the packet.
581    pub fn total_len(&self) -> usize {
582        let total = u16::from_be_bytes([self.raw()[2], self.raw()[3]]) as usize;
583        if total == 0 {
584            self.raw().len()
585        } else {
586            total.min(self.raw().len())
587        }
588    }
589
590    /// Retrieve the version field.
591    pub fn get_version(&self) -> u8 {
592        self.raw()[0] >> 4
593    }
594
595    /// Update the version field.
596    pub fn set_version(&mut self, version: u8) {
597        let buffer = self.raw_mut();
598        buffer[0] = (buffer[0] & 0x0F) | ((version & 0x0F) << 4);
599        self.after_field_mutation();
600    }
601
602    /// Retrieve the header length in 32-bit words.
603    pub fn get_header_length(&self) -> u8 {
604        self.raw()[0] & 0x0F
605    }
606
607    /// Update the header length in 32-bit words.
608    pub fn set_header_length(&mut self, ihl: u8) {
609        let buffer = self.raw_mut();
610        buffer[0] = (buffer[0] & 0xF0) | (ihl & 0x0F);
611        self.after_field_mutation();
612    }
613
614    /// Retrieve the DSCP field.
615    pub fn get_dscp(&self) -> u8 {
616        self.raw()[1] >> 2
617    }
618
619    /// Update the DSCP field.
620    pub fn set_dscp(&mut self, dscp: u8) {
621        let buffer = self.raw_mut();
622        buffer[1] = (buffer[1] & 0x03) | ((dscp & 0x3F) << 2);
623        self.after_field_mutation();
624    }
625
626    /// Retrieve the ECN field.
627    pub fn get_ecn(&self) -> u8 {
628        self.raw()[1] & 0x03
629    }
630
631    /// Update the ECN field.
632    pub fn set_ecn(&mut self, ecn: u8) {
633        let buffer = self.raw_mut();
634        buffer[1] = (buffer[1] & 0xFC) | (ecn & 0x03);
635        self.after_field_mutation();
636    }
637
638    /// Retrieve the total length field.
639    pub fn get_total_length(&self) -> u16 {
640        u16::from_be_bytes([self.raw()[2], self.raw()[3]])
641    }
642
643    /// Update the total length field.
644    pub fn set_total_length(&mut self, len: u16) {
645        self.raw_mut()[2..4].copy_from_slice(&len.to_be_bytes());
646        self.after_field_mutation();
647    }
648
649    /// Retrieve the identification field.
650    pub fn get_identification(&self) -> u16 {
651        u16::from_be_bytes([self.raw()[4], self.raw()[5]])
652    }
653
654    /// Update the identification field.
655    pub fn set_identification(&mut self, id: u16) {
656        self.raw_mut()[4..6].copy_from_slice(&id.to_be_bytes());
657        self.after_field_mutation();
658    }
659
660    /// Retrieve the flags field.
661    pub fn get_flags(&self) -> u8 {
662        (self.raw()[6] & 0xE0) >> 5
663    }
664
665    /// Update the flags field.
666    pub fn set_flags(&mut self, flags: u8) {
667        let buffer = self.raw_mut();
668        buffer[6] = (buffer[6] & 0x1F) | ((flags & 0x07) << 5);
669        self.after_field_mutation();
670    }
671
672    /// Retrieve the fragment offset field.
673    pub fn get_fragment_offset(&self) -> u16 {
674        u16::from_be_bytes([self.raw()[6], self.raw()[7]]) & 0x1FFF
675    }
676
677    /// Update the fragment offset field.
678    pub fn set_fragment_offset(&mut self, offset: u16) {
679        let buffer = self.raw_mut();
680        let combined = (u16::from_be_bytes([buffer[6], buffer[7]]) & 0xE000) | (offset & 0x1FFF);
681        buffer[6..8].copy_from_slice(&combined.to_be_bytes());
682        self.after_field_mutation();
683    }
684
685    /// Retrieve the TTL field.
686    pub fn get_ttl(&self) -> u8 {
687        self.raw()[8]
688    }
689
690    /// Update the TTL field.
691    pub fn set_ttl(&mut self, ttl: u8) {
692        self.raw_mut()[8] = ttl;
693        self.after_field_mutation();
694    }
695
696    /// Retrieve the next-level protocol field.
697    pub fn get_next_level_protocol(&self) -> IpNextProtocol {
698        IpNextProtocol::new(self.raw()[9])
699    }
700
701    /// Update the next-level protocol field.
702    pub fn set_next_level_protocol(&mut self, proto: IpNextProtocol) {
703        self.raw_mut()[9] = proto.value();
704        self.after_field_mutation();
705    }
706
707    /// Retrieve the checksum field.
708    pub fn get_checksum(&self) -> u16 {
709        u16::from_be_bytes([self.raw()[10], self.raw()[11]])
710    }
711
712    /// Update the checksum field.
713    pub fn set_checksum(&mut self, checksum: u16) {
714        self.write_checksum(checksum);
715        self.checksum.clear_dirty();
716    }
717
718    /// Retrieve the source address.
719    pub fn get_source(&self) -> Ipv4Addr {
720        Ipv4Addr::new(
721            self.raw()[12],
722            self.raw()[13],
723            self.raw()[14],
724            self.raw()[15],
725        )
726    }
727
728    /// Update the source address.
729    pub fn set_source(&mut self, addr: Ipv4Addr) {
730        self.raw_mut()[12..16].copy_from_slice(&addr.octets());
731        self.after_field_mutation();
732    }
733
734    /// Retrieve the destination address.
735    pub fn get_destination(&self) -> Ipv4Addr {
736        Ipv4Addr::new(
737            self.raw()[16],
738            self.raw()[17],
739            self.raw()[18],
740            self.raw()[19],
741        )
742    }
743
744    /// Update the destination address.
745    pub fn set_destination(&mut self, addr: Ipv4Addr) {
746        self.raw_mut()[16..20].copy_from_slice(&addr.octets());
747        self.after_field_mutation();
748    }
749}
750
751/// Calculates a checksum of an IPv4 packet header.
752/// The checksum field of the packet is regarded as zeros during the calculation.
753pub fn checksum(packet: &Ipv4Packet) -> u16be {
754    use crate::util;
755
756    let bytes = packet.to_bytes();
757    let len = packet.header_len();
758    util::checksum(&bytes[..len], 5)
759}
760
761#[cfg(test)]
762mod tests {
763    use super::*;
764    #[test]
765    fn test_ipv4_packet_round_trip() {
766        let raw = Bytes::from_static(&[
767            0x45, 0x00, 0x00, 0x1c, // Version + IHL, DSCP + ECN, Total Length (28)
768            0x1c, 0x46, 0x40, 0x00, // Identification, Flags + Fragment Offset
769            0x40, 0x06, 0xb1, 0xe6, // TTL, Protocol (TCP), Header checksum
770            0xc0, 0xa8, 0x00, 0x01, // Source: 192.168.0.1
771            0xc0, 0xa8, 0x00, 0xc7, // Destination: 192.168.0.199
772            // Payload (8 bytes)
773            0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe,
774        ]);
775
776        let packet = Ipv4Packet::from_bytes(raw.clone()).expect("Failed to parse Ipv4Packet");
777        assert_eq!(packet.header.version, 4);
778        assert_eq!(packet.header.header_length, 5);
779        assert_eq!(packet.header.total_length, 28u16);
780        assert_eq!(packet.header.source, Ipv4Addr::new(192, 168, 0, 1));
781        assert_eq!(packet.header.destination, Ipv4Addr::new(192, 168, 0, 199));
782        assert_eq!(
783            packet.payload,
784            Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe])
785        );
786
787        let serialized = packet.to_bytes();
788        assert_eq!(&serialized[..], &raw[..]);
789    }
790
791    #[test]
792    fn test_ipv4_packet_with_options_round_trip() {
793        let raw = Bytes::from_static(&[
794            // IPv4 header (20bytes + 8bytes option + 4bytes payload = 32bytes -> IHL=7)
795            0x47, 0x00, 0x00,
796            0x20, // [0-3] Version(4), IHL(7=28bytes), DSCP/ECN, Total Length=32 bytes
797            0x12, 0x34, 0x40, 0x00, // [4-7] Identification, Flags=DF(0x40), Fragment Offset
798            0x40, 0x11, 0x00,
799            0x00, // [8-11] TTL=64, Protocol=17(UDP), Header Checksum (0 for now)
800            0xc0, 0xa8, 0x00, 0x01, // [12-15] Source IP = 192.168.0.1
801            0xc0, 0xa8, 0x00, 0x02, // [16-19] Destination IP = 192.168.0.2
802            // IPv4 options (8bytes)
803            // Option 1: 1byte NOP
804            0x01, // [20] NOP (No Operation)
805            // Option 2: 4bytes
806            0x87, 0x04, 0x12,
807            0x34, // [21-24] Option Type=RR(7), Copied=1, Class=0, Length=4, Data=[0x12, 0x34]
808            // Option 3: EOL (End of Options List) with padding
809            0x00, // [25] EOL (End of Options List)
810            0x00, // [26] Padding
811            0x00, // [27] Padding
812            // Payload 4bytes
813            0xde, 0xad, 0xbe, 0xef, // [28-31] Payload: deadbeef
814        ]);
815
816        let packet = Ipv4Packet::from_bytes(raw.clone()).expect("Failed to parse Ipv4Packet");
817
818        assert_eq!(packet.header.version, 4);
819        assert_eq!(packet.header.header_length, 7);
820        assert_eq!(packet.header.total_length, 32);
821        assert_eq!(packet.header.source, Ipv4Addr::new(192, 168, 0, 1));
822        assert_eq!(packet.header.destination, Ipv4Addr::new(192, 168, 0, 2));
823
824        assert_eq!(
825            packet.payload,
826            Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef])
827        );
828
829        assert_eq!(packet.header.options.len(), 3);
830        assert_eq!(packet.header.options[0].header.number, Ipv4OptionType::NOP);
831        assert_eq!(packet.header.options[1].header.copied, 1);
832        assert_eq!(packet.header.options[1].header.class, 0);
833        assert_eq!(packet.header.options[1].header.number, Ipv4OptionType::RR);
834        assert_eq!(packet.header.options[1].header.number.value(), 7);
835        assert_eq!(packet.header.options[1].header.length, Some(4));
836        assert_eq!(packet.header.options[1].data.as_ref(), &[0x12, 0x34]);
837        assert_eq!(packet.header.options[2].header.number, Ipv4OptionType::EOL);
838
839        let serialized = packet.to_bytes();
840        assert_eq!(&serialized[..], &raw[..]);
841    }
842
843    #[test]
844    fn ipv4_option_packet_test() {
845        let option = Ipv4OptionPacket {
846            header: Ipv4OptionHeader {
847                copied: 1,
848                class: 0,
849                number: Ipv4OptionType::LSR,
850                length: Some(3),
851            },
852            data: Bytes::from_static(&[0x10]),
853        };
854
855        let mut buf = BytesMut::new();
856        let ty = (option.header.copied << 7)
857            | (option.header.class << 5)
858            | (option.header.number.value() & 0x1F);
859        buf.put_u8(ty);
860        buf.put_u8(3);
861        buf.put_slice(&[0x10]);
862
863        assert_eq!(buf.freeze(), Bytes::from_static(&[0x83, 0x03, 0x10]));
864    }
865
866    #[test]
867    #[should_panic(expected = "Payload too long")]
868    fn ipv4_payload_too_long_should_panic() {
869        let packet = Ipv4Packet {
870            header: Ipv4Header {
871                version: 4,
872                header_length: 5,
873                dscp: 0,
874                ecn: 0,
875                total_length: 24, // Header 20 + payload 4 = 24 but ...
876                identification: 0,
877                flags: 0,
878                fragment_offset: 0,
879                ttl: 64,
880                next_level_protocol: IpNextProtocol::Udp,
881                checksum: 0,
882                source: Ipv4Addr::LOCALHOST,
883                destination: Ipv4Addr::LOCALHOST,
884                options: vec![],
885            },
886            payload: Bytes::from_static(&[0, 1, 2, 3, 4, 5]), // 6 bytes payload
887        };
888
889        // This should panic because the payload length exceeds the total_length specified in the header
890        let _ = packet.to_bytes();
891    }
892
893    #[test]
894    fn test_ipv4_checksum() {
895        let raw = Bytes::from_static(&[
896            0x45, 0x00, 0x00, 0x14, 0x00, 0x00, 0x40, 0x00, 0x40, 0x06, 0x00,
897            0x00, // checksum: 0
898            0x0a, 0x00, 0x00, 0x01, 0x0a, 0x00, 0x00, 0x02,
899        ]);
900
901        let mut packet = Ipv4Packet::from_bytes(raw.clone()).expect("Failed to parse");
902        let computed = checksum(&packet);
903        packet.header.checksum = computed;
904
905        let serialized = packet.to_bytes();
906        let reparsed = Ipv4Packet::from_bytes(serialized).expect("Reparse failed");
907
908        // Check if the checksum matches
909        assert_eq!(reparsed.header.checksum, computed);
910
911        // Check if the serialized bytes match the original raw bytes
912        let mut raw_copy = raw.to_vec();
913        raw_copy[10] = (computed >> 8) as u8;
914        raw_copy[11] = (computed & 0xff) as u8;
915        assert_eq!(&packet.to_bytes()[..], &raw_copy[..]);
916    }
917
918    #[test]
919    fn test_mutable_ipv4_packet_updates() {
920        let mut raw = [
921            0x45, 0x00, 0x00, 0x1c, // Version + IHL, DSCP/ECN, Total Length
922            0x1c, 0x46, 0x40, 0x00, // Identification, Flags/Fragment offset
923            0x40, 0x06, 0x00, 0x00, // TTL, Protocol, Header checksum
924            0xc0, 0xa8, 0x00, 0x01, // Source
925            0xc0, 0xa8, 0x00, 0xc7, // Destination
926            0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe, // Payload
927        ];
928
929        let mut packet = MutableIpv4Packet::new(&mut raw).expect("mutable ipv4");
930        assert_eq!(packet.get_version(), 4);
931        assert_eq!(packet.get_ttl(), 0x40);
932
933        packet.set_ttl(128);
934        packet.set_destination(Ipv4Addr::new(192, 0, 2, 1));
935        packet.payload_mut()[0] = 0x11;
936
937        {
938            let packet_view = packet.packet();
939            assert_eq!(packet_view[8], 128);
940            assert_eq!(&packet_view[16..20], &[192, 0, 2, 1]);
941            assert_eq!(packet_view[20], 0x11);
942        }
943
944        let frozen = packet.freeze().expect("freeze mutable packet");
945        drop(packet);
946
947        assert_eq!(raw[8], 128);
948        assert_eq!(&raw[16..20], &[192, 0, 2, 1]);
949        assert_eq!(raw[20], 0x11);
950
951        assert_eq!(frozen.header.ttl, 128);
952        assert_eq!(frozen.header.destination, Ipv4Addr::new(192, 0, 2, 1));
953        assert_eq!(frozen.payload[0], 0x11);
954    }
955
956    #[test]
957    fn test_ipv4_auto_checksum_updates() {
958        let mut raw = [
959            0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0x00, 0x00, 0xc0, 0xa8,
960            0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe,
961        ];
962
963        let mut packet = MutableIpv4Packet::new(&mut raw).expect("mutable ipv4");
964        packet.enable_auto_checksum();
965        let baseline = packet.recompute_checksum().expect("checksum");
966        let before = packet.get_checksum();
967        assert_eq!(baseline, before);
968
969        packet.set_ttl(0x41);
970        let after = packet.get_checksum();
971        assert_ne!(before, after);
972        assert!(!packet.is_checksum_dirty());
973
974        let frozen = packet.freeze().expect("freeze");
975        let expected = checksum(&frozen);
976        assert_eq!(after, expected);
977    }
978
979    #[test]
980    fn test_ipv4_manual_checksum_tracking() {
981        let mut raw = [
982            0x45, 0x00, 0x00, 0x1c, 0x1c, 0x46, 0x40, 0x00, 0x40, 0x06, 0xb1, 0xe6, 0xc0, 0xa8,
983            0x00, 0x01, 0xc0, 0xa8, 0x00, 0xc7, 0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0xba, 0xbe,
984        ];
985
986        let mut packet = MutableIpv4Packet::new(&mut raw).expect("mutable ipv4");
987        assert!(!packet.is_checksum_dirty());
988
989        packet.set_identification(0x1c47);
990        assert!(packet.is_checksum_dirty());
991
992        let recomputed = packet.recompute_checksum().expect("checksum");
993        assert_eq!(recomputed, packet.get_checksum());
994        assert!(!packet.is_checksum_dirty());
995    }
996}