midnight_serialize/
deserializable.rs1use crate::VecExt;
15use crate::serializable::GLOBAL_TAG;
16use crate::tagged::Tagged;
17use std::borrow::Cow;
18use std::io::Read;
19use std::marker::PhantomData;
20use std::sync::Arc;
21use std::{collections::HashMap, collections::HashSet, hash::Hash};
22
23#[cfg(debug_assertions)]
24pub const RECURSION_LIMIT: u32 = 50;
25#[cfg(not(debug_assertions))]
26pub const RECURSION_LIMIT: u32 = 250;
27
28pub fn tagged_deserialize<T: Deserializable + Tagged>(mut reader: impl Read) -> std::io::Result<T> {
30 let tag_expected = format!("{GLOBAL_TAG}{}:", T::tag());
31 let mut read_tag = vec![0u8; tag_expected.len()];
32 let mut remaining_tag_buf = &mut read_tag[..];
33 while !remaining_tag_buf.is_empty() {
34 let read = reader.read(remaining_tag_buf)?;
35 if read == 0 {
36 let rem = remaining_tag_buf.len();
37 let len = read_tag.len() - rem;
38 read_tag.truncate(len);
39 break;
40 }
41 remaining_tag_buf = &mut remaining_tag_buf[read..];
42 }
43 if read_tag != tag_expected.as_bytes() {
44 let sanitised = String::from_utf8_lossy(&read_tag).replace(
45 |c: char| -> bool { !c.is_ascii_alphanumeric() && !":_-()[],".contains(c) },
46 "�",
47 );
48 return Err(std::io::Error::new(
49 std::io::ErrorKind::InvalidData,
50 format!("expected header tag '{tag_expected}', got '{sanitised}'"),
51 ));
52 }
53 let value = <T as Deserializable>::deserialize(&mut reader, 0)?;
54
55 #[allow(clippy::unbuffered_bytes)] let count = reader.bytes().count(); if count == 0 {
59 return Ok(value);
60 }
61
62 Err(std::io::Error::new(
63 std::io::ErrorKind::InvalidData,
64 format!(
65 "Not all bytes read deserializing '{}'; {} bytes remaining",
66 tag_expected, count
67 ),
68 ))
69}
70
71pub trait Deserializable
72where
73 Self: Sized,
74{
75 const LIMIT_RECURSION: bool = true;
76
77 fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> std::io::Result<Self>;
78
79 fn check_rec(depth: &mut u32) -> std::io::Result<()> {
80 if Self::LIMIT_RECURSION {
81 *depth += 1;
82 if *depth > RECURSION_LIMIT {
83 return Err(std::io::Error::new(
84 std::io::ErrorKind::InvalidData,
85 "exceeded recursion depth deserializing",
86 ));
87 }
88 }
89 Ok(())
90 }
91}
92
93impl<T: Deserializable> Deserializable for Vec<T> {
94 fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
95 Self::check_rec(&mut recursion_depth)?;
96 let len = <u32 as Deserializable>::deserialize(reader, recursion_depth)?;
97 let mut result = Vec::with_bounded_capacity(len as usize);
98 for _ in 0..len {
99 result.push(<T as Deserializable>::deserialize(reader, recursion_depth)?);
100 }
101 Ok(result)
102 }
103}
104
105impl<K: Deserializable + PartialOrd + Hash + Eq, V: Deserializable> Deserializable
106 for HashMap<K, V>
107{
108 fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
109 Self::check_rec(&mut recursion_depth)?;
110 let len = <u32 as Deserializable>::deserialize(reader, recursion_depth)?;
111 let mut result = HashMap::new();
112 for _ in 0..len {
113 let k = <K as Deserializable>::deserialize(reader, recursion_depth)?;
114 let v = <V as Deserializable>::deserialize(reader, recursion_depth)?;
115 result.insert(k, v);
116 }
117 Ok(result)
118 }
119}
120
121impl<T: Deserializable + Hash + Eq> Deserializable for HashSet<T> {
122 fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
123 Self::check_rec(&mut recursion_depth)?;
124 let len = <u32 as Deserializable>::deserialize(reader, recursion_depth)?;
125 let mut result = HashSet::new();
126 for _ in 0..len {
127 result.insert(<T as Deserializable>::deserialize(reader, recursion_depth)?);
128 }
129 Ok(result)
130 }
131}
132
133impl<T: Deserializable> Deserializable for Option<T> {
134 fn deserialize(reader: &mut impl Read, mut recursion_depth: u32) -> std::io::Result<Self> {
135 Self::check_rec(&mut recursion_depth)?;
136 let some = <u8 as Deserializable>::deserialize(reader, recursion_depth)?;
137 match some {
138 0 => Ok(None),
139 1 => Ok(Some(<T as Deserializable>::deserialize(
140 reader,
141 recursion_depth,
142 )?)),
143 _ => Err(std::io::Error::new(
144 std::io::ErrorKind::InvalidData,
145 format!("Invalid discriminant: {}.", some),
146 )),
147 }
148 }
149}
150
151impl<T: Deserializable> Deserializable for Arc<T> {
152 fn deserialize(
153 reader: &mut impl Read,
154 mut recursion_depth: u32,
155 ) -> Result<Self, std::io::Error> {
156 Self::check_rec(&mut recursion_depth)?;
157 Ok(Arc::new(T::deserialize(reader, recursion_depth)?))
158 }
159}
160
161impl<const N: usize> Deserializable for [u8; N] {
162 fn deserialize(reader: &mut impl Read, _recursion_depth: u32) -> std::io::Result<Self> {
163 let mut res = [0u8; N];
164 reader.read_exact(&mut res[..])?;
165 Ok(res)
166 }
167}
168
169impl<T: ?Sized> Deserializable for PhantomData<T> {
170 fn deserialize(_reader: &mut impl Read, _recursion_depth: u32) -> std::io::Result<Self> {
171 Ok(PhantomData)
172 }
173}
174
175impl<T: Deserializable> Deserializable for Box<T> {
176 fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> std::io::Result<Self> {
177 T::deserialize(reader, recursion_depth).map(Box::new)
178 }
179}
180
181impl<'a, T: ToOwned + ?Sized> Deserializable for Cow<'a, T>
182where
183 T::Owned: Deserializable,
184{
185 fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> std::io::Result<Self> {
186 <T::Owned>::deserialize(reader, recursion_depth).map(Cow::Owned)
187 }
188}