use std::marker::PhantomData;
use nom::{
bytes::complete::{tag, take},
combinator::{map, map_res, verify},
number::{complete::u32, Endianness},
sequence::{delimited, terminated},
IResult,
};
use super::{CanonicalFixedSizedPod, FixedSizedPod};
pub trait PodDeserialize<'de> {
fn deserialize(
deserializer: PodDeserializer<'de>,
) -> Result<(Self, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>>
where
Self: Sized;
}
impl<'de> PodDeserialize<'de> for &'de str {
fn deserialize(
deserializer: PodDeserializer<'de>,
) -> Result<(Self, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>>
where
Self: Sized,
{
deserializer.deserialize_str()
}
}
impl<'de> PodDeserialize<'de> for String {
fn deserialize(
deserializer: PodDeserializer<'de>,
) -> Result<(Self, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>>
where
Self: Sized,
{
deserializer
.deserialize_str()
.map(|(s, success)| (s.to_owned(), success))
}
}
impl<'de> PodDeserialize<'de> for &'de [u8] {
fn deserialize(
deserializer: PodDeserializer<'de>,
) -> Result<(Self, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>>
where
Self: Sized,
{
deserializer.deserialize_bytes()
}
}
impl<'de> PodDeserialize<'de> for Vec<u8> {
fn deserialize(
deserializer: PodDeserializer<'de>,
) -> Result<(Self, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>>
where
Self: Sized,
{
deserializer
.deserialize_bytes()
.map(|(b, success)| (b.to_owned(), success))
}
}
impl<'de, P: FixedSizedPod> PodDeserialize<'de> for Vec<P> {
fn deserialize(
deserializer: PodDeserializer<'de>,
) -> Result<(Self, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>>
where
Self: Sized,
{
let (mut arr_deserializer, num_elems) = deserializer.deserialize_array::<P>()?;
let mut result = Vec::with_capacity(num_elems as usize);
for _ in 0..num_elems {
result.push(arr_deserializer.deserialize_element()?);
}
let success = arr_deserializer.end()?;
Ok((result, success))
}
}
pub struct DeserializeSuccess<'de>(PodDeserializer<'de>);
pub struct PodDeserializer<'de> {
input: &'de [u8],
}
impl<'de, 'a> PodDeserializer<'de> {
#[allow(clippy::clippy::type_complexity)]
pub fn deserialize_from<P: PodDeserialize<'de>>(
input: &'de [u8],
) -> Result<(&'de [u8], P), nom::Err<nom::error::Error<&'de [u8]>>> {
let deserializer = Self { input };
P::deserialize(deserializer).map(|(res, success)| (success.0.input, res))
}
fn parse<T, F>(&mut self, mut f: F) -> Result<T, nom::Err<nom::error::Error<&'de [u8]>>>
where
F: FnMut(&'de [u8]) -> IResult<&'de [u8], T>,
{
f(self.input).map(|(input, result)| {
self.input = input;
result
})
}
pub(super) fn header<'b>(type_: u32) -> impl FnMut(&'b [u8]) -> IResult<&'b [u8], u32> {
terminated(u32(Endianness::Native), tag(type_.to_ne_bytes()))
}
pub fn deserialize_fixed_sized_pod<P: FixedSizedPod>(
mut self,
) -> Result<(P, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>> {
let padding = if 8 - (P::CanonicalType::SIZE % 8) == 8 {
0
} else {
8 - (P::CanonicalType::SIZE % 8)
};
self.parse(delimited(
Self::header(P::CanonicalType::TYPE),
map(P::CanonicalType::deserialize_body, |res| {
P::from_canonical_type(&res)
}),
take(padding),
))
.map(|res| (res, DeserializeSuccess(self)))
}
pub fn deserialize_str(
mut self,
) -> Result<(&'de str, DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>> {
let len = self.parse(Self::header(spa_sys::SPA_TYPE_String))?;
let padding = (8 - len) % 8;
self.parse(terminated(
map_res(terminated(take(len - 1), tag([b'\0'])), std::str::from_utf8),
take(padding),
))
.map(|res| (res, DeserializeSuccess(self)))
}
#[allow(clippy::clippy::type_complexity)]
pub fn deserialize_bytes(
mut self,
) -> Result<(&'de [u8], DeserializeSuccess<'de>), nom::Err<nom::error::Error<&'de [u8]>>> {
let len = self.parse(Self::header(spa_sys::SPA_TYPE_Bytes))?;
let padding = (8 - len) % 8;
self.parse(terminated(take(len), take(padding)))
.map(|res| (res, DeserializeSuccess(self)))
}
#[allow(clippy::type_complexity)] pub fn deserialize_array<E>(
mut self,
) -> Result<(ArrayPodDeserializer<'de, E>, u32), nom::Err<nom::error::Error<&'de [u8]>>>
where
E: FixedSizedPod,
{
let len = self.parse(Self::header(spa_sys::SPA_TYPE_Array))?;
self.parse(verify(Self::header(E::CanonicalType::TYPE), |len| {
*len == E::CanonicalType::SIZE
}))?;
let num_elems = if E::CanonicalType::SIZE != 0 {
(len - 8) / E::CanonicalType::SIZE
} else {
0
};
Ok((
ArrayPodDeserializer {
deserializer: self,
length: num_elems,
deserialized: 0,
_phantom: PhantomData,
},
num_elems,
))
}
pub fn deserialize_struct(
mut self,
) -> Result<StructPodDeserializer<'de>, nom::Err<nom::error::Error<&'de [u8]>>> {
let len = self.parse(Self::header(spa_sys::SPA_TYPE_Struct))?;
Ok(StructPodDeserializer {
deserializer: Some(self),
remaining: len,
})
}
}
pub struct ArrayPodDeserializer<'de, E: FixedSizedPod> {
deserializer: PodDeserializer<'de>,
length: u32,
deserialized: u32,
_phantom: PhantomData<E>,
}
impl<'de, E: FixedSizedPod> ArrayPodDeserializer<'de, E> {
pub fn deserialize_element(&mut self) -> Result<E, nom::Err<nom::error::Error<&'de [u8]>>> {
if !self.deserialized < self.length {
panic!("No elements left in the pod to deserialize");
}
let result = self
.deserializer
.parse(|input| E::CanonicalType::deserialize_body(input))
.map(|res| E::from_canonical_type(&res));
self.deserialized += 1;
result
}
pub fn end(
mut self,
) -> Result<DeserializeSuccess<'de>, nom::Err<nom::error::Error<&'de [u8]>>> {
assert!(
self.length == self.deserialized,
"Not all fields were deserialized from the struct pod"
);
let bytes_read = self.deserialized * E::CanonicalType::SIZE;
let padding = if bytes_read % 8 == 0 {
0
} else {
8 - (bytes_read as usize % 8)
};
self.deserializer.parse(take(padding))?;
Ok(DeserializeSuccess(self.deserializer))
}
}
pub struct StructPodDeserializer<'de> {
deserializer: Option<PodDeserializer<'de>>,
remaining: u32,
}
impl<'de> StructPodDeserializer<'de> {
pub fn deserialize_field<P: PodDeserialize<'de>>(
&mut self,
) -> Result<Option<P>, nom::Err<nom::error::Error<&'de [u8]>>> {
if self.remaining == 0 {
Ok(None)
} else {
let deserializer = self
.deserializer
.take()
.expect("StructPodDeserializer does not contain a deserializer");
let remaining_input_len = deserializer.input.len();
let (res, success) = P::deserialize(deserializer)?;
self.remaining -= remaining_input_len as u32 - success.0.input.len() as u32;
self.deserializer = Some(success.0);
Ok(Some(res))
}
}
pub fn end(self) -> Result<DeserializeSuccess<'de>, nom::Err<nom::error::Error<&'de [u8]>>> {
assert!(
self.remaining == 0,
"Not all fields have been deserialized from the struct"
);
Ok(DeserializeSuccess(self.deserializer.expect(
"StructPodDeserializer does not contain a deserializer",
)))
}
}