strict_encoding/
reader.rs

1// Strict encoding library for deterministic binary serialization.
2//
3// SPDX-License-Identifier: Apache-2.0
4//
5// Designed in 2019-2025 by Dr Maxim Orlovsky <orlovsky@ubideco.org>
6// Written in 2024-2025 by Dr Maxim Orlovsky <orlovsky@ubideco.org>
7//
8// Copyright (C) 2019-2022 LNP/BP Standards Association.
9// Copyright (C) 2022-2025 Laboratories for Ubiquitous Deterministic Computing (UBIDECO),
10//                         Institute for Distributed and Cognitive Systems (InDCS), Switzerland.
11// Copyright (C) 2019-2025 Dr Maxim Orlovsky.
12// All rights under the above copyrights are reserved.
13//
14// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
15// in compliance with the License. You may obtain a copy of the License at
16//
17//        http://www.apache.org/licenses/LICENSE-2.0
18//
19// Unless required by applicable law or agreed to in writing, software distributed under the License
20// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
21// or implied. See the License for the specific language governing permissions and limitations under
22// the License.
23
24use std::io;
25
26use crate::{
27    DecodeError, FieldName, ReadRaw, ReadStruct, ReadTuple, ReadUnion, StrictDecode, StrictEnum,
28    StrictStruct, StrictSum, StrictTuple, StrictUnion, TypedRead, VariantName,
29};
30
31// TODO: Move to amplify crate
32/// A simple way to count bytes read through [`io::Read`].
33#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
34pub struct ReadCounter {
35    /// Count of bytes which passed through this reader
36    pub count: usize,
37}
38
39impl io::Read for ReadCounter {
40    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
41        let count = buf.len();
42        self.count += count;
43        Ok(count)
44    }
45}
46
47// TODO: Move to amplify crate
48#[derive(Clone, Debug)]
49pub struct ConfinedReader<R: io::Read> {
50    count: usize,
51    limit: usize,
52    reader: R,
53}
54
55impl<R: io::Read> From<R> for ConfinedReader<R> {
56    fn from(reader: R) -> Self {
57        Self {
58            count: 0,
59            limit: usize::MAX,
60            reader,
61        }
62    }
63}
64
65impl<R: io::Read> ConfinedReader<R> {
66    pub fn with(limit: usize, reader: R) -> Self {
67        Self {
68            count: 0,
69            limit,
70            reader,
71        }
72    }
73
74    pub fn count(&self) -> usize { self.count }
75
76    pub fn unconfine(self) -> R { self.reader }
77}
78
79impl<R: io::Read> io::Read for ConfinedReader<R> {
80    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
81        let len = self.reader.read(buf)?;
82        match self.count.checked_add(len) {
83            None => return Err(io::ErrorKind::OutOfMemory.into()),
84            Some(len) if len > self.limit => return Err(io::ErrorKind::InvalidInput.into()),
85            Some(len) => self.count = len,
86        };
87        Ok(len)
88    }
89}
90
91#[derive(Clone, Debug)]
92pub struct StreamReader<R: io::Read>(ConfinedReader<R>);
93
94impl<R: io::Read> StreamReader<R> {
95    pub fn new<const MAX: usize>(inner: R) -> Self { Self(ConfinedReader::with(MAX, inner)) }
96    pub fn unconfine(self) -> R { self.0.unconfine() }
97}
98
99impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
100    pub fn cursor<const MAX: usize>(inner: T) -> Self {
101        Self(ConfinedReader::with(MAX, io::Cursor::new(inner)))
102    }
103}
104
105impl<R: io::Read> ReadRaw for StreamReader<R> {
106    fn read_raw<const MAX_LEN: usize>(&mut self, len: usize) -> io::Result<Vec<u8>> {
107        use io::Read;
108        let mut buf = vec![0u8; len];
109        self.0.read_exact(&mut buf)?;
110        Ok(buf)
111    }
112
113    fn read_raw_array<const LEN: usize>(&mut self) -> io::Result<[u8; LEN]> {
114        use io::Read;
115        let mut buf = [0u8; LEN];
116        self.0.read_exact(&mut buf)?;
117        Ok(buf)
118    }
119}
120
121impl<T: AsRef<[u8]>> StreamReader<io::Cursor<T>> {
122    pub fn in_memory<const MAX: usize>(data: T) -> Self { Self::new::<MAX>(io::Cursor::new(data)) }
123    pub fn into_cursor(self) -> io::Cursor<T> { self.0.unconfine() }
124}
125
126impl StreamReader<ReadCounter> {
127    pub fn counter<const MAX: usize>() -> Self { Self::new::<MAX>(ReadCounter::default()) }
128}
129
130#[derive(Clone, Debug, From)]
131pub struct StrictReader<R: ReadRaw>(R);
132
133impl<T: AsRef<[u8]>> StrictReader<StreamReader<io::Cursor<T>>> {
134    pub fn in_memory<const MAX: usize>(data: T) -> Self {
135        Self(StreamReader::in_memory::<MAX>(data))
136    }
137    pub fn into_cursor(self) -> io::Cursor<T> { self.0.into_cursor() }
138}
139
140impl StrictReader<StreamReader<ReadCounter>> {
141    pub fn counter<const MAX: usize>() -> Self { Self(StreamReader::counter::<MAX>()) }
142}
143
144impl<R: ReadRaw> StrictReader<R> {
145    pub fn with(reader: R) -> Self { Self(reader) }
146
147    pub fn unbox(self) -> R { self.0 }
148}
149
150impl<R: ReadRaw> TypedRead for StrictReader<R> {
151    type TupleReader<'parent>
152        = TupleReader<'parent, R>
153    where Self: 'parent;
154    type StructReader<'parent>
155        = StructReader<'parent, R>
156    where Self: 'parent;
157    type UnionReader = Self;
158    type RawReader = R;
159
160    unsafe fn raw_reader(&mut self) -> &mut Self::RawReader { &mut self.0 }
161
162    fn read_union<T: StrictUnion>(
163        &mut self,
164        inner: impl FnOnce(VariantName, &mut Self::UnionReader) -> Result<T, DecodeError>,
165    ) -> Result<T, DecodeError> {
166        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
167        let tag = u8::strict_decode(self)?;
168        let variant_name = T::variant_name_by_tag(tag)
169            .ok_or(DecodeError::UnionTagNotKnown(name.to_string(), tag))?;
170        inner(variant_name, self)
171    }
172
173    fn read_enum<T: StrictEnum>(&mut self) -> Result<T, DecodeError>
174    where u8: From<T> {
175        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
176        let tag = u8::strict_decode(self)?;
177        T::try_from(tag).map_err(|_| DecodeError::EnumTagNotKnown(name.to_string(), tag))
178    }
179
180    fn read_tuple<'parent, 'me, T: StrictTuple>(
181        &'me mut self,
182        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
183    ) -> Result<T, DecodeError>
184    where
185        Self: 'parent,
186        'me: 'parent,
187    {
188        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
189        let mut reader = TupleReader {
190            read_fields: 0,
191            parent: self,
192        };
193        let res = inner(&mut reader)?;
194        assert_ne!(reader.read_fields, 0, "you forget to read fields for a tuple {}", name);
195        assert_eq!(
196            reader.read_fields,
197            T::FIELD_COUNT,
198            "the number of fields read for a tuple {} doesn't match type declaration",
199            name
200        );
201        Ok(res)
202    }
203
204    fn read_struct<'parent, 'me, T: StrictStruct>(
205        &'me mut self,
206        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
207    ) -> Result<T, DecodeError>
208    where
209        Self: 'parent,
210        'me: 'parent,
211    {
212        let name = T::strict_name().unwrap_or_else(|| tn!("__unnamed"));
213        let mut reader = StructReader {
214            named_fields: empty!(),
215            parent: self,
216        };
217        let res = inner(&mut reader)?;
218        assert!(!reader.named_fields.is_empty(), "you forget to read fields for a tuple {}", name);
219
220        for field in T::ALL_FIELDS {
221            let pos = reader
222                .named_fields
223                .iter()
224                .position(|f| f.as_str() == *field)
225                .unwrap_or_else(|| panic!("field {} is not read for {}", field, name));
226            reader.named_fields.remove(pos);
227        }
228        assert!(reader.named_fields.is_empty(), "excessive fields are read for {}", name);
229        Ok(res)
230    }
231}
232
233#[derive(Debug)]
234pub struct TupleReader<'parent, R: ReadRaw> {
235    read_fields: u8,
236    parent: &'parent mut StrictReader<R>,
237}
238
239impl<R: ReadRaw> ReadTuple for TupleReader<'_, R> {
240    fn read_field<T: StrictDecode>(&mut self) -> Result<T, DecodeError> {
241        self.read_fields += 1;
242        T::strict_decode(self.parent)
243    }
244}
245
246#[derive(Debug)]
247pub struct StructReader<'parent, R: ReadRaw> {
248    named_fields: Vec<FieldName>,
249    parent: &'parent mut StrictReader<R>,
250}
251
252impl<R: ReadRaw> ReadStruct for StructReader<'_, R> {
253    fn read_field<T: StrictDecode>(&mut self, field: FieldName) -> Result<T, DecodeError> {
254        self.named_fields.push(field);
255        T::strict_decode(self.parent)
256    }
257}
258
259impl<R: ReadRaw> ReadUnion for StrictReader<R> {
260    type TupleReader<'parent>
261        = TupleReader<'parent, R>
262    where Self: 'parent;
263    type StructReader<'parent>
264        = StructReader<'parent, R>
265    where Self: 'parent;
266
267    fn read_tuple<'parent, 'me, T: StrictSum>(
268        &'me mut self,
269        inner: impl FnOnce(&mut Self::TupleReader<'parent>) -> Result<T, DecodeError>,
270    ) -> Result<T, DecodeError>
271    where
272        Self: 'parent,
273        'me: 'parent,
274    {
275        let mut reader = TupleReader {
276            read_fields: 0,
277            parent: self,
278        };
279        inner(&mut reader)
280    }
281
282    fn read_struct<'parent, 'me, T: StrictSum>(
283        &'me mut self,
284        inner: impl FnOnce(&mut Self::StructReader<'parent>) -> Result<T, DecodeError>,
285    ) -> Result<T, DecodeError>
286    where
287        Self: 'parent,
288        'me: 'parent,
289    {
290        let mut reader = StructReader {
291            named_fields: empty!(),
292            parent: self,
293        };
294        inner(&mut reader)
295    }
296}