const TCP_HEADER_MIN_LEN: usize = 20;
const TCP_OPT_MSS: u8 = 2;
const TCP_OPT_MSS_LEN: u8 = 4;
const TCP_FLAGS_OFFSET: usize = 13;
const TCP_FLAG_SYN: u8 = 0x02;
fn is_tcp_syn(tcp_header: &[u8]) -> bool {
if tcp_header.len() < TCP_HEADER_MIN_LEN {
return false;
}
(tcp_header[TCP_FLAGS_OFFSET] & TCP_FLAG_SYN) != 0
}
fn get_tcp_data_offset(tcp_header: &[u8]) -> usize {
if tcp_header.len() < TCP_HEADER_MIN_LEN {
return 0;
}
((tcp_header[12] >> 4) as usize) * 4
}
pub fn clamp_tcp_mss(ipv6_packet: &mut [u8], max_mss: u16) -> bool {
if ipv6_packet.len() < 40 || ipv6_packet[0] >> 4 != 6 {
return false;
}
let next_header = ipv6_packet[6];
if next_header != 6 {
return false;
}
let tcp_start = 40;
if ipv6_packet.len() < tcp_start + TCP_HEADER_MIN_LEN {
return false;
}
let tcp_header = &ipv6_packet[tcp_start..];
if !is_tcp_syn(tcp_header) {
return false;
}
let tcp_header_len = get_tcp_data_offset(tcp_header);
if tcp_header_len < TCP_HEADER_MIN_LEN || tcp_header_len > tcp_header.len() {
return false;
}
let options_start = tcp_start + TCP_HEADER_MIN_LEN;
let options_end = tcp_start + tcp_header_len;
if options_end > ipv6_packet.len() {
return false;
}
let mut modified = false;
let mut i = options_start;
while i < options_end {
let kind = ipv6_packet[i];
if kind == 0 {
break;
}
if kind == 1 {
i += 1;
continue;
}
if i + 1 >= options_end {
break;
}
let length = ipv6_packet[i + 1] as usize;
if length < 2 || i + length > options_end {
break;
}
if kind == TCP_OPT_MSS && length == TCP_OPT_MSS_LEN as usize {
let current_mss = u16::from_be_bytes([ipv6_packet[i + 2], ipv6_packet[i + 3]]);
if current_mss > max_mss {
ipv6_packet[i + 2..i + 4].copy_from_slice(&max_mss.to_be_bytes());
recalculate_tcp_checksum(ipv6_packet, tcp_start);
modified = true;
}
break; }
i += length;
}
modified
}
fn recalculate_tcp_checksum(ipv6_packet: &mut [u8], tcp_start: usize) {
ipv6_packet[tcp_start + 16] = 0;
ipv6_packet[tcp_start + 17] = 0;
let src = &ipv6_packet[8..24];
let dst = &ipv6_packet[24..40];
let payload_len = u16::from_be_bytes([ipv6_packet[4], ipv6_packet[5]]) as usize;
let tcp_segment = &ipv6_packet[tcp_start..tcp_start + payload_len];
let mut sum: u32 = 0;
for chunk in src.chunks(2) {
sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
}
for chunk in dst.chunks(2) {
sum += u16::from_be_bytes([chunk[0], chunk[1]]) as u32;
}
sum += payload_len as u32;
sum += 6;
for chunk in tcp_segment.chunks(2) {
let value = if chunk.len() == 2 {
u16::from_be_bytes([chunk[0], chunk[1]])
} else {
u16::from_be_bytes([chunk[0], 0])
};
sum += value as u32;
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
let checksum = !sum as u16;
ipv6_packet[tcp_start + 16..tcp_start + 18].copy_from_slice(&checksum.to_be_bytes());
}
#[cfg(test)]
mod tests {
use super::*;
fn make_tcp_syn_packet(src: [u8; 16], dst: [u8; 16], mss: u16) -> Vec<u8> {
let mut packet = vec![0u8; 40 + 40];
packet[0] = 0x60; packet[4..6].copy_from_slice(&40u16.to_be_bytes()); packet[6] = 6; packet[7] = 64; packet[8..24].copy_from_slice(&src);
packet[24..40].copy_from_slice(&dst);
let tcp_start = 40;
packet[tcp_start..tcp_start + 2].copy_from_slice(&12345u16.to_be_bytes()); packet[tcp_start + 2..tcp_start + 4].copy_from_slice(&80u16.to_be_bytes()); packet[tcp_start + 4..tcp_start + 8].copy_from_slice(&1000u32.to_be_bytes()); packet[tcp_start + 8..tcp_start + 12].copy_from_slice(&0u32.to_be_bytes()); packet[tcp_start + 12] = 0xa0; packet[tcp_start + 13] = TCP_FLAG_SYN; packet[tcp_start + 14..tcp_start + 16].copy_from_slice(&8192u16.to_be_bytes());
packet[tcp_start + 20] = TCP_OPT_MSS; packet[tcp_start + 21] = TCP_OPT_MSS_LEN; packet[tcp_start + 22..tcp_start + 24].copy_from_slice(&mss.to_be_bytes());
packet[tcp_start + 24] = 0;
recalculate_tcp_checksum(&mut packet, tcp_start);
packet
}
#[test]
fn test_clamp_tcp_mss_reduces_large_mss() {
let src = [0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
let dst = [0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2];
let mut packet = make_tcp_syn_packet(src, dst, 1460);
let modified = clamp_tcp_mss(&mut packet, 1200);
assert!(modified);
let tcp_start = 40;
let mss = u16::from_be_bytes([packet[tcp_start + 22], packet[tcp_start + 23]]);
assert_eq!(mss, 1200);
}
#[test]
fn test_clamp_tcp_mss_leaves_small_mss_unchanged() {
let src = [0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
let dst = [0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2];
let mut packet = make_tcp_syn_packet(src, dst, 1000);
let modified = clamp_tcp_mss(&mut packet, 1200);
assert!(!modified);
let tcp_start = 40;
let mss = u16::from_be_bytes([packet[tcp_start + 22], packet[tcp_start + 23]]);
assert_eq!(mss, 1000);
}
#[test]
fn test_clamp_tcp_mss_ignores_non_syn() {
let src = [0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
let dst = [0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2];
let mut packet = make_tcp_syn_packet(src, dst, 1460);
packet[40 + 13] = 0x10;
let modified = clamp_tcp_mss(&mut packet, 1200);
assert!(!modified);
}
#[test]
fn test_clamp_tcp_mss_ignores_non_tcp() {
let mut packet = vec![0u8; 80];
packet[0] = 0x60; packet[6] = 17;
let modified = clamp_tcp_mss(&mut packet, 1200);
assert!(!modified);
}
}