use crate::coding::{Decode, DecodeError, Encode, EncodeError};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use super::Version;
use crate::ietf::Param;
#[derive(Debug, Clone, Copy, PartialEq, Eq, TryFromPrimitive, IntoPrimitive)]
#[repr(u8)]
pub enum GroupOrder {
Any = 0x0,
Ascending = 0x1,
Descending = 0x2,
}
impl GroupOrder {
pub fn any_to_descending(self) -> Self {
match self {
Self::Any => Self::Descending,
other => other,
}
}
}
impl Encode<Version> for GroupOrder {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
u8::from(*self).encode(w, version)?;
Ok(())
}
}
impl Decode<Version> for GroupOrder {
fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
Self::try_from(u8::decode(r, version)?).map_err(|_| DecodeError::InvalidValue)
}
}
impl Param for GroupOrder {
fn param_encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
u8::from(*self).param_encode(w, version)
}
fn param_decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let v = u8::param_decode(r, version)?;
Ok(GroupOrder::try_from(v)
.unwrap_or(GroupOrder::Descending)
.any_to_descending())
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GroupFlags {
pub has_extensions: bool,
pub has_subgroup: bool,
pub has_subgroup_object: bool,
pub has_end: bool,
pub has_priority: bool,
}
impl GroupFlags {
pub const START: u64 = 0x10;
pub const END: u64 = 0x1d;
pub const START_NO_PRIORITY: u64 = 0x30;
pub const END_NO_PRIORITY: u64 = 0x3d;
pub const FIRST_OBJECT_BIT: u64 = 0x40;
pub fn encode(&self, version: Version) -> Result<u64, EncodeError> {
if self.has_subgroup && self.has_subgroup_object {
return Err(EncodeError::InvalidState);
}
let base = if self.has_priority {
Self::START
} else {
Self::START_NO_PRIORITY
};
let mut id: u64 = base;
if self.has_extensions {
id |= 0x01;
}
if self.has_subgroup_object {
id |= 0x02;
}
if self.has_subgroup {
id |= 0x04;
}
if self.has_end {
id |= 0x08;
}
if !matches!(
version,
Version::Draft14 | Version::Draft15 | Version::Draft16 | Version::Draft17
) {
id |= Self::FIRST_OBJECT_BIT;
}
Ok(id)
}
pub fn decode(id: u64, version: Version) -> Result<Self, DecodeError> {
let id = if matches!(
version,
Version::Draft14 | Version::Draft15 | Version::Draft16 | Version::Draft17
) {
id
} else {
id & !Self::FIRST_OBJECT_BIT
};
let (has_priority, base_id) = if (Self::START..=Self::END).contains(&id) {
(true, id)
} else if (Self::START_NO_PRIORITY..=Self::END_NO_PRIORITY).contains(&id) {
(false, id - (Self::START_NO_PRIORITY - Self::START))
} else {
return Err(DecodeError::InvalidValue);
};
let has_extensions = (base_id & 0x01) != 0;
let has_subgroup_object = (base_id & 0x02) != 0;
let has_subgroup = (base_id & 0x04) != 0;
let has_end = (base_id & 0x08) != 0;
if has_subgroup && has_subgroup_object {
return Err(DecodeError::InvalidValue);
}
Ok(Self {
has_extensions,
has_subgroup,
has_subgroup_object,
has_end,
has_priority,
})
}
}
impl Default for GroupFlags {
fn default() -> Self {
Self {
has_extensions: false,
has_subgroup: false,
has_subgroup_object: false,
has_end: true,
has_priority: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GroupHeader {
pub track_alias: u64,
pub group_id: u64,
pub sub_group_id: u64,
pub publisher_priority: u8,
pub flags: GroupFlags,
}
impl Encode<Version> for GroupHeader {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
tracing::trace!(?self, "encoding group header");
self.flags.encode(version)?.encode(w, version)?;
self.track_alias.encode(w, version)?;
self.group_id.encode(w, version)?;
if !self.flags.has_subgroup && self.sub_group_id != 0 {
return Err(EncodeError::InvalidState);
}
if self.flags.has_subgroup {
self.sub_group_id.encode(w, version)?;
}
if self.flags.has_priority {
self.publisher_priority.encode(w, version)?;
}
Ok(())
}
}
impl Decode<Version> for GroupHeader {
fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
let flags = GroupFlags::decode(u64::decode(r, version)?, version)?;
let track_alias = u64::decode(r, version)?;
let group_id = u64::decode(r, version)?;
let sub_group_id = match flags.has_subgroup {
true => u64::decode(r, version)?,
false => 0,
};
let publisher_priority = if flags.has_priority {
u8::decode(r, version)?
} else {
128 };
let result = Self {
track_alias,
group_id,
sub_group_id,
publisher_priority,
flags,
};
tracing::trace!(?result, "decoded group header");
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_group_flags_spec_table() {
let flags = GroupFlags::decode(0x10, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(!flags.has_extensions);
assert!(!flags.has_end);
assert!(flags.has_priority);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x10);
let flags = GroupFlags::decode(0x11, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(flags.has_extensions);
assert!(!flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x11);
let flags = GroupFlags::decode(0x12, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(flags.has_subgroup_object);
assert!(!flags.has_extensions);
assert!(!flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x12);
let flags = GroupFlags::decode(0x13, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(flags.has_subgroup_object);
assert!(flags.has_extensions);
assert!(!flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x13);
let flags = GroupFlags::decode(0x14, Version::Draft14).unwrap();
assert!(flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(!flags.has_extensions);
assert!(!flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x14);
let flags = GroupFlags::decode(0x15, Version::Draft14).unwrap();
assert!(flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(flags.has_extensions);
assert!(!flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x15);
let flags = GroupFlags::decode(0x18, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(!flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x18);
let flags = GroupFlags::decode(0x19, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x19);
let flags = GroupFlags::decode(0x1A, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(flags.has_subgroup_object);
assert!(!flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x1A);
let flags = GroupFlags::decode(0x1B, Version::Draft14).unwrap();
assert!(!flags.has_subgroup);
assert!(flags.has_subgroup_object);
assert!(flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x1B);
let flags = GroupFlags::decode(0x1C, Version::Draft14).unwrap();
assert!(flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(!flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x1C);
let flags = GroupFlags::decode(0x1D, Version::Draft14).unwrap();
assert!(flags.has_subgroup);
assert!(!flags.has_subgroup_object);
assert!(flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x1D);
assert!(GroupFlags::decode(0x16, Version::Draft14).is_err());
}
#[test]
fn test_group_flags_no_priority_range() {
let flags = GroupFlags::decode(0x30, Version::Draft14).unwrap();
assert!(!flags.has_priority);
assert!(!flags.has_subgroup);
assert!(!flags.has_extensions);
assert!(!flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x30);
let flags = GroupFlags::decode(0x38, Version::Draft14).unwrap();
assert!(!flags.has_priority);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x38);
let flags = GroupFlags::decode(0x3D, Version::Draft14).unwrap();
assert!(!flags.has_priority);
assert!(flags.has_subgroup);
assert!(flags.has_extensions);
assert!(flags.has_end);
assert_eq!(flags.encode(Version::Draft14).unwrap(), 0x3D);
assert!(GroupFlags::decode(0x36, Version::Draft14).is_err());
}
#[test]
fn test_first_object_bit_draft18() {
let flags = GroupFlags::default();
let encoded = flags.encode(Version::Draft18).unwrap();
assert_eq!(encoded & GroupFlags::FIRST_OBJECT_BIT, GroupFlags::FIRST_OBJECT_BIT);
let v17 = flags.encode(Version::Draft17).unwrap();
assert_eq!(encoded, v17 | GroupFlags::FIRST_OBJECT_BIT);
let decoded = GroupFlags::decode(v17 | GroupFlags::FIRST_OBJECT_BIT, Version::Draft18).unwrap();
assert_eq!(decoded, flags);
assert!(GroupFlags::decode(v17 | GroupFlags::FIRST_OBJECT_BIT, Version::Draft17).is_err());
}
#[test]
fn test_draft18_extended_range() {
let flags = GroupFlags::decode(0x70, Version::Draft18).unwrap();
assert!(!flags.has_priority);
assert!(!flags.has_subgroup);
assert!(!flags.has_extensions);
assert!(!flags.has_end);
let flags = GroupFlags::decode(0x7D, Version::Draft18).unwrap();
assert!(!flags.has_priority);
assert!(flags.has_subgroup);
assert!(flags.has_extensions);
assert!(flags.has_end);
}
#[test]
fn test_draft18_group_header_passes_stream_classifier() {
let header = GroupHeader {
track_alias: 1,
group_id: 0,
sub_group_id: 0,
publisher_priority: 0,
flags: GroupFlags::default(),
};
let mut buf = bytes::BytesMut::new();
header.encode(&mut buf, Version::Draft18).unwrap();
let type_byte = buf[0] as u64;
assert_eq!(
type_byte & 0x90,
0x10,
"draft-18 SUBGROUP_HEADER type 0x{type_byte:02x} not recognized by uni-stream classifier",
);
}
}