use std::any::Any;
use std::collections::HashMap;
use super::{validate_and_split, ExtensionBodyDef, ExtensionDescriptor};
use crate::error::Result;
#[cfg(not(feature = "serde"))]
pub trait ExtensionObject: std::fmt::Debug + Any + Send + Sync {
fn as_any(&self) -> &dyn Any;
}
#[cfg(feature = "serde")]
pub trait ExtensionObject: std::fmt::Debug + Any + Send + Sync + erased_serde::Serialize {
fn as_any(&self) -> &dyn Any;
}
#[cfg(not(feature = "serde"))]
impl<T> ExtensionObject for T
where
T: std::fmt::Debug + Any + Send + Sync,
{
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(feature = "serde")]
impl<T> ExtensionObject for T
where
T: std::fmt::Debug + Any + Send + Sync + serde::Serialize,
{
fn as_any(&self) -> &dyn Any {
self
}
}
impl dyn ExtensionObject {
#[must_use]
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
self.as_any().downcast_ref::<T>()
}
#[must_use]
pub fn is<T: Any>(&self) -> bool {
self.as_any().is::<T>()
}
}
#[cfg(feature = "serde")]
#[allow(clippy::borrowed_box)]
pub(crate) fn serialize_erased<S: serde::Serializer>(
v: &Box<dyn ExtensionObject>,
s: S,
) -> std::result::Result<S::Ok, S::Error> {
erased_serde::serialize(&**v, s)
}
pub(crate) type CustomParse =
Box<dyn for<'a> Fn(&'a [u8]) -> Result<Box<dyn ExtensionObject>> + Send + Sync>;
#[derive(Default)]
pub struct ExtensionRegistry {
custom: HashMap<u8, CustomParse>,
}
impl ExtensionRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn has_custom(&self, tag_extension: u8) -> bool {
self.custom.contains_key(&tag_extension)
}
pub fn register<T>(&mut self) -> &mut Self
where
T: for<'a> ExtensionBodyDef<'a> + ExtensionObject + 'static,
{
let tag_ext = T::TAG_EXTENSION;
self.custom.insert(
tag_ext,
Box::new(|sel| {
Ok(Box::new(<T as dvb_common::Parse>::parse(sel)?) as Box<dyn ExtensionObject>)
}),
);
self
}
pub fn parse_body<'a>(
&self,
tag_extension: u8,
selector: &'a [u8],
) -> Result<RegisteredExtension<'a>> {
if let Some(parse_fn) = self.custom.get(&tag_extension) {
let value = parse_fn(selector)?;
Ok(RegisteredExtension::Custom {
tag_extension,
value,
})
} else {
let body = super::parse_body(tag_extension, selector)?;
Ok(RegisteredExtension::Builtin(ExtensionDescriptor {
tag_extension,
body,
}))
}
}
pub fn parse<'a>(&self, bytes: &'a [u8]) -> Result<RegisteredExtension<'a>> {
let (tag_extension, sel) = validate_and_split(bytes)?;
self.parse_body(tag_extension, sel)
}
}
#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[cfg_attr(feature = "serde", serde(rename_all = "camelCase"))]
#[non_exhaustive]
pub enum RegisteredExtension<'a> {
Builtin(super::ExtensionDescriptor<'a>),
Custom {
tag_extension: u8,
#[cfg_attr(feature = "serde", serde(serialize_with = "serialize_erased"))]
value: Box<dyn ExtensionObject>,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::descriptors::extension::{ExtensionBodyDef, TAG, TAG_EXTENSION_LEN};
use crate::error::Error;
const TEST_TAG_EXTENSION: u8 = 0x40;
#[derive(Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
struct MyExtBody {
payload: Vec<u8>,
}
impl<'a> ExtensionBodyDef<'a> for MyExtBody {
const TAG_EXTENSION: u8 = TEST_TAG_EXTENSION;
const NAME: &'static str = "MY_EXT_BODY";
}
impl<'a> dvb_common::Parse<'a> for MyExtBody {
type Error = crate::error::Error;
fn parse(sel: &'a [u8]) -> Result<Self> {
Ok(Self {
payload: sel.to_vec(),
})
}
}
fn wrap_ext(tag_ext: u8, sel: &[u8]) -> Vec<u8> {
let mut v = vec![TAG, (sel.len() + TAG_EXTENSION_LEN) as u8, tag_ext];
v.extend_from_slice(sel);
v
}
#[test]
fn custom_extension_parsed_and_downcastable() {
let mut reg = ExtensionRegistry::new();
reg.register::<MyExtBody>();
let sel = [0xDE, 0xAD, 0xBE];
let bytes = wrap_ext(TEST_TAG_EXTENSION, &sel);
let re = reg.parse(&bytes).unwrap();
match re {
RegisteredExtension::Custom {
tag_extension,
value,
} => {
assert_eq!(tag_extension, TEST_TAG_EXTENSION);
let concrete = value
.downcast_ref::<MyExtBody>()
.expect("downcast should succeed");
assert_eq!(concrete.payload, sel);
}
other => panic!("expected Custom, got {other:?}"),
}
}
#[test]
fn unregistered_tag_extension_yields_builtin() {
use crate::descriptors::extension::ExtensionBody;
let reg = ExtensionRegistry::new();
let d = crate::descriptors::extension::ExtensionDescriptor {
tag_extension: 0x0B,
body: ExtensionBody::ServiceRelocated(
crate::descriptors::extension::ServiceRelocated {
old_original_network_id: 1,
old_transport_stream_id: 2,
old_service_id: 3,
},
),
};
let mut buf = vec![0u8; d.serialized_len()];
use dvb_common::Serialize;
d.serialize_into(&mut buf).unwrap();
let re = reg.parse(&buf).unwrap();
match re {
RegisteredExtension::Builtin(d) => {
assert_eq!(d.tag_extension, 0x0B);
assert!(matches!(d.body, ExtensionBody::ServiceRelocated(_)));
}
other => panic!("expected Builtin, got {other:?}"),
}
}
#[test]
fn unknown_tag_extension_yields_builtin_raw() {
use crate::descriptors::extension::ExtensionBody;
let reg = ExtensionRegistry::new();
let sel = [0xAA, 0xBB];
let bytes = wrap_ext(0xFE, &sel);
let re = reg.parse(&bytes).unwrap();
match re {
RegisteredExtension::Builtin(d) => {
assert_eq!(d.tag_extension, 0xFE);
assert!(matches!(d.body, ExtensionBody::Raw(b) if b == sel));
}
other => panic!("expected Builtin, got {other:?}"),
}
}
#[test]
fn parse_rejects_wrong_tag() {
let reg = ExtensionRegistry::new();
let raw = [0x43, 1, 0x04];
assert!(matches!(
reg.parse(&raw).unwrap_err(),
Error::InvalidDescriptor { tag: 0x43, .. }
));
}
#[test]
fn parse_rejects_short_buffer() {
let reg = ExtensionRegistry::new();
let raw = [TAG];
assert!(matches!(
reg.parse(&raw).unwrap_err(),
Error::BufferTooShort { .. }
));
}
#[test]
fn parse_rejects_empty_body() {
let reg = ExtensionRegistry::new();
let raw = [TAG, 0];
assert!(matches!(
reg.parse(&raw).unwrap_err(),
Error::InvalidDescriptor { tag: TAG, .. }
));
}
#[cfg(feature = "serde")]
#[test]
fn custom_variant_serializes_via_erased_serde() {
let mut reg = ExtensionRegistry::new();
reg.register::<MyExtBody>();
let bytes = wrap_ext(TEST_TAG_EXTENSION, &[0x01, 0x02]);
let re = reg.parse(&bytes).unwrap();
let json = serde_json::to_value(&re).unwrap();
let custom = json.get("custom").expect("expected 'custom' key");
assert_eq!(custom["tag_extension"], TEST_TAG_EXTENSION as u64);
assert_eq!(custom["value"]["payload"], serde_json::json!([1, 2]));
}
}