use alloc::vec::Vec;
use crate::error::{Error, Result};
use crate::type_field::TypeField;
pub const H_TYPE_TEST_SNDU: u8 = 0x00;
pub const H_TYPE_BRIDGED_FRAME: u8 = 0x01;
pub const H_TYPE_TS_CONCAT: u8 = 0x02;
pub const H_TYPE_PDU_CONCAT: u8 = 0x03;
pub const H_TYPE_TIMESTAMP: u8 = 0x01;
pub const H_TYPE_EXT_PADDING: u8 = 0x00;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[non_exhaustive]
pub enum MandatoryHType {
TestSndu,
BridgedFrame,
TsConcat,
PduConcat,
Other(u8),
}
impl MandatoryHType {
pub fn from_u8(raw: u8) -> Self {
match raw {
H_TYPE_TEST_SNDU => MandatoryHType::TestSndu,
H_TYPE_BRIDGED_FRAME => MandatoryHType::BridgedFrame,
H_TYPE_TS_CONCAT => MandatoryHType::TsConcat,
H_TYPE_PDU_CONCAT => MandatoryHType::PduConcat,
other => MandatoryHType::Other(other),
}
}
pub fn to_u8(self) -> u8 {
match self {
MandatoryHType::TestSndu => H_TYPE_TEST_SNDU,
MandatoryHType::BridgedFrame => H_TYPE_BRIDGED_FRAME,
MandatoryHType::TsConcat => H_TYPE_TS_CONCAT,
MandatoryHType::PduConcat => H_TYPE_PDU_CONCAT,
MandatoryHType::Other(v) => v,
}
}
pub fn name(&self) -> &'static str {
match self {
MandatoryHType::TestSndu => "test-sndu",
MandatoryHType::BridgedFrame => "bridged-frame",
MandatoryHType::TsConcat => "ts-concat",
MandatoryHType::PduConcat => "pdu-concat",
MandatoryHType::Other(_) => "mandatory",
}
}
}
dvb_common::impl_spec_display!(MandatoryHType, Other);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[non_exhaustive]
pub enum OptionalHType {
ExtPadding,
TimeStamp,
Other(u8),
}
impl OptionalHType {
pub fn from_u8(raw: u8) -> Self {
match raw {
H_TYPE_EXT_PADDING => OptionalHType::ExtPadding,
H_TYPE_TIMESTAMP => OptionalHType::TimeStamp,
other => OptionalHType::Other(other),
}
}
pub fn to_u8(self) -> u8 {
match self {
OptionalHType::ExtPadding => H_TYPE_EXT_PADDING,
OptionalHType::TimeStamp => H_TYPE_TIMESTAMP,
OptionalHType::Other(v) => v,
}
}
pub fn name(&self) -> &'static str {
match self {
OptionalHType::ExtPadding => "extension-padding",
OptionalHType::TimeStamp => "timestamp",
OptionalHType::Other(_) => "optional",
}
}
}
dvb_common::impl_spec_display!(OptionalHType, Other);
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[non_exhaustive]
pub enum ExtensionHeader {
Optional {
h_len: u8,
h_type: u8,
body: Vec<u8>,
},
Mandatory {
h_type: u8,
body: Vec<u8>,
},
}
impl ExtensionHeader {
pub fn h_len(&self) -> u8 {
match self {
ExtensionHeader::Optional { h_len, .. } => *h_len,
ExtensionHeader::Mandatory { .. } => 0,
}
}
pub fn h_type(&self) -> u8 {
match self {
ExtensionHeader::Optional { h_type, .. } => *h_type,
ExtensionHeader::Mandatory { h_type, .. } => *h_type,
}
}
pub fn type_field(&self) -> TypeField {
TypeField::NextHeader {
h_len: self.h_len(),
h_type: self.h_type(),
}
}
pub fn is_mandatory(&self) -> bool {
matches!(self, ExtensionHeader::Mandatory { .. })
}
pub fn mandatory_h_type(&self) -> Option<MandatoryHType> {
match self {
ExtensionHeader::Mandatory { h_type, .. } => Some(MandatoryHType::from_u8(*h_type)),
ExtensionHeader::Optional { .. } => None,
}
}
pub fn optional_h_type(&self) -> Option<OptionalHType> {
match self {
ExtensionHeader::Optional { h_type, .. } => Some(OptionalHType::from_u8(*h_type)),
ExtensionHeader::Mandatory { .. } => None,
}
}
pub fn name(&self) -> &'static str {
match self {
ExtensionHeader::Optional { h_type, .. } => OptionalHType::from_u8(*h_type).name(),
ExtensionHeader::Mandatory { h_type, .. } => MandatoryHType::from_u8(*h_type).name(),
}
}
pub fn wire_len(&self) -> usize {
match self {
ExtensionHeader::Optional { h_len, .. } => 2 * (*h_len as usize),
ExtensionHeader::Mandatory { body, .. } => 2 + body.len(),
}
}
}
dvb_common::impl_spec_display!(ExtensionHeader);
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct PayloadChain<'a> {
pub headers: Vec<ExtensionHeader>,
pub final_type: TypeField,
pub pdu: &'a [u8],
}
impl<'a> PayloadChain<'a> {
pub fn parse(first_type: TypeField, data: &'a [u8]) -> Result<Self> {
let mut headers = Vec::new();
let mut cur = first_type;
let mut off = 0usize;
loop {
match cur {
TypeField::EtherType(_) => {
return Ok(PayloadChain {
headers,
final_type: cur,
pdu: &data[off..],
});
}
TypeField::NextHeader { h_len, h_type } => {
if h_len == 0 {
return Ok(PayloadChain {
headers,
final_type: cur,
pdu: &data[off..],
});
}
let body_len = 2 * (h_len as usize) - 2;
let next_type_at = off + body_len;
if next_type_at + 2 > data.len() {
return Err(Error::InvalidExtensionHeader {
reason: "optional extension header body/next-type exceeds payload",
});
}
let body = data[off..next_type_at].to_vec();
headers.push(ExtensionHeader::Optional {
h_len,
h_type,
body,
});
let next_raw = u16::from_be_bytes([data[next_type_at], data[next_type_at + 1]]);
cur = TypeField::from_u16(next_raw);
off = next_type_at + 2;
}
}
}
}
pub fn serialized_len(&self) -> usize {
let mut n = 0usize;
for h in &self.headers {
n += (h.wire_len() - 2) + 2;
}
n + self.pdu.len()
}
pub fn base_type(&self) -> TypeField {
match self.headers.first() {
Some(h) => h.type_field(),
None => self.final_type,
}
}
pub fn serialize_into(&self, out: &mut [u8]) -> Result<usize> {
let need = self.serialized_len();
if out.len() < need {
return Err(Error::OutputBufferTooSmall {
need,
have: out.len(),
});
}
let mut off = 0usize;
for (i, h) in self.headers.iter().enumerate() {
let body = match h {
ExtensionHeader::Optional { body, .. } => body,
ExtensionHeader::Mandatory { .. } => {
return Err(Error::InvalidExtensionHeader {
reason: "mandatory header must be the chain terminator, not a link",
});
}
};
out[off..off + body.len()].copy_from_slice(body);
off += body.len();
let following = if i + 1 < self.headers.len() {
self.headers[i + 1].type_field()
} else {
self.final_type
};
out[off..off + 2].copy_from_slice(&following.to_u16().to_be_bytes());
off += 2;
}
out[off..off + self.pdu.len()].copy_from_slice(self.pdu);
off += self.pdu.len();
Ok(off)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn optional_header_chain_round_trip() {
use crate::sndu::Sndu;
let ts = ExtensionHeader::Optional {
h_len: 3,
h_type: H_TYPE_TIMESTAMP,
body: alloc::vec![0xAA, 0xBB, 0xCC, 0xDD],
};
assert_eq!(ts.wire_len(), 6);
let pdu = [0x45u8, 0x00, 0x00, 0x10];
let chain = PayloadChain {
headers: alloc::vec![ts.clone()],
final_type: TypeField::EtherType(0x0800),
pdu: &pdu,
};
assert_eq!(chain.base_type().to_u16(), 0x0301);
let sndu = Sndu {
dest_address: None,
payload: chain.clone(),
};
let mut buf = alloc::vec![0u8; sndu.serialized_len()];
sndu.serialize_into(&mut buf).unwrap();
let parsed = Sndu::parse(&buf).unwrap();
assert_eq!(parsed.payload.headers.len(), 1);
assert_eq!(parsed.payload.headers[0], ts);
assert_eq!(parsed.payload.final_type, TypeField::EtherType(0x0800));
assert_eq!(parsed.payload.pdu, &pdu);
assert_eq!(parsed, sndu);
}
#[test]
fn mandatory_header_round_trip() {
use crate::sndu::Sndu;
let body = [0xDEu8, 0xAD, 0xBE, 0xEF, 0x00];
let chain = PayloadChain {
headers: Vec::new(),
final_type: TypeField::NextHeader {
h_len: 0,
h_type: H_TYPE_TEST_SNDU,
},
pdu: &body,
};
assert_eq!(chain.base_type().to_u16(), 0x0000);
let sndu = Sndu {
dest_address: Some([1, 2, 3, 4, 5, 6]),
payload: chain,
};
let mut buf = alloc::vec![0u8; sndu.serialized_len()];
sndu.serialize_into(&mut buf).unwrap();
let parsed = Sndu::parse(&buf).unwrap();
assert!(parsed.payload.headers.is_empty());
assert_eq!(
parsed.payload.final_type,
TypeField::NextHeader {
h_len: 0,
h_type: 0
}
);
assert_eq!(parsed.payload.pdu, &body);
assert_eq!(parsed, sndu);
}
#[test]
fn two_optional_headers_chain() {
use crate::sndu::Sndu;
let h1 = ExtensionHeader::Optional {
h_len: 1,
h_type: H_TYPE_EXT_PADDING,
body: Vec::new(), };
let h2 = ExtensionHeader::Optional {
h_len: 2,
h_type: 0x42,
body: alloc::vec![0x11, 0x22], };
let pdu = [0x99u8];
let chain = PayloadChain {
headers: alloc::vec![h1.clone(), h2.clone()],
final_type: TypeField::EtherType(0x86DD),
pdu: &pdu,
};
let sndu = Sndu {
dest_address: None,
payload: chain,
};
let mut buf = alloc::vec![0u8; sndu.serialized_len()];
sndu.serialize_into(&mut buf).unwrap();
let parsed = Sndu::parse(&buf).unwrap();
assert_eq!(parsed.payload.headers, alloc::vec![h1, h2]);
assert_eq!(parsed.payload.final_type, TypeField::EtherType(0x86DD));
assert_eq!(parsed.payload.pdu, &pdu);
assert_eq!(parsed, sndu);
}
#[test]
fn typed_h_type_accessors() {
let ts = ExtensionHeader::Optional {
h_len: 3,
h_type: H_TYPE_TIMESTAMP,
body: alloc::vec![0, 0, 0, 0],
};
assert_eq!(ts.optional_h_type(), Some(OptionalHType::TimeStamp));
assert_eq!(ts.mandatory_h_type(), None);
let mand = ExtensionHeader::Mandatory {
h_type: H_TYPE_BRIDGED_FRAME,
body: alloc::vec![],
};
assert_eq!(mand.mandatory_h_type(), Some(MandatoryHType::BridgedFrame));
assert_eq!(mand.optional_h_type(), None);
let unk_m = ExtensionHeader::Mandatory {
h_type: 0xF0,
body: alloc::vec![],
};
assert_eq!(unk_m.mandatory_h_type(), Some(MandatoryHType::Other(0xF0)));
let unk_o = ExtensionHeader::Optional {
h_len: 2,
h_type: 0xF0,
body: alloc::vec![0, 0],
};
assert_eq!(unk_o.optional_h_type(), Some(OptionalHType::Other(0xF0)));
}
#[test]
fn all_h_type_constants_have_non_default_mandatory_label() {
let mandatory_constants: &[(u8, &str)] = &[
(H_TYPE_TEST_SNDU, "test-sndu"),
(H_TYPE_BRIDGED_FRAME, "bridged-frame"),
(H_TYPE_TS_CONCAT, "ts-concat"),
(H_TYPE_PDU_CONCAT, "pdu-concat"),
];
for &(raw, expected_label) in mandatory_constants {
let t = MandatoryHType::from_u8(raw);
assert_ne!(
t.name(),
"mandatory",
"H_TYPE constant 0x{raw:02X} maps to the default fallback label — add a named arm"
);
assert_eq!(
t.name(),
expected_label,
"H_TYPE constant 0x{raw:02X} label mismatch"
);
}
}
#[test]
fn all_h_type_constants_have_non_default_optional_label() {
let optional_constants: &[(u8, &str)] = &[
(H_TYPE_EXT_PADDING, "extension-padding"),
(H_TYPE_TIMESTAMP, "timestamp"),
];
for &(raw, expected_label) in optional_constants {
let t = OptionalHType::from_u8(raw);
assert_ne!(
t.name(), "optional",
"optional H_TYPE constant 0x{raw:02X} maps to the default fallback label — add a named arm"
);
assert_eq!(
t.name(),
expected_label,
"optional H_TYPE constant 0x{raw:02X} label mismatch"
);
}
}
}