1use std::io;
25
26use crate::{
27 DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
28 StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
29};
30
31#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
34pub struct ReadCounter {
35 pub count: usize,
37}
38
39impl io::Read for ReadCounter {
40 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
41 let count = buf.len();
42 self.count += count;
43 Ok(count)
44 }
45}
46
47#[derive(Clone, Debug)]
49pub struct ConfinedReader<R: io::Read> {
50 count: usize,
51 limit: usize,
52 reader: R,
53}
54
55impl<R: io::Read> From<R> for ConfinedReader<R> {
56 fn from(reader: R) -> Self {
57 Self {
58 count: 0,
59 limit: usize::MAX,
60 reader,
61 }
62 }
63}
64
65impl<R: io::Read> ConfinedReader<R> {
66 pub fn with(limit: usize, reader: R) -> Self {
67 Self {
68 count: 0,
69 limit,
70 reader,
71 }
72 }
73
74 pub fn count(&self) -> usize { self.count }
75
76 pub fn unconfine(self) -> R { self.reader }
77}
78
79impl<R: io::Read> io::Read for ConfinedReader<R> {
80 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
81 let len = self.reader.read(buf)?;
82 match self.count.checked_add(len) {
83 None => return Err(io::ErrorKind::OutOfMemory.into()),
84 Some(len) if len > self.limit => return Err(io::ErrorKind::InvalidInput.into()),
85 Some(len) => self.count = len,
86 };
87 Ok(len)
88 }
89}
90
91#[derive(Clone, Debug)]
92pub struct StreamReader<R: io::Read>(ConfinedReader<R>);
93
94impl<R: io::Read> StreamReader<R> {
95 pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
96 pub fn unconfine(self) -> R { self.0.unconfine() }
97}
98
99impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
100 pub fn cursor<const MAX: usize>(inner: T) -> Self {
101 Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
102 }
103}
104
105impl<R: io::Read> ReadRaw for StreamReader<R> {
106 fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
107 use io::Read;
108 let mut buf = vec![0u8; len];
109 self.0.read_exact(&mut buf)?;
110 Ok(buf)
111 }
112
113 fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
114 use io::Read;
115 let mut buf = [0u8; LEN];
116 self.0.read_exact(&mut buf)?;
117 Ok(buf)
118 }
119}
120
121impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
122 pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
123 pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
124}
125
126impl StreamReader<ReadCounter> {
127 pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }
128}
129
130#[derive(Clone, Debug, From)]
131pub struct StrictReader<R: ReadRaw>(R);
132
133impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
134 pub fn in_memory<const MAX: usize>(data: T) -> Self {
135 Self(StreamReader::in_memory::<MAX>(data))
136 }
137 pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
138}
139
140impl StrictReader<StreamReader<ReadCounter>> {
141 pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }
142}
143
144impl<R: ReadRaw> StrictReader<R> {
145 pub fn with(reader: R) -> Self { Self(reader) }
146
147 pub fn unbox(self) -> R { self.0 }
148}
149
150impl<R: ReadRaw> TypedRead for StrictReader<R> {
151 type TupleReader<'parent>
152 = TupleReader<'parent, R>
153 where Self: 'parent;
154 type StructReader<'parent>
155 = StructReader<'parent, R>
156 where Self: 'parent;
157 type UnionReader = Self;
158 type RawReader = R;
159
160 unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }
161
162 fn read_union<T: StrictUnion>(
163 &mut self,
164 inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
165 ) -> Result<T, DecodeError> {
166 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
167 let tag = u8::strict_decode(self)?;
168 let variant_name = T::variant_name_by_tag(tag)
169 .ok_or(DecodeError::UnionTagNotKnown(name.to_string(), tag))?;
170 inner(variant_name, self)
171 }
172
173 fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
174 where u8: From<T> {
175 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
176 let tag = u8::strict_decode(self)?;
177 T::try_from(tag).map_err(|_| DecodeError::EnumTagNotKnown(name.to_string(), tag))
178 }
179
180 fn read_tuple<'parent, 'me, T: StrictTuple>(
181 &'me mut self,
182 inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
183 ) -> Result<T, DecodeError>
184 where
185 Self: 'parent,
186 'me: 'parent,
187 {
188 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
189 let mut reader = TupleReader {
190 read_fields: 0,
191 parent: self,
192 };
193 let res = inner(&mut reader)?;
194 assert_ne!(reader.read_fields, 0, "you forget to read fields for a tuple {}", name);
195 assert_eq!(
196 reader.read_fields,
197 T::FIELD_COUNT,
198 "the number of fields read for a tuple {} doesn't match type declaration",
199 name
200 );
201 Ok(res)
202 }
203
204 fn read_struct<'parent, 'me, T: StrictStruct>(
205 &'me mut self,
206 inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
207 ) -> Result<T, DecodeError>
208 where
209 Self: 'parent,
210 'me: 'parent,
211 {
212 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
213 let mut reader = StructReader {
214 named_fields: empty!(),
215 parent: self,
216 };
217 let res = inner(&mut reader)?;
218 assert!(!reader.named_fields.is_empty(), "you forget to read fields for a tuple {}", name);
219
220 for field in T::ALL_FIELDS {
221 let pos = reader
222 .named_fields
223 .iter()
224 .position(|f| f.as_str() == *field)
225 .unwrap_or_else(|| panic!("field {} is not read for {}", field, name));
226 reader.named_fields.remove(pos);
227 }
228 assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
229 Ok(res)
230 }
231}
232
233#[derive(Debug)]
234pub struct TupleReader<'parent, R: ReadRaw> {
235 read_fields: u8,
236 parent: &'parent mut StrictReader<R>,
237}
238
239impl<R: ReadRaw> ReadTuple for TupleReader<'_, R> {
240 fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
241 self.read_fields += 1;
242 T::strict_decode(self.parent)
243 }
244}
245
246#[derive(Debug)]
247pub struct StructReader<'parent, R: ReadRaw> {
248 named_fields: Vec<FieldName>,
249 parent: &'parent mut StrictReader<R>,
250}
251
252impl<R: ReadRaw> ReadStruct for StructReader<'_, R> {
253 fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
254 self.named_fields.push(field);
255 T::strict_decode(self.parent)
256 }
257}
258
259impl<R: ReadRaw> ReadUnion for StrictReader<R> {
260 type TupleReader<'parent>
261 = TupleReader<'parent, R>
262 where Self: 'parent;
263 type StructReader<'parent>
264 = StructReader<'parent, R>
265 where Self: 'parent;
266
267 fn read_tuple<'parent, 'me, T: StrictSum>(
268 &'me mut self,
269 inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
270 ) -> Result<T, DecodeError>
271 where
272 Self: 'parent,
273 'me: 'parent,
274 {
275 let mut reader = TupleReader {
276 read_fields: 0,
277 parent: self,
278 };
279 inner(&mut reader)
280 }
281
282 fn read_struct<'parent, 'me, T: StrictSum>(
283 &'me mut self,
284 inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
285 ) -> Result<T, DecodeError>
286 where
287 Self: 'parent,
288 'me: 'parent,
289 {
290 let mut reader = StructReader {
291 named_fields: empty!(),
292 parent: self,
293 };
294 inner(&mut reader)
295 }
296}