use crate::packets::ip::v4::Ipv4;
use crate::packets::ip::v6::Ipv6;
use crate::packets::ip::{Flow, IpPacket, ProtocolNumbers};
use crate::packets::types::{u16be, u32be};
use crate::packets::{checksum, Internal, Packet};
use crate::{ensure, SizeOf};
use anyhow::{anyhow, Result};
use std::fmt;
use std::net::IpAddr;
use std::ptr::NonNull;
const CWR: u8 = 0b1000_0000;
const ECE: u8 = 0b0100_0000;
const URG: u8 = 0b0010_0000;
const ACK: u8 = 0b0001_0000;
const PSH: u8 = 0b0000_1000;
const RST: u8 = 0b0000_0100;
const SYN: u8 = 0b0000_0010;
const FIN: u8 = 0b0000_0001;
pub struct Tcp<E: IpPacket> {
envelope: E,
header: NonNull<TcpHeader>,
offset: usize,
}
impl<E: IpPacket> Tcp<E> {
#[inline]
fn header(&self) -> &TcpHeader {
unsafe { self.header.as_ref() }
}
#[inline]
fn header_mut(&mut self) -> &mut TcpHeader {
unsafe { self.header.as_mut() }
}
#[inline]
pub fn src_port(&self) -> u16 {
self.header().src_port.into()
}
#[inline]
pub fn set_src_port(&mut self, src_port: u16) {
self.header_mut().src_port = src_port.into();
}
#[inline]
pub fn dst_port(&self) -> u16 {
self.header().dst_port.into()
}
#[inline]
pub fn set_dst_port(&mut self, dst_port: u16) {
self.header_mut().dst_port = dst_port.into();
}
#[inline]
pub fn seq_no(&self) -> u32 {
self.header().seq_no.into()
}
#[inline]
pub fn set_seq_no(&mut self, seq_no: u32) {
self.header_mut().seq_no = seq_no.into();
}
#[inline]
pub fn ack_no(&self) -> u32 {
self.header().ack_no.into()
}
#[inline]
pub fn set_ack_no(&mut self, ack_no: u32) {
self.header_mut().ack_no = ack_no.into();
}
#[inline]
pub fn data_offset(&self) -> u8 {
(self.header().offset_to_ns & 0xf0) >> 4
}
#[allow(dead_code)]
#[inline]
fn set_data_offset(&mut self, data_offset: u8) {
self.header_mut().offset_to_ns = (self.header().offset_to_ns & 0x0f) | (data_offset << 4);
}
#[inline]
pub fn ns(&self) -> bool {
(self.header().offset_to_ns & 0x01) != 0
}
#[inline]
pub fn set_ns(&mut self) {
self.header_mut().offset_to_ns |= 0x01;
}
#[inline]
pub fn unset_ns(&mut self) {
self.header_mut().offset_to_ns &= !0x1;
}
#[inline]
pub fn cwr(&self) -> bool {
(self.header().flags & CWR) != 0
}
#[inline]
pub fn set_cwr(&mut self) {
self.header_mut().flags |= CWR;
}
#[inline]
pub fn unset_cwr(&mut self) {
self.header_mut().flags &= !CWR;
}
#[inline]
pub fn ece(&self) -> bool {
(self.header().flags & ECE) != 0
}
#[inline]
pub fn set_ece(&mut self) {
self.header_mut().flags |= ECE;
}
#[inline]
pub fn unset_ece(&mut self) {
self.header_mut().flags &= !ECE;
}
#[inline]
pub fn urg(&self) -> bool {
(self.header().flags & URG) != 0
}
#[inline]
pub fn set_urg(&mut self) {
self.header_mut().flags |= URG;
}
#[inline]
pub fn unset_urg(&mut self) {
self.header_mut().flags &= !URG;
}
#[inline]
pub fn ack(&self) -> bool {
(self.header().flags & ACK) != 0
}
#[inline]
pub fn set_ack(&mut self) {
self.header_mut().flags |= ACK;
}
#[inline]
pub fn unset_ack(&mut self) {
self.header_mut().flags &= !ACK;
}
#[inline]
pub fn psh(&self) -> bool {
(self.header().flags & PSH) != 0
}
#[inline]
pub fn set_psh(&mut self) {
self.header_mut().flags |= PSH;
}
#[inline]
pub fn unset_psh(&mut self) {
self.header_mut().flags &= !PSH;
}
#[inline]
pub fn rst(&self) -> bool {
(self.header().flags & RST) != 0
}
#[inline]
pub fn set_rst(&mut self) {
self.header_mut().flags |= RST;
}
#[inline]
pub fn unset_rst(&mut self) {
self.header_mut().flags &= !RST;
}
#[inline]
pub fn syn(&self) -> bool {
(self.header().flags & SYN) != 0
}
#[inline]
pub fn set_syn(&mut self) {
self.header_mut().flags |= SYN;
}
#[inline]
pub fn unset_syn(&mut self) {
self.header_mut().flags &= !SYN;
}
#[inline]
pub fn syn_ack(&self) -> bool {
self.syn() && self.ack()
}
#[inline]
pub fn fin(&self) -> bool {
(self.header().flags & FIN) != 0
}
#[inline]
pub fn set_fin(&mut self) {
self.header_mut().flags |= FIN;
}
#[inline]
pub fn unset_fin(&mut self) {
self.header_mut().flags &= !FIN;
}
#[inline]
pub fn window(&self) -> u16 {
self.header().window.into()
}
#[inline]
pub fn set_window(&mut self, window: u16) {
self.header_mut().window = window.into();
}
#[inline]
pub fn checksum(&self) -> u16 {
self.header().checksum.into()
}
#[inline]
fn set_checksum(&mut self, checksum: u16) {
self.header_mut().checksum = checksum.into();
}
#[inline]
pub fn urgent_pointer(&self) -> u16 {
self.header().urgent_pointer.into()
}
#[inline]
pub fn set_urgent_pointer(&mut self, urgent_pointer: u16) {
self.header_mut().urgent_pointer = urgent_pointer.into();
}
#[inline]
pub fn flow(&self) -> Flow {
Flow::new(
self.envelope().src(),
self.envelope().dst(),
self.src_port(),
self.dst_port(),
ProtocolNumbers::Tcp,
)
}
#[inline]
pub fn set_src_ip(&mut self, src_ip: IpAddr) -> Result<()> {
let old_ip = self.envelope().src();
let checksum = checksum::compute_with_ipaddr(self.checksum(), &old_ip, &src_ip)?;
self.envelope_mut().set_src(src_ip)?;
self.set_checksum(checksum);
Ok(())
}
#[inline]
pub fn set_dst_ip(&mut self, dst_ip: IpAddr) -> Result<()> {
let old_ip = self.envelope().dst();
let checksum = checksum::compute_with_ipaddr(self.checksum(), &old_ip, &dst_ip)?;
self.envelope_mut().set_dst(dst_ip)?;
self.set_checksum(checksum);
Ok(())
}
#[inline]
fn compute_checksum(&mut self) {
self.set_checksum(0);
if let Ok(data) = self.mbuf().read_data_slice(self.offset, self.len()) {
let data = unsafe { data.as_ref() };
let pseudo_header_sum = self
.envelope()
.pseudo_header(data.len() as u16, ProtocolNumbers::Tcp)
.sum();
let checksum = checksum::compute(pseudo_header_sum, data);
self.set_checksum(checksum);
} else {
unreachable!()
}
}
}
impl<E: IpPacket> fmt::Debug for Tcp<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("tcp")
.field("src_port", &self.src_port())
.field("dst_port", &self.dst_port())
.field("seq_no", &self.seq_no())
.field("ack_no", &self.ack_no())
.field("data_offset", &self.data_offset())
.field("window", &self.window())
.field("checksum", &format!("0x{:04x}", self.checksum()))
.field("urgent pointer", &self.urgent_pointer())
.field("ns", &self.ns())
.field("cwr", &self.cwr())
.field("ece", &self.ece())
.field("urg", &self.urg())
.field("ack", &self.ack())
.field("psh", &self.psh())
.field("rst", &self.rst())
.field("syn", &self.syn())
.field("fin", &self.fin())
.field("$offset", &self.offset())
.field("$len", &self.len())
.field("$header_len", &self.header_len())
.finish()
}
}
impl<E: IpPacket> Packet for Tcp<E> {
type Envelope = E;
#[inline]
fn envelope(&self) -> &Self::Envelope {
&self.envelope
}
#[inline]
fn envelope_mut(&mut self) -> &mut Self::Envelope {
&mut self.envelope
}
#[inline]
fn offset(&self) -> usize {
self.offset
}
#[inline]
fn header_len(&self) -> usize {
TcpHeader::size_of()
}
#[inline]
unsafe fn clone(&self, internal: Internal) -> Self {
Tcp::<E> {
envelope: self.envelope.clone(internal),
header: self.header,
offset: self.offset,
}
}
#[inline]
fn try_parse(envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
ensure!(
envelope.next_protocol() == ProtocolNumbers::Tcp,
anyhow!("not a TCP packet.")
);
let mbuf = envelope.mbuf();
let offset = envelope.payload_offset();
let header = mbuf.read_data(offset)?;
Ok(Tcp {
envelope,
header,
offset,
})
}
#[inline]
fn try_push(mut envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
let offset = envelope.payload_offset();
let mbuf = envelope.mbuf_mut();
mbuf.extend(offset, TcpHeader::size_of())?;
let header = mbuf.write_data(offset, &TcpHeader::default())?;
envelope.set_next_protocol(ProtocolNumbers::Tcp);
Ok(Tcp {
envelope,
header,
offset,
})
}
#[inline]
fn deparse(self) -> Self::Envelope {
self.envelope
}
#[inline]
fn reconcile(&mut self) {
self.compute_checksum();
}
}
pub type Tcp4 = Tcp<Ipv4>;
pub type Tcp6 = Tcp<Ipv6>;
#[derive(Clone, Copy, Debug, SizeOf)]
#[repr(C, packed)]
struct TcpHeader {
src_port: u16be,
dst_port: u16be,
seq_no: u32be,
ack_no: u32be,
offset_to_ns: u8,
flags: u8,
window: u16be,
checksum: u16be,
urgent_pointer: u16be,
}
impl Default for TcpHeader {
fn default() -> TcpHeader {
TcpHeader {
src_port: u16be::default(),
dst_port: u16be::default(),
seq_no: u32be::default(),
ack_no: u32be::default(),
offset_to_ns: 5 << 4,
flags: 0,
window: u16be::default(),
checksum: u16be::default(),
urgent_pointer: u16be::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packets::ip::v6::SegmentRouting;
use crate::packets::Ethernet;
use crate::testils::byte_arrays::{IPV4_TCP_PACKET, IPV4_UDP_PACKET, SR_TCP_PACKET};
use crate::Mbuf;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn size_of_tcp_header() {
assert_eq!(20, TcpHeader::size_of());
}
#[capsule::test]
fn parse_tcp_packet() {
let packet = Mbuf::from_bytes(&IPV4_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let tcp = ipv4.parse::<Tcp4>().unwrap();
assert_eq!(36869, tcp.src_port());
assert_eq!(23, tcp.dst_port());
assert_eq!(1_913_975_060, tcp.seq_no());
assert_eq!(0, tcp.ack_no());
assert_eq!(6, tcp.data_offset());
assert_eq!(8760, tcp.window());
assert_eq!(0xa92c, tcp.checksum());
assert_eq!(0, tcp.urgent_pointer());
assert!(!tcp.ns());
assert!(!tcp.cwr());
assert!(!tcp.ece());
assert!(!tcp.urg());
assert!(!tcp.ack());
assert!(!tcp.psh());
assert!(!tcp.rst());
assert!(tcp.syn());
assert!(!tcp.fin());
}
#[capsule::test]
fn parse_non_tcp_packet() {
let packet = Mbuf::from_bytes(&IPV4_UDP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
assert!(ipv4.parse::<Tcp4>().is_err());
}
#[capsule::test]
fn tcp_flow_v4() {
let packet = Mbuf::from_bytes(&IPV4_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let tcp = ipv4.parse::<Tcp4>().unwrap();
let flow = tcp.flow();
assert_eq!("139.133.217.110", flow.src_ip().to_string());
assert_eq!("139.133.233.2", flow.dst_ip().to_string());
assert_eq!(36869, flow.src_port());
assert_eq!(23, flow.dst_port());
assert_eq!(ProtocolNumbers::Tcp, flow.protocol());
}
#[capsule::test]
fn tcp_flow_v6() {
let packet = Mbuf::from_bytes(&SR_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let srh = ipv6.parse::<SegmentRouting<Ipv6>>().unwrap();
let tcp = srh.parse::<Tcp<SegmentRouting<Ipv6>>>().unwrap();
let flow = tcp.flow();
assert_eq!("2001:db8:85a3::1", flow.src_ip().to_string());
assert_eq!("2001:db8:85a3::8a2e:370:7333", flow.dst_ip().to_string());
assert_eq!(3464, flow.src_port());
assert_eq!(1024, flow.dst_port());
assert_eq!(ProtocolNumbers::Tcp, flow.protocol());
}
#[capsule::test]
fn set_src_dst_ip() {
let packet = Mbuf::from_bytes(&IPV4_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let mut tcp = ipv4.parse::<Tcp4>().unwrap();
let old_checksum = tcp.checksum();
let new_ip = Ipv4Addr::new(10, 0, 0, 0);
assert!(tcp.set_src_ip(new_ip.into()).is_ok());
assert!(tcp.checksum() != old_checksum);
assert_eq!(new_ip.to_string(), tcp.envelope().src().to_string());
let old_checksum = tcp.checksum();
let new_ip = Ipv4Addr::new(20, 0, 0, 0);
assert!(tcp.set_dst_ip(new_ip.into()).is_ok());
assert!(tcp.checksum() != old_checksum);
assert_eq!(new_ip.to_string(), tcp.envelope().dst().to_string());
assert!(tcp.set_src_ip(Ipv6Addr::UNSPECIFIED.into()).is_err());
}
#[capsule::test]
fn compute_checksum() {
let packet = Mbuf::from_bytes(&IPV4_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let mut tcp = ipv4.parse::<Tcp4>().unwrap();
let expected = tcp.checksum();
tcp.reconcile_all();
assert_eq!(expected, tcp.checksum());
}
#[capsule::test]
fn push_tcp_packet() {
let packet = Mbuf::new().unwrap();
let ethernet = packet.push::<Ethernet>().unwrap();
let ipv4 = ethernet.push::<Ipv4>().unwrap();
let tcp = ipv4.push::<Tcp4>().unwrap();
assert_eq!(TcpHeader::size_of(), tcp.len());
assert_eq!(5, tcp.data_offset());
assert_eq!(ProtocolNumbers::Tcp, tcp.envelope().next_protocol());
}
}