use core::{fmt, ops};
use byteorder::{ByteOrder, NetworkEndian};
use crate::wire::{ip, Error, Reframe, Result, Payload, PayloadError, PayloadMut, payload};
use crate::wire::pretty_print::{PrettyPrint, PrettyIndent};
use super::ip::checksum;
#[derive(Debug, PartialEq, Clone)]
pub struct Packet<T> {
buffer: T,
repr: Repr,
}
byte_wrapper! {
#[derive(Debug, PartialEq, Eq)]
pub struct udp([u8]);
}
mod field {
#![allow(non_snake_case)]
use crate::wire::field::Field;
pub(crate) const SRC_PORT: Field = 0..2;
pub(crate) const DST_PORT: Field = 2..4;
pub(crate) const LENGTH: Field = 4..6;
pub(crate) const CHECKSUM: Field = 6..8;
pub(crate) fn PAYLOAD(length: u16) -> Field {
CHECKSUM.end..(length as usize)
}
}
impl udp {
pub fn new_unchecked(data: &[u8]) -> &Self {
Self::__from_macro_new_unchecked(data)
}
pub fn new_unchecked_mut(data: &mut [u8]) -> &mut Self {
Self::__from_macro_new_unchecked_mut(data)
}
pub fn new_checked(data: &[u8]) -> Result<&Self> {
Self::new_unchecked(data).check_len()?;
Ok(Self::new_unchecked(data))
}
pub fn new_checked_mut(data: &mut [u8]) -> Result<&mut Self> {
Self::new_checked(&data[..])?;
Ok(Self::new_unchecked_mut(data))
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.0
}
pub fn check_len(&self) -> Result<()> {
let buffer_len = self.0.len();
if buffer_len < field::CHECKSUM.end {
Err(Error::Truncated)
} else {
let field_len = self.len() as usize;
if buffer_len < field_len {
Err(Error::Truncated)
} else if field_len < field::CHECKSUM.end {
Err(Error::Malformed)
} else {
Ok(())
}
}
}
#[inline]
pub fn src_port(&self) -> u16 {
NetworkEndian::read_u16(&self.0[field::SRC_PORT])
}
#[inline]
pub fn dst_port(&self) -> u16 {
NetworkEndian::read_u16(&self.0[field::DST_PORT])
}
#[inline]
pub fn len(&self) -> u16 {
NetworkEndian::read_u16(&self.0[field::LENGTH])
}
#[inline]
pub fn checksum(&self) -> u16 {
NetworkEndian::read_u16(&self.0[field::CHECKSUM])
}
#[inline]
pub fn set_src_port(&mut self, value: u16) {
NetworkEndian::write_u16(&mut self.0[field::SRC_PORT], value)
}
#[inline]
pub fn set_dst_port(&mut self, value: u16) {
NetworkEndian::write_u16(&mut self.0[field::DST_PORT], value)
}
#[inline]
pub fn set_len(&mut self, value: u16) {
NetworkEndian::write_u16(&mut self.0[field::LENGTH], value)
}
#[inline]
pub fn set_checksum(&mut self, value: u16) {
NetworkEndian::write_u16(&mut self.0[field::CHECKSUM], value)
}
pub fn fill_checksum(&mut self, src_addr: ip::Address, dst_addr: ip::Address) {
self.set_checksum(0);
let checksum = {
!checksum::combine(&[
checksum::pseudo_header(&src_addr, &dst_addr, ip::Protocol::Udp,
self.len() as u32),
checksum::data(&self.0[..self.len() as usize])
])
};
self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
}
pub fn verify_checksum(&self, src_addr: ip::Address, dst_addr: ip::Address) -> bool {
if cfg!(fuzzing) { return true }
checksum::combine(&[
checksum::pseudo_header(&src_addr, &dst_addr, ip::Protocol::Udp,
self.len() as u32),
checksum::data(&self.0[..self.len() as usize])
]) == !0
}
pub fn payload_slice(&self) -> &[u8] {
let len = self.len();
&self.0[field::PAYLOAD(len)]
}
pub fn payload_mut_slice(&mut self) -> &mut [u8] {
let len = self.len();
&mut self.0[field::PAYLOAD(len)]
}
}
impl<T: Payload> Packet<T> {
pub fn new_checked(buffer: T, checksum: Checksum) -> Result<Self> {
let frame = udp::new_checked(buffer.payload())?;
let repr = Repr::parse(frame, checksum)?;
Ok(Packet {
buffer,
repr,
})
}
pub fn new_unchecked(buffer: T, repr: Repr) -> Self {
Packet {
buffer,
repr,
}
}
pub fn get_ref(&self) -> &T {
&self.buffer
}
pub fn repr(&self) -> Repr {
self.repr
}
pub fn into_inner(self) -> T {
self.buffer
}
pub fn payload_mut_slice(&mut self) -> &mut [u8] where T: PayloadMut {
udp::new_unchecked_mut(self.buffer.payload_mut())
.payload_mut_slice()
}
}
impl<T: Payload + PayloadMut> Packet<T> {
pub fn fill_checksum(&mut self, checksum: Checksum) {
let buffer = udp::new_unchecked_mut(self.buffer.payload_mut());
match checksum {
Checksum::Lazy { src_addr: ip::Address::Ipv4(_), dst_addr: ip::Address::Ipv4(_) }
| Checksum::Ignored => (),
Checksum::Manual { src_addr, dst_addr }
| Checksum::Lazy { src_addr, dst_addr } => {
buffer.fill_checksum(src_addr, dst_addr)
},
}
}
}
impl AsRef<[u8]> for udp {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl AsMut<[u8]> for udp {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.0
}
}
impl<'a, T: Payload + ?Sized> Packet<&'a T> {
#[inline]
pub fn payload_slice(&self) -> &'a [u8] {
udp::new_unchecked(self.buffer.payload())
.payload_slice()
}
}
impl<T: Payload + PayloadMut> Packet<T> {
#[inline]
pub fn payload_mut(&mut self) -> &mut [u8] {
udp::new_unchecked_mut(self.buffer.payload_mut())
.payload_mut_slice()
}
}
impl<T: Payload> ops::Deref for Packet<T> {
type Target = udp;
fn deref(&self) -> &udp {
udp::new_unchecked(self.buffer.payload())
}
}
impl<T: Payload> Payload for Packet<T> {
fn payload(&self) -> &payload {
self.payload_slice().into()
}
}
impl<T: Payload + PayloadMut> PayloadMut for Packet<T> {
fn payload_mut(&mut self) -> &mut payload {
udp::new_unchecked_mut(self.buffer.payload_mut())
.payload_mut_slice()
.into()
}
fn resize(&mut self, length: usize) -> core::result::Result<(), PayloadError> {
self.buffer.resize(length + field::CHECKSUM.end)
}
fn reframe(&mut self, mut reframe: Reframe)
-> core::result::Result<(), PayloadError>
{
reframe.within_header(field::CHECKSUM.end);
self.buffer.reframe(reframe)
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct Repr {
pub src_port: u16,
pub dst_port: u16,
pub length: u16,
}
pub enum Checksum {
Manual {
src_addr: ip::Address,
dst_addr: ip::Address,
},
Lazy {
src_addr: ip::Address,
dst_addr: ip::Address,
},
Ignored,
}
impl Repr {
pub fn parse(packet: &udp, checksum: Checksum) -> Result<Repr> {
packet.check_len()?;
if packet.dst_port() == 0 { return Err(Error::Malformed) }
if let Checksum::Manual { src_addr, dst_addr } = checksum {
match (src_addr, dst_addr) {
(ip::Address::Ipv4(_), ip::Address::Ipv4(_)) if packet.checksum() == 0 => { }
_ if !packet.verify_checksum(src_addr, dst_addr) => return Err(Error::WrongChecksum),
_ => (),
}
}
Ok(Repr {
src_port: packet.src_port(),
dst_port: packet.dst_port(),
length: packet.len(),
})
}
pub fn buffer_len(&self) -> usize {
self.length.into()
}
pub fn emit(&self, packet: &mut udp, checksum: Checksum) {
packet.set_src_port(self.src_port);
packet.set_dst_port(self.dst_port);
packet.set_len(self.length);
if let Checksum::Manual { src_addr, dst_addr, } = checksum {
packet.fill_checksum(src_addr, dst_addr)
} else {
packet.set_checksum(0);
}
}
}
impl Checksum {
pub fn for_pseudo_header<A, B>(src_addr: A, dst_addr: B) -> Self
where A: Into<ip::Address>, B: Into<ip::Address>
{
Checksum::Manual {
src_addr: src_addr.into(),
dst_addr: dst_addr.into(),
}
}
}
impl<T: Payload> fmt::Display for Packet<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.repr)
}
}
impl fmt::Display for Repr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let payload_len = usize::from(self.length)
.checked_sub(field::CHECKSUM.end);
if let Some(payload_len) = payload_len {
write!(f, "UDP src={} dst={} len={}",
self.src_port, self.dst_port, payload_len)
} else {
write!(f, "UDP src={} dst={} len=??",
self.src_port, self.dst_port)
}
}
}
impl PrettyPrint for udp {
fn pretty_print(buffer: &[u8], f: &mut fmt::Formatter,
indent: &mut PrettyIndent) -> fmt::Result {
match Packet::new_checked(buffer, Checksum::Ignored) {
Err(err) => write!(f, "{}({})", indent, err),
Ok(packet) => write!(f, "{}{}", indent, packet)
}
}
}
#[cfg(test)]
mod test {
use crate::wire::ip::v4::Address as Ipv4Address;
use super::*;
const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
static PACKET_BYTES: [u8; 12] =
[0xbf, 0x00, 0x00, 0x35,
0x00, 0x0c, 0x12, 0x4d,
0xaa, 0x00, 0x00, 0xff];
static PAYLOAD_BYTES: [u8; 4] =
[0xaa, 0x00, 0x00, 0xff];
#[test]
fn test_deconstruct() {
let packet = udp::new_unchecked(&PACKET_BYTES[..]);
assert_eq!(packet.src_port(), 48896);
assert_eq!(packet.dst_port(), 53);
assert_eq!(packet.len(), 12);
assert_eq!(packet.checksum(), 0x124d);
assert_eq!(packet.payload_slice(), &PAYLOAD_BYTES[..]);
assert!(packet.verify_checksum(SRC_ADDR.into(), DST_ADDR.into()));
}
#[test]
fn test_construct() {
let mut bytes = vec![0xa5; 12];
let packet = udp::new_unchecked_mut(&mut bytes);
packet.set_src_port(48896);
packet.set_dst_port(53);
packet.set_len(12);
packet.set_checksum(0xffff);
packet.payload_mut_slice().copy_from_slice(&PAYLOAD_BYTES[..]);
packet.fill_checksum(SRC_ADDR.into(), DST_ADDR.into());
assert_eq!(packet.as_bytes(), &PACKET_BYTES[..]);
}
#[test]
fn test_impossible_len() {
let mut bytes = vec![0; 12];
let packet = udp::new_unchecked_mut(&mut bytes);
packet.set_len(4);
assert_eq!(packet.check_len(), Err(Error::Malformed));
}
#[test]
fn test_zero_checksum() {
let mut bytes = vec![0; 8];
let packet = udp::new_unchecked_mut(&mut bytes);
packet.set_src_port(1);
packet.set_dst_port(31881);
packet.set_len(8);
packet.fill_checksum(SRC_ADDR.into(), DST_ADDR.into());
assert_eq!(packet.checksum(), 0xffff);
}
fn packet_repr() -> Repr {
Repr {
src_port: 48896,
dst_port: 53,
length: PACKET_BYTES.len() as u16,
}
}
#[test]
fn test_parse() {
let packet = udp::new_unchecked(&PACKET_BYTES[..]);
let repr = Repr::parse(
packet,
Checksum::for_pseudo_header(SRC_ADDR, DST_ADDR),
).unwrap();
assert_eq!(repr, packet_repr());
assert_eq!(packet.payload_slice(), &PAYLOAD_BYTES[..]);
}
#[test]
fn test_emit() {
let repr = packet_repr();
let mut bytes = vec![0xa5; repr.buffer_len()];
let packet = udp::new_unchecked_mut(&mut bytes);
repr.emit(packet, Checksum::Ignored);
packet.payload_mut_slice().copy_from_slice(&PAYLOAD_BYTES[..]);
repr.emit(packet,
Checksum::for_pseudo_header(SRC_ADDR, DST_ADDR));
assert_eq!(packet.as_bytes(), &PACKET_BYTES[..]);
assert_eq!(packet.payload_slice(), &PAYLOAD_BYTES[..]);
}
}