makepad_zune_core/bytestream/
reader.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9use core::cmp::min;
10
11use crate::bytestream::traits::ZReaderTrait;
12
13const ERROR_MSG: &str = "No more bytes";
14
15/// An encapsulation of a byte stream reader
16///
17/// This provides an interface similar to [std::io::Cursor] but
18/// it provides fine grained options for reading different integer data types from
19/// the underlying buffer.
20///
21/// There are two variants mainly error and non error variants,
22/// the error variants are useful for cases where you need bytes
23/// from the underlying stream, and cannot do with zero result.
24/// the non error variants are useful when you may have proved data already exists
25/// eg by using [`has`] method or you are okay with returning zero if the underlying
26/// buffer has been completely read.
27///
28/// [std::io::Cursor]: https://doc.rust-lang.org/std/io/struct.Cursor.html
29/// [`has`]: Self::has
30pub struct ZByteReader<T: ZReaderTrait> {
31    /// Data stream
32    stream:   T,
33    position: usize
34}
35
36enum Mode {
37    // Big endian
38    BE,
39    // Little Endian
40    LE
41}
42
43impl<T: ZReaderTrait> ZByteReader<T> {
44    /// Create a new instance of the byte stream
45    ///
46    /// Bytes will be read from the start of `buf`.
47    ///
48    /// `buf` is expected to live as long as this and
49    /// all references to it live
50    ///
51    /// # Returns
52    /// A byte reader which will pull bits from bye
53    pub const fn new(buf: T) -> ZByteReader<T> {
54        ZByteReader {
55            stream:   buf,
56            position: 0
57        }
58    }
59    /// Skip `num` bytes ahead of the stream.
60    ///
61    /// This bumps up the internal cursor wit a wrapping addition
62    /// The bytes between current position and `num` will be skipped
63    ///
64    /// # Arguments
65    /// `num`: How many bytes to skip
66    ///
67    /// # Note
68    /// This does not consider length of the buffer, so skipping more bytes
69    /// than possible and then reading bytes will return an error if using error variants
70    /// or zero if using non-error variants
71    ///
72    /// # Example
73    /// ```
74    /// use zune_core::bytestream::ZByteReader;
75    /// let zero_to_hundred:Vec<u8> = (0..100).collect();
76    /// let mut stream = ZByteReader::new(&zero_to_hundred);
77    /// // skip 37 bytes
78    /// stream.skip(37);
79    ///
80    /// assert_eq!(stream.get_u8(),37);
81    /// ```
82    ///
83    /// See [`rewind`](ZByteReader::rewind) for moving the internal cursor back
84    pub fn skip(&mut self, num: usize) {
85        // Can this overflow ??
86        self.position = self.position.wrapping_add(num);
87    }
88    /// Undo a buffer read by moving the position pointer `num`
89    /// bytes behind.
90    ///
91    /// This operation will saturate at zero
92    pub fn rewind(&mut self, num: usize) {
93        self.position = self.position.saturating_sub(num);
94    }
95
96    /// Return whether the underlying buffer
97    /// has `num` bytes available for reading
98    ///
99    /// # Example
100    ///
101    /// ```
102    /// use zune_core::bytestream::ZByteReader;
103    /// let data = [0_u8;120];
104    /// let reader = ZByteReader::new(data.as_slice());
105    /// assert!(reader.has(3));
106    /// assert!(!reader.has(121));
107    /// ```
108    #[inline]
109    pub fn has(&self, num: usize) -> bool {
110        self.position.saturating_add(num) <= self.stream.get_len()
111    }
112    /// Get number of bytes available in the stream
113    #[inline]
114    pub fn get_bytes_left(&self) -> usize {
115        // Must be saturating to prevent underflow
116        self.stream.get_len().saturating_sub(self.position)
117    }
118    /// Get length of the underlying buffer.
119    ///
120    /// To get the number of bytes left in the buffer,
121    /// use [remaining] method
122    ///
123    /// [remaining]: Self::remaining
124    #[inline]
125    pub fn len(&self) -> usize {
126        self.stream.get_len()
127    }
128    /// Return true if the underlying buffer stream is empty
129    #[inline]
130    pub fn is_empty(&self) -> bool {
131        self.stream.get_len() == 0
132    }
133    /// Get current position of the buffer.
134    #[inline]
135    pub const fn get_position(&self) -> usize {
136        self.position
137    }
138    /// Return true whether or not we read to the end of the
139    /// buffer and have no more bytes left.
140    #[inline]
141    pub fn eof(&self) -> bool {
142        self.position >= self.len()
143    }
144    /// Get number of bytes unread inside this
145    /// stream.
146    ///
147    /// To get the length of the underlying stream,
148    /// use [len] method
149    ///
150    /// [len]: Self::len()
151    #[inline]
152    pub fn remaining(&self) -> usize {
153        self.stream.get_len().saturating_sub(self.position)
154    }
155    /// Get a part of the bytestream as a reference.
156    ///
157    /// This increments the position to point past the bytestream
158    /// if position+num is in bounds
159    pub fn get(&mut self, num: usize) -> Result<&[u8], &'static str> {
160        match self.stream.get_slice(self.position..self.position + num) {
161            Some(bytes) => {
162                self.position += num;
163                Ok(bytes)
164            }
165            None => Err(ERROR_MSG)
166        }
167    }
168    /// Look ahead position bytes and return a reference
169    /// to num_bytes from that position, or an error if the
170    /// peek would be out of bounds.
171    ///
172    /// This doesn't increment the position, bytes would have to be discarded
173    /// at a later point.
174    #[inline]
175    pub fn peek_at(&self, position: usize, num_bytes: usize) -> Result<&[u8], &'static str> {
176        let start = self.position + position;
177        let end = self.position + position + num_bytes;
178
179        match self.stream.get_slice(start..end) {
180            Some(bytes) => Ok(bytes),
181            None => Err(ERROR_MSG)
182        }
183    }
184    /// Get a fixed amount of bytes or return an error if we cant
185    /// satisfy the read
186    ///
187    /// This should be combined with [`has`] since if there are no
188    /// more bytes you get an error.
189    ///
190    /// But it's useful for cases where you expect bytes but they are not present
191    ///
192    /// For the zero  variant see, [`get_fixed_bytes_or_zero`]
193    ///
194    /// # Example
195    /// ```rust
196    /// use zune_core::bytestream::ZByteReader;
197    /// let mut stream = ZByteReader::new([0x0,0x5,0x3,0x2].as_slice());
198    /// let first_bytes = stream.get_fixed_bytes_or_err::<10>(); // not enough bytes
199    /// assert!(first_bytes.is_err());
200    /// ```
201    ///
202    /// [`has`]:Self::has
203    /// [`get_fixed_bytes_or_zero`]: Self::get_fixed_bytes_or_zero
204    #[inline]
205    pub fn get_fixed_bytes_or_err<const N: usize>(&mut self) -> Result<[u8; N], &'static str> {
206        let mut byte_store: [u8; N] = [0; N];
207
208        match self.stream.get_slice(self.position..self.position + N) {
209            Some(bytes) => {
210                self.position += N;
211                byte_store.copy_from_slice(bytes);
212
213                Ok(byte_store)
214            }
215            None => Err(ERROR_MSG)
216        }
217    }
218
219    /// Get a fixed amount of bytes or return a zero array size
220    /// if we can't satisfy the read
221    ///
222    /// This should be combined with [`has`] since if there are no
223    /// more bytes you get a zero initialized array
224    ///
225    /// For the error variant see, [`get_fixed_bytes_or_err`]
226    ///
227    /// # Example
228    /// ```rust
229    /// use zune_core::bytestream::ZByteReader;
230    /// let mut stream = ZByteReader::new([0x0,0x5,0x3,0x2].as_slice());
231    /// let first_bytes = stream.get_fixed_bytes_or_zero::<2>();
232    /// assert_eq!(first_bytes,[0x0,0x5]);
233    /// ```
234    ///
235    /// [`has`]:Self::has
236    /// [`get_fixed_bytes_or_err`]: Self::get_fixed_bytes_or_err
237    #[inline]
238    pub fn get_fixed_bytes_or_zero<const N: usize>(&mut self) -> [u8; N] {
239        let mut byte_store: [u8; N] = [0; N];
240
241        match self.stream.get_slice(self.position..self.position + N) {
242            Some(bytes) => {
243                self.position += N;
244                byte_store.copy_from_slice(bytes);
245
246                byte_store
247            }
248            None => byte_store
249        }
250    }
251    #[inline]
252    /// Skip bytes until a condition becomes false or the stream runs out of bytes
253    ///
254    /// # Example
255    ///
256    /// ```rust
257    /// use zune_core::bytestream::ZByteReader;
258    /// let mut stream = ZByteReader::new([0;10].as_slice());
259    /// stream.skip_until_false(|x| x.is_ascii()) // skip until we meet a non ascii character
260    /// ```
261    pub fn skip_until_false<F: Fn(u8) -> bool>(&mut self, func: F) {
262        // iterate until we have no more bytes
263        while !self.eof() {
264            // get a byte from stream
265            let byte = self.get_u8();
266
267            if !(func)(byte) {
268                // function returned false meaning we stop skipping
269                self.rewind(1);
270                break;
271            }
272        }
273    }
274    /// Return the remaining unread bytes in this byte reader
275    pub fn remaining_bytes(&self) -> &[u8] {
276        self.stream.get_slice(self.position..self.len()).unwrap()
277    }
278
279    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, &'static str> {
280        let buf_length = buf.len();
281        let start = self.position;
282        let end = min(self.len(), self.position + buf_length);
283        let diff = end - start;
284
285        buf[0..diff].copy_from_slice(self.stream.get_slice(start..end).unwrap());
286
287        self.skip(diff);
288
289        Ok(diff)
290    }
291
292    /// Read enough bytes to fill in
293    pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), &'static str> {
294        let size = self.read(buf)?;
295
296        if size != buf.len() {
297            return Err("Could not read into the whole buffer");
298        }
299        Ok(())
300    }
301
302    /// Set the cursor position
303    ///
304    /// After this, all reads will proceed from the position as an anchor
305    /// point
306    pub fn set_position(&mut self, position: usize) {
307        self.position = position;
308    }
309}
310
311macro_rules! get_single_type {
312    ($name:tt,$name2:tt,$name3:tt,$name4:tt,$name5:tt,$name6:tt,$int_type:tt) => {
313        impl<T:ZReaderTrait> ZByteReader<T>
314        {
315            #[inline(always)]
316            fn $name(&mut self, mode: Mode) -> $int_type
317            {
318                const SIZE_OF_VAL: usize = core::mem::size_of::<$int_type>();
319
320                let mut space = [0; SIZE_OF_VAL];
321
322                match self.stream.get_slice(self.position..self.position + SIZE_OF_VAL)
323                {
324                    Some(position) =>
325                    {
326                        space.copy_from_slice(position);
327                        self.position += SIZE_OF_VAL;
328
329                        match mode
330                        {
331                            Mode::LE => $int_type::from_le_bytes(space),
332                            Mode::BE => $int_type::from_be_bytes(space),
333                        }
334                    }
335                    None => 0,
336                }
337            }
338
339            #[inline(always)]
340            fn $name2(&mut self, mode: Mode) -> Result<$int_type, &'static str>
341            {
342                const SIZE_OF_VAL: usize = core::mem::size_of::<$int_type>();
343
344                let mut space = [0; SIZE_OF_VAL];
345
346                match self.stream.get_slice(self.position..self.position + SIZE_OF_VAL)
347                {
348                    Some(position) =>
349                    {
350                        space.copy_from_slice(position);
351                        self.position += SIZE_OF_VAL;
352
353                        match mode
354                        {
355                            Mode::LE => Ok($int_type::from_le_bytes(space)),
356                            Mode::BE => Ok($int_type::from_be_bytes(space)),
357                        }
358                    }
359                    None => Err(ERROR_MSG),
360                }
361            }
362            #[doc=concat!("Read ",stringify!($int_type)," as a big endian integer")]
363            #[doc=concat!("Returning an error if the underlying buffer cannot support a ",stringify!($int_type)," read.")]
364            #[inline]
365            pub fn $name3(&mut self) -> Result<$int_type, &'static str>
366            {
367                self.$name2(Mode::BE)
368            }
369
370            #[doc=concat!("Read ",stringify!($int_type)," as a little endian integer")]
371            #[doc=concat!("Returning an error if the underlying buffer cannot support a ",stringify!($int_type)," read.")]
372            #[inline]
373            pub fn $name4(&mut self) -> Result<$int_type, &'static str>
374            {
375                self.$name2(Mode::LE)
376            }
377            #[doc=concat!("Read ",stringify!($int_type)," as a big endian integer")]
378            #[doc=concat!("Returning 0 if the underlying  buffer does not have enough bytes for a ",stringify!($int_type)," read.")]
379            #[inline(always)]
380            pub fn $name5(&mut self) -> $int_type
381            {
382                self.$name(Mode::BE)
383            }
384            #[doc=concat!("Read ",stringify!($int_type)," as a little endian integer")]
385            #[doc=concat!("Returning 0 if the underlying buffer does not have enough bytes for a ",stringify!($int_type)," read.")]
386            #[inline(always)]
387            pub fn $name6(&mut self) -> $int_type
388            {
389                self.$name(Mode::LE)
390            }
391        }
392    };
393}
394// U8 implementation
395// The benefit of our own unrolled u8 impl instead of macros is that this is sometimes used in some
396// impls and is called multiple times, e.g jpeg during huffman decoding.
397// we can make some functions leaner like get_u8 is branchless
398impl<T> ZByteReader<T>
399where
400    T: ZReaderTrait
401{
402    /// Retrieve a byte from the underlying stream
403    /// returning 0 if there are no more bytes available
404    ///
405    /// This means 0 might indicate a bit or an end of stream, but
406    /// this is useful for some scenarios where one needs a byte.
407    ///
408    /// For the panicking one, see [`get_u8_err`]
409    ///
410    /// [`get_u8_err`]: Self::get_u8_err
411    #[inline(always)]
412    pub fn get_u8(&mut self) -> u8 {
413        let byte = *self.stream.get_byte(self.position).unwrap_or(&0);
414
415        self.position += usize::from(self.position < self.len());
416        byte
417    }
418
419    /// Retrieve a byte from the underlying stream
420    /// returning an error if there are no more bytes available
421    ///
422    /// For the non panicking one, see [`get_u8`]
423    ///
424    /// [`get_u8`]: Self::get_u8
425    #[inline(always)]
426    pub fn get_u8_err(&mut self) -> Result<u8, &'static str> {
427        match self.stream.get_byte(self.position) {
428            Some(byte) => {
429                self.position += 1;
430                Ok(*byte)
431            }
432            None => Err(ERROR_MSG)
433        }
434    }
435}
436
437// u16,u32,u64 -> macros
438get_single_type!(
439    get_u16_inner_or_default,
440    get_u16_inner_or_die,
441    get_u16_be_err,
442    get_u16_le_err,
443    get_u16_be,
444    get_u16_le,
445    u16
446);
447get_single_type!(
448    get_u32_inner_or_default,
449    get_u32_inner_or_die,
450    get_u32_be_err,
451    get_u32_le_err,
452    get_u32_be,
453    get_u32_le,
454    u32
455);
456get_single_type!(
457    get_u64_inner_or_default,
458    get_u64_inner_or_die,
459    get_u64_be_err,
460    get_u64_le_err,
461    get_u64_be,
462    get_u64_le,
463    u64
464);