mod destination_unreachable;
mod echo_reply;
mod echo_request;
pub mod ndp;
mod time_exceeded;
mod too_big;
pub use self::destination_unreachable::*;
pub use self::echo_reply::*;
pub use self::echo_request::*;
pub use self::time_exceeded::*;
pub use self::too_big::*;
pub use capsule_macros::Icmpv6Packet;
use crate::packets::ip::v6::Ipv6Packet;
use crate::packets::ip::ProtocolNumbers;
use crate::packets::types::u16be;
use crate::packets::{checksum, Internal, Packet};
use crate::{ensure, SizeOf};
use anyhow::{anyhow, Result};
use std::fmt;
use std::ptr::NonNull;
pub struct Icmpv6<E: Ipv6Packet> {
envelope: E,
header: NonNull<Icmpv6Header>,
offset: usize,
}
impl<E: Ipv6Packet> Icmpv6<E> {
#[inline]
fn header(&self) -> &Icmpv6Header {
unsafe { self.header.as_ref() }
}
#[inline]
fn header_mut(&mut self) -> &mut Icmpv6Header {
unsafe { self.header.as_mut() }
}
#[inline]
pub fn msg_type(&self) -> Icmpv6Type {
Icmpv6Type::new(self.header().msg_type)
}
#[inline]
pub fn code(&self) -> u8 {
self.header().code
}
#[inline]
pub fn set_code(&mut self, code: u8) {
self.header_mut().code = code
}
#[inline]
pub fn checksum(&self) -> u16 {
self.header().checksum.into()
}
#[inline]
pub fn compute_checksum(&mut self) {
self.header_mut().checksum = u16be::default();
if let Ok(data) = self.mbuf().read_data_slice(self.offset(), self.len()) {
let data = unsafe { data.as_ref() };
let pseudo_header_sum = self
.envelope()
.pseudo_header(data.len() as u16, ProtocolNumbers::Icmpv6)
.sum();
let checksum = checksum::compute(pseudo_header_sum, data);
self.header_mut().checksum = checksum.into();
} else {
unreachable!()
}
}
#[inline]
pub fn downcast<T: Icmpv6Message<Envelope = E>>(self) -> Result<T> {
ensure!(
self.msg_type() == T::msg_type(),
anyhow!("the ICMPv6 packet is not {}.", T::msg_type())
);
T::try_parse(self, Internal(()))
}
}
impl<E: Ipv6Packet> fmt::Debug for Icmpv6<E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("icmpv6")
.field("type", &format!("{}", self.msg_type()))
.field("code", &self.code())
.field("checksum", &format!("0x{:04x}", self.checksum()))
.field("$offset", &self.offset())
.field("$len", &self.len())
.field("$header_len", &self.header_len())
.finish()
}
}
impl<E: Ipv6Packet> Packet for Icmpv6<E> {
type Envelope = E;
#[inline]
fn envelope(&self) -> &Self::Envelope {
&self.envelope
}
#[inline]
fn envelope_mut(&mut self) -> &mut Self::Envelope {
&mut self.envelope
}
#[inline]
fn offset(&self) -> usize {
self.offset
}
#[inline]
fn header_len(&self) -> usize {
Icmpv6Header::size_of()
}
#[inline]
unsafe fn clone(&self, internal: Internal) -> Self {
Icmpv6 {
envelope: self.envelope.clone(internal),
header: self.header,
offset: self.offset,
}
}
#[inline]
fn try_parse(envelope: Self::Envelope, _internal: Internal) -> Result<Self> {
ensure!(
envelope.next_protocol() == ProtocolNumbers::Icmpv6,
anyhow!("not an ICMPv6 packet.")
);
let mbuf = envelope.mbuf();
let offset = envelope.payload_offset();
let header = mbuf.read_data(offset)?;
Ok(Icmpv6 {
envelope,
header,
offset,
})
}
#[inline]
fn try_push(_envelope: Self::Envelope, _internal: crate::packets::Internal) -> Result<Self> {
Err(anyhow!(
"cannot push a generic ICMPv6 header without a message body."
))
}
#[inline]
fn deparse(self) -> Self::Envelope {
self.envelope
}
#[inline]
fn reconcile(&mut self) {
self.compute_checksum();
}
}
#[derive(Clone, Copy, Debug, Default, Eq, Hash, PartialEq)]
#[repr(C, packed)]
pub struct Icmpv6Type(pub u8);
impl Icmpv6Type {
pub fn new(value: u8) -> Self {
Icmpv6Type(value)
}
}
#[allow(non_snake_case)]
#[allow(non_upper_case_globals)]
pub mod Icmpv6Types {
use super::Icmpv6Type;
pub const PacketTooBig: Icmpv6Type = Icmpv6Type(2);
pub const TimeExceeded: Icmpv6Type = Icmpv6Type(3);
pub const DestinationUnreachable: Icmpv6Type = Icmpv6Type(1);
pub const EchoRequest: Icmpv6Type = Icmpv6Type(128);
pub const EchoReply: Icmpv6Type = Icmpv6Type(129);
pub const RouterSolicitation: Icmpv6Type = Icmpv6Type(133);
pub const RouterAdvertisement: Icmpv6Type = Icmpv6Type(134);
pub const NeighborSolicitation: Icmpv6Type = Icmpv6Type(135);
pub const NeighborAdvertisement: Icmpv6Type = Icmpv6Type(136);
pub const Redirect: Icmpv6Type = Icmpv6Type(137);
}
impl fmt::Display for Icmpv6Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match *self {
Icmpv6Types::PacketTooBig => "Packet Too Big".to_string(),
Icmpv6Types::TimeExceeded => "Time Exceeded".to_string(),
Icmpv6Types::EchoRequest => "Echo Request".to_string(),
Icmpv6Types::EchoReply => "Echo Reply".to_string(),
Icmpv6Types::RouterSolicitation => "Router Solicitation".to_string(),
Icmpv6Types::RouterAdvertisement => "Router Advertisement".to_string(),
Icmpv6Types::NeighborSolicitation => "Neighbor Solicitation".to_string(),
Icmpv6Types::NeighborAdvertisement => "Neighbor Advertisement".to_string(),
Icmpv6Types::Redirect => "Redirect".to_string(),
_ => format!("{}", self.0),
}
)
}
}
#[doc(hidden)]
#[derive(Clone, Copy, Debug, Default, SizeOf)]
#[repr(C, packed)]
pub struct Icmpv6Header {
msg_type: u8,
code: u8,
checksum: u16be,
}
pub trait Icmpv6Message {
type Envelope: Ipv6Packet;
fn msg_type() -> Icmpv6Type;
fn icmp(&self) -> &Icmpv6<Self::Envelope>;
fn icmp_mut(&mut self) -> &mut Icmpv6<Self::Envelope>;
fn into_icmp(self) -> Icmpv6<Self::Envelope>;
unsafe fn clone(&self, internal: Internal) -> Self;
fn try_parse(icmp: Icmpv6<Self::Envelope>, internal: Internal) -> Result<Self>
where
Self: Sized;
fn try_push(icmp: Icmpv6<Self::Envelope>, internal: Internal) -> Result<Self>
where
Self: Sized;
#[inline]
fn reconcile(&mut self) {
self.icmp_mut().compute_checksum()
}
}
pub trait Icmpv6Packet {
fn msg_type(&self) -> Icmpv6Type;
fn code(&self) -> u8;
fn set_code(&mut self, code: u8);
fn checksum(&self) -> u16;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::packets::icmp::v6::ndp::RouterAdvertisement;
use crate::packets::ip::v6::Ipv6;
use crate::packets::Ethernet;
use crate::testils::byte_arrays::{ICMPV6_PACKET, IPV6_TCP_PACKET, ROUTER_ADVERT_PACKET};
use crate::Mbuf;
#[test]
fn size_of_icmpv6_header() {
assert_eq!(4, Icmpv6Header::size_of());
}
#[capsule::test]
fn parse_icmpv6_packet() {
let packet = Mbuf::from_bytes(&ROUTER_ADVERT_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let icmpv6 = ipv6.parse::<Icmpv6<Ipv6>>().unwrap();
assert_eq!(Icmpv6Types::RouterAdvertisement, icmpv6.msg_type());
assert_eq!(0, icmpv6.code());
assert_eq!(0xf50c, icmpv6.checksum());
let advert = icmpv6.downcast::<RouterAdvertisement<Ipv6>>().unwrap();
assert_eq!(Icmpv6Types::RouterAdvertisement, advert.msg_type());
assert_eq!(0, advert.code());
assert_eq!(0xf50c, advert.checksum());
assert_eq!(64, advert.current_hop_limit());
assert!(!advert.managed_addr_cfg());
assert!(advert.other_cfg());
assert_eq!(3600, advert.router_lifetime());
assert_eq!(0, advert.reachable_time());
assert_eq!(0, advert.retrans_timer());
let ipv6 = advert.deparse();
assert!(ipv6.parse::<RouterAdvertisement<Ipv6>>().is_ok());
}
#[capsule::test]
fn parse_wrong_icmpv6_type() {
let packet = Mbuf::from_bytes(&ICMPV6_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let icmpv6 = ipv6.parse::<Icmpv6<Ipv6>>().unwrap();
assert!(icmpv6.downcast::<EchoReply<Ipv6>>().is_err());
}
#[capsule::test]
fn parse_non_icmpv6_packet() {
let packet = Mbuf::from_bytes(&IPV6_TCP_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
assert!(ipv6.parse::<Icmpv6<Ipv6>>().is_err());
}
#[capsule::test]
fn compute_checksum() {
let packet = Mbuf::from_bytes(&ROUTER_ADVERT_PACKET).unwrap();
let ethernet = packet.parse::<Ethernet>().unwrap();
let ipv6 = ethernet.parse::<Ipv6>().unwrap();
let mut icmpv6 = ipv6.parse::<Icmpv6<Ipv6>>().unwrap();
let expected = icmpv6.checksum();
icmpv6.reconcile_all();
assert_eq!(expected, icmpv6.checksum());
}
#[capsule::test]
fn push_icmpv6_header_without_body() {
let packet = Mbuf::new().unwrap();
let ethernet = packet.push::<Ethernet>().unwrap();
let ipv6 = ethernet.push::<Ipv6>().unwrap();
assert!(ipv6.push::<Icmpv6<Ipv6>>().is_err());
}
}