#![cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
use std::arch::asm;
macro_rules! impl_block {
($x86:block, $aarch64:block) => {
unsafe {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
$x86
}
#[cfg(target_arch = "aarch64")]
{
$aarch64
}
}
};
}
#[inline]
pub fn fold_checksum(mut csum: u32) -> u16 {
csum = (csum & 0xffff) + (csum >> 16);
csum = (csum & 0xffff) + (csum >> 16);
!csum as u16
}
#[inline]
pub fn to_u16(mut csum: u32) -> u16 {
csum = csum.overflowing_add(csum.rotate_left(16)).0;
(csum >> 16) as u16
}
#[inline]
pub fn add(mut a: u32, b: u32) -> u32 {
impl_block!(
{
asm!(
"addl {b:e}, {a:e}",
"adcl $0, {a:e}",
a = inout(reg) a,
b = in(reg) b,
options(att_syntax)
);
},
{
asm!(
"adds {a:w}, {a:w}, {b:w}",
"adc {a:w}, {a:w}, {zero:w}",
a = inout(reg) a,
b = in(reg) b,
zero = in(reg) 0,
);
}
);
a
}
#[inline]
pub fn sub(a: u32, b: u32) -> u32 {
add(a, !b)
}
#[inline]
pub fn diff(from: &[u8], to: &[u8], seed: u32) -> u16 {
let ret = if !from.is_empty() && !to.is_empty() {
let mut a = 0;
let mut b = 0;
std::thread::scope(|s| {
s.spawn(|| a = partial(to, seed));
s.spawn(|| b = partial(from, 0));
});
sub(a, b)
} else if !to.is_empty() {
partial(to, seed)
} else if !from.is_empty() {
!partial(from, !seed)
} else {
seed
};
to_u16(ret)
}
#[inline]
fn finalize(sum: u64) -> u32 {
(sum.overflowing_add(sum.rotate_right(32)).0 >> 32) as u32
}
pub fn partial(mut buf: &[u8], sum: u32) -> u32 {
#[inline]
fn update_40(mut sum: u64, bytes: &[u8]) -> u64 {
debug_assert_eq!(bytes.len(), 40);
impl_block!(
{
asm!(
"addq 0*8({buf}), {sum}",
"adcq 1*8({buf}), {sum}",
"adcq 2*8({buf}), {sum}",
"adcq 3*8({buf}), {sum}",
"adcq 4*8({buf}), {sum}",
"adcq $0, {sum}",
buf = in(reg) bytes.as_ptr(),
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"ldr {tmp},[{buf},#0]",
"adds {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#8]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#16]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#24]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#32]",
"adcs {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
buf = in(reg) bytes.as_ptr(),
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
sum
}
let mut sum = sum as u64;
if buf.len() >= 80 {
let mut sum2 = 0;
while buf.len() >= 80 {
sum = update_40(sum, &buf[..40]);
sum2 = update_40(sum2, &buf[40..80]);
buf = &buf[80..];
}
impl_block!(
{
asm!(
"addq {0}, {sum}",
"adcq $0, {sum}",
in(reg) sum2,
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"adds {sum}, {sum}, {sum2}",
"adc {sum}, {sum}, {zero}",
sum2 = in(reg) sum2,
sum = inout(reg) sum,
zero = in(reg) 0u64,
);
}
);
}
if buf.len() >= 40 {
sum = update_40(sum, &buf[..40]);
buf = &buf[40..];
if buf.is_empty() {
return finalize(sum);
}
}
let len = buf.len();
if len & 32 != 0 {
impl_block!(
{
asm!(
"addq 0*8({buf}), {sum}",
"adcq 1*8({buf}), {sum}",
"adcq 2*8({buf}), {sum}",
"adcq 3*8({buf}), {sum}",
"adcq $0, {sum}",
buf = in(reg) buf.as_ptr(),
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"ldr {tmp},[{buf},#0]",
"adds {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#8]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#16]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#24]",
"adcs {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
buf = in(reg) buf.as_ptr(),
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
buf = &buf[32..];
}
if len & 16 != 0 {
impl_block!(
{
asm!(
"addq 0*8({buf}), {sum}",
"adcq 1*8({buf}), {sum}",
"adcq $0, {sum}",
buf = in(reg) buf.as_ptr(),
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"ldr {tmp},[{buf},#0]",
"adds {sum}, {sum}, {tmp}",
"ldr {tmp},[{buf},#8]",
"adcs {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
buf = in(reg) buf.as_ptr(),
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
buf = &buf[16..];
}
if len & 8 != 0 {
impl_block!(
{
asm!(
"addq 0*8({buf}), {sum}",
"adcq $0, {sum}",
buf = in(reg) buf.as_ptr(),
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"ldr {tmp},[{buf},#0]",
"adds {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
buf = in(reg) buf.as_ptr(),
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
buf = &buf[8..];
}
if len & 7 != 0 {
let shift = ((-(len as i64) << 3) & 63) as u32;
impl_block!(
{
let trail = {
let mut ual: u64;
asm!(
"movq 0*8({buf}), {ual}",
buf = in(reg) buf.as_ptr(),
ual = out(reg) ual,
options(att_syntax)
);
(ual << shift) >> shift
};
asm!(
"addq {trail}, {sum}",
"adcq $0, {sum}",
trail = in(reg) trail,
sum = inout(reg) sum,
options(att_syntax)
);
},
{
let trail = {
let mut ual: u64;
asm!(
"ldr {ual},[{buf},#0]",
buf = in(reg) buf.as_ptr(),
ual = out(reg) ual,
);
(ual << shift) >> shift
};
asm!(
"adds {sum}, {sum}, {trail}",
"adc {sum}, {sum}, {zero}",
trail = in(reg) trail,
sum = inout(reg) sum,
zero = in(reg) 0u64,
);
}
);
}
finalize(sum)
}
#[derive(Copy, Clone)]
pub struct DataChecksum {
checksum: u32,
length: usize,
}
impl DataChecksum {
#[inline]
pub fn calculate(data: &[u8]) -> Self {
assert!(
data.len() <= 4096,
"the specified slice is too large to fit in a Packet"
);
Self {
checksum: partial(data, 0),
length: data.len(),
}
}
pub fn calculate_if_needed(data: &[u8], packet: &super::Packet) -> Self {
assert!(
data.len() <= 4096,
"the specified slice is too large to fit in a Packet"
);
Self {
checksum: if packet.can_offload_checksum() {
0
} else {
partial(data, 0)
},
length: data.len(),
}
}
}
use crate::packet::net_types as nt;
#[derive(Debug)]
pub enum UdpCalcError {
NotIp(nt::EtherType::Enum),
NotUdp(nt::IpProto::Enum),
Packet(super::PacketError),
}
use std::fmt;
impl fmt::Display for UdpCalcError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::NotIp(et) => {
write!(f, "not an IP packet, but a {et:?}")
}
Self::NotUdp(proto) => {
write!(f, "not a UDP packet, but a {proto:?}")
}
Self::Packet(fe) => {
write!(f, "failed to parse packet: {fe}")
}
}
}
}
impl std::error::Error for UdpCalcError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Packet(fe) => Some(fe),
_ => None,
}
}
}
impl From<super::PacketError> for UdpCalcError {
#[inline]
fn from(value: super::PacketError) -> Self {
Self::Packet(value)
}
}
impl super::Packet {
pub fn calc_udp_checksum(&mut self) -> Result<u16, UdpCalcError> {
use crate::packet::Pod as _;
use nt::*;
let mut offset = 0;
let eth = self.read::<EthHdr>(offset)?;
offset += EthHdr::LEN;
let (pseudo_seed, mut udp_hdr) = match eth.ether_type {
EtherType::Ipv4 => {
let ipv4 = self.read::<Ipv4Hdr>(offset)?;
debug_assert_eq!(
ipv4.internet_header_length(),
Ipv4Hdr::LEN as u8,
"ipv4 options are not supported"
);
offset += Ipv4Hdr::LEN;
if ipv4.proto != IpProto::Udp {
return Err(UdpCalcError::NotUdp(ipv4.proto));
}
let udp_hdr = self.read::<UdpHdr>(offset)?;
let mut sum = 0;
impl_block!(
{
asm!(
"addl {saddr:e}, {sum:e}",
"adcl {daddr:e}, {sum:e}",
"adcl {pseudo:e}, {sum:e}",
"adcl $0, {sum:e}",
saddr = in(reg) ipv4.source.0,
daddr = in(reg) ipv4.destination.0,
pseudo = in(reg) (udp_hdr.length.host() as u32 + IpProto::Udp as u32) << 8,
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"adds {sum:w}, {sum:w}, {saddr:w}",
"adcs {sum:w}, {sum:w}, {daddr:w}",
"adcs {sum:w}, {sum:w}, {pseudo:w}",
"adc {sum:w}, {sum:w}, {zero:w}",
saddr = in(reg) ipv4.source.0,
daddr = in(reg) ipv4.destination.0,
pseudo = in(reg) (udp_hdr.length.host() as u32 + IpProto::Udp as u32) << 8,
sum = inout(reg) sum,
zero = in(reg) 0u32,
);
}
);
(sum, udp_hdr)
}
EtherType::Ipv6 => {
let ipv6 = self.read::<Ipv6Hdr>(offset)?;
offset += Ipv6Hdr::LEN;
if ipv6.next_header != IpProto::Udp {
return Err(UdpCalcError::NotUdp(ipv6.next_header));
}
let udp_hdr = self.read::<UdpHdr>(offset)?;
let mut sum = ((udp_hdr.length.host() as u32).to_be() as u64)
.wrapping_add((IpProto::Udp as u64).to_be());
impl_block!(
{
asm!(
"addq 0*8({saddr}), {sum}",
"adcq 1*8({saddr}), {sum}",
"adcq 0*8({daddr}), {sum}",
"adcq 1*8({daddr}), {sum}",
"adcq $0, {sum}",
saddr = in(reg) ipv6.source.as_ptr(),
daddr = in(reg) ipv6.destination.as_ptr(),
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"ldr {tmp},[{saddr},#0]",
"adds {sum}, {sum}, {tmp}",
"ldr {tmp},[{saddr},#8]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{daddr},#0]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{daddr},#8]",
"adcs {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
saddr = in(reg) ipv6.source.as_ptr(),
daddr = in(reg) ipv6.destination.as_ptr(),
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
(finalize(sum), udp_hdr)
}
invalid => return Err(UdpCalcError::NotIp(invalid)),
};
let checksum = if self.can_offload_checksum() {
let csum = fold_checksum(pseudo_seed);
udp_hdr.check = !csum;
self.write(offset, udp_hdr)?;
self.set_tx_metadata(
crate::packet::CsumOffload::Request {
start: offset as u16,
offset: std::mem::offset_of!(UdpHdr, check) as u16,
},
false,
)?;
csum
} else {
udp_hdr.check = 0;
let sum = partial(udp_hdr.as_bytes(), pseudo_seed);
let data_offset = offset + nt::UdpHdr::LEN;
let data_payload = &self[data_offset..self.len()];
let mut csum = fold_checksum(partial(data_payload, sum));
if csum == 0 {
csum = 0xffff;
}
udp_hdr.check = csum;
self.write(offset, udp_hdr)?;
csum
};
Ok(checksum)
}
}
impl nt::UdpHeaders {
#[inline]
pub fn calc_checksum(&mut self, data_checksum: DataChecksum) -> u16 {
self.data.end = self.data.start + data_checksum.length;
let mut sum = data_checksum.checksum as u64;
let data_len = data_checksum.length + nt::UdpHdr::LEN;
match &self.ip {
nt::IpHdr::V4(v4) => {
impl_block!(
{
asm!(
"addq {pseudo_udp}, {sum}",
"adcq {saddr}, {sum}",
"adcq {daddr}, {sum}",
"adcq 0*8({udp}), {sum}",
"adcq $0, {sum}",
pseudo_udp = in(reg) ((data_len + nt::IpProto::Udp as usize) as u64).to_be(),
saddr = in(reg) (v4.source.host() as u64).to_be(),
daddr = in(reg) (v4.destination.host() as u64).to_be(),
udp = in(reg) &nt::UdpHdr {
source: self.udp.source,
destination: self.udp.destination,
length: (data_len as u16).into(),
check: 0,
},
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"adds {sum}, {sum}, {pseudo_udp}",
"adcs {sum}, {sum}, {saddr}",
"adcs {sum}, {sum}, {daddr}",
"ldr {tmp},[{udp},#0]",
"adcs {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
pseudo_udp = in(reg) ((data_len + nt::IpProto::Udp as usize) as u64).to_be(),
saddr = in(reg) (v4.source.host() as u64).to_be(),
daddr = in(reg) (v4.destination.host() as u64).to_be(),
udp = in(reg) &nt::UdpHdr {
source: self.udp.source,
destination: self.udp.destination,
length: (data_len as u16).into(),
check: 0,
},
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
}
nt::IpHdr::V6(v6) => {
let source = v6.source;
let destination = v6.destination;
impl_block!(
{
asm!(
"addq {pseudo_udp}, {sum}",
"adcq 0*8({saddr}), {sum}",
"adcq 1*8({saddr}), {sum}",
"adcq 0*8({daddr}), {sum}",
"adcq 1*8({daddr}), {sum}",
"adcq 0*8({udp}), {sum}",
"adcq $0, {sum}",
pseudo_udp = in(reg) ((data_len + nt::IpProto::Udp as usize) as u64).to_be(),
saddr = in(reg) source.as_ptr(),
daddr = in(reg) destination.as_ptr(),
udp = in(reg) &nt::UdpHdr {
source: self.udp.source,
destination: self.udp.destination,
length: (data_len as u16).into(),
check: 0,
},
sum = inout(reg) sum,
options(att_syntax)
);
},
{
asm!(
"adds {sum}, {sum}, {pseudo_udp}",
"ldr {tmp},[{saddr},#0]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{saddr},#8]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{daddr},#0]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{daddr},#8]",
"adcs {sum}, {sum}, {tmp}",
"ldr {tmp},[{udp},#0]",
"adcs {sum}, {sum}, {tmp}",
"adc {sum}, {sum}, {zero}",
pseudo_udp = in(reg) ((data_len + nt::IpProto::Udp as usize) as u64).to_be(),
saddr = in(reg) source.as_ptr(),
daddr = in(reg) destination.as_ptr(),
udp = in(reg) &nt::UdpHdr {
source: self.udp.source,
destination: self.udp.destination,
length: (data_len as u16).into(),
check: 0,
},
sum = inout(reg) sum,
tmp = out(reg) _,
zero = in(reg) 0u64,
);
}
);
}
}
self.udp.check = fold_checksum(finalize(sum));
if self.udp.check == 0 {
self.udp.check = 0xffff;
}
self.udp.check
}
}