nex_packet/
tcp.rs

1//! A TCP packet abstraction.
2
3use crate::checksum::{ChecksumMode, ChecksumState, TransportChecksumContext};
4use crate::ip::IpNextProtocol;
5use crate::packet::{MutablePacket, Packet};
6
7use crate::util::{self, Octets};
8use std::net::Ipv6Addr;
9use std::net::{IpAddr, Ipv4Addr};
10
11use bytes::{Buf, BufMut, Bytes, BytesMut};
12use nex_core::bitfield::{u4, u16be, u32be};
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16/// Minimum TCP Header Length
17pub const TCP_HEADER_LEN: usize = 20;
18/// Minimum TCP Data Offset
19pub const TCP_MIN_DATA_OFFSET: u8 = 5;
20/// Maximum TCP Option Length
21pub const TCP_OPTION_MAX_LEN: usize = 40;
22/// Maximum TCP Header Length (with options)
23pub const TCP_HEADER_MAX_LEN: usize = TCP_HEADER_LEN + TCP_OPTION_MAX_LEN;
24
25/// Represents a TCP Option Kind.
26/// <https://www.iana.org/assignments/tcp-parameters/tcp-parameters.xhtml#tcp-parameters-1>
27#[allow(non_camel_case_types)]
28#[repr(u8)]
29#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
30#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
31pub enum TcpOptionKind {
32    EOL = 0,
33    NOP = 1,
34    MSS = 2,
35    WSCALE = 3,
36    SACK_PERMITTED = 4,
37    SACK = 5,
38    ECHO = 6,
39    ECHO_REPLY = 7,
40    TIMESTAMPS = 8,
41    POCP = 9,
42    POSP = 10,
43    CC = 11,
44    CC_NEW = 12,
45    CC_ECHO = 13,
46    ALT_CHECKSUM_REQ = 14,
47    ALT_CHECKSUM_DATA = 15,
48    SKEETER = 16,
49    BUBBA = 17,
50    TRAILER_CHECKSUM = 18,
51    MD5_SIGNATURE = 19,
52    SCPS_CAPABILITIES = 20,
53    SELECTIVE_ACK = 21,
54    RECORD_BOUNDARIES = 22,
55    CORRUPTION_EXPERIENCED = 23,
56    SNAP = 24,
57    UNASSIGNED = 25,
58    TCP_COMPRESSION_FILTER = 26,
59    QUICK_START = 27,
60    USER_TIMEOUT = 28,
61    TCP_AO = 29,
62    MPTCP = 30,
63    RESERVED_31 = 31,
64    RESERVED_32 = 32,
65    RESERVED_33 = 33,
66    FAST_OPEN_COOKIE = 34,
67    TCP_ENO = 69,
68    ACC_ECNO_0 = 172,
69    ACC_ECNO_1 = 174,
70    EXPERIMENT_1 = 253,
71    EXPERIMENT_2 = 254,
72    RESERVED(u8),
73}
74
75impl TcpOptionKind {
76    /// Construct a TCP option kind from a u8.
77    pub fn new(n: u8) -> TcpOptionKind {
78        match n {
79            0 => TcpOptionKind::EOL,
80            1 => TcpOptionKind::NOP,
81            2 => TcpOptionKind::MSS,
82            3 => TcpOptionKind::WSCALE,
83            4 => TcpOptionKind::SACK_PERMITTED,
84            5 => TcpOptionKind::SACK,
85            6 => TcpOptionKind::ECHO,
86            7 => TcpOptionKind::ECHO_REPLY,
87            8 => TcpOptionKind::TIMESTAMPS,
88            9 => TcpOptionKind::POCP,
89            10 => TcpOptionKind::POSP,
90            11 => TcpOptionKind::CC,
91            12 => TcpOptionKind::CC_NEW,
92            13 => TcpOptionKind::CC_ECHO,
93            14 => TcpOptionKind::ALT_CHECKSUM_REQ,
94            15 => TcpOptionKind::ALT_CHECKSUM_DATA,
95            16 => TcpOptionKind::SKEETER,
96            17 => TcpOptionKind::BUBBA,
97            18 => TcpOptionKind::TRAILER_CHECKSUM,
98            19 => TcpOptionKind::MD5_SIGNATURE,
99            20 => TcpOptionKind::SCPS_CAPABILITIES,
100            21 => TcpOptionKind::SELECTIVE_ACK,
101            22 => TcpOptionKind::RECORD_BOUNDARIES,
102            23 => TcpOptionKind::CORRUPTION_EXPERIENCED,
103            24 => TcpOptionKind::SNAP,
104            25 => TcpOptionKind::UNASSIGNED,
105            26 => TcpOptionKind::TCP_COMPRESSION_FILTER,
106            27 => TcpOptionKind::QUICK_START,
107            28 => TcpOptionKind::USER_TIMEOUT,
108            29 => TcpOptionKind::TCP_AO,
109            30 => TcpOptionKind::MPTCP,
110            31 => TcpOptionKind::RESERVED_31,
111            32 => TcpOptionKind::RESERVED_32,
112            33 => TcpOptionKind::RESERVED_33,
113            34 => TcpOptionKind::FAST_OPEN_COOKIE,
114            69 => TcpOptionKind::TCP_ENO,
115            172 => TcpOptionKind::ACC_ECNO_0,
116            174 => TcpOptionKind::ACC_ECNO_1,
117            253 => TcpOptionKind::EXPERIMENT_1,
118            254 => TcpOptionKind::EXPERIMENT_2,
119            _ => TcpOptionKind::RESERVED(n),
120        }
121    }
122
123    /// Get the name of the TCP option kind.
124    pub fn name(&self) -> &'static str {
125        match *self {
126            TcpOptionKind::EOL => "EOL",
127            TcpOptionKind::NOP => "NOP",
128            TcpOptionKind::MSS => "MSS",
129            TcpOptionKind::WSCALE => "WSCALE",
130            TcpOptionKind::SACK_PERMITTED => "SACK_PERMITTED",
131            TcpOptionKind::SACK => "SACK",
132            TcpOptionKind::ECHO => "ECHO",
133            TcpOptionKind::ECHO_REPLY => "ECHO_REPLY",
134            TcpOptionKind::TIMESTAMPS => "TIMESTAMPS",
135            TcpOptionKind::POCP => "POCP",
136            TcpOptionKind::POSP => "POSP",
137            TcpOptionKind::CC => "CC",
138            TcpOptionKind::CC_NEW => "CC_NEW",
139            TcpOptionKind::CC_ECHO => "CC_ECHO",
140            TcpOptionKind::ALT_CHECKSUM_REQ => "ALT_CHECKSUM_REQ",
141            TcpOptionKind::ALT_CHECKSUM_DATA => "ALT_CHECKSUM_DATA",
142            TcpOptionKind::SKEETER => "SKEETER",
143            TcpOptionKind::BUBBA => "BUBBA",
144            TcpOptionKind::TRAILER_CHECKSUM => "TRAILER_CHECKSUM",
145            TcpOptionKind::MD5_SIGNATURE => "MD5_SIGNATURE",
146            TcpOptionKind::SCPS_CAPABILITIES => "SCPS_CAPABILITIES",
147            TcpOptionKind::SELECTIVE_ACK => "SELECTIVE_ACK",
148            TcpOptionKind::RECORD_BOUNDARIES => "RECORD_BOUNDARIES",
149            TcpOptionKind::CORRUPTION_EXPERIENCED => "CORRUPTION_EXPERIENCED",
150            TcpOptionKind::SNAP => "SNAP",
151            TcpOptionKind::UNASSIGNED => "UNASSIGNED",
152            TcpOptionKind::TCP_COMPRESSION_FILTER => "TCP_COMPRESSION_FILTER",
153            TcpOptionKind::QUICK_START => "QUICK_START",
154            TcpOptionKind::USER_TIMEOUT => "USER_TIMEOUT",
155            TcpOptionKind::TCP_AO => "TCP_AO",
156            TcpOptionKind::MPTCP => "MPTCP",
157            TcpOptionKind::RESERVED_31 => "RESERVED_31",
158            TcpOptionKind::RESERVED_32 => "RESERVED_32",
159            TcpOptionKind::RESERVED_33 => "RESERVED_33",
160            TcpOptionKind::FAST_OPEN_COOKIE => "FAST_OPEN_COOKIE",
161            TcpOptionKind::TCP_ENO => "TCP_ENO",
162            TcpOptionKind::ACC_ECNO_0 => "ACC_ECNO_0",
163            TcpOptionKind::ACC_ECNO_1 => "ACC_ECNO_1",
164            TcpOptionKind::EXPERIMENT_1 => "EXPERIMENT_1",
165            TcpOptionKind::EXPERIMENT_2 => "EXPERIMENT_2",
166            TcpOptionKind::RESERVED(_) => "RESERVED",
167        }
168    }
169    /// Get the value of the TCP option kind.
170    pub fn value(&self) -> u8 {
171        match *self {
172            TcpOptionKind::EOL => 0,
173            TcpOptionKind::NOP => 1,
174            TcpOptionKind::MSS => 2,
175            TcpOptionKind::WSCALE => 3,
176            TcpOptionKind::SACK_PERMITTED => 4,
177            TcpOptionKind::SACK => 5,
178            TcpOptionKind::ECHO => 6,
179            TcpOptionKind::ECHO_REPLY => 7,
180            TcpOptionKind::TIMESTAMPS => 8,
181            TcpOptionKind::POCP => 9,
182            TcpOptionKind::POSP => 10,
183            TcpOptionKind::CC => 11,
184            TcpOptionKind::CC_NEW => 12,
185            TcpOptionKind::CC_ECHO => 13,
186            TcpOptionKind::ALT_CHECKSUM_REQ => 14,
187            TcpOptionKind::ALT_CHECKSUM_DATA => 15,
188            TcpOptionKind::SKEETER => 16,
189            TcpOptionKind::BUBBA => 17,
190            TcpOptionKind::TRAILER_CHECKSUM => 18,
191            TcpOptionKind::MD5_SIGNATURE => 19,
192            TcpOptionKind::SCPS_CAPABILITIES => 20,
193            TcpOptionKind::SELECTIVE_ACK => 21,
194            TcpOptionKind::RECORD_BOUNDARIES => 22,
195            TcpOptionKind::CORRUPTION_EXPERIENCED => 23,
196            TcpOptionKind::SNAP => 24,
197            TcpOptionKind::UNASSIGNED => 25,
198            TcpOptionKind::TCP_COMPRESSION_FILTER => 26,
199            TcpOptionKind::QUICK_START => 27,
200            TcpOptionKind::USER_TIMEOUT => 28,
201            TcpOptionKind::TCP_AO => 29,
202            TcpOptionKind::MPTCP => 30,
203            TcpOptionKind::RESERVED_31 => 31,
204            TcpOptionKind::RESERVED_32 => 32,
205            TcpOptionKind::RESERVED_33 => 33,
206            TcpOptionKind::FAST_OPEN_COOKIE => 34,
207            TcpOptionKind::TCP_ENO => 69,
208            TcpOptionKind::ACC_ECNO_0 => 172,
209            TcpOptionKind::ACC_ECNO_1 => 174,
210            TcpOptionKind::EXPERIMENT_1 => 253,
211            TcpOptionKind::EXPERIMENT_2 => 254,
212            TcpOptionKind::RESERVED(n) => n,
213        }
214    }
215    /// Get size (bytes) of the TCP option.
216    pub fn size(&self) -> usize {
217        match *self {
218            TcpOptionKind::EOL => 1,
219            TcpOptionKind::NOP => 1,
220            TcpOptionKind::MSS => 4,
221            TcpOptionKind::WSCALE => 3,
222            TcpOptionKind::SACK_PERMITTED => 2,
223            TcpOptionKind::SACK => 10,
224            TcpOptionKind::ECHO => 6,
225            TcpOptionKind::ECHO_REPLY => 6,
226            TcpOptionKind::TIMESTAMPS => 10,
227            TcpOptionKind::POCP => 2,
228            TcpOptionKind::POSP => 3,
229            TcpOptionKind::ALT_CHECKSUM_REQ => 3,
230            TcpOptionKind::ALT_CHECKSUM_DATA => 12,
231            TcpOptionKind::TRAILER_CHECKSUM => 3,
232            TcpOptionKind::MD5_SIGNATURE => 18,
233            TcpOptionKind::QUICK_START => 8,
234            TcpOptionKind::USER_TIMEOUT => 4,
235            _ => 0,
236        }
237    }
238}
239
240/// Represents the TCP Flags
241/// <https://www.iana.org/assignments/tcp-parameters/tcp-parameters.xhtml#tcp-header-flags>
242#[allow(non_snake_case)]
243#[allow(non_upper_case_globals)]
244pub mod TcpFlags {
245    /// CWR - Congestion Window Reduced (CWR) flag is set by the sending
246    /// host to indicate that it received a TCP segment with the ECE flag set
247    /// and had responded in congestion control mechanism.
248    pub const CWR: u8 = 0b10000000;
249    /// ECE - ECN-Echo has a dual role, depending on the value of the
250    /// SYN flag. It indicates:
251    /// If the SYN flag is set (1), that the TCP peer is ECN capable.
252    /// If the SYN flag is clear (0), that a packet with Congestion Experienced
253    /// flag set (ECN=11) in IP header received during normal transmission.
254    pub const ECE: u8 = 0b01000000;
255    /// URG - indicates that the Urgent pointer field is significant.
256    pub const URG: u8 = 0b00100000;
257    /// ACK - indicates that the Acknowledgment field is significant.
258    /// All packets after the initial SYN packet sent by the client should have this flag set.
259    pub const ACK: u8 = 0b00010000;
260    /// PSH - Push function. Asks to push the buffered data to the receiving application.
261    pub const PSH: u8 = 0b00001000;
262    /// RST - Reset the connection.
263    pub const RST: u8 = 0b00000100;
264    /// SYN - Synchronize sequence numbers. Only the first packet sent from each end
265    /// should have this flag set.
266    pub const SYN: u8 = 0b00000010;
267    /// FIN - No more data from sender.
268    pub const FIN: u8 = 0b00000001;
269}
270
271/// Represents the TCP option header.
272#[derive(Clone, Debug, PartialEq, Eq)]
273#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
274pub struct TcpOptionHeader {
275    pub kind: TcpOptionKind,
276    pub length: Option<u8>,
277    pub data: Bytes,
278}
279
280impl TcpOptionHeader {
281    /// Get the timestamp of the TCP option
282    pub fn get_timestamp(&self) -> (u32, u32) {
283        if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 {
284            let mut my: [u8; 4] = [0; 4];
285            my.copy_from_slice(&self.data[0..4]);
286            let mut their: [u8; 4] = [0; 4];
287            their.copy_from_slice(&self.data[4..8]);
288            (u32::from_be_bytes(my), u32::from_be_bytes(their))
289        } else {
290            return (0, 0);
291        }
292    }
293    /// Get the MSS of the TCP option
294    pub fn get_mss(&self) -> u16 {
295        if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 {
296            let mut mss: [u8; 2] = [0; 2];
297            mss.copy_from_slice(&self.data[0..2]);
298            u16::from_be_bytes(mss)
299        } else {
300            0
301        }
302    }
303    /// Get the WSCALE of the TCP option
304    pub fn get_wscale(&self) -> u8 {
305        if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 {
306            self.data[0]
307        } else {
308            0
309        }
310    }
311}
312
313/// A TCP option.
314#[derive(Clone, Debug, PartialEq, Eq)]
315#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
316pub struct TcpOptionPacket {
317    kind: TcpOptionKind,
318    length: Option<u8>,
319    data: Bytes,
320}
321
322impl TcpOptionPacket {
323    /// NOP: This may be used to align option fields on 32-bit boundaries for better performance.
324    pub fn nop() -> Self {
325        TcpOptionPacket {
326            kind: TcpOptionKind::NOP,
327            length: None,
328            data: Bytes::new(),
329        }
330    }
331
332    /// Timestamp: TCP timestamps, defined in RFC 1323, can help TCP determine in which order
333    /// packets were sent. TCP timestamps are not normally aligned to the system clock and
334    /// start at some random value.
335    pub fn timestamp(my: u32, their: u32) -> Self {
336        let mut data = BytesMut::new();
337        data.extend_from_slice(&my.octets()[..]);
338        data.extend_from_slice(&their.octets()[..]);
339
340        TcpOptionPacket {
341            kind: TcpOptionKind::TIMESTAMPS,
342            length: Some(10),
343            data: data.freeze(),
344        }
345    }
346
347    /// MSS: The maximum segment size (MSS) is the largest amount of data, specified in bytes,
348    /// that TCP is willing to receive in a single segment.
349    pub fn mss(val: u16) -> Self {
350        let mut data = BytesMut::new();
351        data.extend_from_slice(&val.octets()[..]);
352
353        TcpOptionPacket {
354            kind: TcpOptionKind::MSS,
355            length: Some(4),
356            data: data.freeze(),
357        }
358    }
359
360    /// Window scale: The TCP window scale option, as defined in RFC 1323, is an option used to
361    /// increase the maximum window size from 65,535 bytes to 1 gigabyte.
362    pub fn wscale(val: u8) -> Self {
363        TcpOptionPacket {
364            kind: TcpOptionKind::WSCALE,
365            length: Some(3),
366            data: Bytes::from(vec![val]),
367        }
368    }
369
370    /// Selective acknowledgment (SACK) option, defined in RFC 2018 allows the receiver to acknowledge
371    /// discontinuous blocks of packets which were received correctly. This options enables use of
372    /// SACK during negotiation.
373    pub fn sack_perm() -> Self {
374        TcpOptionPacket {
375            kind: TcpOptionKind::SACK_PERMITTED,
376            length: Some(2),
377            data: Bytes::new(),
378        }
379    }
380
381    /// Selective acknowledgment (SACK) option, defined in RFC 2018 allows the receiver to acknowledge
382    /// discontinuous blocks of packets which were received correctly. The acknowledgement can specify
383    /// a number of SACK blocks, where each SACK block is conveyed by the starting and ending sequence
384    /// numbers of a contiguous range that the receiver correctly received.
385    pub fn selective_ack(acks: &[u32]) -> Self {
386        let mut data = BytesMut::new();
387        for ack in acks {
388            data.extend_from_slice(&ack.octets()[..]);
389        }
390        TcpOptionPacket {
391            kind: TcpOptionKind::SACK,
392            length: Some(1 /* number */ + 1 /* length */ + data.len() as u8),
393            data: data.freeze(),
394        }
395    }
396    /// Get the TCP option kind.
397    pub fn kind(&self) -> TcpOptionKind {
398        self.kind
399    }
400    /// Get length of the TCP option.
401    pub fn length(&self) -> u8 {
402        if let Some(len) = self.length {
403            len
404        } else {
405            // If length is None, it means the option has no length (like NOP).
406            0
407        }
408    }
409    /// Get the timestamp of the TCP option
410    pub fn get_timestamp(&self) -> (u32, u32) {
411        if self.kind == TcpOptionKind::TIMESTAMPS && self.data.len() >= 8 {
412            let mut my: [u8; 4] = [0; 4];
413            my.copy_from_slice(&self.data[0..4]);
414            let mut their: [u8; 4] = [0; 4];
415            their.copy_from_slice(&self.data[4..8]);
416            (u32::from_be_bytes(my), u32::from_be_bytes(their))
417        } else {
418            return (0, 0);
419        }
420    }
421    /// Get the MSS of the TCP option
422    pub fn get_mss(&self) -> u16 {
423        if self.kind == TcpOptionKind::MSS && self.data.len() >= 2 {
424            let mut mss: [u8; 2] = [0; 2];
425            mss.copy_from_slice(&self.data[0..2]);
426            u16::from_be_bytes(mss)
427        } else {
428            0
429        }
430    }
431    /// Get the WSCALE of the TCP option
432    pub fn get_wscale(&self) -> u8 {
433        if self.kind == TcpOptionKind::WSCALE && self.data.len() > 0 {
434            self.data[0]
435        } else {
436            0
437        }
438    }
439}
440
441/// Represents the TCP header.
442#[derive(Clone, Debug, PartialEq, Eq)]
443#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
444pub struct TcpHeader {
445    pub source: u16be,
446    pub destination: u16be,
447    pub sequence: u32be,
448    pub acknowledgement: u32be,
449    pub data_offset: u4,
450    pub reserved: u4,
451    pub flags: u8,
452    pub window: u16be,
453    pub checksum: u16be,
454    pub urgent_ptr: u16be,
455    pub options: Vec<TcpOptionPacket>,
456}
457
458/// Represents a TCP packet.
459#[derive(Clone, Debug, PartialEq, Eq)]
460pub struct TcpPacket {
461    pub header: TcpHeader,
462    pub payload: Bytes,
463}
464
465impl Packet for TcpPacket {
466    type Header = TcpHeader;
467
468    fn from_buf(mut bytes: &[u8]) -> Option<Self> {
469        if bytes.len() < TCP_HEADER_LEN {
470            return None;
471        }
472
473        let source = bytes.get_u16();
474        let destination = bytes.get_u16();
475        let sequence = bytes.get_u32();
476        let acknowledgement = bytes.get_u32();
477
478        let offset_reserved = bytes.get_u8();
479        let data_offset = offset_reserved >> 4;
480        let reserved = offset_reserved & 0x0F;
481
482        let flags = bytes.get_u8();
483        let window = bytes.get_u16();
484        let checksum = bytes.get_u16();
485        let urgent_ptr = bytes.get_u16();
486
487        let header_len = data_offset as usize * 4;
488        if header_len < TCP_HEADER_LEN || bytes.len() + 20 < header_len {
489            return None;
490        }
491
492        let mut options = Vec::new();
493        let options_len = header_len - TCP_HEADER_LEN;
494        let (mut options_bytes, rest) = bytes.split_at(options_len);
495        bytes = rest;
496
497        while options_bytes.has_remaining() {
498            let kind = TcpOptionKind::new(options_bytes.get_u8());
499            match kind {
500                TcpOptionKind::EOL => {
501                    options.push(TcpOptionPacket {
502                        kind,
503                        length: None,
504                        data: Bytes::new(),
505                    });
506                    break;
507                }
508                TcpOptionKind::NOP => {
509                    options.push(TcpOptionPacket {
510                        kind,
511                        length: None,
512                        data: Bytes::new(),
513                    });
514                }
515                _ => {
516                    if options_bytes.remaining() < 1 {
517                        return None;
518                    }
519                    let len = options_bytes.get_u8();
520                    if len < 2 || (len as usize) > options_bytes.remaining() + 2 {
521                        return None;
522                    }
523                    let data_len = (len - 2) as usize;
524                    let (data_slice, rest) = options_bytes.split_at(data_len);
525                    options_bytes = rest;
526                    options.push(TcpOptionPacket {
527                        kind,
528                        length: Some(len),
529                        data: Bytes::copy_from_slice(data_slice),
530                    });
531                }
532            }
533        }
534
535        Some(TcpPacket {
536            header: TcpHeader {
537                source,
538                destination,
539                sequence,
540                acknowledgement,
541                data_offset: u4::from_be(data_offset),
542                reserved: u4::from_be(reserved),
543                flags,
544                window,
545                checksum,
546                urgent_ptr,
547                options,
548            },
549            payload: Bytes::copy_from_slice(bytes),
550        })
551    }
552    fn from_bytes(mut bytes: Bytes) -> Option<Self> {
553        Self::from_buf(&mut bytes)
554    }
555
556    fn to_bytes(&self) -> Bytes {
557        // Calculate the actual encoded length of TCP options
558        let mut enc_opt_len = 0usize;
559        for opt in &self.header.options {
560            match opt.kind {
561                TcpOptionKind::EOL | TcpOptionKind::NOP => enc_opt_len += 1,
562                _ => {
563                    // Total length including kind + length fields
564                    let len = opt.length.unwrap_or(2) as usize;
565                    enc_opt_len += len;
566                }
567            }
568        }
569        // Round up to the nearest 4-byte boundary
570        let padded_opt_len = (enc_opt_len + 3) & !3;
571        let header_len = TCP_HEADER_LEN + padded_opt_len;
572        // In 32-bit words
573        let data_offset_words = (header_len / 4) as u8;
574
575        // Write the TCP header
576        let mut bytes = BytesMut::with_capacity(header_len + self.payload.len());
577        bytes.put_u16(self.header.source);
578        bytes.put_u16(self.header.destination);
579        bytes.put_u32(self.header.sequence);
580        bytes.put_u32(self.header.acknowledgement);
581
582        let offset_reserved = (data_offset_words << 4) | (self.header.reserved.to_be() & 0x0F);
583        bytes.put_u8(offset_reserved);
584
585        bytes.put_u8(self.header.flags);
586        bytes.put_u16(self.header.window);
587        bytes.put_u16(self.header.checksum);
588        bytes.put_u16(self.header.urgent_ptr);
589
590        // Encode the options
591        let before_opts = bytes.len();
592        for opt in &self.header.options {
593            bytes.put_u8(opt.kind.value());
594            if let Some(length) = opt.length {
595                bytes.put_u8(length);
596                bytes.extend_from_slice(&opt.data);
597            }
598        }
599        // Add option padding (zero-filled) to reach the padded length
600        let written_opt = bytes.len() - before_opts;
601        let pad = padded_opt_len.saturating_sub(written_opt);
602        for _ in 0..pad {
603            bytes.put_u8(0);
604        }
605
606        // Append payload
607        bytes.extend_from_slice(&self.payload);
608
609        bytes.freeze()
610    }
611
612    fn header(&self) -> Bytes {
613        self.to_bytes().slice(..self.header_len())
614    }
615
616    fn payload(&self) -> Bytes {
617        self.payload.clone()
618    }
619
620    fn header_len(&self) -> usize {
621        let base = TCP_HEADER_LEN;
622        let mut opt_len = 0;
623
624        for opt in &self.header.options {
625            match opt.kind {
626                TcpOptionKind::EOL | TcpOptionKind::NOP => {
627                    opt_len += 1; // EOL and NOP are one byte
628                }
629                _ => {
630                    // kind(1B) + length(1B) + payload
631                    if let Some(len) = opt.length {
632                        opt_len += len as usize;
633                    } else {
634                        // Ensure at least 2 bytes (kind + length)
635                        opt_len += 2;
636                    }
637                }
638            }
639        }
640
641        let total = base + opt_len;
642        // The TCP header is always rounded to a 4 byte boundary
643        (total + 3) & !0x03
644    }
645
646    fn payload_len(&self) -> usize {
647        self.payload.len()
648    }
649
650    fn total_len(&self) -> usize {
651        self.header_len() + self.payload_len()
652    }
653
654    fn into_parts(self) -> (Self::Header, Bytes) {
655        (self.header, self.payload)
656    }
657}
658
659impl TcpPacket {
660    pub fn tcp_options_length(&self) -> usize {
661        if self.header.data_offset > 5 {
662            self.header.data_offset as usize * 4 - 20
663        } else {
664            0
665        }
666    }
667}
668
669/// Represents a mutable TCP packet.
670pub struct MutableTcpPacket<'a> {
671    buffer: &'a mut [u8],
672    checksum: ChecksumState,
673    checksum_context: Option<TransportChecksumContext>,
674}
675
676impl<'a> MutablePacket<'a> for MutableTcpPacket<'a> {
677    type Packet = TcpPacket;
678
679    fn new(buffer: &'a mut [u8]) -> Option<Self> {
680        if buffer.len() < TCP_HEADER_LEN {
681            return None;
682        }
683
684        let data_offset = buffer[12] >> 4;
685        if data_offset < TCP_MIN_DATA_OFFSET {
686            return None;
687        }
688
689        let header_len = (data_offset as usize) * 4;
690        if header_len > buffer.len() {
691            return None;
692        }
693
694        Some(Self {
695            buffer,
696            checksum: ChecksumState::new(),
697            checksum_context: None,
698        })
699    }
700
701    fn packet(&self) -> &[u8] {
702        &*self.buffer
703    }
704
705    fn packet_mut(&mut self) -> &mut [u8] {
706        &mut *self.buffer
707    }
708
709    fn header(&self) -> &[u8] {
710        let len = self.header_len();
711        &self.packet()[..len]
712    }
713
714    fn header_mut(&mut self) -> &mut [u8] {
715        let len = self.header_len();
716        let (header, _) = (&mut *self.buffer).split_at_mut(len);
717        header
718    }
719
720    fn payload(&self) -> &[u8] {
721        let len = self.header_len();
722        &self.packet()[len..]
723    }
724
725    fn payload_mut(&mut self) -> &mut [u8] {
726        let len = self.header_len();
727        let (_, payload) = (&mut *self.buffer).split_at_mut(len);
728        payload
729    }
730}
731
732impl<'a> MutableTcpPacket<'a> {
733    /// Create a packet without validating the header fields.
734    pub fn new_unchecked(buffer: &'a mut [u8]) -> Self {
735        Self {
736            buffer,
737            checksum: ChecksumState::new(),
738            checksum_context: None,
739        }
740    }
741
742    fn raw(&self) -> &[u8] {
743        &*self.buffer
744    }
745
746    fn raw_mut(&mut self) -> &mut [u8] {
747        &mut *self.buffer
748    }
749
750    fn after_field_mutation(&mut self) {
751        self.checksum.mark_dirty();
752        if self.checksum.automatic() {
753            let _ = self.recompute_checksum();
754        }
755    }
756
757    fn write_checksum(&mut self, value: u16) {
758        self.raw_mut()[16..18].copy_from_slice(&value.to_be_bytes());
759    }
760
761    /// Returns the checksum recalculation mode for the packet.
762    pub fn checksum_mode(&self) -> ChecksumMode {
763        self.checksum.mode()
764    }
765
766    /// Updates how checksum recalculation should be handled.
767    pub fn set_checksum_mode(&mut self, mode: ChecksumMode) {
768        self.checksum.set_mode(mode);
769        if self.checksum.automatic() && self.checksum.is_dirty() {
770            let _ = self.recompute_checksum();
771        }
772    }
773
774    /// Enables automatic checksum recomputation after field mutations.
775    pub fn enable_auto_checksum(&mut self) {
776        self.set_checksum_mode(ChecksumMode::Automatic);
777    }
778
779    /// Disables automatic checksum recomputation.
780    pub fn disable_auto_checksum(&mut self) {
781        self.set_checksum_mode(ChecksumMode::Manual);
782    }
783
784    /// Returns true if the checksum needs to be updated before serialization.
785    pub fn is_checksum_dirty(&self) -> bool {
786        self.checksum.is_dirty()
787    }
788
789    /// Marks the checksum as dirty and recomputes it when automatic mode is enabled.
790    pub fn mark_checksum_dirty(&mut self) {
791        self.checksum.mark_dirty();
792        if self.checksum.automatic() {
793            let _ = self.recompute_checksum();
794        }
795    }
796
797    /// Configures the pseudo-header context required for checksum calculation.
798    pub fn set_checksum_context(&mut self, context: TransportChecksumContext) {
799        self.checksum_context = Some(context);
800        if self.checksum.automatic() && self.checksum.is_dirty() {
801            let _ = self.recompute_checksum();
802        }
803    }
804
805    /// Sets an IPv4 pseudo-header context for checksum calculation.
806    pub fn set_ipv4_checksum_context(&mut self, source: Ipv4Addr, destination: Ipv4Addr) {
807        self.set_checksum_context(TransportChecksumContext::ipv4(source, destination));
808    }
809
810    /// Sets an IPv6 pseudo-header context for checksum calculation.
811    pub fn set_ipv6_checksum_context(&mut self, source: Ipv6Addr, destination: Ipv6Addr) {
812        self.set_checksum_context(TransportChecksumContext::ipv6(source, destination));
813    }
814
815    /// Clears the configured pseudo-header context.
816    pub fn clear_checksum_context(&mut self) {
817        self.checksum_context = None;
818    }
819
820    /// Returns the currently configured pseudo-header context.
821    pub fn checksum_context(&self) -> Option<TransportChecksumContext> {
822        self.checksum_context
823    }
824
825    /// Recomputes the checksum using the configured pseudo-header context.
826    pub fn recompute_checksum(&mut self) -> Option<u16> {
827        let context = self.checksum_context?;
828
829        let checksum = match context {
830            TransportChecksumContext::Ipv4 {
831                source,
832                destination,
833            } => util::ipv4_checksum(
834                self.raw(),
835                8,
836                &[],
837                &source,
838                &destination,
839                IpNextProtocol::Tcp,
840            ) as u16,
841            TransportChecksumContext::Ipv6 {
842                source,
843                destination,
844            } => util::ipv6_checksum(
845                self.raw(),
846                8,
847                &[],
848                &source,
849                &destination,
850                IpNextProtocol::Tcp,
851            ) as u16,
852        };
853
854        self.write_checksum(checksum);
855        self.checksum.clear_dirty();
856        Some(checksum)
857    }
858
859    /// Returns the header length in bytes.
860    pub fn header_len(&self) -> usize {
861        let offset = (self.raw()[12] >> 4).max(TCP_MIN_DATA_OFFSET);
862        let len = (offset as usize) * 4;
863        len.min(self.raw().len())
864    }
865
866    /// Returns the payload length of the packet.
867    pub fn payload_len(&self) -> usize {
868        self.raw().len().saturating_sub(self.header_len())
869    }
870
871    pub fn get_source(&self) -> u16 {
872        u16::from_be_bytes([self.raw()[0], self.raw()[1]])
873    }
874
875    pub fn set_source(&mut self, value: u16) {
876        self.raw_mut()[0..2].copy_from_slice(&value.to_be_bytes());
877        self.after_field_mutation();
878    }
879
880    pub fn get_destination(&self) -> u16 {
881        u16::from_be_bytes([self.raw()[2], self.raw()[3]])
882    }
883
884    pub fn set_destination(&mut self, value: u16) {
885        self.raw_mut()[2..4].copy_from_slice(&value.to_be_bytes());
886        self.after_field_mutation();
887    }
888
889    pub fn get_sequence(&self) -> u32 {
890        u32::from_be_bytes([self.raw()[4], self.raw()[5], self.raw()[6], self.raw()[7]])
891    }
892
893    pub fn set_sequence(&mut self, value: u32) {
894        self.raw_mut()[4..8].copy_from_slice(&value.to_be_bytes());
895        self.after_field_mutation();
896    }
897
898    pub fn get_acknowledgement(&self) -> u32 {
899        u32::from_be_bytes([self.raw()[8], self.raw()[9], self.raw()[10], self.raw()[11]])
900    }
901
902    pub fn set_acknowledgement(&mut self, value: u32) {
903        self.raw_mut()[8..12].copy_from_slice(&value.to_be_bytes());
904        self.after_field_mutation();
905    }
906
907    pub fn get_data_offset(&self) -> u8 {
908        self.raw()[12] >> 4
909    }
910
911    pub fn set_data_offset(&mut self, offset: u8) {
912        let buf = self.raw_mut();
913        buf[12] = (buf[12] & 0x0F) | ((offset & 0x0F) << 4);
914        self.after_field_mutation();
915    }
916
917    pub fn get_reserved(&self) -> u8 {
918        self.raw()[12] & 0x0F
919    }
920
921    pub fn set_reserved(&mut self, value: u8) {
922        let buf = self.raw_mut();
923        buf[12] = (buf[12] & 0xF0) | (value & 0x0F);
924        self.after_field_mutation();
925    }
926
927    pub fn get_flags(&self) -> u8 {
928        self.raw()[13]
929    }
930
931    pub fn set_flags(&mut self, flags: u8) {
932        self.raw_mut()[13] = flags;
933        self.after_field_mutation();
934    }
935
936    pub fn get_window(&self) -> u16 {
937        u16::from_be_bytes([self.raw()[14], self.raw()[15]])
938    }
939
940    pub fn set_window(&mut self, value: u16) {
941        self.raw_mut()[14..16].copy_from_slice(&value.to_be_bytes());
942        self.after_field_mutation();
943    }
944
945    pub fn get_checksum(&self) -> u16 {
946        u16::from_be_bytes([self.raw()[16], self.raw()[17]])
947    }
948
949    pub fn set_checksum(&mut self, value: u16) {
950        self.write_checksum(value);
951        self.checksum.clear_dirty();
952    }
953
954    pub fn get_urgent_ptr(&self) -> u16 {
955        u16::from_be_bytes([self.raw()[18], self.raw()[19]])
956    }
957
958    pub fn set_urgent_ptr(&mut self, value: u16) {
959        self.raw_mut()[18..20].copy_from_slice(&value.to_be_bytes());
960        self.after_field_mutation();
961    }
962
963    pub fn options(&self) -> &[u8] {
964        let len = self.header_len();
965        &self.raw()[TCP_HEADER_LEN..len]
966    }
967
968    pub fn options_mut(&mut self) -> &mut [u8] {
969        let len = self.header_len();
970        &mut self.raw_mut()[TCP_HEADER_LEN..len]
971    }
972}
973
974pub fn checksum(packet: &TcpPacket, source: &IpAddr, destination: &IpAddr) -> u16 {
975    match (source, destination) {
976        (IpAddr::V4(src), IpAddr::V4(dst)) => ipv4_checksum(packet, src, dst),
977        (IpAddr::V6(src), IpAddr::V6(dst)) => ipv6_checksum(packet, src, dst),
978        _ => 0, // Unsupported IP version
979    }
980}
981
982/// Calculate a checksum for a packet built on IPv4.
983pub fn ipv4_checksum(packet: &TcpPacket, source: &Ipv4Addr, destination: &Ipv4Addr) -> u16 {
984    ipv4_checksum_adv(packet, &[], source, destination)
985}
986
987/// Calculate the checksum for a packet built on IPv4, Advanced version which
988/// accepts an extra slice of data that will be included in the checksum
989/// as being part of the data portion of the packet.
990///
991/// If `packet` contains an odd number of bytes the last byte will not be
992/// counted as the first byte of a word together with the first byte of
993/// `extra_data`.
994pub fn ipv4_checksum_adv(
995    packet: &TcpPacket,
996    extra_data: &[u8],
997    source: &Ipv4Addr,
998    destination: &Ipv4Addr,
999) -> u16 {
1000    util::ipv4_checksum(
1001        &packet.to_bytes(),
1002        8,
1003        extra_data,
1004        source,
1005        destination,
1006        IpNextProtocol::Tcp,
1007    )
1008}
1009
1010/// Calculate a checksum for a packet built on IPv6.
1011pub fn ipv6_checksum(packet: &TcpPacket, source: &Ipv6Addr, destination: &Ipv6Addr) -> u16 {
1012    ipv6_checksum_adv(packet, &[], source, destination)
1013}
1014
1015/// Calculate the checksum for a packet built on IPv6, Advanced version which
1016/// accepts an extra slice of data that will be included in the checksum
1017/// as being part of the data portion of the packet.
1018///
1019/// If `packet` contains an odd number of bytes the last byte will not be
1020/// counted as the first byte of a word together with the first byte of
1021/// `extra_data`.
1022pub fn ipv6_checksum_adv(
1023    packet: &TcpPacket,
1024    extra_data: &[u8],
1025    source: &Ipv6Addr,
1026    destination: &Ipv6Addr,
1027) -> u16 {
1028    util::ipv6_checksum(
1029        &packet.to_bytes(),
1030        8,
1031        extra_data,
1032        source,
1033        destination,
1034        IpNextProtocol::Tcp,
1035    )
1036}
1037
1038#[cfg(test)]
1039mod tests {
1040    use super::*;
1041
1042    #[test]
1043    fn test_basic_tcp_parse() {
1044        let ref_packet = Bytes::from_static(&[
1045            0xc1, 0x67, /* source */
1046            0x23, 0x28, /* destination */
1047            0x90, 0x37, 0xd2, 0xb8, /* seq */
1048            0x94, 0x4b, 0xb2, 0x76, /* ack */
1049            0x80, 0x18, 0x0f, 0xaf, /* offset+reserved, flags, win */
1050            0xc0, 0x31, /* checksum */
1051            0x00, 0x00, /* urg ptr */
1052            0x01, 0x01, /* NOP */
1053            0x08, 0x0a, 0x2c, 0x57, 0xcd, 0xa5, 0x02, 0xa0, 0x41, 0x92, /* timestamp */
1054            0x74, 0x65, 0x73, 0x74, /* payload: "test" */
1055        ]);
1056        let packet = TcpPacket::from_bytes(ref_packet.clone()).unwrap();
1057
1058        assert_eq!(packet.header.source, 0xc167);
1059        assert_eq!(packet.header.destination, 0x2328);
1060        assert_eq!(packet.header.sequence, 0x9037d2b8);
1061        assert_eq!(packet.header.acknowledgement, 0x944bb276);
1062        assert_eq!(packet.header.data_offset, 8); // adjusted
1063        assert_eq!(packet.header.reserved, 0);
1064        assert_eq!(packet.header.flags, 0x18); // PSH + ACK
1065        assert_eq!(packet.header.window, 0x0faf);
1066        assert_eq!(packet.header.checksum, 0xc031);
1067        assert_eq!(packet.header.urgent_ptr, 0x0000);
1068        assert_eq!(packet.header.options.len(), 3);
1069        assert_eq!(packet.header.options[0].kind, TcpOptionKind::NOP);
1070        assert_eq!(packet.header.options[1].kind, TcpOptionKind::NOP);
1071        assert_eq!(packet.header.options[2].kind, TcpOptionKind::TIMESTAMPS);
1072        assert_eq!(
1073            packet.header.options[2].get_timestamp(),
1074            (0x2c57cda5, 0x02a04192)
1075        );
1076        assert_eq!(packet.payload, Bytes::from_static(b"test"));
1077        assert_eq!(packet.header_len(), 32); // adjusted
1078        assert_eq!(packet.to_bytes(), ref_packet);
1079        assert_eq!(packet.header().len(), 32); // adjusted
1080        assert_eq!(packet.payload().len(), 4);
1081    }
1082
1083    #[test]
1084    fn test_basic_tcp_create() {
1085        let options = vec![
1086            TcpOptionPacket::nop(),
1087            TcpOptionPacket::nop(),
1088            TcpOptionPacket::timestamp(0x2c57cda5, 0x02a04192),
1089        ];
1090
1091        let packet = TcpPacket {
1092            header: TcpHeader {
1093                source: 0xc167,
1094                destination: 0x2328,
1095                sequence: 0x9037d2b8,
1096                acknowledgement: 0x944bb276,
1097                data_offset: 8.into(), // 8 * 4 = 32 bytes
1098                reserved: 0.into(),
1099                flags: 0x18, // PSH + ACK
1100                window: 0x0faf,
1101                checksum: 0xc031,
1102                urgent_ptr: 0x0000,
1103                options: options.clone(),
1104            },
1105            payload: Bytes::from_static(b"test"),
1106        };
1107
1108        let bytes = packet.to_bytes();
1109        let parsed = TcpPacket::from_bytes(bytes.clone()).expect("Failed to parse TCP packet");
1110
1111        assert_eq!(parsed, packet);
1112        assert_eq!(parsed.to_bytes(), bytes);
1113        assert_eq!(parsed.header.options.len(), 3);
1114        assert_eq!(
1115            parsed.header.options[2].get_timestamp(),
1116            (0x2c57cda5, 0x02a04192)
1117        );
1118    }
1119
1120    #[test]
1121    fn test_mutable_tcp_packet_round_trip() {
1122        let mut raw = [
1123            0x00, 0x50, // source
1124            0x01, 0xbb, // destination
1125            0x00, 0x00, 0x00, 0x01, // seq
1126            0x00, 0x00, 0x00, 0x00, // ack
1127            0x50, // data offset/reserved
1128            0x18, // flags
1129            0x40, 0x00, // window
1130            0x12, 0x34, // checksum
1131            0x00, 0x00, // urgent pointer
1132            b'h', b'e', b'l', b'l', b'o',
1133        ];
1134
1135        let mut packet = MutableTcpPacket::new(&mut raw).expect("mutable tcp");
1136        assert_eq!(packet.get_source(), 80);
1137        packet.set_source(1234);
1138        packet.set_destination(4321);
1139        packet.set_sequence(0xfeedbeef);
1140        packet.set_flags(0x11);
1141        packet.payload_mut()[0] = b'H';
1142
1143        let frozen = packet.freeze().expect("freeze");
1144        assert_eq!(frozen.header.source, 1234);
1145        assert_eq!(frozen.header.destination, 4321);
1146        assert_eq!(frozen.header.sequence, 0xfeedbeef);
1147        assert_eq!(frozen.header.flags, 0x11);
1148        assert_eq!(frozen.payload[0], b'H');
1149    }
1150
1151    #[test]
1152    fn test_tcp_auto_checksum_with_context() {
1153        let mut raw = [
1154            0x00, 0x50, 0x01, 0xbb, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x50, 0x18,
1155            0x40, 0x00, 0x00, 0x00, 0x00, 0x00, b'h', b'e', b'l', b'l', b'o',
1156        ];
1157
1158        let mut packet = MutableTcpPacket::new(&mut raw).expect("mutable tcp");
1159        let src = Ipv4Addr::new(192, 0, 2, 1);
1160        let dst = Ipv4Addr::new(198, 51, 100, 2);
1161        packet.set_ipv4_checksum_context(src, dst);
1162        packet.enable_auto_checksum();
1163
1164        let baseline = packet.recompute_checksum().expect("checksum");
1165        assert_eq!(baseline, packet.get_checksum());
1166
1167        packet.set_window(0x2000);
1168        let updated = packet.get_checksum();
1169        assert_ne!(baseline, updated);
1170        assert!(!packet.is_checksum_dirty());
1171
1172        let frozen = packet.freeze().expect("freeze");
1173        let expected = ipv4_checksum(&frozen, &src, &dst);
1174        assert_eq!(updated, expected as u16);
1175    }
1176
1177    #[test]
1178    fn test_tcp_manual_checksum_tracking() {
1179        let mut raw = [
1180            0x12, 0x34, 0xab, 0xcd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02,
1181            0x10, 0x00, 0x00, 0x00, 0x00, 0x00,
1182        ];
1183
1184        let mut packet = MutableTcpPacket::new(&mut raw).expect("mutable tcp");
1185        let src = Ipv6Addr::LOCALHOST;
1186        let dst = Ipv6Addr::LOCALHOST;
1187        packet.set_ipv6_checksum_context(src, dst);
1188
1189        packet.set_flags(0x12);
1190        assert!(packet.is_checksum_dirty());
1191
1192        let recomputed = packet.recompute_checksum().expect("checksum");
1193        assert_eq!(recomputed, packet.get_checksum());
1194        assert!(!packet.is_checksum_dirty());
1195    }
1196}