use crate::packets::checksum::PseudoHeader;
use crate::packets::ip::v6::Ipv6Packet;
use crate::packets::ip::{IpPacket, ProtocolNumber, ProtocolNumbers};
use crate::packets::types::{u16be, u32be};
use crate::packets::{Internal, Packet};
use crate::{ensure, SizeOf};
use anyhow::{anyhow, Result};
use std::fmt;
use std::net::IpAddr;
use std::ptr::NonNull;
const FRAG_OS: u16be = u16be(u16::to_be(!0b111));
const FLAG_MORE: u16be = u16be(u16::to_be(0b1));
pub struct Fragment<E: Ipv6Packet> {
envelope: E,
header: NonNull<FragmentHeader>,
offset: usize,
}
impl<E: Ipv6Packet> Fragment<E> {
#[inline]
fn header(&self) -> &FragmentHeader {
unsafe { self.header.as_ref() }
}
#[inline]
fn header_mut(&mut self) -> &mut FragmentHeader {
unsafe { self.header.as_mut() }
}
pub fn fragment_offset(&self) -> u16 {
let v: u16 = (self.header().frag_res_m & FRAG_OS).into();
v >> 3
}
pub fn set_fragment_offset(&mut self, offset: u16) {
self.header_mut().frag_res_m =
(self.header().frag_res_m & !FRAG_OS) | u16be::from(offset << 3);
}
pub fn more_fragments(&self) -> bool {
self.header().frag_res_m & FLAG_MORE > u16be::MIN
}
pub fn set_more_fragments(&mut self) {
self.header_mut().frag_res_m |= FLAG_MORE
}
pub fn unset_more_fragments(&mut self) {
self.header_mut().frag_res_m &= !FLAG_MORE
}
pub fn identification(&self) -> u32 {
self.header().identification.into()
}
pub fn set_identification(&mut self, identification: u32) {
self.header_mut().identification = identification.into()
}
}
impl<E: Ipv6Packet> fmt::Debug for Fragment<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("fragment")
.field("next_header", &format!("{}", self.next_header()))
.field("fragment_offset", &self.fragment_offset())
.field("more_fragments", &self.more_fragments())
.field("identification", &format!("{:x}", self.identification()))
.finish()
}
}
impl<E: Ipv6Packet> Packet for Fragment<E> {
type Envelope = E;
#[inline]
fn envelope(&self) -> &Self::Envelope {
&self.envelope
}
#[inline]
fn envelope_mut(&mut self) -> &mut Self::Envelope {
&mut self.envelope
}
#[inline]
fn offset(&self) -> usize {
self.offset
}
#[inline]
fn header_len(&self) -> usize {
FragmentHeader::size_of()
}
#[inline]
unsafe fn clone(&self, internal: Internal) -> Self {
Fragment::<E> {
envelope: self.envelope.clone(internal),
header: self.header,
offset: self.offset,
}
}
#[inline]
fn try_parse(envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
ensure!(
envelope.next_header() == ProtocolNumbers::Ipv6Frag,
anyhow!("not an IPv6 fragment packet.")
);
let mbuf = envelope.mbuf();
let offset = envelope.payload_offset();
let header = mbuf.read_data(offset)?;
Ok(Fragment {
envelope,
header,
offset,
})
}
#[inline]
fn try_push(mut envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
let offset = envelope.payload_offset();
let mbuf = envelope.mbuf_mut();
mbuf.extend(offset, FragmentHeader::size_of())?;
let header = mbuf.write_data(offset, &FragmentHeader::default())?;
let mut packet = Fragment {
envelope,
header,
offset,
};
packet.set_next_header(packet.envelope().next_header());
packet
.envelope_mut()
.set_next_header(ProtocolNumbers::Ipv6Frag);
Ok(packet)
}
#[inline]
fn remove(mut self) -> Result<Self::Envelope> {
let offset = self.offset();
let len = self.header_len();
let next_header = self.next_header();
self.mbuf_mut().shrink(offset, len)?;
self.envelope_mut().set_next_header(next_header);
Ok(self.envelope)
}
#[inline]
fn deparse(self) -> Self::Envelope {
self.envelope
}
}
impl<E: Ipv6Packet> IpPacket for Fragment<E> {
#[inline]
fn next_protocol(&self) -> ProtocolNumber {
self.next_header()
}
#[inline]
fn set_next_protocol(&mut self, proto: ProtocolNumber) {
self.set_next_header(proto);
}
#[inline]
fn src(&self) -> IpAddr {
self.envelope().src()
}
#[inline]
fn set_src(&mut self, src: IpAddr) -> Result<()> {
self.envelope_mut().set_src(src)
}
#[inline]
fn dst(&self) -> IpAddr {
self.envelope().dst()
}
#[inline]
fn set_dst(&mut self, dst: IpAddr) -> Result<()> {
self.envelope_mut().set_dst(dst)
}
#[inline]
fn pseudo_header(&self, packet_len: u16, protocol: ProtocolNumber) -> PseudoHeader {
self.envelope().pseudo_header(packet_len, protocol)
}
#[inline]
fn truncate(&mut self, mtu: usize) -> Result<()> {
self.envelope_mut().truncate(mtu)
}
}
impl<E: Ipv6Packet> Ipv6Packet for Fragment<E> {
#[inline]
fn next_header(&self) -> ProtocolNumber {
ProtocolNumber::new(self.header().next_header)
}
#[inline]
fn set_next_header(&mut self, next_header: ProtocolNumber) {
self.header_mut().next_header = next_header.0;
}
}
#[derive(Clone, Copy, Debug, Default, SizeOf)]
#[repr(C, packed)]
struct FragmentHeader {
next_header: u8,
reserved: u8,
frag_res_m: u16be,
identification: u32be,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packets::ip::v6::Ipv6;
use crate::packets::Ethernet;
use crate::testils::byte_arrays::{IPV6_FRAGMENT_PACKET, IPV6_TCP_PACKET};
use crate::Mbuf;
#[test]
fn size_of_fragment_header() {
assert_eq!(8, FragmentHeader::size_of());
}
#[capsule::test]
fn parse_fragment_packet() {
let packet = Mbuf::from_bytes(&IPV6_FRAGMENT_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let frag = ipv6.parse::<Fragment<Ipv6>>().unwrap();
assert_eq!(ProtocolNumbers::Udp, frag.next_header());
assert_eq!(543, frag.fragment_offset());
assert!(!frag.more_fragments());
assert_eq!(0xf88e_b466, frag.identification());
}
#[capsule::test]
fn parse_non_fragment_packet() {
let packet = Mbuf::from_bytes(&IPV6_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
assert!(ipv6.parse::<Fragment<Ipv6>>().is_err());
}
#[capsule::test]
fn push_and_set_fragment_packet() {
let packet = Mbuf::new().unwrap();
let ethernet = packet.push::<Ethernet>().unwrap();
let ipv6 = ethernet.push::<Ipv6>().unwrap();
let mut frag = ipv6.push::<Fragment<Ipv6>>().unwrap();
assert_eq!(FragmentHeader::size_of(), frag.len());
assert_eq!(ProtocolNumbers::Ipv6Frag, frag.envelope().next_header());
frag.set_fragment_offset(100);
assert_eq!(100, frag.fragment_offset());
assert!(!frag.more_fragments());
frag.set_more_fragments();
assert_eq!(100, frag.fragment_offset());
assert!(frag.more_fragments());
frag.unset_more_fragments();
assert!(!frag.more_fragments());
frag.set_identification(0xabcd_1234);
assert_eq!(0xabcd_1234, frag.identification());
}
#[capsule::test]
fn insert_fragment_packet() {
let packet = Mbuf::from_bytes(&IPV6_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let next_header = ipv6.next_header();
let payload_len = ipv6.payload_len();
let frag = ipv6.push::<Fragment<Ipv6>>().unwrap();
assert_eq!(ProtocolNumbers::Ipv6Frag, frag.envelope().next_header());
assert_eq!(next_header, frag.next_header());
assert_eq!(payload_len, frag.payload_len());
}
#[capsule::test]
fn remove_fragment_packet() {
let packet = Mbuf::from_bytes(&IPV6_FRAGMENT_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let frag = ipv6.parse::<Fragment<Ipv6>>().unwrap();
let next_header = frag.next_header();
let payload_len = frag.payload_len();
let ipv6 = frag.remove().unwrap();
assert_eq!(next_header, ipv6.next_header());
assert_eq!(payload_len, ipv6.payload_len());
}
}