use super::{Pod, csum};
use std::{
fmt,
mem::size_of,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
};
macro_rules! len {
($record:ty) => {
unsafe impl Pod for $record {}
impl $record {
pub const LEN: usize = size_of::<$record>();
}
};
}
macro_rules! net_int {
($name:ident, $int:ty, $fmt:literal) => {
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(C)]
pub struct $name(pub $int);
impl $name {
#[inline]
pub fn host(self) -> $int {
<$int>::from_be(self.0)
}
}
impl From<$int> for $name {
#[inline]
fn from(v: $int) -> Self {
Self(v.to_be())
}
}
impl std::fmt::Debug for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.host())
}
}
impl std::fmt::Display for $name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, $fmt, self.host())
}
}
};
}
net_int!(NetworkU16, u16, "{:04x}");
net_int!(NetworkU32, u32, "{:08x}");
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[repr(C)]
pub struct MacAddress(pub [u8; 6]);
impl fmt::Debug for MacAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self}")
}
}
impl fmt::Display for MacAddress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
self.0[0], self.0[1], self.0[2], self.0[3], self.0[4], self.0[5]
)
}
}
#[derive(Copy, Clone)]
#[cfg_attr(feature = "__debug", derive(Debug))]
#[repr(C)]
pub struct EthHdr {
pub destination: MacAddress,
pub source: MacAddress,
pub ether_type: EtherType::Enum,
}
len!(EthHdr);
impl EthHdr {
#[inline]
pub fn swapped(&self) -> Self {
Self {
destination: self.source,
source: self.destination,
ether_type: self.ether_type,
}
}
}
#[allow(non_snake_case, non_upper_case_globals)]
pub mod EtherType {
pub type Enum = u16;
pub const Ipv4: Enum = 0x0800_u16.to_be();
pub const Arp: Enum = 0x0806_u16.to_be();
pub const Ipv6: Enum = 0x86dd_u16.to_be();
}
#[allow(non_snake_case, non_upper_case_globals)]
pub mod IpProto {
pub type Enum = u8;
pub const Icmp: Enum = 1;
pub const Igmp: Enum = 2;
pub const Tcp: Enum = 6;
pub const Udp: Enum = 17;
pub const Ipv6Icmp: Enum = 58;
pub const UdpLite: Enum = 136;
}
#[derive(Copy, Clone)]
#[repr(C)]
pub struct Ipv4Hdr {
bitfield: u16,
pub total_length: NetworkU16,
pub identification: NetworkU16,
fragment: u16,
#[doc(alias = "ttl")]
pub time_to_live: u8,
pub proto: IpProto::Enum,
pub check: u16,
pub source: NetworkU32,
pub destination: NetworkU32,
}
impl Ipv4Hdr {
#[inline]
pub fn reset(&mut self, ttl: u8, proto: IpProto::Enum) {
*self = Self::zeroed();
self.bitfield = 0x0045;
self.time_to_live = ttl;
self.proto = proto;
}
#[doc(alias = "ihl")]
#[inline]
pub fn internet_header_length(&self) -> u8 {
((self.bitfield & 0x000f) * 4) as u8
}
#[inline]
pub fn calc_checksum(&mut self) {
self.check = 0;
self.check = csum::fold_checksum(csum::partial(self.as_bytes(), 0));
}
#[inline]
pub fn swapped(&self) -> Self {
let mut new = *self;
new.source = self.destination;
new.destination = self.source;
new
}
}
len!(Ipv4Hdr);
#[cfg(feature = "__debug")]
impl fmt::Debug for Ipv4Hdr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Ipv4Hdr")
.field("total_length", &self.total_length)
.field("proto", &self.proto)
.field("ttl", &self.time_to_live)
.field("check", &format_args!("{:04x}", self.check))
.field("source", &Ipv4Addr::from_bits(self.source.host()))
.field("destination", &Ipv4Addr::from_bits(self.destination.host()))
.finish_non_exhaustive()
}
}
#[derive(Copy, Clone)]
#[repr(C)]
pub struct Ipv6Hdr {
bitfield: u32,
pub payload_length: NetworkU16,
pub next_header: IpProto::Enum,
pub hop_limit: u8,
pub source: [u8; 16],
pub destination: [u8; 16],
}
impl Ipv6Hdr {
#[inline]
pub fn reset(&mut self, hop: u8, proto: IpProto::Enum) {
*self = Self::zeroed();
self.bitfield = 0x00000060;
self.next_header = proto;
self.hop_limit = hop;
}
#[inline]
pub fn swapped(&self) -> Self {
let mut new = *self;
new.source = self.destination;
new.destination = self.source;
new
}
}
len!(Ipv6Hdr);
#[inline]
pub const fn ipv6_addr_from_bytes(octets: [u8; 16]) -> Ipv6Addr {
Ipv6Addr::from_bits(u128::from_be_bytes(octets))
}
#[cfg(feature = "__debug")]
impl fmt::Debug for Ipv6Hdr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Ipv6Hdr")
.field("payload_length", &self.payload_length)
.field("next_header", &self.next_header)
.field("hop_limit", &self.hop_limit)
.field("source", &ipv6_addr_from_bytes(self.source))
.field("destination", &ipv6_addr_from_bytes(self.destination))
.finish_non_exhaustive()
}
}
#[derive(Copy, Clone)]
#[repr(C)]
pub struct UdpHdr {
pub source: NetworkU16,
pub destination: NetworkU16,
pub length: NetworkU16,
pub check: u16,
}
len!(UdpHdr);
impl UdpHdr {
#[inline]
pub fn swapped(&self) -> Self {
Self {
source: self.destination,
destination: self.source,
length: self.length,
check: self.check,
}
}
}
#[cfg(feature = "__debug")]
impl fmt::Debug for UdpHdr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UdpHdr")
.field("source", &self.source)
.field("destination", &self.destination)
.field("length", &self.length)
.field("check", &format_args!("{:04x}", self.check))
.finish()
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum IpAddresses {
V4 {
source: Ipv4Addr,
destination: Ipv4Addr,
},
V6 {
source: Ipv6Addr,
destination: Ipv6Addr,
},
}
use std::net::IpAddr;
impl IpAddresses {
#[inline]
pub fn source(&self) -> IpAddr {
match self {
Self::V4 { source, .. } => (*source).into(),
Self::V6 { source, .. } => (*source).into(),
}
}
#[inline]
pub fn destination(&self) -> IpAddr {
match self {
Self::V4 { destination, .. } => (*destination).into(),
Self::V6 { destination, .. } => (*destination).into(),
}
}
#[inline]
pub fn both(&self) -> (IpAddr, IpAddr) {
match self {
Self::V4 {
source,
destination,
..
} => ((*source).into(), (*destination).into()),
Self::V6 {
source,
destination,
..
} => ((*source).into(), (*destination).into()),
}
}
#[inline]
pub fn with_header(self, prev: &IpHdr) -> IpHdr {
let mut iphdr = match (self, prev) {
(
Self::V4 {
source,
destination,
},
IpHdr::V4(old),
) => {
let mut new = *old;
new.source = source.to_bits().into();
new.destination = destination.to_bits().into();
IpHdr::V4(new)
}
(
Self::V6 {
source,
destination,
},
IpHdr::V6(old),
) => {
let mut new = *old;
new.source = source.octets();
new.destination = destination.octets();
IpHdr::V6(new)
}
(
Self::V4 {
source,
destination,
},
IpHdr::V6(old),
) => {
let mut new = Ipv4Hdr::zeroed();
new.reset(old.hop_limit, old.next_header);
new.source = source.to_bits().into();
new.destination = destination.to_bits().into();
IpHdr::V4(new)
}
(
Self::V6 {
source,
destination,
},
IpHdr::V4(old),
) => {
let mut new = Ipv6Hdr::zeroed();
new.reset(old.time_to_live, old.proto);
new.source = source.octets();
new.destination = destination.octets();
IpHdr::V6(new)
}
};
iphdr.decrement_hop();
iphdr
}
}
#[cfg_attr(feature = "__debug", derive(Debug))]
pub enum IpHdr {
V4(Ipv4Hdr),
V6(Ipv6Hdr),
}
impl IpHdr {
#[inline]
pub fn swapped(&self) -> Self {
match self {
Self::V4(v4) => Self::V4(v4.swapped()),
Self::V6(v6) => Self::V6(v6.swapped()),
}
}
#[inline]
pub fn decrement_hop(&mut self) -> u8 {
let hop = match self {
Self::V4(v4) => &mut v4.time_to_live,
Self::V6(v6) => &mut v6.hop_limit,
};
if *hop != 0 {
*hop -= 1;
}
*hop
}
}
impl PartialEq<IpAddresses> for IpHdr {
fn eq(&self, other: &IpAddresses) -> bool {
match (self, other) {
(
Self::V4(v4),
IpAddresses::V4 {
source,
destination,
},
) => {
v4.source.host() == source.to_bits()
&& v4.destination.host() == destination.to_bits()
}
(
Self::V6(v6),
IpAddresses::V6 {
source,
destination,
},
) => v6.source == source.octets() && v6.destination == destination.octets(),
_ => false,
}
}
}
#[derive(Copy, Clone)]
pub struct DataRange {
pub start: usize,
pub end: usize,
}
impl From<std::ops::Range<usize>> for DataRange {
#[inline]
fn from(value: std::ops::Range<usize>) -> Self {
Self {
start: value.start,
end: value.end,
}
}
}
impl fmt::Debug for DataRange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}..{}", self.start, self.end)
}
}
impl std::ops::Index<DataRange> for [u8] {
type Output = [u8];
#[inline]
fn index(&self, index: DataRange) -> &Self::Output {
self.index(index.start..index.end)
}
}
#[cfg_attr(feature = "__debug", derive(Debug))]
pub struct UdpHeaders {
pub eth: EthHdr,
pub ip: IpHdr,
pub udp: UdpHdr,
pub data: DataRange,
}
impl UdpHeaders {
#[inline]
pub fn new(eth: EthHdr, ip: IpHdr, udp: UdpHdr, data: impl Into<DataRange>) -> Self {
Self {
eth,
ip,
udp,
data: data.into(),
}
}
pub fn parse_packet(packet: &super::Packet) -> Result<Option<Self>, super::PacketError> {
let mut offset = 0;
let eth = packet.read::<EthHdr>(offset)?;
offset += EthHdr::LEN;
let ip = match eth.ether_type {
EtherType::Ipv4 => {
let ipv4 = packet.read::<Ipv4Hdr>(offset)?;
offset += Ipv4Hdr::LEN;
if ipv4.proto == IpProto::Udp {
IpHdr::V4(ipv4)
} else {
return Ok(None);
}
}
EtherType::Ipv6 => {
let ipv6 = packet.read::<Ipv6Hdr>(offset)?;
offset += Ipv6Hdr::LEN;
if ipv6.next_header == IpProto::Udp {
IpHdr::V6(ipv6)
} else {
return Ok(None);
}
}
_ => {
return Ok(None);
}
};
let udp = packet.read::<UdpHdr>(offset)?;
let data_length = udp.length.host() as usize;
if offset + data_length != packet.len() {
return Err(super::PacketError::InsufficientData {
offset,
size: data_length,
length: packet.len(),
});
}
let start = offset + UdpHdr::LEN;
Ok(Some(Self {
eth,
ip,
udp,
data: (start..start + data_length - UdpHdr::LEN).into(),
}))
}
#[inline]
pub fn is_ipv4(&self) -> bool {
matches!(&self.ip, IpHdr::V4(_))
}
#[inline]
pub fn header_length(&self) -> usize {
EthHdr::LEN
+ if self.is_ipv4() {
Ipv4Hdr::LEN
} else {
Ipv6Hdr::LEN
}
+ UdpHdr::LEN
}
#[inline(always)]
pub fn data_length(&self) -> usize {
self.data.end - self.data.start
}
#[inline]
pub fn decrement_hop(&mut self) -> u8 {
self.ip.decrement_hop()
}
#[inline]
pub fn source_address(&self) -> SocketAddr {
use std::net::*;
match self.ip {
IpHdr::V4(v4) => SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from_bits(v4.source.host()),
self.udp.source.host(),
)),
IpHdr::V6(v6) => SocketAddr::V6(SocketAddrV6::new(
ipv6_addr_from_bytes(v6.source),
self.udp.source.host(),
0,
0,
)),
}
}
#[inline]
pub fn destination_address(&self) -> SocketAddr {
use std::net::*;
match self.ip {
IpHdr::V4(v4) => SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from_bits(v4.destination.host()),
self.udp.destination.host(),
)),
IpHdr::V6(v6) => SocketAddr::V6(SocketAddrV6::new(
ipv6_addr_from_bytes(v6.destination),
self.udp.destination.host(),
0,
0,
)),
}
}
pub fn set_packet_headers(
&mut self,
packet: &mut super::Packet,
) -> Result<(), super::PacketError> {
let mut offset = EthHdr::LEN;
let length = (self.data.end - self.data.start + UdpHdr::LEN) as u16;
self.eth.ether_type = match &mut self.ip {
IpHdr::V4(v4) => {
v4.total_length = (length + Ipv4Hdr::LEN as u16).into();
v4.calc_checksum();
packet.write(offset, *v4)?;
offset += Ipv4Hdr::LEN;
EtherType::Ipv4
}
IpHdr::V6(v6) => {
v6.payload_length = length.into();
packet.write(offset, *v6)?;
offset += Ipv6Hdr::LEN;
EtherType::Ipv6
}
};
packet.write(0, self.eth)?;
self.udp.length = length.into();
packet.write(offset, self.udp)?;
Ok(())
}
}
#[cfg(test)]
mod test {
#[test]
fn sanity_check() {
use super::*;
assert_eq!(EthHdr::LEN, 14);
assert_eq!(Ipv4Hdr::LEN, 20);
assert_eq!(Ipv6Hdr::LEN, 40);
assert_eq!(UdpHdr::LEN, 8);
let mut ip = Ipv4Hdr::zeroed();
ip.reset(56, IpProto::Tcp);
assert_eq!(20, ip.internet_header_length());
}
}