musli_binary_common/
reader.rs

1//! Trait for governing how a particular source of bytes is read.
2//!
3//! `musli` requires all sources to reference the complete data being read from
4//! it which allows it to make the assumption the bytes are always returned with
5//! the `'de` lifetime.
6
7use core::fmt;
8use core::marker;
9use core::ops::Range;
10use core::ptr;
11use core::slice;
12
13use musli::de::ValueVisitor;
14use musli::error::Error;
15
16/// A reader where the current position is exactly known.
17pub trait PositionedReader<'de>: Reader<'de> {
18    /// The exact position of a reader.
19    fn pos(&self) -> usize;
20}
21
22/// Trait governing how a source of bytes is read.
23///
24/// This requires the reader to be able to hand out contiguous references to the
25/// byte source through [Reader::read_bytes].
26pub trait Reader<'de> {
27    /// Error type raised by the current reader.
28    type Error: Error;
29
30    /// Skip over the given number of bytes.
31    fn skip(&mut self, n: usize) -> Result<(), Self::Error>;
32
33    /// Read a slice into the given buffer.
34    #[inline]
35    fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
36        return self.read_bytes(buf.len(), Visitor::<Self::Error>(buf, marker::PhantomData));
37
38        struct Visitor<'a, E>(&'a mut [u8], marker::PhantomData<E>);
39
40        impl<'a, 'de, E> ValueVisitor<'de> for Visitor<'a, E>
41        where
42            E: Error,
43        {
44            type Target = [u8];
45            type Ok = ();
46            type Error = E;
47
48            #[inline]
49            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50                write!(f, "bytes")
51            }
52
53            #[inline]
54            fn visit_borrowed(self, bytes: &'de Self::Target) -> Result<Self::Ok, Self::Error> {
55                self.visit_any(bytes)
56            }
57
58            #[inline]
59            fn visit_any(self, bytes: &Self::Target) -> Result<Self::Ok, Self::Error> {
60                self.0.copy_from_slice(bytes);
61                Ok(())
62            }
63        }
64    }
65
66    /// Read a slice out of the current reader.
67    fn read_bytes<V>(&mut self, n: usize, visitor: V) -> Result<V::Ok, V::Error>
68    where
69        V: ValueVisitor<'de, Target = [u8], Error = Self::Error>;
70
71    /// Read a single byte.
72    #[inline]
73    fn read_byte(&mut self) -> Result<u8, Self::Error> {
74        let [byte] = self.read_array::<1>()?;
75        Ok(byte)
76    }
77
78    /// Read an array out of the current reader.
79    #[inline]
80    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], Self::Error> {
81        return self.read_bytes(N, Visitor::<N, Self::Error>([0u8; N], marker::PhantomData));
82
83        struct Visitor<const N: usize, E>([u8; N], marker::PhantomData<E>);
84
85        impl<'de, const N: usize, E> ValueVisitor<'de> for Visitor<N, E>
86        where
87            E: Error,
88        {
89            type Target = [u8];
90            type Ok = [u8; N];
91            type Error = E;
92
93            #[inline]
94            fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95                write!(f, "bytes")
96            }
97
98            #[inline]
99            fn visit_borrowed(self, bytes: &'de Self::Target) -> Result<Self::Ok, Self::Error> {
100                self.visit_any(bytes)
101            }
102
103            #[inline]
104            fn visit_any(mut self, bytes: &Self::Target) -> Result<Self::Ok, Self::Error> {
105                self.0.copy_from_slice(bytes);
106                Ok(self.0)
107            }
108        }
109    }
110
111    /// Keep an accurate record of the position within the reader.
112    fn with_position(self) -> WithPosition<Self>
113    where
114        Self: Sized,
115    {
116        WithPosition {
117            pos: 0,
118            reader: self,
119        }
120    }
121
122    /// Keep an accurate record of the position within the reader.
123    fn limit(self, limit: usize) -> Limit<Self>
124    where
125        Self: Sized,
126    {
127        Limit {
128            remaining: limit,
129            reader: self,
130        }
131    }
132}
133
134decl_message_repr!(SliceReaderErrorRepr, "error reading from slice");
135
136/// An error raised while decoding a slice.
137#[derive(Debug)]
138pub struct SliceReaderError(SliceReaderErrorRepr);
139
140impl fmt::Display for SliceReaderError {
141    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        self.0.fmt(f)
143    }
144}
145
146impl Error for SliceReaderError {
147    #[inline]
148    fn custom<T>(message: T) -> Self
149    where
150        T: 'static + Send + Sync + fmt::Display + fmt::Debug,
151    {
152        Self(SliceReaderErrorRepr::collect(message))
153    }
154
155    #[inline]
156    fn message<T>(message: T) -> Self
157    where
158        T: fmt::Display,
159    {
160        Self(SliceReaderErrorRepr::collect(message))
161    }
162}
163
164#[cfg(feature = "std")]
165impl std::error::Error for SliceReaderError {}
166
167impl<'de> Reader<'de> for &'de [u8] {
168    type Error = SliceReaderError;
169
170    #[inline]
171    fn skip(&mut self, n: usize) -> Result<(), Self::Error> {
172        if self.len() < n {
173            return Err(SliceReaderError::custom("buffer underflow"));
174        }
175
176        let (_, tail) = self.split_at(n);
177        *self = tail;
178        Ok(())
179    }
180
181    #[inline]
182    fn read_bytes<V>(&mut self, n: usize, visitor: V) -> Result<V::Ok, V::Error>
183    where
184        V: ValueVisitor<'de, Target = [u8], Error = Self::Error>,
185    {
186        if self.len() < n {
187            return Err(SliceReaderError::custom("buffer underflow"));
188        }
189
190        let (head, tail) = self.split_at(n);
191        *self = tail;
192        visitor.visit_borrowed(head)
193    }
194
195    #[inline]
196    fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
197        if self.len() < buf.len() {
198            return Err(SliceReaderError::custom("buffer underflow"));
199        }
200
201        let (head, tail) = self.split_at(buf.len());
202        buf.copy_from_slice(head);
203        *self = tail;
204        Ok(())
205    }
206}
207
208/// An efficient [Reader] wrapper around a slice.
209pub struct SliceReader<'de> {
210    range: Range<*const u8>,
211    _marker: marker::PhantomData<&'de [u8]>,
212}
213
214impl<'de> SliceReader<'de> {
215    /// Construct a new instance around the specified slice.
216    #[inline]
217    pub fn new(slice: &'de [u8]) -> Self {
218        Self {
219            range: slice.as_ptr_range(),
220            _marker: marker::PhantomData,
221        }
222    }
223}
224
225impl<'de> Reader<'de> for SliceReader<'de> {
226    type Error = SliceReaderError;
227
228    #[inline]
229    fn skip(&mut self, n: usize) -> Result<(), Self::Error> {
230        self.range.start = bounds_check_add(&self.range, n)?;
231        Ok(())
232    }
233
234    #[inline]
235    fn read_bytes<V>(&mut self, n: usize, visitor: V) -> Result<V::Ok, V::Error>
236    where
237        V: ValueVisitor<'de, Target = [u8], Error = Self::Error>,
238    {
239        let outcome = bounds_check_add(&self.range, n)?;
240
241        unsafe {
242            let bytes = slice::from_raw_parts(self.range.start, n);
243            self.range.start = outcome;
244            visitor.visit_borrowed(bytes)
245        }
246    }
247
248    #[inline]
249    fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
250        let outcome = bounds_check_add(&self.range, buf.len())?;
251
252        unsafe {
253            ptr::copy_nonoverlapping(self.range.start, buf.as_mut_ptr(), buf.len());
254            self.range.start = outcome;
255        }
256
257        Ok(())
258    }
259}
260
261#[inline]
262fn bounds_check_add(range: &Range<*const u8>, len: usize) -> Result<*const u8, SliceReaderError> {
263    let outcome = range.start.wrapping_add(len);
264
265    if outcome > range.end || outcome < range.start {
266        Err(SliceReaderError::custom("buffer underflow"))
267    } else {
268        Ok(outcome)
269    }
270}
271
272/// Keep a record of the current position.
273///
274/// Constructed through [Reader::with_position].
275pub struct WithPosition<R> {
276    pos: usize,
277    reader: R,
278}
279
280impl<'de, R> PositionedReader<'de> for WithPosition<R>
281where
282    R: Reader<'de>,
283{
284    #[inline]
285    fn pos(&self) -> usize {
286        self.pos
287    }
288}
289
290impl<'de, R> Reader<'de> for WithPosition<R>
291where
292    R: Reader<'de>,
293{
294    type Error = R::Error;
295
296    #[inline]
297    fn skip(&mut self, n: usize) -> Result<(), Self::Error> {
298        self.reader.skip(n)?;
299        self.pos += n;
300        Ok(())
301    }
302
303    #[inline]
304    fn read_bytes<V>(&mut self, n: usize, visitor: V) -> Result<V::Ok, V::Error>
305    where
306        V: ValueVisitor<'de, Target = [u8], Error = Self::Error>,
307    {
308        let ok = self.reader.read_bytes(n, visitor)?;
309        self.pos += n;
310        Ok(ok)
311    }
312
313    #[inline]
314    fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
315        self.reader.read(buf)?;
316        self.pos += buf.len();
317        Ok(())
318    }
319
320    #[inline]
321    fn read_byte(&mut self) -> Result<u8, Self::Error> {
322        let b = self.reader.read_byte()?;
323        self.pos += 1;
324        Ok(b)
325    }
326
327    #[inline]
328    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], Self::Error> {
329        let array = self.reader.read_array()?;
330        self.pos += N;
331        Ok(array)
332    }
333}
334
335/// Limit the number of bytes that can be read out of a reader to the specified limit.
336///
337/// Constructed through [Reader::limit].
338pub struct Limit<R> {
339    remaining: usize,
340    reader: R,
341}
342
343impl<'de, R> Limit<R>
344where
345    R: Reader<'de>,
346{
347    fn bounds_check(&mut self, n: usize) -> Result<(), R::Error> {
348        match self.remaining.checked_sub(n) {
349            Some(remaining) => {
350                self.remaining = remaining;
351                Ok(())
352            }
353            None => Err(R::Error::custom("out of bounds")),
354        }
355    }
356}
357
358impl<'de, R> PositionedReader<'de> for Limit<R>
359where
360    R: PositionedReader<'de>,
361{
362    #[inline]
363    fn pos(&self) -> usize {
364        self.reader.pos()
365    }
366}
367
368impl<'de, R> Reader<'de> for Limit<R>
369where
370    R: Reader<'de>,
371{
372    type Error = R::Error;
373
374    #[inline]
375    fn skip(&mut self, n: usize) -> Result<(), Self::Error> {
376        self.bounds_check(n)?;
377        self.reader.skip(n)
378    }
379
380    #[inline]
381    fn read_bytes<V>(&mut self, n: usize, visitor: V) -> Result<V::Ok, V::Error>
382    where
383        V: ValueVisitor<'de, Target = [u8], Error = Self::Error>,
384    {
385        self.bounds_check(n)?;
386        self.reader.read_bytes(n, visitor)
387    }
388
389    #[inline]
390    fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
391        self.bounds_check(buf.len())?;
392        self.reader.read(buf)
393    }
394
395    #[inline]
396    fn read_byte(&mut self) -> Result<u8, Self::Error> {
397        self.bounds_check(1)?;
398        self.reader.read_byte()
399    }
400
401    #[inline]
402    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], Self::Error> {
403        self.bounds_check(N)?;
404        self.reader.read_array()
405    }
406}
407
408// Forward implementations.
409
410impl<'de, R> PositionedReader<'de> for &mut R
411where
412    R: ?Sized + PositionedReader<'de>,
413{
414    #[inline]
415    fn pos(&self) -> usize {
416        (**self).pos()
417    }
418}
419
420impl<'de, R> Reader<'de> for &mut R
421where
422    R: ?Sized + Reader<'de>,
423{
424    type Error = R::Error;
425
426    #[inline]
427    fn skip(&mut self, n: usize) -> Result<(), Self::Error> {
428        (**self).skip(n)
429    }
430
431    #[inline]
432    fn read_bytes<V>(&mut self, n: usize, visitor: V) -> Result<V::Ok, V::Error>
433    where
434        V: ValueVisitor<'de, Target = [u8], Error = Self::Error>,
435    {
436        (**self).read_bytes(n, visitor)
437    }
438
439    #[inline]
440    fn read(&mut self, buf: &mut [u8]) -> Result<(), Self::Error> {
441        (**self).read(buf)
442    }
443
444    #[inline]
445    fn read_byte(&mut self) -> Result<u8, Self::Error> {
446        (**self).read_byte()
447    }
448
449    #[inline]
450    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], Self::Error> {
451        (**self).read_array()
452    }
453}