use crate::descriptors::descriptor_body;
use crate::error::{Error, Result};
use alloc::vec::Vec;
use dvb_common::{Parse, Serialize};
pub const TAG: u8 = 0x02;
const HEADER_LEN: usize = 2;
const PROTOCOL_ID_LEN: usize = 2;
const LABEL_LEN: usize = 1;
pub const PROTOCOL_ID_OBJECT_CAROUSEL: u16 = 0x0001;
pub const PROTOCOL_ID_HTTP: u16 = 0x0003;
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct OcSelector {
pub remote_connection: bool,
pub remote_connection_info: Option<OcRemoteConnection>,
pub component_tag: u8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct OcRemoteConnection {
pub original_network_id: u16,
pub transport_stream_id: u16,
pub service_id: u16,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct HttpUrlEntry<'a> {
pub url_base: &'a [u8],
pub url_extensions: Vec<&'a [u8]>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct HttpSelector<'a> {
pub urls: Vec<HttpUrlEntry<'a>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[non_exhaustive]
pub enum SelectorKind<'a> {
ObjectCarousel(OcSelector),
Http(HttpSelector<'a>),
Unknown(&'a [u8]),
}
impl SelectorKind<'_> {
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::ObjectCarousel(_) => "OBJECT_CAROUSEL",
Self::Http(_) => "HTTP",
Self::Unknown(_) => "UNKNOWN",
}
}
}
dvb_common::impl_spec_display!(SelectorKind<'_>);
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct TransportProtocolDescriptor<'a> {
pub protocol_id: u16,
pub transport_protocol_label: u8,
#[cfg_attr(feature = "serde", serde(borrow))]
pub selector_bytes: &'a [u8],
}
impl<'a> TransportProtocolDescriptor<'a> {
#[must_use]
pub fn selector(&self) -> SelectorKind<'a> {
match self.protocol_id {
PROTOCOL_ID_OBJECT_CAROUSEL => self.decode_oc_selector(),
PROTOCOL_ID_HTTP => self.decode_http_selector(),
_ => SelectorKind::Unknown(self.selector_bytes),
}
}
fn decode_oc_selector(&self) -> SelectorKind<'a> {
let b = self.selector_bytes;
if b.len() < 2 {
return SelectorKind::Unknown(b);
}
let remote_connection = (b[0] & 0x80) != 0;
let mut pos = 1;
let remote_connection_info = if remote_connection {
if pos + 6 > b.len() {
return SelectorKind::Unknown(b);
}
let info = OcRemoteConnection {
original_network_id: u16::from_be_bytes([b[pos], b[pos + 1]]),
transport_stream_id: u16::from_be_bytes([b[pos + 2], b[pos + 3]]),
service_id: u16::from_be_bytes([b[pos + 4], b[pos + 5]]),
};
pos += 6;
Some(info)
} else {
None
};
if pos >= b.len() {
return SelectorKind::Unknown(b);
}
let component_tag = b[pos];
SelectorKind::ObjectCarousel(OcSelector {
remote_connection,
remote_connection_info,
component_tag,
})
}
fn decode_http_selector(&self) -> SelectorKind<'a> {
let b = self.selector_bytes;
let mut urls = Vec::new();
let mut pos = 0;
while pos < b.len() {
let url_base_length = b[pos] as usize;
pos += 1;
if pos + url_base_length > b.len() {
break;
}
let base_end = pos + url_base_length;
let url_base = &b[pos..base_end];
pos = base_end;
if pos >= b.len() {
urls.push(HttpUrlEntry {
url_base,
url_extensions: Vec::new(),
});
break;
}
let url_extension_count = b[pos] as usize;
pos += 1;
let mut url_extensions = Vec::with_capacity(url_extension_count);
for _ in 0..url_extension_count {
if pos >= b.len() {
break;
}
let ext_len = b[pos] as usize;
pos += 1;
let ext_end = pos + ext_len;
if ext_end > b.len() {
break;
}
url_extensions.push(&b[pos..ext_end]);
pos = ext_end;
}
urls.push(HttpUrlEntry {
url_base,
url_extensions,
});
}
SelectorKind::Http(HttpSelector { urls })
}
}
impl<'a> Parse<'a> for TransportProtocolDescriptor<'a> {
type Error = crate::error::Error;
fn parse(bytes: &'a [u8]) -> Result<Self> {
let body = descriptor_body(
bytes,
TAG,
"TransportProtocolDescriptor",
"unexpected tag for transport_protocol_descriptor",
)?;
if body.len() < PROTOCOL_ID_LEN + LABEL_LEN {
return Err(Error::InvalidDescriptor {
tag: TAG,
reason: "transport_protocol_descriptor body shorter than minimum 3 bytes",
});
}
let protocol_id = u16::from_be_bytes([body[0], body[1]]);
let transport_protocol_label = body[2];
let selector_bytes = &body[PROTOCOL_ID_LEN + LABEL_LEN..];
Ok(Self {
protocol_id,
transport_protocol_label,
selector_bytes,
})
}
}
impl Serialize for TransportProtocolDescriptor<'_> {
type Error = crate::error::Error;
fn serialized_len(&self) -> usize {
HEADER_LEN + PROTOCOL_ID_LEN + LABEL_LEN + self.selector_bytes.len()
}
fn serialize_into(&self, buf: &mut [u8]) -> Result<usize> {
let len = self.serialized_len();
let body_len = len - HEADER_LEN;
if body_len > u8::MAX as usize {
return Err(Error::InvalidDescriptor {
tag: TAG,
reason: "transport_protocol_descriptor body exceeds 255 bytes",
});
}
if buf.len() < len {
return Err(Error::OutputBufferTooSmall {
need: len,
have: buf.len(),
});
}
buf[0] = TAG;
buf[1] = body_len as u8;
buf[2..4].copy_from_slice(&self.protocol_id.to_be_bytes());
buf[4] = self.transport_protocol_label;
buf[5..5 + self.selector_bytes.len()].copy_from_slice(self.selector_bytes);
Ok(len)
}
}
impl<'a> crate::traits::DescriptorDef<'a> for TransportProtocolDescriptor<'a> {
const TAG: u8 = TAG;
const NAME: &'static str = "TRANSPORT_PROTOCOL";
}
#[cfg(test)]
mod tests {
use super::*;
fn build_oc_local() -> [u8; 7] {
[TAG, 5, 0x00, 0x01, 0x01, 0x00, 0xB4]
}
fn build_oc_remote() -> [u8; 13] {
[
TAG, 11, 0x00, 0x01, 0x01, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, 0x04, ]
}
fn build_http() -> [u8; 16] {
[
TAG, 14, 0x00, 0x03, 0x01, 4, b'h', b't', b't', b'p', 1, 4, b'/', b'a', b'p', b'p',
]
}
#[test]
fn parse_oc_local_selector() {
let bytes = build_oc_local();
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
assert_eq!(d.protocol_id, PROTOCOL_ID_OBJECT_CAROUSEL);
assert_eq!(d.transport_protocol_label, 0x01);
assert_eq!(d.selector_bytes, &[0x00, 0xB4]);
match d.selector() {
SelectorKind::ObjectCarousel(oc) => {
assert!(!oc.remote_connection);
assert!(oc.remote_connection_info.is_none());
assert_eq!(oc.component_tag, 0xB4);
}
other => panic!("expected ObjectCarousel, got {other:?}"),
}
}
#[test]
fn parse_oc_remote_selector() {
let bytes = build_oc_remote();
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
match d.selector() {
SelectorKind::ObjectCarousel(oc) => {
assert!(oc.remote_connection);
let rc = oc.remote_connection_info.unwrap();
assert_eq!(rc.original_network_id, 1);
assert_eq!(rc.transport_stream_id, 2);
assert_eq!(rc.service_id, 3);
assert_eq!(oc.component_tag, 0x04);
}
other => panic!("expected ObjectCarousel, got {other:?}"),
}
}
#[test]
fn parse_http_selector() {
let bytes = build_http();
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
assert_eq!(d.protocol_id, PROTOCOL_ID_HTTP);
match d.selector() {
SelectorKind::Http(http) => {
assert_eq!(http.urls.len(), 1);
assert_eq!(http.urls[0].url_base, b"http");
assert_eq!(http.urls[0].url_extensions.len(), 1);
assert_eq!(http.urls[0].url_extensions[0], b"/app");
}
other => panic!("expected Http, got {other:?}"),
}
}
#[test]
fn oc_remote_selector_missing_component_tag_is_unknown_not_panic() {
let bytes = [
TAG, 10, 0x00, 0x01, 0x01, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, ];
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
assert_eq!(d.selector_bytes.len(), 7); assert!(d.selector_bytes[0] & 0x80 != 0); assert!(matches!(d.selector(), SelectorKind::Unknown(_)));
}
#[test]
fn parse_unknown_protocol() {
let bytes = [TAG, 5, 0x01, 0x00, 0x02, 0xCA, 0xFE];
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
assert_eq!(d.protocol_id, 0x0100);
assert!(matches!(d.selector(), SelectorKind::Unknown(_)));
}
#[test]
fn serialize_round_trip_oc() {
let bytes = build_oc_local();
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
let mut buf = vec![0u8; d.serialized_len()];
d.serialize_into(&mut buf).unwrap();
assert_eq!(buf.as_slice(), &bytes[..]);
let re = TransportProtocolDescriptor::parse(&buf).unwrap();
assert_eq!(d, re);
}
#[test]
fn serialize_round_trip_http() {
let bytes = build_http();
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
let mut buf = vec![0u8; d.serialized_len()];
d.serialize_into(&mut buf).unwrap();
assert_eq!(buf.as_slice(), &bytes[..]);
let re = TransportProtocolDescriptor::parse(&buf).unwrap();
assert_eq!(d, re);
}
#[test]
fn serialize_round_trip_oc_remote() {
let bytes = build_oc_remote();
let d = TransportProtocolDescriptor::parse(&bytes).unwrap();
let mut buf = vec![0u8; d.serialized_len()];
d.serialize_into(&mut buf).unwrap();
assert_eq!(buf.as_slice(), &bytes[..]);
}
}