1use std::io;
24
25use crate::{
26 DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
27 StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
28};
29
30#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
33pub struct ReadCounter {
34 pub count: usize,
36}
37
38impl io::Read for ReadCounter {
39 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
40 let count = buf.len();
41 self.count += count;
42 Ok(count)
43 }
44}
45
46#[derive(Clone, Debug)]
48pub struct ConfinedReader<R: io::Read> {
49 count: usize,
50 limit: usize,
51 reader: R,
52}
53
54impl<R: io::Read> From<R> for ConfinedReader<R> {
55 fn from(reader: R) -> Self {
56 Self {
57 count: 0,
58 limit: usize::MAX,
59 reader,
60 }
61 }
62}
63
64impl<R: io::Read> ConfinedReader<R> {
65 pub fn with(limit: usize, reader: R) -> Self {
66 Self {
67 count: 0,
68 limit,
69 reader,
70 }
71 }
72
73 pub fn count(&self) -> usize { self.count }
74
75 pub fn unconfine(self) -> R { self.reader }
76}
77
78impl<R: io::Read> io::Read for ConfinedReader<R> {
79 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
80 let len = self.reader.read(buf)?;
81 match self.count.checked_add(len) {
82 None => return Err(io::ErrorKind::OutOfMemory.into()),
83 Some(len) if len > self.limit => return Err(io::ErrorKind::InvalidInput.into()),
84 Some(len) => self.count = len,
85 };
86 Ok(len)
87 }
88}
89
90#[derive(Clone, Debug)]
91pub struct StreamReader<R: io::Read>(ConfinedReader<R>);
92
93impl<R: io::Read> StreamReader<R> {
94 pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
95 pub fn unconfine(self) -> R { self.0.unconfine() }
96}
97
98impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
99 pub fn cursor<const MAX: usize>(inner: T) -> Self {
100 Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
101 }
102}
103
104impl<R: io::Read> ReadRaw for StreamReader<R> {
105 fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
106 use io::Read;
107 let mut buf = vec![0u8; len];
108 self.0.read_exact(&mut buf)?;
109 Ok(buf)
110 }
111
112 fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
113 use io::Read;
114 let mut buf = [0u8; LEN];
115 self.0.read_exact(&mut buf)?;
116 Ok(buf)
117 }
118}
119
120impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
121 pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
122 pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
123}
124
125impl StreamReader<ReadCounter> {
126 pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }
127}
128
129#[derive(Clone, Debug, From)]
130pub struct StrictReader<R: ReadRaw>(R);
131
132impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
133 pub fn in_memory<const MAX: usize>(data: T) -> Self {
134 Self(StreamReader::in_memory::<MAX>(data))
135 }
136 pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
137}
138
139impl StrictReader<StreamReader<ReadCounter>> {
140 pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }
141}
142
143impl<R: ReadRaw> StrictReader<R> {
144 pub fn with(reader: R) -> Self { Self(reader) }
145
146 pub fn unbox(self) -> R { self.0 }
147}
148
149impl<R: ReadRaw> TypedRead for StrictReader<R> {
150 type TupleReader<'parent>
151 = TupleReader<'parent, R>
152 where Self: 'parent;
153 type StructReader<'parent>
154 = StructReader<'parent, R>
155 where Self: 'parent;
156 type UnionReader = Self;
157 type RawReader = R;
158
159 unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }
160
161 fn read_union<T: StrictUnion>(
162 &mut self,
163 inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
164 ) -> Result<T, DecodeError> {
165 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
166 let tag = u8::strict_decode(self)?;
167 let variant_name = T::variant_name_by_tag(tag)
168 .ok_or(DecodeError::UnionTagNotKnown(name.to_string(), tag))?;
169 inner(variant_name, self)
170 }
171
172 fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
173 where u8: From<T> {
174 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
175 let tag = u8::strict_decode(self)?;
176 T::try_from(tag).map_err(|_| DecodeError::EnumTagNotKnown(name.to_string(), tag))
177 }
178
179 fn read_tuple<'parent, 'me, T: StrictTuple>(
180 &'me mut self,
181 inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
182 ) -> Result<T, DecodeError>
183 where
184 Self: 'parent,
185 'me: 'parent,
186 {
187 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
188 let mut reader = TupleReader {
189 read_fields: 0,
190 parent: self,
191 };
192 let res = inner(&mut reader)?;
193 assert_ne!(reader.read_fields, 0, "you forget to read fields for a tuple {}", name);
194 assert_eq!(
195 reader.read_fields,
196 T::FIELD_COUNT,
197 "the number of fields read for a tuple {} doesn't match type declaration",
198 name
199 );
200 Ok(res)
201 }
202
203 fn read_struct<'parent, 'me, T: StrictStruct>(
204 &'me mut self,
205 inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
206 ) -> Result<T, DecodeError>
207 where
208 Self: 'parent,
209 'me: 'parent,
210 {
211 let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
212 let mut reader = StructReader {
213 named_fields: empty!(),
214 parent: self,
215 };
216 let res = inner(&mut reader)?;
217 assert!(!reader.named_fields.is_empty(), "you forget to read fields for a tuple {}", name);
218
219 for field in T::ALL_FIELDS {
220 let pos = reader
221 .named_fields
222 .iter()
223 .position(|f| f.as_str() == *field)
224 .unwrap_or_else(|| panic!("field {} is not read for {}", field, name));
225 reader.named_fields.remove(pos);
226 }
227 assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
228 Ok(res)
229 }
230}
231
232#[derive(Debug)]
233pub struct TupleReader<'parent, R: ReadRaw> {
234 read_fields: u8,
235 parent: &'parent mut StrictReader<R>,
236}
237
238impl<R: ReadRaw> ReadTuple for TupleReader<'_, R> {
239 fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
240 self.read_fields += 1;
241 T::strict_decode(self.parent)
242 }
243}
244
245#[derive(Debug)]
246pub struct StructReader<'parent, R: ReadRaw> {
247 named_fields: Vec<FieldName>,
248 parent: &'parent mut StrictReader<R>,
249}
250
251impl<R: ReadRaw> ReadStruct for StructReader<'_, R> {
252 fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
253 self.named_fields.push(field);
254 T::strict_decode(self.parent)
255 }
256}
257
258impl<R: ReadRaw> ReadUnion for StrictReader<R> {
259 type TupleReader<'parent>
260 = TupleReader<'parent, R>
261 where Self: 'parent;
262 type StructReader<'parent>
263 = StructReader<'parent, R>
264 where Self: 'parent;
265
266 fn read_tuple<'parent, 'me, T: StrictSum>(
267 &'me mut self,
268 inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
269 ) -> Result<T, DecodeError>
270 where
271 Self: 'parent,
272 'me: 'parent,
273 {
274 let mut reader = TupleReader {
275 read_fields: 0,
276 parent: self,
277 };
278 inner(&mut reader)
279 }
280
281 fn read_struct<'parent, 'me, T: StrictSum>(
282 &'me mut self,
283 inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
284 ) -> Result<T, DecodeError>
285 where
286 Self: 'parent,
287 'me: 'parent,
288 {
289 let mut reader = StructReader {
290 named_fields: empty!(),
291 parent: self,
292 };
293 inner(&mut reader)
294 }
295}