use std::net::{Ipv4Addr, Ipv6Addr};
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
#[repr(u8)]
pub enum Protocol {
#[default]
Unknown = 0,
Icmp = 1,
Tcp = 6,
Udp = 17,
Icmpv6 = 58,
}
impl From<u8> for Protocol {
fn from(value: u8) -> Self {
match value {
1 => Self::Icmp,
6 => Self::Tcp,
17 => Self::Udp,
58 => Self::Icmpv6,
_ => Self::Unknown,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u8)]
pub enum IpVersion {
#[default]
Unknown = 0,
V4 = 4,
V6 = 6,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct PacketMetadata {
pub l2_offset: u16,
pub l3_offset: u16,
pub l4_offset: u16,
pub protocol: Protocol,
pub ip_version: IpVersion,
pub flow_hash: u64,
pub src_port: u16,
pub dst_port: u16,
pub flags: u8,
_padding: [u8; 3],
}
impl PacketMetadata {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self {
l2_offset: 0,
l3_offset: 0,
l4_offset: 0,
protocol: Protocol::Unknown,
ip_version: IpVersion::Unknown,
flow_hash: 0,
src_port: 0,
dst_port: 0,
flags: 0,
_padding: [0; 3],
}
}
#[inline]
#[must_use]
pub const fn is_tcp(&self) -> bool {
matches!(self.protocol, Protocol::Tcp)
}
#[inline]
#[must_use]
pub const fn is_udp(&self) -> bool {
matches!(self.protocol, Protocol::Udp)
}
#[inline]
#[must_use]
pub const fn is_icmp(&self) -> bool {
matches!(self.protocol, Protocol::Icmp | Protocol::Icmpv6)
}
}
#[repr(C, align(64))]
pub struct ZeroCopyPacket {
data: *const u8,
len: u32,
metadata: PacketMetadata,
refcount: AtomicU32,
desc_idx: u16,
flags: u16,
timestamp: u64,
}
unsafe impl Send for ZeroCopyPacket {}
unsafe impl Sync for ZeroCopyPacket {}
impl Default for ZeroCopyPacket {
fn default() -> Self {
Self::empty()
}
}
impl ZeroCopyPacket {
pub const FLAG_NEEDS_CSUM: u16 = 1 << 0;
pub const FLAG_GSO: u16 = 1 << 1;
pub const FLAG_FROM_GUEST: u16 = 1 << 2;
pub const FLAG_TO_GUEST: u16 = 1 << 3;
#[inline]
#[must_use]
pub const fn empty() -> Self {
Self {
data: std::ptr::null(),
len: 0,
metadata: PacketMetadata::new(),
refcount: AtomicU32::new(0),
desc_idx: 0,
flags: 0,
timestamp: 0,
}
}
#[inline]
#[must_use]
pub const unsafe fn from_raw_parts(data: *const u8, len: u32, desc_idx: u16) -> Self {
Self {
data,
len,
metadata: PacketMetadata::new(),
refcount: AtomicU32::new(1),
desc_idx,
flags: 0,
timestamp: 0,
}
}
#[inline]
#[must_use]
pub const unsafe fn from_slice(data: &[u8], desc_idx: u16) -> Self {
Self {
data: data.as_ptr(),
len: data.len() as u32,
metadata: PacketMetadata::new(),
refcount: AtomicU32::new(1),
desc_idx,
flags: 0,
timestamp: 0,
}
}
#[inline]
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0 || self.data.is_null()
}
#[inline]
#[must_use]
pub const fn len(&self) -> usize {
self.len as usize
}
#[inline]
#[must_use]
pub unsafe fn as_slice(&self) -> &[u8] {
if self.data.is_null() {
&[]
} else {
unsafe { std::slice::from_raw_parts(self.data, self.len as usize) }
}
}
#[inline]
#[must_use]
pub const fn data_ptr(&self) -> *const u8 {
self.data
}
#[inline]
#[must_use]
pub const fn desc_idx(&self) -> u16 {
self.desc_idx
}
#[inline]
#[must_use]
pub const fn metadata(&self) -> &PacketMetadata {
&self.metadata
}
#[inline]
#[must_use]
pub fn metadata_mut(&mut self) -> &mut PacketMetadata {
&mut self.metadata
}
#[inline]
pub fn set_metadata(&mut self, metadata: PacketMetadata) {
self.metadata = metadata;
}
#[inline]
#[must_use]
pub const fn flags(&self) -> u16 {
self.flags
}
#[inline]
pub fn set_flags(&mut self, flags: u16) {
self.flags = flags;
}
#[inline]
pub fn add_flag(&mut self, flag: u16) {
self.flags |= flag;
}
#[inline]
#[must_use]
pub const fn has_flag(&self, flag: u16) -> bool {
self.flags & flag != 0
}
#[inline]
#[must_use]
pub const fn timestamp(&self) -> u64 {
self.timestamp
}
#[inline]
pub fn set_timestamp(&mut self, timestamp: u64) {
self.timestamp = timestamp;
}
#[inline]
pub fn add_ref(&self) {
self.refcount.fetch_add(1, Ordering::AcqRel);
}
#[inline]
pub fn release(&self) -> bool {
self.refcount.fetch_sub(1, Ordering::AcqRel) == 1
}
#[inline]
#[must_use]
pub fn refcount(&self) -> u32 {
self.refcount.load(Ordering::Acquire)
}
pub unsafe fn parse_headers(&mut self) {
if self.len < 14 {
return; }
let data = unsafe { std::slice::from_raw_parts(self.data, self.len as usize) };
self.metadata.l2_offset = 0;
self.metadata.l3_offset = 14;
let ethertype = u16::from_be_bytes([data[12], data[13]]);
match ethertype {
0x0800 => {
self.metadata.ip_version = IpVersion::V4;
self.parse_ipv4(data, 14);
}
0x86DD => {
self.metadata.ip_version = IpVersion::V6;
self.parse_ipv6(data, 14);
}
0x8100 => {
self.metadata.l3_offset = 18;
if self.len >= 18 {
let inner_ethertype = u16::from_be_bytes([data[16], data[17]]);
match inner_ethertype {
0x0800 => {
self.metadata.ip_version = IpVersion::V4;
self.parse_ipv4(data, 18);
}
0x86DD => {
self.metadata.ip_version = IpVersion::V6;
self.parse_ipv6(data, 18);
}
_ => {}
}
}
}
_ => {}
}
self.metadata.flow_hash = self.calculate_flow_hash();
}
fn parse_ipv4(&mut self, data: &[u8], offset: usize) {
if data.len() < offset + 20 {
return; }
let ihl = (data[offset] & 0x0F) as usize * 4;
self.metadata.l4_offset = (offset + ihl) as u16;
self.metadata.protocol = Protocol::from(data[offset + 9]);
let l4_offset = self.metadata.l4_offset as usize;
if data.len() >= l4_offset + 4 {
match self.metadata.protocol {
Protocol::Tcp | Protocol::Udp => {
self.metadata.src_port =
u16::from_be_bytes([data[l4_offset], data[l4_offset + 1]]);
self.metadata.dst_port =
u16::from_be_bytes([data[l4_offset + 2], data[l4_offset + 3]]);
if self.metadata.protocol == Protocol::Tcp && data.len() >= l4_offset + 14 {
self.metadata.flags = data[l4_offset + 13];
}
}
_ => {}
}
}
}
fn parse_ipv6(&mut self, data: &[u8], offset: usize) {
if data.len() < offset + 40 {
return; }
self.metadata.l4_offset = (offset + 40) as u16;
self.metadata.protocol = Protocol::from(data[offset + 6]);
let l4_offset = self.metadata.l4_offset as usize;
if data.len() >= l4_offset + 4 {
match self.metadata.protocol {
Protocol::Tcp | Protocol::Udp => {
self.metadata.src_port =
u16::from_be_bytes([data[l4_offset], data[l4_offset + 1]]);
self.metadata.dst_port =
u16::from_be_bytes([data[l4_offset + 2], data[l4_offset + 3]]);
if self.metadata.protocol == Protocol::Tcp && data.len() >= l4_offset + 14 {
self.metadata.flags = data[l4_offset + 13];
}
}
_ => {}
}
}
}
fn calculate_flow_hash(&self) -> u64 {
let mut hash: u64 = 0xcbf2_9ce4_8422_2325;
hash ^= self.metadata.protocol as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
hash ^= self.metadata.src_port as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
hash ^= self.metadata.dst_port as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
unsafe {
let data = self.as_slice();
if self.metadata.ip_version == IpVersion::V4 {
let l3 = self.metadata.l3_offset as usize;
if data.len() >= l3 + 20 {
for i in 0..4 {
hash ^= data[l3 + 12 + i] as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
for i in 0..4 {
hash ^= data[l3 + 16 + i] as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
}
} else if self.metadata.ip_version == IpVersion::V6 {
let l3 = self.metadata.l3_offset as usize;
if data.len() >= l3 + 40 {
for i in 0..16 {
hash ^= data[l3 + 8 + i] as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
for i in 0..16 {
hash ^= data[l3 + 24 + i] as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3);
}
}
}
}
hash
}
#[must_use]
pub unsafe fn src_ipv4(&self) -> Option<Ipv4Addr> {
if self.metadata.ip_version != IpVersion::V4 {
return None;
}
let data = unsafe { self.as_slice() };
let l3 = self.metadata.l3_offset as usize;
if data.len() >= l3 + 20 {
Some(Ipv4Addr::new(
data[l3 + 12],
data[l3 + 13],
data[l3 + 14],
data[l3 + 15],
))
} else {
None
}
}
#[must_use]
pub unsafe fn dst_ipv4(&self) -> Option<Ipv4Addr> {
if self.metadata.ip_version != IpVersion::V4 {
return None;
}
let data = unsafe { self.as_slice() };
let l3 = self.metadata.l3_offset as usize;
if data.len() >= l3 + 20 {
Some(Ipv4Addr::new(
data[l3 + 16],
data[l3 + 17],
data[l3 + 18],
data[l3 + 19],
))
} else {
None
}
}
#[must_use]
pub unsafe fn src_ipv6(&self) -> Option<Ipv6Addr> {
if self.metadata.ip_version != IpVersion::V6 {
return None;
}
let data = unsafe { self.as_slice() };
let l3 = self.metadata.l3_offset as usize;
if data.len() >= l3 + 40 {
let mut octets = [0u8; 16];
octets.copy_from_slice(&data[l3 + 8..l3 + 24]);
Some(Ipv6Addr::from(octets))
} else {
None
}
}
#[must_use]
pub unsafe fn dst_ipv6(&self) -> Option<Ipv6Addr> {
if self.metadata.ip_version != IpVersion::V6 {
return None;
}
let data = unsafe { self.as_slice() };
let l3 = self.metadata.l3_offset as usize;
if data.len() >= l3 + 40 {
let mut octets = [0u8; 16];
octets.copy_from_slice(&data[l3 + 24..l3 + 40]);
Some(Ipv6Addr::from(octets))
} else {
None
}
}
}
impl std::fmt::Debug for ZeroCopyPacket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ZeroCopyPacket")
.field("data", &self.data)
.field("len", &self.len)
.field("metadata", &self.metadata)
.field("refcount", &self.refcount.load(Ordering::Relaxed))
.field("desc_idx", &self.desc_idx)
.field("flags", &self.flags)
.field("timestamp", &self.timestamp)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_packet_size() {
assert!(std::mem::size_of::<ZeroCopyPacket>() <= 128);
assert_eq!(std::mem::align_of::<ZeroCopyPacket>(), 64);
}
#[test]
fn test_empty_packet() {
let pkt = ZeroCopyPacket::empty();
assert!(pkt.is_empty());
assert_eq!(pkt.len(), 0);
assert_eq!(pkt.refcount(), 0);
}
#[test]
fn test_protocol_from() {
assert_eq!(Protocol::from(1), Protocol::Icmp);
assert_eq!(Protocol::from(6), Protocol::Tcp);
assert_eq!(Protocol::from(17), Protocol::Udp);
assert_eq!(Protocol::from(58), Protocol::Icmpv6);
assert_eq!(Protocol::from(255), Protocol::Unknown);
}
#[test]
fn test_metadata() {
let mut meta = PacketMetadata::new();
meta.protocol = Protocol::Tcp;
meta.src_port = 12345;
meta.dst_port = 80;
assert!(meta.is_tcp());
assert!(!meta.is_udp());
assert!(!meta.is_icmp());
}
#[test]
fn test_packet_from_slice() {
let data = [0u8; 64];
let pkt = unsafe { ZeroCopyPacket::from_slice(&data, 42) };
assert!(!pkt.is_empty());
assert_eq!(pkt.len(), 64);
assert_eq!(pkt.desc_idx(), 42);
assert_eq!(pkt.refcount(), 1);
}
#[test]
fn test_refcount() {
let data = [0u8; 64];
let pkt = unsafe { ZeroCopyPacket::from_slice(&data, 0) };
assert_eq!(pkt.refcount(), 1);
pkt.add_ref();
assert_eq!(pkt.refcount(), 2);
assert!(!pkt.release());
assert_eq!(pkt.refcount(), 1);
assert!(pkt.release());
assert_eq!(pkt.refcount(), 0);
}
#[test]
fn test_flags() {
let mut pkt = ZeroCopyPacket::empty();
assert_eq!(pkt.flags(), 0);
assert!(!pkt.has_flag(ZeroCopyPacket::FLAG_NEEDS_CSUM));
pkt.add_flag(ZeroCopyPacket::FLAG_NEEDS_CSUM);
assert!(pkt.has_flag(ZeroCopyPacket::FLAG_NEEDS_CSUM));
assert!(!pkt.has_flag(ZeroCopyPacket::FLAG_GSO));
pkt.add_flag(ZeroCopyPacket::FLAG_GSO);
assert!(pkt.has_flag(ZeroCopyPacket::FLAG_NEEDS_CSUM));
assert!(pkt.has_flag(ZeroCopyPacket::FLAG_GSO));
}
}