use super::super::*;
use core::net::Ipv6Addr;
const MLD_MULTICAST_ADDRESS_LEN: usize = 16;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MulticastListenerMessage {
pub(crate) multicast_address: Ipv6Addr,
}
impl MulticastListenerMessage {
pub fn new(multicast_address: Ipv6Addr) -> Self {
Self { multicast_address }
}
pub fn multicast_address(mut self, multicast_address: Ipv6Addr) -> Self {
self.multicast_address = multicast_address;
self
}
pub fn multicast_address_value(&self) -> Ipv6Addr {
self.multicast_address
}
}
impl Layer for MulticastListenerMessage {
fn name(&self) -> &'static str {
"MulticastListenerMessage"
}
fn summary(&self) -> String {
format!(
"MulticastListenerMessage(multicast={})",
self.multicast_address
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![("multicast_address", self.multicast_address.to_string())]
}
fn encoded_len(&self) -> usize {
MLD_MULTICAST_ADDRESS_LEN
}
fn compile(&self, _ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
out.extend_from_slice(&self.multicast_address.octets());
Ok(())
}
impl_layer_object!(MulticastListenerMessage);
}
impl_layer_div!(MulticastListenerMessage);
fn mld_rest_of_header(max_response_delay: u16) -> [u8; 4] {
let delay = max_response_delay.to_be_bytes();
[delay[0], delay[1], 0, 0]
}
impl Icmpv6 {
pub fn mld_query(multicast_address: Ipv6Addr, max_response_delay: u16) -> Packet {
Self::mld_message(
ICMPV6_MULTICAST_LISTENER_QUERY,
max_response_delay,
MulticastListenerMessage::new(multicast_address),
)
}
pub fn mld_general_query(max_response_delay: u16) -> Packet {
Self::mld_query(Ipv6Addr::UNSPECIFIED, max_response_delay)
}
pub fn mld_report(group: Ipv6Addr) -> Packet {
Self::mld_message(
ICMPV6_MULTICAST_LISTENER_REPORT,
0,
MulticastListenerMessage::new(group),
)
}
pub fn mld_done(group: Ipv6Addr) -> Packet {
Self::mld_message(
ICMPV6_MULTICAST_LISTENER_DONE,
0,
MulticastListenerMessage::new(group),
)
}
fn mld_message(
icmp_type: u8,
max_response_delay: u16,
body: MulticastListenerMessage,
) -> Packet {
Self::new()
.icmp_type(icmp_type)
.code(0)
.rest_of_header(mld_rest_of_header(max_response_delay))
/ body
}
}
pub(crate) fn decode_multicast_listener_message(bytes: &[u8]) -> Result<MulticastListenerMessage> {
if bytes.len() != MLD_MULTICAST_ADDRESS_LEN {
return Err(CrafterError::buffer_too_short(
"icmpv6.mld.multicast_address",
MLD_MULTICAST_ADDRESS_LEN,
bytes.len(),
));
}
let mut octets = [0u8; MLD_MULTICAST_ADDRESS_LEN];
octets.copy_from_slice(&bytes[..MLD_MULTICAST_ADDRESS_LEN]);
Ok(MulticastListenerMessage {
multicast_address: Ipv6Addr::from(octets),
})
}
const MLDV2_ADDRESS_LEN: usize = 16;
const MLDV2_QUERY_FLAGS_LEN: usize = 4;
pub const MLDV2_QUERY_MIN_BODY_LEN: usize = MLDV2_ADDRESS_LEN + MLDV2_QUERY_FLAGS_LEN;
pub const MLDV2_QUERY_S_FLAG: u8 = 0x08;
pub const MLDV2_QUERY_QRV_MASK: u8 = 0x07;
pub const MLDV2_QUERY_RESV_MASK: u8 = 0xf0;
const MLDV2_RECORD_FIXED_LEN: usize = 1 + 1 + 2 + MLDV2_ADDRESS_LEN;
const MLDV2_AUX_DATA_UNIT: usize = 4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MulticastRecordType {
ModeIsInclude,
ModeIsExclude,
ChangeToIncludeMode,
ChangeToExcludeMode,
AllowNewSources,
BlockOldSources,
Unknown(u8),
}
impl MulticastRecordType {
pub fn to_u8(self) -> u8 {
match self {
MulticastRecordType::ModeIsInclude => 1,
MulticastRecordType::ModeIsExclude => 2,
MulticastRecordType::ChangeToIncludeMode => 3,
MulticastRecordType::ChangeToExcludeMode => 4,
MulticastRecordType::AllowNewSources => 5,
MulticastRecordType::BlockOldSources => 6,
MulticastRecordType::Unknown(value) => value,
}
}
pub fn from_u8(value: u8) -> Self {
match value {
1 => MulticastRecordType::ModeIsInclude,
2 => MulticastRecordType::ModeIsExclude,
3 => MulticastRecordType::ChangeToIncludeMode,
4 => MulticastRecordType::ChangeToExcludeMode,
5 => MulticastRecordType::AllowNewSources,
6 => MulticastRecordType::BlockOldSources,
other => MulticastRecordType::Unknown(other),
}
}
}
impl core::fmt::Display for MulticastRecordType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let name = match self {
MulticastRecordType::ModeIsInclude => "MODE_IS_INCLUDE",
MulticastRecordType::ModeIsExclude => "MODE_IS_EXCLUDE",
MulticastRecordType::ChangeToIncludeMode => "CHANGE_TO_INCLUDE_MODE",
MulticastRecordType::ChangeToExcludeMode => "CHANGE_TO_EXCLUDE_MODE",
MulticastRecordType::AllowNewSources => "ALLOW_NEW_SOURCES",
MulticastRecordType::BlockOldSources => "BLOCK_OLD_SOURCES",
MulticastRecordType::Unknown(value) => return write!(f, "Unknown({value})"),
};
f.write_str(name)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MulticastAddressRecord {
record_type: MulticastRecordType,
multicast_address: Ipv6Addr,
sources: Vec<Ipv6Addr>,
aux_data: Vec<u8>,
}
impl MulticastAddressRecord {
pub fn new(record_type: MulticastRecordType, multicast_address: Ipv6Addr) -> Self {
Self {
record_type,
multicast_address,
sources: Vec::new(),
aux_data: Vec::new(),
}
}
pub fn record_type(mut self, record_type: MulticastRecordType) -> Self {
self.record_type = record_type;
self
}
pub fn multicast_address(mut self, multicast_address: Ipv6Addr) -> Self {
self.multicast_address = multicast_address;
self
}
pub fn source(mut self, source: Ipv6Addr) -> Self {
self.sources.push(source);
self
}
pub fn sources(mut self, sources: Vec<Ipv6Addr>) -> Self {
self.sources = sources;
self
}
pub fn aux_data(mut self, aux_data: impl Into<Vec<u8>>) -> Self {
self.aux_data = aux_data.into();
self
}
pub fn record_type_value(&self) -> MulticastRecordType {
self.record_type
}
pub fn multicast_address_value(&self) -> Ipv6Addr {
self.multicast_address
}
pub fn number_of_sources(&self) -> u16 {
self.sources.len() as u16
}
pub fn sources_ref(&self) -> &[Ipv6Addr] {
&self.sources
}
pub fn aux_data_len(&self) -> u8 {
mldv2_aux_data_len_words(self.aux_data.len())
}
pub fn aux_data_value(&self) -> &[u8] {
&self.aux_data
}
fn aux_data_padded(&self) -> Vec<u8> {
let words = mldv2_aux_data_len_words(self.aux_data.len()) as usize;
let mut padded = self.aux_data.clone();
padded.resize(words * MLDV2_AUX_DATA_UNIT, 0);
padded
}
fn encoded_len(&self) -> usize {
MLDV2_RECORD_FIXED_LEN
+ self.sources.len() * MLDV2_ADDRESS_LEN
+ (mldv2_aux_data_len_words(self.aux_data.len()) as usize) * MLDV2_AUX_DATA_UNIT
}
fn encode_into(&self, out: &mut Vec<u8>) {
out.push(self.record_type.to_u8());
out.push(self.aux_data_len());
out.extend_from_slice(&self.number_of_sources().to_be_bytes());
out.extend_from_slice(&self.multicast_address.octets());
for source in &self.sources {
out.extend_from_slice(&source.octets());
}
out.extend_from_slice(&self.aux_data_padded());
}
}
impl core::fmt::Display for MulticastAddressRecord {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(
f,
"{}(group={}, sources={}, aux={}B)",
self.record_type,
self.multicast_address,
self.sources.len(),
self.aux_data.len()
)
}
}
fn mldv2_aux_data_len_words(aux_bytes: usize) -> u8 {
let words = aux_bytes.div_ceil(MLDV2_AUX_DATA_UNIT);
u8::try_from(words).unwrap_or(u8::MAX)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Mldv2Report {
pub(crate) records: Vec<MulticastAddressRecord>,
}
impl Mldv2Report {
pub fn new() -> Self {
Self {
records: Vec::new(),
}
}
pub fn record(mut self, record: MulticastAddressRecord) -> Self {
self.records.push(record);
self
}
pub fn records(mut self, records: Vec<MulticastAddressRecord>) -> Self {
self.records = records;
self
}
pub fn number_of_records(&self) -> u16 {
self.records.len() as u16
}
pub fn records_ref(&self) -> &[MulticastAddressRecord] {
&self.records
}
}
impl Default for Mldv2Report {
fn default() -> Self {
Self::new()
}
}
impl Layer for Mldv2Report {
fn name(&self) -> &'static str {
"Mldv2Report"
}
fn summary(&self) -> String {
format!("Mldv2Report(records={})", self.records.len())
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
let mut fields = vec![("record_count", self.records.len().to_string())];
for (index, record) in self.records.iter().enumerate() {
fields.push((record_field_name(index), record.to_string()));
}
fields
}
fn encoded_len(&self) -> usize {
self.records.iter().map(|record| record.encoded_len()).sum()
}
fn compile(&self, _ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
for record in &self.records {
record.encode_into(out);
}
Ok(())
}
impl_layer_object!(Mldv2Report);
}
impl_layer_div!(Mldv2Report);
fn record_field_name(index: usize) -> &'static str {
const NAMES: [&str; 8] = [
"record[0]",
"record[1]",
"record[2]",
"record[3]",
"record[4]",
"record[5]",
"record[6]",
"record[7]",
];
NAMES.get(index).copied().unwrap_or("record[*]")
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Mldv2Query {
pub(crate) multicast_address: Ipv6Addr,
pub(crate) suppress_router_processing: bool,
pub(crate) querier_robustness: u8,
pub(crate) querier_query_interval_code: u8,
pub(crate) sources: Vec<Ipv6Addr>,
pub(crate) reserved_bits: u8,
}
impl Mldv2Query {
pub fn new(multicast_address: Ipv6Addr) -> Self {
Self {
multicast_address,
suppress_router_processing: false,
querier_robustness: 0,
querier_query_interval_code: 0,
sources: Vec::new(),
reserved_bits: 0,
}
}
pub fn multicast_address(mut self, multicast_address: Ipv6Addr) -> Self {
self.multicast_address = multicast_address;
self
}
pub fn suppress_router_processing(mut self, suppress: bool) -> Self {
self.suppress_router_processing = suppress;
self
}
pub fn querier_robustness(mut self, qrv: u8) -> Self {
self.querier_robustness = qrv;
self
}
pub fn querier_query_interval_code(mut self, qqic: u8) -> Self {
self.querier_query_interval_code = qqic;
self
}
pub fn source(mut self, source: Ipv6Addr) -> Self {
self.sources.push(source);
self
}
pub fn sources(mut self, sources: Vec<Ipv6Addr>) -> Self {
self.sources = sources;
self
}
pub fn multicast_address_value(&self) -> Ipv6Addr {
self.multicast_address
}
pub fn suppress_router_processing_value(&self) -> bool {
self.suppress_router_processing
}
pub fn querier_robustness_value(&self) -> u8 {
self.querier_robustness & MLDV2_QUERY_QRV_MASK
}
pub fn querier_query_interval_code_value(&self) -> u8 {
self.querier_query_interval_code
}
pub fn number_of_sources(&self) -> u16 {
self.sources.len() as u16
}
pub fn sources_ref(&self) -> &[Ipv6Addr] {
&self.sources
}
fn resv_s_qrv_byte(&self) -> u8 {
let mut byte = self.reserved_bits & MLDV2_QUERY_RESV_MASK;
if self.suppress_router_processing {
byte |= MLDV2_QUERY_S_FLAG;
}
byte |= self.querier_robustness & MLDV2_QUERY_QRV_MASK;
byte
}
}
impl Layer for Mldv2Query {
fn name(&self) -> &'static str {
"Mldv2Query"
}
fn summary(&self) -> String {
format!(
"Mldv2Query(multicast={}, S={}, QRV={}, QQIC={}, sources={})",
self.multicast_address,
self.suppress_router_processing,
self.querier_robustness_value(),
self.querier_query_interval_code,
self.sources.len()
)
}
fn inspection_fields(&self) -> Vec<(&'static str, String)> {
vec![
("multicast_address", self.multicast_address.to_string()),
("s_flag", self.suppress_router_processing.to_string()),
("qrv", self.querier_robustness_value().to_string()),
("qqic", self.querier_query_interval_code.to_string()),
("source_count", self.sources.len().to_string()),
]
}
fn encoded_len(&self) -> usize {
MLDV2_QUERY_MIN_BODY_LEN + self.sources.len() * MLDV2_ADDRESS_LEN
}
fn compile(&self, _ctx: &LayerContext<'_>, out: &mut Vec<u8>) -> Result<()> {
out.extend_from_slice(&self.multicast_address.octets());
out.push(self.resv_s_qrv_byte());
out.push(self.querier_query_interval_code);
out.extend_from_slice(&self.number_of_sources().to_be_bytes());
for source in &self.sources {
out.extend_from_slice(&source.octets());
}
Ok(())
}
impl_layer_object!(Mldv2Query);
}
impl_layer_div!(Mldv2Query);
impl Icmpv6 {
pub fn mldv2_report(records: Vec<MulticastAddressRecord>) -> Packet {
let body = Mldv2Report::new().records(records);
let count = body.number_of_records();
let count_bytes = count.to_be_bytes();
Self::new()
.icmp_type(ICMPV6_MLDV2_REPORT)
.code(0)
.rest_of_header([0, 0, count_bytes[0], count_bytes[1]])
/ body
}
pub fn mldv2_query(max_response_code: u16, query: Mldv2Query) -> Packet {
Self::new()
.icmp_type(ICMPV6_MULTICAST_LISTENER_QUERY)
.code(0)
.rest_of_header(mld_rest_of_header(max_response_code))
/ query
}
pub fn mldv2_general_query(max_response_code: u16) -> Packet {
Self::mldv2_query(max_response_code, Mldv2Query::new(Ipv6Addr::UNSPECIFIED))
}
}
pub(crate) fn decode_mldv2_query(bytes: &[u8]) -> Result<Mldv2Query> {
if bytes.len() < MLDV2_QUERY_MIN_BODY_LEN {
return Err(CrafterError::buffer_too_short(
"icmpv6.mldv2.query",
MLDV2_QUERY_MIN_BODY_LEN,
bytes.len(),
));
}
let mut octets = [0u8; MLDV2_ADDRESS_LEN];
octets.copy_from_slice(&bytes[..MLDV2_ADDRESS_LEN]);
let multicast_address = Ipv6Addr::from(octets);
let resv_s_qrv = bytes[MLDV2_ADDRESS_LEN];
let querier_query_interval_code = bytes[MLDV2_ADDRESS_LEN + 1];
let number_of_sources =
u16::from_be_bytes([bytes[MLDV2_ADDRESS_LEN + 2], bytes[MLDV2_ADDRESS_LEN + 3]]) as usize;
let sources_offset = MLDV2_QUERY_MIN_BODY_LEN;
let needed = sources_offset + number_of_sources * MLDV2_ADDRESS_LEN;
if bytes.len() < needed {
return Err(CrafterError::buffer_too_short(
"icmpv6.mldv2.query.sources",
needed,
bytes.len(),
));
}
let mut sources = Vec::with_capacity(number_of_sources);
for index in 0..number_of_sources {
let start = sources_offset + index * MLDV2_ADDRESS_LEN;
let mut source = [0u8; MLDV2_ADDRESS_LEN];
source.copy_from_slice(&bytes[start..start + MLDV2_ADDRESS_LEN]);
sources.push(Ipv6Addr::from(source));
}
Ok(Mldv2Query {
multicast_address,
suppress_router_processing: resv_s_qrv & MLDV2_QUERY_S_FLAG != 0,
querier_robustness: resv_s_qrv & MLDV2_QUERY_QRV_MASK,
querier_query_interval_code,
sources,
reserved_bits: resv_s_qrv & MLDV2_QUERY_RESV_MASK,
})
}
pub(crate) fn decode_mldv2_report(bytes: &[u8]) -> Result<Mldv2Report> {
let mut records = Vec::new();
let mut offset = 0usize;
while offset < bytes.len() {
let remaining = &bytes[offset..];
if remaining.len() < MLDV2_RECORD_FIXED_LEN {
return Err(CrafterError::buffer_too_short(
"icmpv6.mldv2.report.record",
MLDV2_RECORD_FIXED_LEN,
remaining.len(),
));
}
let record_type = MulticastRecordType::from_u8(remaining[0]);
let aux_data_len_words = remaining[1] as usize;
let number_of_sources = u16::from_be_bytes([remaining[2], remaining[3]]) as usize;
let mut group = [0u8; MLDV2_ADDRESS_LEN];
group.copy_from_slice(&remaining[4..4 + MLDV2_ADDRESS_LEN]);
let multicast_address = Ipv6Addr::from(group);
let sources_len = number_of_sources * MLDV2_ADDRESS_LEN;
let aux_len = aux_data_len_words * MLDV2_AUX_DATA_UNIT;
let record_len = MLDV2_RECORD_FIXED_LEN + sources_len + aux_len;
if remaining.len() < record_len {
return Err(CrafterError::buffer_too_short(
"icmpv6.mldv2.report.record_body",
record_len,
remaining.len(),
));
}
let mut sources = Vec::with_capacity(number_of_sources);
for index in 0..number_of_sources {
let start = MLDV2_RECORD_FIXED_LEN + index * MLDV2_ADDRESS_LEN;
let mut source = [0u8; MLDV2_ADDRESS_LEN];
source.copy_from_slice(&remaining[start..start + MLDV2_ADDRESS_LEN]);
sources.push(Ipv6Addr::from(source));
}
let aux_start = MLDV2_RECORD_FIXED_LEN + sources_len;
let aux_data = remaining[aux_start..aux_start + aux_len].to_vec();
records.push(MulticastAddressRecord {
record_type,
multicast_address,
sources,
aux_data,
});
offset += record_len;
}
Ok(Mldv2Report { records })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocols::icmp::{
Icmpv6, Icmpv6Body, Mldv2Query, Mldv2Report, MulticastAddressRecord,
MulticastListenerMessage, MulticastRecordType, ICMPV6_MLDV2_REPORT,
ICMPV6_MULTICAST_LISTENER_DONE, ICMPV6_MULTICAST_LISTENER_QUERY,
ICMPV6_MULTICAST_LISTENER_REPORT,
};
use crate::{Ipv6, NetworkLayer, Packet};
fn all_nodes() -> Ipv6Addr {
Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 1)
}
fn link_local_src() -> Ipv6Addr {
Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x0010)
}
fn doc_group() -> Ipv6Addr {
Ipv6Addr::new(0xff1e, 0, 0, 0, 0, 0, 0x0db8, 0x0001)
}
const ICMPV6_OFFSET: usize = 40;
#[test]
fn mldv1_general_query_round_trips() {
let compiled = (Ipv6::new()
.src(link_local_src())
.dst(all_nodes())
.hop_limit(1)
/ Icmpv6::mld_general_query(10_000))
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(bytes[ICMPV6_OFFSET], ICMPV6_MULTICAST_LISTENER_QUERY);
assert_eq!(bytes[ICMPV6_OFFSET + 1], 0);
assert_eq!(
&bytes[ICMPV6_OFFSET + 4..ICMPV6_OFFSET + 6],
&10_000u16.to_be_bytes()
);
assert_eq!(&bytes[ICMPV6_OFFSET + 6..ICMPV6_OFFSET + 8], &[0, 0]);
assert_eq!(
&bytes[ICMPV6_OFFSET + 8..ICMPV6_OFFSET + 24],
&Ipv6Addr::UNSPECIFIED.octets()
);
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv6, bytes).unwrap();
let icmpv6 = decoded.layer::<Icmpv6>().unwrap();
assert_eq!(icmpv6.icmp_type_value(), ICMPV6_MULTICAST_LISTENER_QUERY);
assert!(matches!(
icmpv6.body(),
Icmpv6Body::MulticastListenerQuery { .. }
));
if let Icmpv6Body::MulticastListenerQuery { max_response_delay } = icmpv6.body() {
assert_eq!(max_response_delay, 10_000);
}
let mld = decoded.layer::<MulticastListenerMessage>().unwrap();
assert_eq!(mld.multicast_address_value(), Ipv6Addr::UNSPECIFIED);
assert_eq!(decoded.compile().unwrap().as_bytes(), bytes);
}
#[test]
fn mldv1_group_specific_query_round_trips() {
let compiled = (Ipv6::new()
.src(link_local_src())
.dst(doc_group())
.hop_limit(1)
/ Icmpv6::mld_query(doc_group(), 1_000))
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(bytes[ICMPV6_OFFSET], ICMPV6_MULTICAST_LISTENER_QUERY);
assert_eq!(
&bytes[ICMPV6_OFFSET + 4..ICMPV6_OFFSET + 6],
&1_000u16.to_be_bytes()
);
assert_eq!(
&bytes[ICMPV6_OFFSET + 8..ICMPV6_OFFSET + 24],
&doc_group().octets()
);
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv6, bytes).unwrap();
let icmpv6 = decoded.layer::<Icmpv6>().unwrap();
assert!(matches!(
icmpv6.body(),
Icmpv6Body::MulticastListenerQuery {
max_response_delay: 1_000
}
));
let mld = decoded.layer::<MulticastListenerMessage>().unwrap();
assert_eq!(mld.multicast_address_value(), doc_group());
assert_eq!(decoded.compile().unwrap().as_bytes(), bytes);
}
#[test]
fn mldv1_report_round_trips() {
let compiled = (Ipv6::new()
.src(link_local_src())
.dst(doc_group())
.hop_limit(1)
/ Icmpv6::mld_report(doc_group()))
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(bytes[ICMPV6_OFFSET], ICMPV6_MULTICAST_LISTENER_REPORT);
assert_eq!(bytes[ICMPV6_OFFSET + 1], 0);
assert_eq!(&bytes[ICMPV6_OFFSET + 4..ICMPV6_OFFSET + 6], &[0, 0]);
assert_eq!(
&bytes[ICMPV6_OFFSET + 8..ICMPV6_OFFSET + 24],
&doc_group().octets()
);
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv6, bytes).unwrap();
let icmpv6 = decoded.layer::<Icmpv6>().unwrap();
assert_eq!(icmpv6.icmp_type_value(), ICMPV6_MULTICAST_LISTENER_REPORT);
assert!(matches!(
icmpv6.body(),
Icmpv6Body::MulticastListenerReport {
max_response_delay: 0
}
));
let mld = decoded.layer::<MulticastListenerMessage>().unwrap();
assert_eq!(mld.multicast_address_value(), doc_group());
assert_eq!(decoded.compile().unwrap().as_bytes(), bytes);
}
#[test]
fn mldv1_done_round_trips() {
let compiled = (Ipv6::new()
.src(link_local_src())
.dst(all_nodes())
.hop_limit(1)
/ Icmpv6::mld_done(doc_group()))
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(bytes[ICMPV6_OFFSET], ICMPV6_MULTICAST_LISTENER_DONE);
assert_eq!(bytes[ICMPV6_OFFSET + 1], 0);
assert_eq!(&bytes[ICMPV6_OFFSET + 4..ICMPV6_OFFSET + 6], &[0, 0]);
assert_eq!(
&bytes[ICMPV6_OFFSET + 8..ICMPV6_OFFSET + 24],
&doc_group().octets()
);
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv6, bytes).unwrap();
let icmpv6 = decoded.layer::<Icmpv6>().unwrap();
assert_eq!(icmpv6.icmp_type_value(), ICMPV6_MULTICAST_LISTENER_DONE);
assert!(matches!(
icmpv6.body(),
Icmpv6Body::MulticastListenerDone {
max_response_delay: 0
}
));
let mld = decoded.layer::<MulticastListenerMessage>().unwrap();
assert_eq!(mld.multicast_address_value(), doc_group());
assert_eq!(decoded.compile().unwrap().as_bytes(), bytes);
}
#[test]
fn mldv1_decode_rejects_non_16_byte_body() {
assert!(decode_multicast_listener_message(&[0u8; 20]).is_err());
assert!(decode_multicast_listener_message(&[0u8; 8]).is_err());
assert!(decode_multicast_listener_message(&[0u8; 16]).is_ok());
}
fn doc_group2() -> Ipv6Addr {
Ipv6Addr::new(0xff1e, 0, 0, 0, 0, 0, 0x0db8, 0x0002)
}
fn doc_source1() -> Ipv6Addr {
Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 0x0101)
}
fn doc_source2() -> Ipv6Addr {
Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0, 0x0102)
}
#[test]
fn mldv2_report_two_records_round_trips() {
let include = MulticastAddressRecord::new(MulticastRecordType::ModeIsInclude, doc_group())
.source(doc_source1())
.source(doc_source2())
.aux_data([0xaa, 0xbb]); let exclude = MulticastAddressRecord::new(MulticastRecordType::ModeIsExclude, doc_group2());
let report = Icmpv6::mldv2_report(vec![include.clone(), exclude.clone()]);
let compiled = (Ipv6::new()
.src(link_local_src())
.dst(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x0016))
.hop_limit(1)
/ report)
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(bytes[ICMPV6_OFFSET], ICMPV6_MLDV2_REPORT);
assert_eq!(bytes[ICMPV6_OFFSET + 1], 0);
assert_eq!(&bytes[ICMPV6_OFFSET + 4..ICMPV6_OFFSET + 6], &[0, 0]);
assert_eq!(
&bytes[ICMPV6_OFFSET + 6..ICMPV6_OFFSET + 8],
&2u16.to_be_bytes()
);
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv6, bytes).unwrap();
let icmpv6 = decoded.layer::<Icmpv6>().unwrap();
assert_eq!(icmpv6.icmp_type_value(), ICMPV6_MLDV2_REPORT);
assert!(matches!(
icmpv6.body(),
Icmpv6Body::Mldv2Report {
number_of_records: 2
}
));
let report = decoded.layer::<Mldv2Report>().unwrap();
assert_eq!(report.number_of_records(), 2);
let records = report.records_ref();
assert_eq!(
records[0].record_type_value(),
MulticastRecordType::ModeIsInclude
);
assert_eq!(records[0].multicast_address_value(), doc_group());
assert_eq!(records[0].number_of_sources(), 2);
assert_eq!(records[0].sources_ref(), &[doc_source1(), doc_source2()]);
assert_eq!(records[0].aux_data_len(), 1);
assert_eq!(records[0].aux_data_value(), &[0xaa, 0xbb, 0x00, 0x00]);
assert_eq!(
records[1].record_type_value(),
MulticastRecordType::ModeIsExclude
);
assert_eq!(records[1].multicast_address_value(), doc_group2());
assert_eq!(records[1].number_of_sources(), 0);
assert!(records[1].sources_ref().is_empty());
assert_eq!(records[1].aux_data_len(), 0);
assert_eq!(decoded.compile().unwrap().as_bytes(), bytes);
}
#[test]
fn mldv2_query_with_sources_round_trips_and_decodes_as_mldv2() {
let query = Mldv2Query::new(doc_group())
.suppress_router_processing(true)
.querier_robustness(2)
.querier_query_interval_code(125)
.source(doc_source1())
.source(doc_source2());
let compiled = (Ipv6::new()
.src(link_local_src())
.dst(doc_group())
.hop_limit(1)
/ Icmpv6::mldv2_query(10_000, query))
.compile()
.unwrap();
let bytes = compiled.as_bytes();
assert_eq!(bytes[ICMPV6_OFFSET], ICMPV6_MULTICAST_LISTENER_QUERY);
assert_eq!(bytes[ICMPV6_OFFSET + 1], 0);
assert_eq!(
&bytes[ICMPV6_OFFSET + 4..ICMPV6_OFFSET + 6],
&10_000u16.to_be_bytes()
);
assert_eq!(
&bytes[ICMPV6_OFFSET + 8..ICMPV6_OFFSET + 24],
&doc_group().octets()
);
assert_eq!(bytes[ICMPV6_OFFSET + 24], MLDV2_QUERY_S_FLAG | 2);
assert_eq!(bytes[ICMPV6_OFFSET + 25], 125);
assert_eq!(
&bytes[ICMPV6_OFFSET + 26..ICMPV6_OFFSET + 28],
&2u16.to_be_bytes()
);
let decoded = Packet::decode_from_l3(NetworkLayer::Ipv6, bytes).unwrap();
let icmpv6 = decoded.layer::<Icmpv6>().unwrap();
assert_eq!(icmpv6.icmp_type_value(), ICMPV6_MULTICAST_LISTENER_QUERY);
let query = decoded.layer::<Mldv2Query>().unwrap();
assert_eq!(query.multicast_address_value(), doc_group());
assert!(query.suppress_router_processing_value());
assert_eq!(query.querier_robustness_value(), 2);
assert_eq!(query.querier_query_interval_code_value(), 125);
assert_eq!(query.number_of_sources(), 2);
assert_eq!(query.sources_ref(), &[doc_source1(), doc_source2()]);
assert!(decoded.layer::<MulticastListenerMessage>().is_none());
assert_eq!(decoded.compile().unwrap().as_bytes(), bytes);
}
#[test]
fn type_130_body_length_disambiguates_mldv1_from_mldv2() {
let mldv1 = (Ipv6::new()
.src(link_local_src())
.dst(doc_group())
.hop_limit(1)
/ Icmpv6::mld_query(doc_group(), 1_000))
.compile()
.unwrap();
let mldv1_bytes = mldv1.as_bytes();
assert_eq!(mldv1_bytes.len() - (ICMPV6_OFFSET + 8), 16);
let decoded_v1 = Packet::decode_from_l3(NetworkLayer::Ipv6, mldv1_bytes).unwrap();
assert!(decoded_v1.layer::<MulticastListenerMessage>().is_some());
assert!(decoded_v1.layer::<Mldv2Query>().is_none());
let mldv2 = (Ipv6::new()
.src(link_local_src())
.dst(all_nodes())
.hop_limit(1)
/ Icmpv6::mldv2_general_query(10_000))
.compile()
.unwrap();
let mldv2_bytes = mldv2.as_bytes();
assert_eq!(
mldv2_bytes.len() - (ICMPV6_OFFSET + 8),
MLDV2_QUERY_MIN_BODY_LEN
);
let decoded_v2 = Packet::decode_from_l3(NetworkLayer::Ipv6, mldv2_bytes).unwrap();
assert!(decoded_v2.layer::<Mldv2Query>().is_some());
assert!(decoded_v2.layer::<MulticastListenerMessage>().is_none());
}
#[test]
fn mldv2_record_preserves_unknown_record_type() {
let record = MulticastAddressRecord::new(MulticastRecordType::Unknown(200), doc_group());
assert_eq!(
record.record_type_value(),
MulticastRecordType::Unknown(200)
);
let decoded = decode_mldv2_report(&{
let mut out = Vec::new();
record.encode_into(&mut out);
out
})
.unwrap();
assert_eq!(decoded.number_of_records(), 1);
assert_eq!(
decoded.records_ref()[0].record_type_value(),
MulticastRecordType::Unknown(200)
);
}
}