use crate::packets::ip::v4::Ipv4;
use crate::packets::ip::v6::Ipv6;
use crate::packets::ip::{Flow, IpPacket, ProtocolNumbers};
use crate::packets::types::u16be;
use crate::packets::{checksum, Internal, Packet};
use crate::{ensure, SizeOf};
use anyhow::{anyhow, Result};
use std::fmt;
use std::net::IpAddr;
use std::ptr::NonNull;
pub struct Udp<E: IpPacket> {
envelope: E,
header: NonNull<UdpHeader>,
offset: usize,
}
impl<E: IpPacket> Udp<E> {
#[inline]
fn header(&self) -> &UdpHeader {
unsafe { self.header.as_ref() }
}
#[inline]
fn header_mut(&mut self) -> &mut UdpHeader {
unsafe { self.header.as_mut() }
}
#[inline]
pub fn src_port(&self) -> u16 {
self.header().src_port.into()
}
#[inline]
pub fn set_src_port(&mut self, src_port: u16) {
self.header_mut().src_port = src_port.into();
}
#[inline]
pub fn dst_port(&self) -> u16 {
self.header().dst_port.into()
}
#[inline]
pub fn set_dst_port(&mut self, dst_port: u16) {
self.header_mut().dst_port = dst_port.into();
}
#[inline]
pub fn length(&self) -> u16 {
self.header().length.into()
}
#[inline]
fn set_length(&mut self, length: u16) {
self.header_mut().length = length.into()
}
#[inline]
pub fn checksum(&self) -> u16 {
self.header().checksum.into()
}
#[inline]
fn set_checksum(&mut self, checksum: u16) {
self.header_mut().checksum = match checksum {
0 => u16be::from(0xFFFF),
_ => checksum.into(),
}
}
#[inline]
pub fn no_checksum(&mut self) {
self.header_mut().checksum = u16be::default();
}
#[inline]
pub fn flow(&self) -> Flow {
Flow::new(
self.envelope().src(),
self.envelope().dst(),
self.src_port(),
self.dst_port(),
ProtocolNumbers::Udp,
)
}
#[inline]
pub fn set_src_ip(&mut self, src_ip: IpAddr) -> Result<()> {
let old_ip = self.envelope().src();
let checksum = checksum::compute_with_ipaddr(self.checksum(), &old_ip, &src_ip)?;
self.envelope_mut().set_src(src_ip)?;
self.set_checksum(checksum);
Ok(())
}
#[inline]
pub fn set_dst_ip(&mut self, dst_ip: IpAddr) -> Result<()> {
let old_ip = self.envelope().dst();
let checksum = checksum::compute_with_ipaddr(self.checksum(), &old_ip, &dst_ip)?;
self.envelope_mut().set_dst(dst_ip)?;
self.set_checksum(checksum);
Ok(())
}
#[inline]
fn compute_checksum(&mut self) {
self.no_checksum();
if let Ok(data) = self.mbuf().read_data_slice(self.offset, self.len()) {
let data = unsafe { data.as_ref() };
let pseudo_header_sum = self
.envelope()
.pseudo_header(data.len() as u16, ProtocolNumbers::Udp)
.sum();
let checksum = checksum::compute(pseudo_header_sum, data);
self.set_checksum(checksum);
} else {
unreachable!()
}
}
}
impl<E: IpPacket> fmt::Debug for Udp<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("udp")
.field("src_port", &self.src_port())
.field("dst_port", &self.dst_port())
.field("length", &self.length())
.field("checksum", &format!("0x{:04x}", self.checksum()))
.field("$offset", &self.offset())
.field("$len", &self.len())
.field("$header_len", &self.header_len())
.finish()
}
}
impl<E: IpPacket> Packet for Udp<E> {
type Envelope = E;
#[inline]
fn envelope(&self) -> &Self::Envelope {
&self.envelope
}
#[inline]
fn envelope_mut(&mut self) -> &mut Self::Envelope {
&mut self.envelope
}
#[inline]
fn offset(&self) -> usize {
self.offset
}
#[inline]
fn header_len(&self) -> usize {
UdpHeader::size_of()
}
#[inline]
unsafe fn clone(&self, internal: Internal) -> Self {
Udp::<E> {
envelope: self.envelope.clone(internal),
header: self.header,
offset: self.offset,
}
}
#[inline]
fn try_parse(envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
ensure!(
envelope.next_protocol() == ProtocolNumbers::Udp,
anyhow!("not a UDP packet.")
);
let mbuf = envelope.mbuf();
let offset = envelope.payload_offset();
let header = mbuf.read_data(offset)?;
Ok(Udp {
envelope,
header,
offset,
})
}
#[inline]
fn try_push(mut envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
let offset = envelope.payload_offset();
let mbuf = envelope.mbuf_mut();
mbuf.extend(offset, UdpHeader::size_of())?;
let header = mbuf.write_data(offset, &UdpHeader::default())?;
envelope.set_next_protocol(ProtocolNumbers::Udp);
Ok(Udp {
envelope,
header,
offset,
})
}
#[inline]
fn deparse(self) -> Self::Envelope {
self.envelope
}
#[inline]
fn reconcile(&mut self) {
let len = self.len() as u16;
self.set_length(len);
self.compute_checksum();
}
}
pub type Udp4 = Udp<Ipv4>;
pub type Udp6 = Udp<Ipv6>;
#[derive(Clone, Copy, Debug, Default, SizeOf)]
#[repr(C)]
struct UdpHeader {
src_port: u16be,
dst_port: u16be,
length: u16be,
checksum: u16be,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packets::Ethernet;
use crate::testils::byte_arrays::{IPV4_TCP_PACKET, IPV4_UDP_PACKET};
use crate::Mbuf;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn size_of_udp_header() {
assert_eq!(8, UdpHeader::size_of());
}
#[capsule::test]
fn parse_udp_packet() {
let packet = Mbuf::from_bytes(&IPV4_UDP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let udp = ipv4.parse::<Udp4>().unwrap();
assert_eq!(39376, udp.src_port());
assert_eq!(1087, udp.dst_port());
assert_eq!(18, udp.length());
assert_eq!(0x7228, udp.checksum());
}
#[capsule::test]
fn parse_non_udp_packet() {
let packet = Mbuf::from_bytes(&IPV4_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
assert!(ipv4.parse::<Udp4>().is_err());
}
#[capsule::test]
fn udp_flow_v4() {
let packet = Mbuf::from_bytes(&IPV4_UDP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let udp = ipv4.parse::<Udp4>().unwrap();
let flow = udp.flow();
assert_eq!("139.133.217.110", flow.src_ip().to_string());
assert_eq!("139.133.233.2", flow.dst_ip().to_string());
assert_eq!(39376, flow.src_port());
assert_eq!(1087, flow.dst_port());
assert_eq!(ProtocolNumbers::Udp, flow.protocol());
}
#[capsule::test]
fn set_src_dst_ip() {
let packet = Mbuf::from_bytes(&IPV4_UDP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let mut udp = ipv4.parse::<Udp4>().unwrap();
let old_checksum = udp.checksum();
let new_ip = Ipv4Addr::new(10, 0, 0, 0);
assert!(udp.set_src_ip(new_ip.into()).is_ok());
assert!(udp.checksum() != old_checksum);
assert_eq!(new_ip.to_string(), udp.envelope().src().to_string());
let old_checksum = udp.checksum();
let new_ip = Ipv4Addr::new(20, 0, 0, 0);
assert!(udp.set_dst_ip(new_ip.into()).is_ok());
assert!(udp.checksum() != old_checksum);
assert_eq!(new_ip.to_string(), udp.envelope().dst().to_string());
assert!(udp.set_src_ip(Ipv6Addr::UNSPECIFIED.into()).is_err());
}
#[capsule::test]
fn compute_checksum() {
let packet = Mbuf::from_bytes(&IPV4_UDP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv4 = ethernet.parse::<Ipv4>().unwrap();
let mut udp = ipv4.parse::<Udp4>().unwrap();
let expected = udp.checksum();
udp.reconcile_all();
assert_eq!(expected, udp.checksum());
}
#[capsule::test]
fn push_udp_packet() {
let packet = Mbuf::new().unwrap();
let ethernet = packet.push::<Ethernet>().unwrap();
let ipv4 = ethernet.push::<Ipv4>().unwrap();
let udp = ipv4.push::<Udp4>().unwrap();
assert_eq!(UdpHeader::size_of(), udp.len());
assert_eq!(ProtocolNumbers::Udp, udp.envelope().next_protocol());
}
}