1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
use crate::AsBeBytes;
use super::{Header, PacketData, Protocol, ParseError};

struct PseudoHeader {
    src_ip: [u8; 4],
    dst_ip: [u8; 4],
    data_len: u16,
}

#[derive(AddGetter, AddSetter)]
pub struct TcpHeader {
    #[get]
    #[set]
    src_port: u16,
    #[get]
    #[set]
    dst_port: u16,
    #[get]
    flags: u8,
    #[set]
    window: u16,
    pseudo_header: Option<PseudoHeader>,
}

pub enum TcpFlags {
    Urg,
    Ack,
    Psh,
    Rst,
    Syn,
    Fin,
}

impl TcpHeader {
    pub fn new(src_port: u16, dst_port: u16) -> Self {
        TcpHeader {
            src_port: src_port,
            dst_port: dst_port,
            window: 0xffff,
            flags: 0,
            pseudo_header: None,
        }
    }

    pub fn set_flag(&mut self, f: TcpFlags) {
        match f {
            TcpFlags::Urg => self.flags = self.flags | 0b00100000,
            TcpFlags::Ack => self.flags = self.flags | 0b00010000,
            TcpFlags::Psh => self.flags = self.flags | 0b00001000,
            TcpFlags::Rst => self.flags = self.flags | 0b00000100,
            TcpFlags::Syn => self.flags = self.flags | 0b00000010,
            TcpFlags::Fin => self.flags = self.flags | 0b00000001,
        }
    }

    pub fn set_pseudo_header(&mut self, src_ip: [u8; 4], dst_ip: [u8; 4], packet_data: &[u8]) {
        let len = packet_data.len();
        if len > (0xffff - 20) as usize {
            panic!("too much data");
        }
        self.pseudo_header = Some(PseudoHeader {
            src_ip,
            dst_ip,
            data_len: 20 + (len as u16),
        });
    }
}

impl Header for TcpHeader {
    fn make(self) -> PacketData {
        let src_p = self.src_port.split_to_bytes();
        let dst_p = self.dst_port.split_to_bytes();
        let window_bytes = self.window.split_to_bytes();
        let mut packet = vec![
            src_p[0],
            src_p[1],
            dst_p[0],
            dst_p[1],
            0,
            0,
            0,
            0, // Seq num
            0,
            0,
            0,
            0, // Ack num
            0, // Offset + 4 of the reserved bits, the other 2 of the 6 total reserved bits are included at the start of the `flags` byte
            self.flags,
            window_bytes[0],
            window_bytes[1],
            0,
            0,
            0,
            0, // Urgent Pointer -> Should do this at some point
        ];

        // calculate checksum
        if let None = self.pseudo_header {
            panic!("Please set the pseudo header data before calculating the checksum");
        }
        let mut val = 0u32;
        val += ip_sum(self.pseudo_header.as_ref().unwrap().src_ip);
        val += ip_sum(self.pseudo_header.as_ref().unwrap().dst_ip);
        val += 6; // this covers the reserved byte, plus the protocol field, which we set to 6 since that is the value for TCP
        val += 20; // header length (in bytes) : when there are no options+padding present, the header length is 20 bytes
        val += self.pseudo_header.as_ref().unwrap().data_len as u32;
        let checksum = finalize_checksum(val).split_to_bytes();

        packet[16] = checksum[0];
        packet[17] = checksum[1];
        packet
    }

    fn parse(raw_data: &[u8]) -> Result<Box<Self>, ParseError> {
        if raw_data.len() < Self::get_min_length().into() {
            return Err(ParseError::InvalidLength);
        }
        Ok(Box::new(Self {
            src_port: ((raw_data[0] as u16) << 8) + raw_data[1] as u16,
            dst_port: ((raw_data[2] as u16) << 8) + raw_data[3] as u16,
            flags: raw_data[13],
            window: ((raw_data[14] as u16) << 8) + raw_data[15] as u16,
            pseudo_header: None,
        }))
    }

    fn get_proto(&self) -> Protocol {
        Protocol::TCP
    }

    fn get_length(&self) -> u8 {
        20
    }

    fn get_min_length() -> u8 {
        20
    }
}

#[inline(always)]
fn ip_sum(octets: [u8; 4]) -> u32 {
    ((octets[0] as u32) << 8 | octets[1] as u32) + ((octets[2] as u32) << 8 | octets[3] as u32)
}

#[inline]
fn finalize_checksum(mut cs: u32) -> u16 {
    while cs >> 16 != 0 {
        cs = (cs >> 16) + (cs & 0xFFFF);
    }
    !cs as u16
}