use alloc::collections::{BTreeMap, BTreeSet};
use alloc::string::String;
use alloc::vec::Vec;
#[cfg(feature = "std")]
use std::collections::{HashMap, HashSet};
#[cfg(feature = "std")]
use std::hash::{BuildHasher, Hash};
use crate::codec::{Decode, Decoder};
use crate::error::{Result, SerialError};
use crate::traits::Deserialize;
pub trait DeserializeView<'a>: Sized {
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self>;
}
#[inline]
pub fn decode_view<'a, T: DeserializeView<'a>>(bytes: &'a [u8]) -> Result<T> {
let mut dec = Decoder::new(bytes);
let value = T::deserialize_view(&mut dec)?;
let remaining = dec.remaining();
if remaining != 0 {
return Err(SerialError::TrailingBytes { remaining });
}
Ok(value)
}
impl<'a> DeserializeView<'a> for &'a str {
#[inline]
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let bytes = decoder.read_length_prefixed_borrowed()?;
core::str::from_utf8(bytes).map_err(|_| SerialError::InvalidUtf8)
}
}
impl<'a> DeserializeView<'a> for &'a [u8] {
#[inline]
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
decoder.read_length_prefixed_borrowed()
}
}
macro_rules! view_via_owned {
($($t:ty),+ $(,)?) => {
$(
impl<'a> DeserializeView<'a> for $t {
#[inline]
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
<$t as Deserialize>::deserialize(decoder)
}
}
)+
};
}
view_via_owned!(
u8,
u16,
u32,
u64,
u128,
usize,
i8,
i16,
i32,
i64,
i128,
isize,
bool,
f32,
f64,
(),
String,
);
impl<'a, T: DeserializeView<'a>> DeserializeView<'a> for Option<T> {
#[inline]
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
match decoder.read_byte()? {
0x00 => Ok(None),
0x01 => Ok(Some(T::deserialize_view(decoder)?)),
tag => Err(SerialError::InvalidTag {
kind: "Option",
tag,
}),
}
}
}
impl<'a, T: DeserializeView<'a>, E: DeserializeView<'a>> DeserializeView<'a>
for core::result::Result<T, E>
{
#[inline]
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
match decoder.read_byte()? {
0x00 => Ok(Ok(T::deserialize_view(decoder)?)),
0x01 => Ok(Err(E::deserialize_view(decoder)?)),
tag => Err(SerialError::InvalidTag {
kind: "Result",
tag,
}),
}
}
}
impl<'a, T: DeserializeView<'a>, const N: usize> DeserializeView<'a> for [T; N] {
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let mut out: Vec<T> = Vec::with_capacity(N);
for _ in 0..N {
out.push(T::deserialize_view(decoder)?);
}
out.try_into().map_err(|_| SerialError::IntegerOutOfRange)
}
}
impl<'a, T: DeserializeView<'a>> DeserializeView<'a> for Vec<T> {
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let declared = <Decoder<'a> as crate::Decode>::read_varint_u64(decoder)?;
let max = <Decoder<'a> as crate::Decode>::max_alloc(decoder) as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: 0,
});
}
let len = usize::try_from(declared).map_err(|_| SerialError::IntegerOutOfRange)?;
let initial = len.min(4096);
let mut out = Vec::with_capacity(initial);
for _ in 0..len {
out.push(T::deserialize_view(decoder)?);
}
Ok(out)
}
}
macro_rules! view_tuple {
($($name:ident),+) => {
impl<'a, $($name: DeserializeView<'a>),+> DeserializeView<'a> for ($($name,)+) {
#[inline]
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
Ok(( $( $name::deserialize_view(decoder)?, )+ ))
}
}
};
}
view_tuple!(T0);
view_tuple!(T0, T1);
view_tuple!(T0, T1, T2);
view_tuple!(T0, T1, T2, T3);
view_tuple!(T0, T1, T2, T3, T4);
view_tuple!(T0, T1, T2, T3, T4, T5);
view_tuple!(T0, T1, T2, T3, T4, T5, T6);
view_tuple!(T0, T1, T2, T3, T4, T5, T6, T7);
view_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8);
view_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9);
view_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
view_tuple!(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
impl<'a, K, V> DeserializeView<'a> for BTreeMap<K, V>
where
K: DeserializeView<'a> + Ord,
V: DeserializeView<'a>,
{
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let declared = <Decoder<'a> as crate::Decode>::read_varint_u64(decoder)?;
let max = <Decoder<'a> as crate::Decode>::max_alloc(decoder) as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: 0,
});
}
let len = usize::try_from(declared).map_err(|_| SerialError::IntegerOutOfRange)?;
let mut out = BTreeMap::new();
for _ in 0..len {
let k = K::deserialize_view(decoder)?;
let v = V::deserialize_view(decoder)?;
let _ = out.insert(k, v);
}
Ok(out)
}
}
impl<'a, T> DeserializeView<'a> for BTreeSet<T>
where
T: DeserializeView<'a> + Ord,
{
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let declared = <Decoder<'a> as crate::Decode>::read_varint_u64(decoder)?;
let max = <Decoder<'a> as crate::Decode>::max_alloc(decoder) as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: 0,
});
}
let len = usize::try_from(declared).map_err(|_| SerialError::IntegerOutOfRange)?;
let mut out = BTreeSet::new();
for _ in 0..len {
let _ = out.insert(T::deserialize_view(decoder)?);
}
Ok(out)
}
}
#[cfg(feature = "std")]
impl<'a, K, V, S> DeserializeView<'a> for HashMap<K, V, S>
where
K: DeserializeView<'a> + Hash + Eq,
V: DeserializeView<'a>,
S: BuildHasher + Default,
{
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let declared = <Decoder<'a> as crate::Decode>::read_varint_u64(decoder)?;
let max = <Decoder<'a> as crate::Decode>::max_alloc(decoder) as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: 0,
});
}
let len = usize::try_from(declared).map_err(|_| SerialError::IntegerOutOfRange)?;
let initial = len.min(4096);
let mut out = HashMap::with_capacity_and_hasher(initial, S::default());
for _ in 0..len {
let k = K::deserialize_view(decoder)?;
let v = V::deserialize_view(decoder)?;
let _ = out.insert(k, v);
}
Ok(out)
}
}
#[cfg(feature = "std")]
impl<'a, T, S> DeserializeView<'a> for HashSet<T, S>
where
T: DeserializeView<'a> + Hash + Eq,
S: BuildHasher + Default,
{
fn deserialize_view(decoder: &mut Decoder<'a>) -> Result<Self> {
let declared = <Decoder<'a> as crate::Decode>::read_varint_u64(decoder)?;
let max = <Decoder<'a> as crate::Decode>::max_alloc(decoder) as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: 0,
});
}
let len = usize::try_from(declared).map_err(|_| SerialError::IntegerOutOfRange)?;
let initial = len.min(4096);
let mut out = HashSet::with_capacity_and_hasher(initial, S::default());
for _ in 0..len {
let _ = out.insert(T::deserialize_view(decoder)?);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode;
use alloc::string::ToString;
use alloc::vec;
#[test]
fn borrowed_str_round_trips() {
let bytes = encode(&"hello").unwrap();
let view: &str = decode_view(&bytes).unwrap();
assert_eq!(view, "hello");
}
#[test]
fn borrowed_bytes_round_trip() {
let bytes = encode(&vec![1u8, 2, 3, 4, 5]).unwrap();
let view: &[u8] = decode_view(&bytes).unwrap();
assert_eq!(view, &[1, 2, 3, 4, 5]);
}
#[test]
fn primitive_view_decodes_like_owning() {
let bytes = encode(&42_u64).unwrap();
let n: u64 = decode_view(&bytes).unwrap();
assert_eq!(n, 42);
}
#[test]
fn option_borrowed_view_round_trips() {
let bytes = encode(&Some(String::from("hi"))).unwrap();
let v: Option<&str> = decode_view(&bytes).unwrap();
assert_eq!(v, Some("hi"));
let none_bytes = encode::<Option<String>>(&None).unwrap();
let v: Option<&str> = decode_view(&none_bytes).unwrap();
assert_eq!(v, None);
}
#[test]
fn tuple_with_borrowed_str_round_trips() {
let owned = (7_u64, String::from("hello"), true);
let bytes = encode(&owned).unwrap();
let view: (u64, &str, bool) = decode_view(&bytes).unwrap();
assert_eq!(view, (7, "hello", true));
}
#[test]
fn vec_of_borrowed_str_round_trips() {
let owned = vec![
String::from("alpha"),
String::from("beta"),
String::from("gamma"),
];
let bytes = encode(&owned).unwrap();
let view: Vec<&str> = decode_view(&bytes).unwrap();
assert_eq!(view, vec!["alpha", "beta", "gamma"]);
}
#[test]
fn decode_view_rejects_trailing_bytes() {
let mut bytes = encode(&"hi").unwrap();
bytes.push(0xff);
let err = decode_view::<&str>(&bytes).expect_err("trailing bytes");
assert!(matches!(err, SerialError::TrailingBytes { remaining: 1 }));
}
#[test]
fn borrowed_str_with_invalid_utf8_is_rejected() {
let bytes = [0x02u8, 0xff, 0xff];
let err = decode_view::<&str>(&bytes).expect_err("invalid utf-8");
assert!(matches!(err, SerialError::InvalidUtf8));
}
#[test]
fn map_with_borrowed_keys_round_trips() {
let mut owned: BTreeMap<String, u32> = BTreeMap::new();
let _ = owned.insert("alpha".to_string(), 1);
let _ = owned.insert("beta".to_string(), 2);
let bytes = encode(&owned).unwrap();
let view: BTreeMap<&str, u32> = decode_view(&bytes).unwrap();
assert_eq!(view.get("alpha"), Some(&1));
assert_eq!(view.get("beta"), Some(&2));
}
}