use bytes::{Buf, BufMut};
use super::types::ObjectStatus;
use crate::error::CodecError;
use crate::varint::VarInt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SubgroupStreamType(u8);
impl SubgroupStreamType {
pub fn as_u8(self) -> u8 {
self.0
}
pub fn from_u8(v: u8) -> Option<Self> {
if (0x10..=0x15).contains(&v) || (0x18..=0x1D).contains(&v) {
Some(SubgroupStreamType(v))
} else {
None
}
}
pub fn from_flags(
subgroup_id_field_present: bool,
subgroup_id_is_first_object: bool,
extensions_present: bool,
end_of_group: bool,
) -> Self {
let mut v: u8 = 0x10;
if extensions_present {
v |= 0x01;
}
if subgroup_id_field_present {
v |= 0x04;
} else if subgroup_id_is_first_object {
v |= 0x02;
}
if end_of_group {
v |= 0x08;
}
SubgroupStreamType(v)
}
pub fn has_subgroup_id_field(self) -> bool {
self.0 & 0x04 != 0
}
pub fn subgroup_id_is_first_object(self) -> bool {
!self.has_subgroup_id_field() && (self.0 & 0x02 != 0)
}
pub fn extensions_present(self) -> bool {
self.0 & 0x01 != 0
}
pub fn contains_end_of_group(self) -> bool {
self.0 & 0x08 != 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SubgroupHeader {
pub stream_type: SubgroupStreamType,
pub track_alias: VarInt,
pub group_id: VarInt,
pub subgroup_id: Option<VarInt>,
pub publisher_priority: u8,
}
impl SubgroupHeader {
pub fn encode(&self, buf: &mut impl BufMut) {
VarInt::from_u64(self.stream_type.as_u8() as u64).unwrap().encode(buf);
self.track_alias.encode(buf);
self.group_id.encode(buf);
if self.stream_type.has_subgroup_id_field() {
let sg = self.subgroup_id.unwrap_or_else(|| VarInt::from_u64(0).unwrap());
sg.encode(buf);
}
buf.put_u8(self.publisher_priority);
}
pub fn decode(buf: &mut impl Buf) -> Result<Self, CodecError> {
let type_val = VarInt::decode(buf)?.into_inner();
if type_val > 0xFF {
return Err(CodecError::InvalidField);
}
let stream_type =
SubgroupStreamType::from_u8(type_val as u8).ok_or(CodecError::InvalidField)?;
let track_alias = VarInt::decode(buf)?;
let group_id = VarInt::decode(buf)?;
let subgroup_id =
if stream_type.has_subgroup_id_field() { Some(VarInt::decode(buf)?) } else { None };
if buf.remaining() < 1 {
return Err(CodecError::UnexpectedEnd);
}
let publisher_priority = buf.get_u8();
Ok(SubgroupHeader { stream_type, track_alias, group_id, subgroup_id, publisher_priority })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SubgroupObject {
pub object_id: VarInt,
pub extension_headers: Vec<u8>,
pub status: Option<ObjectStatus>,
pub payload: Vec<u8>,
}
#[derive(Debug, Clone)]
pub struct SubgroupObjectReader {
extensions_present: bool,
prev_object_id: Option<u64>,
}
impl SubgroupObjectReader {
pub fn new(header: &SubgroupHeader) -> Self {
Self { extensions_present: header.stream_type.extensions_present(), prev_object_id: None }
}
pub fn read_object(&mut self, buf: &mut impl Buf) -> Result<SubgroupObject, CodecError> {
let delta = VarInt::decode(buf)?.into_inner();
let object_id_val = match self.prev_object_id {
None => delta,
Some(prev) => prev
.checked_add(1)
.and_then(|v| v.checked_add(delta))
.ok_or(CodecError::InvalidField)?,
};
self.prev_object_id = Some(object_id_val);
let object_id = VarInt::from_u64(object_id_val).map_err(|_| CodecError::InvalidField)?;
let extension_headers = if self.extensions_present {
let ext_len = VarInt::decode(buf)?.into_inner() as usize;
crate::types::read_bytes(buf, ext_len)?
} else {
Vec::new()
};
let payload_length = VarInt::decode(buf)?.into_inner() as usize;
let (status, payload) = if payload_length == 0 {
let status_val = VarInt::decode(buf)?.into_inner();
let status = ObjectStatus::from_u64(status_val).ok_or(CodecError::InvalidField)?;
(Some(status), Vec::new())
} else {
let payload = crate::types::read_bytes(buf, payload_length)?;
(None, payload)
};
Ok(SubgroupObject { object_id, extension_headers, status, payload })
}
pub fn write_object(
&mut self,
object: &SubgroupObject,
buf: &mut impl BufMut,
) -> Result<(), CodecError> {
let oid = object.object_id.into_inner();
let delta = match self.prev_object_id {
None => oid,
Some(prev) => oid
.checked_sub(prev)
.and_then(|v| v.checked_sub(1))
.ok_or(CodecError::InvalidField)?,
};
VarInt::from_u64(delta).map_err(|_| CodecError::InvalidField)?.encode(buf);
if self.extensions_present {
VarInt::from_u64(object.extension_headers.len() as u64)
.map_err(|_| CodecError::InvalidField)?
.encode(buf);
buf.put_slice(&object.extension_headers);
}
if let Some(status) = object.status {
VarInt::from_u64(0).unwrap().encode(buf);
VarInt::from_u64(status.as_u64()).unwrap().encode(buf);
} else {
VarInt::from_u64(object.payload.len() as u64)
.map_err(|_| CodecError::InvalidField)?
.encode(buf);
buf.put_slice(&object.payload);
}
self.prev_object_id = Some(oid);
Ok(())
}
}
pub const FETCH_STREAM_TYPE: u8 = 0x05;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FetchHeader {
pub request_id: VarInt,
}
impl FetchHeader {
pub fn encode(&self, buf: &mut impl BufMut) {
VarInt::from_u64(FETCH_STREAM_TYPE as u64).unwrap().encode(buf);
self.request_id.encode(buf);
}
pub fn decode(buf: &mut impl Buf) -> Result<Self, CodecError> {
let type_val = VarInt::decode(buf)?.into_inner();
if type_val != FETCH_STREAM_TYPE as u64 {
return Err(CodecError::InvalidField);
}
let request_id = VarInt::decode(buf)?;
Ok(FetchHeader { request_id })
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FetchObject {
pub group_id: VarInt,
pub subgroup_id: VarInt,
pub object_id: VarInt,
pub publisher_priority: u8,
pub extension_headers: Vec<u8>,
pub status: Option<ObjectStatus>,
pub payload: Vec<u8>,
}
impl FetchObject {
pub fn encode(&self, buf: &mut impl BufMut) {
self.group_id.encode(buf);
self.subgroup_id.encode(buf);
self.object_id.encode(buf);
buf.put_u8(self.publisher_priority);
VarInt::from_u64(self.extension_headers.len() as u64).unwrap().encode(buf);
buf.put_slice(&self.extension_headers);
if let Some(status) = self.status {
VarInt::from_u64(0).unwrap().encode(buf);
VarInt::from_u64(status.as_u64()).unwrap().encode(buf);
} else {
VarInt::from_u64(self.payload.len() as u64).unwrap().encode(buf);
buf.put_slice(&self.payload);
}
}
pub fn decode(buf: &mut impl Buf) -> Result<Self, CodecError> {
let group_id = VarInt::decode(buf)?;
let subgroup_id = VarInt::decode(buf)?;
let object_id = VarInt::decode(buf)?;
if buf.remaining() < 1 {
return Err(CodecError::UnexpectedEnd);
}
let publisher_priority = buf.get_u8();
let ext_len = VarInt::decode(buf)?.into_inner() as usize;
let extension_headers = crate::types::read_bytes(buf, ext_len)?;
let payload_length = VarInt::decode(buf)?.into_inner() as usize;
let (status, payload) = if payload_length == 0 {
let status_val = VarInt::decode(buf)?.into_inner();
let status = ObjectStatus::from_u64(status_val).ok_or(CodecError::InvalidField)?;
(Some(status), Vec::new())
} else {
(None, crate::types::read_bytes(buf, payload_length)?)
};
Ok(FetchObject {
group_id,
subgroup_id,
object_id,
publisher_priority,
extension_headers,
status,
payload,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DatagramType(u8);
impl DatagramType {
pub fn as_u8(self) -> u8 {
self.0
}
pub fn from_u8(v: u8) -> Option<Self> {
if (0x00..=0x07).contains(&v) || v == 0x20 || v == 0x21 {
Some(DatagramType(v))
} else {
None
}
}
pub fn payload(object_id_present: bool, extensions_present: bool, end_of_group: bool) -> Self {
let mut v: u8 = 0x00;
if extensions_present {
v |= 0x01;
}
if end_of_group {
v |= 0x02;
}
if !object_id_present {
v |= 0x04;
}
DatagramType(v)
}
pub fn status(extensions_present: bool) -> Self {
if extensions_present {
DatagramType(0x21)
} else {
DatagramType(0x20)
}
}
pub fn is_status(self) -> bool {
self.0 >= 0x20
}
pub fn object_id_present(self) -> bool {
if self.is_status() {
true
} else {
self.0 & 0x04 == 0
}
}
pub fn end_of_group(self) -> bool {
!self.is_status() && (self.0 & 0x02 != 0)
}
pub fn extensions_present(self) -> bool {
self.0 & 0x01 != 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DatagramObject {
pub datagram_type: DatagramType,
pub track_alias: VarInt,
pub group_id: VarInt,
pub object_id: VarInt,
pub publisher_priority: u8,
pub extension_headers: Vec<u8>,
pub status: Option<ObjectStatus>,
pub payload: Vec<u8>,
}
impl DatagramObject {
pub fn encode(&self, buf: &mut impl BufMut) {
VarInt::from_u64(self.datagram_type.as_u8() as u64).unwrap().encode(buf);
self.track_alias.encode(buf);
self.group_id.encode(buf);
if self.datagram_type.object_id_present() {
self.object_id.encode(buf);
}
buf.put_u8(self.publisher_priority);
if self.datagram_type.extensions_present() {
VarInt::from_u64(self.extension_headers.len() as u64).unwrap().encode(buf);
buf.put_slice(&self.extension_headers);
}
if self.datagram_type.is_status() {
let status = self.status.unwrap_or(ObjectStatus::Normal);
VarInt::from_u64(status.as_u64()).unwrap().encode(buf);
} else {
buf.put_slice(&self.payload);
}
}
pub fn decode(buf: &mut impl Buf) -> Result<Self, CodecError> {
let type_val = VarInt::decode(buf)?.into_inner();
if type_val > 0xFF {
return Err(CodecError::InvalidField);
}
let datagram_type =
DatagramType::from_u8(type_val as u8).ok_or(CodecError::InvalidField)?;
let track_alias = VarInt::decode(buf)?;
let group_id = VarInt::decode(buf)?;
let object_id = if datagram_type.object_id_present() {
VarInt::decode(buf)?
} else {
VarInt::from_u64(0).unwrap()
};
if buf.remaining() < 1 {
return Err(CodecError::UnexpectedEnd);
}
let publisher_priority = buf.get_u8();
let extension_headers = if datagram_type.extensions_present() {
let ext_len = VarInt::decode(buf)?.into_inner() as usize;
crate::types::read_bytes(buf, ext_len)?
} else {
Vec::new()
};
let (status, payload) = if datagram_type.is_status() {
let status_val = VarInt::decode(buf)?.into_inner();
let status = ObjectStatus::from_u64(status_val).ok_or(CodecError::InvalidField)?;
(Some(status), Vec::new())
} else {
let remaining = buf.remaining();
(None, crate::types::read_bytes(buf, remaining)?)
};
Ok(DatagramObject {
datagram_type,
track_alias,
group_id,
object_id,
publisher_priority,
extension_headers,
status,
payload,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn vi(v: u64) -> VarInt {
VarInt::from_u64(v).unwrap()
}
#[test]
fn subgroup_type_0x10_all_off() {
let t = SubgroupStreamType::from_u8(0x10).unwrap();
assert!(!t.has_subgroup_id_field());
assert!(!t.subgroup_id_is_first_object());
assert!(!t.extensions_present());
assert!(!t.contains_end_of_group());
}
#[test]
fn subgroup_type_0x15_explicit_with_ext() {
let t = SubgroupStreamType::from_u8(0x15).unwrap();
assert!(t.has_subgroup_id_field());
assert!(!t.subgroup_id_is_first_object());
assert!(t.extensions_present());
assert!(!t.contains_end_of_group());
}
#[test]
fn subgroup_type_0x1d_all_on() {
let t = SubgroupStreamType::from_u8(0x1D).unwrap();
assert!(t.has_subgroup_id_field());
assert!(t.extensions_present());
assert!(t.contains_end_of_group());
}
#[test]
fn subgroup_type_0x12_first_object() {
let t = SubgroupStreamType::from_u8(0x12).unwrap();
assert!(!t.has_subgroup_id_field());
assert!(t.subgroup_id_is_first_object());
assert!(!t.extensions_present());
}
#[test]
fn subgroup_type_rejects_undefined() {
for bad in [0x00u8, 0x0F, 0x16, 0x17, 0x1E, 0x1F, 0x20] {
assert!(SubgroupStreamType::from_u8(bad).is_none(), "0x{bad:02x} should be rejected");
}
}
#[test]
fn subgroup_type_from_flags_roundtrip() {
for &f_sg in &[false, true] {
for &f_first in &[false, true] {
for &f_ext in &[false, true] {
for &f_eog in &[false, true] {
let t = SubgroupStreamType::from_flags(f_sg, f_first, f_ext, f_eog);
assert_eq!(t.has_subgroup_id_field(), f_sg);
if !f_sg {
assert_eq!(t.subgroup_id_is_first_object(), f_first);
}
assert_eq!(t.extensions_present(), f_ext);
assert_eq!(t.contains_end_of_group(), f_eog);
}
}
}
}
}
#[test]
fn subgroup_header_roundtrip_0x10() {
let h = SubgroupHeader {
stream_type: SubgroupStreamType::from_u8(0x10).unwrap(),
track_alias: vi(1),
group_id: vi(0),
subgroup_id: None,
publisher_priority: 128,
};
let mut buf = Vec::new();
h.encode(&mut buf);
assert_eq!(buf[0], 0x10);
let decoded = SubgroupHeader::decode(&mut &buf[..]).unwrap();
assert_eq!(decoded, h);
}
#[test]
fn subgroup_header_roundtrip_explicit_subgroup() {
let h = SubgroupHeader {
stream_type: SubgroupStreamType::from_u8(0x14).unwrap(),
track_alias: vi(5),
group_id: vi(10),
subgroup_id: Some(vi(2)),
publisher_priority: 64,
};
let mut buf = Vec::new();
h.encode(&mut buf);
let decoded = SubgroupHeader::decode(&mut &buf[..]).unwrap();
assert_eq!(decoded, h);
}
#[test]
fn subgroup_header_decode_rejects_bad_type() {
let buf = [0x16u8, 0x01, 0x00, 0x80];
let err = SubgroupHeader::decode(&mut &buf[..]).unwrap_err();
assert!(matches!(err, CodecError::InvalidField));
}
#[test]
fn subgroup_reader_delta_sequential_ids() {
let header = SubgroupHeader {
stream_type: SubgroupStreamType::from_u8(0x10).unwrap(),
track_alias: vi(1),
group_id: vi(0),
subgroup_id: None,
publisher_priority: 0,
};
let mut write = SubgroupObjectReader::new(&header);
let mut buf = Vec::new();
for i in 0..3u64 {
let obj = SubgroupObject {
object_id: vi(i),
extension_headers: vec![],
status: None,
payload: vec![0xAA + i as u8; 4],
};
write.write_object(&obj, &mut buf).unwrap();
}
let mut read = SubgroupObjectReader::new(&header);
let mut cursor = &buf[..];
let o0 = read.read_object(&mut cursor).unwrap();
assert_eq!(o0.object_id.into_inner(), 0);
assert_eq!(o0.payload, vec![0xAA; 4]);
let o1 = read.read_object(&mut cursor).unwrap();
assert_eq!(o1.object_id.into_inner(), 1);
let o2 = read.read_object(&mut cursor).unwrap();
assert_eq!(o2.object_id.into_inner(), 2);
}
#[test]
fn subgroup_reader_delta_sparse_ids() {
let header = SubgroupHeader {
stream_type: SubgroupStreamType::from_u8(0x10).unwrap(),
track_alias: vi(1),
group_id: vi(0),
subgroup_id: None,
publisher_priority: 0,
};
let mut write = SubgroupObjectReader::new(&header);
let mut buf = Vec::new();
for &id in &[5u64, 10, 11] {
write
.write_object(
&SubgroupObject {
object_id: vi(id),
extension_headers: vec![],
status: None,
payload: vec![1, 2, 3],
},
&mut buf,
)
.unwrap();
}
let mut read = SubgroupObjectReader::new(&header);
let mut cursor = &buf[..];
assert_eq!(read.read_object(&mut cursor).unwrap().object_id.into_inner(), 5);
assert_eq!(read.read_object(&mut cursor).unwrap().object_id.into_inner(), 10);
assert_eq!(read.read_object(&mut cursor).unwrap().object_id.into_inner(), 11);
}
#[test]
fn subgroup_reader_with_extensions() {
let header = SubgroupHeader {
stream_type: SubgroupStreamType::from_u8(0x11).unwrap(),
track_alias: vi(1),
group_id: vi(0),
subgroup_id: None,
publisher_priority: 0,
};
let mut write = SubgroupObjectReader::new(&header);
let mut buf = Vec::new();
write
.write_object(
&SubgroupObject {
object_id: vi(0),
extension_headers: vec![0x01, 0x02, 0x03],
status: None,
payload: vec![0xFF],
},
&mut buf,
)
.unwrap();
let mut read = SubgroupObjectReader::new(&header);
let o = read.read_object(&mut &buf[..]).unwrap();
assert_eq!(o.extension_headers, vec![0x01, 0x02, 0x03]);
assert_eq!(o.payload, vec![0xFF]);
}
#[test]
fn subgroup_reader_status_object() {
let header = SubgroupHeader {
stream_type: SubgroupStreamType::from_u8(0x10).unwrap(),
track_alias: vi(1),
group_id: vi(0),
subgroup_id: None,
publisher_priority: 0,
};
let mut write = SubgroupObjectReader::new(&header);
let mut buf = Vec::new();
write
.write_object(
&SubgroupObject {
object_id: vi(7),
extension_headers: vec![],
status: Some(ObjectStatus::EndOfGroup),
payload: vec![],
},
&mut buf,
)
.unwrap();
let mut read = SubgroupObjectReader::new(&header);
let o = read.read_object(&mut &buf[..]).unwrap();
assert_eq!(o.object_id.into_inner(), 7);
assert_eq!(o.status, Some(ObjectStatus::EndOfGroup));
assert!(o.payload.is_empty());
}
#[test]
fn fetch_header_roundtrip() {
let h = FetchHeader { request_id: vi(99) };
let mut buf = Vec::new();
h.encode(&mut buf);
assert_eq!(buf[0], 0x05);
assert_eq!(FetchHeader::decode(&mut &buf[..]).unwrap(), h);
}
#[test]
fn fetch_header_rejects_wrong_type() {
let buf = [0x10u8, 0x05];
assert!(FetchHeader::decode(&mut &buf[..]).is_err());
}
#[test]
fn fetch_object_roundtrip_with_payload() {
let obj = FetchObject {
group_id: vi(3),
subgroup_id: vi(1),
object_id: vi(7),
publisher_priority: 200,
extension_headers: vec![0xAA, 0xBB],
status: None,
payload: vec![1, 2, 3, 4],
};
let mut buf = Vec::new();
obj.encode(&mut buf);
assert_eq!(FetchObject::decode(&mut &buf[..]).unwrap(), obj);
}
#[test]
fn fetch_object_roundtrip_status() {
let obj = FetchObject {
group_id: vi(3),
subgroup_id: vi(1),
object_id: vi(8),
publisher_priority: 200,
extension_headers: vec![],
status: Some(ObjectStatus::ObjectDoesNotExist),
payload: vec![],
};
let mut buf = Vec::new();
obj.encode(&mut buf);
assert_eq!(FetchObject::decode(&mut &buf[..]).unwrap(), obj);
}
#[test]
fn datagram_type_variants() {
let t0 = DatagramType::from_u8(0x00).unwrap();
assert!(t0.object_id_present());
assert!(!t0.extensions_present());
assert!(!t0.end_of_group());
assert!(!t0.is_status());
let t7 = DatagramType::from_u8(0x07).unwrap();
assert!(!t7.object_id_present()); assert!(t7.extensions_present());
assert!(t7.end_of_group());
assert!(!t7.is_status());
let t20 = DatagramType::from_u8(0x20).unwrap();
assert!(t20.is_status());
assert!(!t20.extensions_present());
assert!(t20.object_id_present());
let t21 = DatagramType::from_u8(0x21).unwrap();
assert!(t21.is_status());
assert!(t21.extensions_present());
}
#[test]
fn datagram_type_rejects_undefined() {
for bad in [0x08u8, 0x10, 0x1F, 0x22, 0x80] {
assert!(DatagramType::from_u8(bad).is_none(), "0x{bad:02x}");
}
}
#[test]
fn datagram_object_0x00_roundtrip() {
let d = DatagramObject {
datagram_type: DatagramType::from_u8(0x00).unwrap(),
track_alias: vi(1),
group_id: vi(2),
object_id: vi(3),
publisher_priority: 100,
extension_headers: vec![],
status: None,
payload: vec![0xDE, 0xAD, 0xBE, 0xEF],
};
let mut buf = Vec::new();
d.encode(&mut buf);
assert_eq!(DatagramObject::decode(&mut &buf[..]).unwrap(), d);
}
#[test]
fn datagram_object_0x04_no_object_id() {
let d = DatagramObject {
datagram_type: DatagramType::from_u8(0x04).unwrap(),
track_alias: vi(1),
group_id: vi(2),
object_id: vi(0),
publisher_priority: 100,
extension_headers: vec![],
status: None,
payload: vec![0xAA],
};
let mut buf = Vec::new();
d.encode(&mut buf);
let decoded = DatagramObject::decode(&mut &buf[..]).unwrap();
assert_eq!(decoded, d);
}
#[test]
fn datagram_object_0x21_status_with_extensions() {
let d = DatagramObject {
datagram_type: DatagramType::from_u8(0x21).unwrap(),
track_alias: vi(9),
group_id: vi(4),
object_id: vi(11),
publisher_priority: 50,
extension_headers: vec![0xCA, 0xFE],
status: Some(ObjectStatus::EndOfTrack),
payload: vec![],
};
let mut buf = Vec::new();
d.encode(&mut buf);
assert_eq!(DatagramObject::decode(&mut &buf[..]).unwrap(), d);
}
}