strict_encoding/
reader.rs

1// Strict encoding library for deterministic binary serialization.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Designed in 2019-2025 by Dr Maxim Orlovsky <orlovsky@ubideco.org>
6// Written in 2024-2025 by Dr Maxim Orlovsky <orlovsky@ubideco.org>
7//
8// Copyright (C) 2022-2025 Laboratories for Ubiquitous Deterministic Computing (UBIDECO),
9//                         Institute for Distributed and Cognitive Systems (InDCS), Switzerland.
10// Copyright (C) 2022-2025 Dr Maxim Orlovsky.
11// All rights under the above copyrights are reserved.
12//
13// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
14// in compliance with the License. You may obtain a copy of the License at
15//
16//        http://www.apache.org/licenses/LICENSE-2.0
17//
18// Unless required by applicable law or agreed to in writing, software distributed under the License
19// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
20// or implied. See the License for the specific language governing permissions and limitations under
21// the License.
22
23use std::io;
24
25use crate::{
26    DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
27    StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
28};
29
30// TODO: Move to amplify crate
31/// A simple way to count bytes read through [`io::Read`].
32#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
33pub struct ReadCounter {
34    /// Count of bytes which passed through this reader
35    pub count: usize,
36}
37
38impl io::Read for ReadCounter {
39    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
40        let count = buf.len();
41        self.count += count;
42        Ok(count)
43    }
44}
45
46// TODO: Move to amplify crate
47#[derive(Clone, Debug)]
48pub struct ConfinedReader<R: io::Read> {
49    count: usize,
50    limit: usize,
51    reader: R,
52}
53
54impl<R: io::Read> From<R> for ConfinedReader<R> {
55    fn from(reader: R) -> Self {
56        Self {
57            count: 0,
58            limit: usize::MAX,
59            reader,
60        }
61    }
62}
63
64impl<R: io::Read> ConfinedReader<R> {
65    pub fn with(limit: usize, reader: R) -> Self {
66        Self {
67            count: 0,
68            limit,
69            reader,
70        }
71    }
72
73    pub fn count(&self) -> usize { self.count }
74
75    pub fn unconfine(self) -> R { self.reader }
76}
77
78impl<R: io::Read> io::Read for ConfinedReader<R> {
79    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
80        let len = self.reader.read(buf)?;
81        match self.count.checked_add(len) {
82            None => return Err(io::ErrorKind::OutOfMemory.into()),
83            Some(len) if len > self.limit => return Err(io::ErrorKind::InvalidInput.into()),
84            Some(len) => self.count = len,
85        };
86        Ok(len)
87    }
88}
89
90#[derive(Clone, Debug)]
91pub struct StreamReader<R: io::Read>(ConfinedReader<R>);
92
93impl<R: io::Read> StreamReader<R> {
94    pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
95    pub fn unconfine(self) -> R { self.0.unconfine() }
96}
97
98impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
99    pub fn cursor<const MAX: usize>(inner: T) -> Self {
100        Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
101    }
102}
103
104impl<R: io::Read> ReadRaw for StreamReader<R> {
105    fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
106        use io::Read;
107        let mut buf = vec![0u8; len];
108        self.0.read_exact(&mut buf)?;
109        Ok(buf)
110    }
111
112    fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
113        use io::Read;
114        let mut buf = [0u8; LEN];
115        self.0.read_exact(&mut buf)?;
116        Ok(buf)
117    }
118}
119
120impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
121    pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
122    pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
123}
124
125impl StreamReader<ReadCounter> {
126    pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }
127}
128
129#[derive(Clone, Debug, From)]
130pub struct StrictReader<R: ReadRaw>(R);
131
132impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
133    pub fn in_memory<const MAX: usize>(data: T) -> Self {
134        Self(StreamReader::in_memory::<MAX>(data))
135    }
136    pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
137}
138
139impl StrictReader<StreamReader<ReadCounter>> {
140    pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }
141}
142
143impl<R: ReadRaw> StrictReader<R> {
144    pub fn with(reader: R) -> Self { Self(reader) }
145
146    pub fn unbox(self) -> R { self.0 }
147}
148
149impl<R: ReadRaw> TypedRead for StrictReader<R> {
150    type TupleReader<'parent>
151        = TupleReader<'parent, R>
152    where Self: 'parent;
153    type StructReader<'parent>
154        = StructReader<'parent, R>
155    where Self: 'parent;
156    type UnionReader = Self;
157    type RawReader = R;
158
159    unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }
160
161    fn read_union<T: StrictUnion>(
162        &mut self,
163        inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
164    ) -> Result<T, DecodeError> {
165        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
166        let tag = u8::strict_decode(self)?;
167        let variant_name = T::variant_name_by_tag(tag)
168            .ok_or(DecodeError::UnionTagNotKnown(name.to_string(), tag))?;
169        inner(variant_name, self)
170    }
171
172    fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
173    where u8: From<T> {
174        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
175        let tag = u8::strict_decode(self)?;
176        T::try_from(tag).map_err(|_| DecodeError::EnumTagNotKnown(name.to_string(), tag))
177    }
178
179    fn read_tuple<'parent, 'me, T: StrictTuple>(
180        &'me mut self,
181        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
182    ) -> Result<T, DecodeError>
183    where
184        Self: 'parent,
185        'me: 'parent,
186    {
187        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
188        let mut reader = TupleReader {
189            read_fields: 0,
190            parent: self,
191        };
192        let res = inner(&mut reader)?;
193        assert_ne!(reader.read_fields, 0, "you forget to read fields for a tuple {}", name);
194        assert_eq!(
195            reader.read_fields,
196            T::FIELD_COUNT,
197            "the number of fields read for a tuple {} doesn't match type declaration",
198            name
199        );
200        Ok(res)
201    }
202
203    fn read_struct<'parent, 'me, T: StrictStruct>(
204        &'me mut self,
205        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
206    ) -> Result<T, DecodeError>
207    where
208        Self: 'parent,
209        'me: 'parent,
210    {
211        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
212        let mut reader = StructReader {
213            named_fields: empty!(),
214            parent: self,
215        };
216        let res = inner(&mut reader)?;
217        assert!(!reader.named_fields.is_empty(), "you forget to read fields for a tuple {}", name);
218
219        for field in T::ALL_FIELDS {
220            let pos = reader
221                .named_fields
222                .iter()
223                .position(|f| f.as_str() == *field)
224                .unwrap_or_else(|| panic!("field {} is not read for {}", field, name));
225            reader.named_fields.remove(pos);
226        }
227        assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
228        Ok(res)
229    }
230}
231
232#[derive(Debug)]
233pub struct TupleReader<'parent, R: ReadRaw> {
234    read_fields: u8,
235    parent: &'parent mut StrictReader<R>,
236}
237
238impl<R: ReadRaw> ReadTuple for TupleReader<'_, R> {
239    fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
240        self.read_fields += 1;
241        T::strict_decode(self.parent)
242    }
243}
244
245#[derive(Debug)]
246pub struct StructReader<'parent, R: ReadRaw> {
247    named_fields: Vec<FieldName>,
248    parent: &'parent mut StrictReader<R>,
249}
250
251impl<R: ReadRaw> ReadStruct for StructReader<'_, R> {
252    fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
253        self.named_fields.push(field);
254        T::strict_decode(self.parent)
255    }
256}
257
258impl<R: ReadRaw> ReadUnion for StrictReader<R> {
259    type TupleReader<'parent>
260        = TupleReader<'parent, R>
261    where Self: 'parent;
262    type StructReader<'parent>
263        = StructReader<'parent, R>
264    where Self: 'parent;
265
266    fn read_tuple<'parent, 'me, T: StrictSum>(
267        &'me mut self,
268        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
269    ) -> Result<T, DecodeError>
270    where
271        Self: 'parent,
272        'me: 'parent,
273    {
274        let mut reader = TupleReader {
275            read_fields: 0,
276            parent: self,
277        };
278        inner(&mut reader)
279    }
280
281    fn read_struct<'parent, 'me, T: StrictSum>(
282        &'me mut self,
283        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
284    ) -> Result<T, DecodeError>
285    where
286        Self: 'parent,
287        'me: 'parent,
288    {
289        let mut reader = StructReader {
290            named_fields: empty!(),
291            parent: self,
292        };
293        inner(&mut reader)
294    }
295}