messagepack_core/
io.rs

1//! Minimal write abstraction used by encoders.
2
3/// Minimal `Write`‑like trait used by encoders to avoid committing to a
4/// specific I/O model.
5pub trait IoWrite {
6    /// Error type produced by the writer.
7    type Error: core::error::Error;
8    /// Write all bytes from `buf`.
9    fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error>;
10}
11
12/// `SliceWriter` Error
13#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)]
14pub enum WError {
15    /// buffer is full
16    BufferFull,
17}
18
19impl core::fmt::Display for WError {
20    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21        match self {
22            WError::BufferFull => write!(f, "Buffer is full"),
23        }
24    }
25}
26
27impl core::error::Error for WError {}
28
29/// Simple writer that writes into a mutable byte slice.
30pub struct SliceWriter<'a> {
31    buf: &'a mut [u8],
32    cursor: usize,
33}
34
35impl<'a> SliceWriter<'a> {
36    /// Create a new writer over the given buffer.
37    pub fn from_slice(buf: &'a mut [u8]) -> Self {
38        Self { buf, cursor: 0 }
39    }
40
41    fn len(&self) -> usize {
42        self.buf.len() - self.cursor
43    }
44}
45
46impl IoWrite for SliceWriter<'_> {
47    type Error = WError;
48
49    fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
50        if self.len() >= buf.len() {
51            let to = &mut self.buf[self.cursor..self.cursor + buf.len()];
52            to.copy_from_slice(buf);
53            self.cursor += buf.len();
54            Ok(())
55        } else {
56            Err(WError::BufferFull)
57        }
58    }
59}
60
61#[cfg(all(not(test), not(feature = "std")))]
62impl IoWrite for &mut [u8] {
63    type Error = WError;
64
65    fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
66        let this = core::mem::take(self);
67
68        let (written, rest) = this
69            .split_at_mut_checked(buf.len())
70            .ok_or(WError::BufferFull)?;
71        written.copy_from_slice(buf);
72        *self = rest;
73
74        Ok(())
75    }
76}
77
78#[cfg(all(not(test), feature = "alloc", not(feature = "std")))]
79mod alloc_without_std {
80    use super::{IoWrite, vec_writer::VecRefWriter};
81    impl IoWrite for alloc::vec::Vec<u8> {
82        type Error = core::convert::Infallible;
83
84        fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
85            VecRefWriter::new(self).write(buf)
86        }
87    }
88
89    impl IoWrite for &mut alloc::vec::Vec<u8> {
90        type Error = core::convert::Infallible;
91
92        fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
93            VecRefWriter::new(self).write(buf)
94        }
95    }
96}
97
98#[cfg(feature = "alloc")]
99mod vec_writer {
100    use super::IoWrite;
101
102    /// Simple writer that writes into a `&mut Vec<u8>`.
103    pub struct VecRefWriter<'a> {
104        vec: &'a mut alloc::vec::Vec<u8>,
105    }
106
107    impl<'a> VecRefWriter<'a> {
108        /// Create a new writer
109        pub fn new(vec: &'a mut alloc::vec::Vec<u8>) -> Self {
110            Self { vec }
111        }
112    }
113
114    impl IoWrite for VecRefWriter<'_> {
115        type Error = core::convert::Infallible;
116
117        fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
118            self.vec.extend_from_slice(buf);
119            Ok(())
120        }
121    }
122
123    /// Simple writer that writes into a `Vec<u8>`.
124    pub struct VecWriter {
125        vec: alloc::vec::Vec<u8>,
126    }
127
128    impl VecWriter {
129        /// Create a new writer
130        pub fn new() -> Self {
131            Self {
132                vec: alloc::vec::Vec::new(),
133            }
134        }
135        /// Get the inner vector
136        pub fn into_vec(self) -> alloc::vec::Vec<u8> {
137            self.vec
138        }
139    }
140
141    impl Default for VecWriter {
142        fn default() -> Self {
143            Self::new()
144        }
145    }
146
147    impl IoWrite for VecWriter {
148        type Error = core::convert::Infallible;
149        fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
150            self.vec.extend_from_slice(buf);
151            Ok(())
152        }
153    }
154}
155#[cfg(feature = "alloc")]
156pub use vec_writer::{VecRefWriter, VecWriter};
157
158#[cfg(any(test, feature = "std"))]
159impl<W> IoWrite for W
160where
161    W: std::io::Write,
162{
163    type Error = std::io::Error;
164
165    fn write(&mut self, buf: &[u8]) -> Result<(), Self::Error> {
166        self.write_all(buf)
167    }
168}
169
170/// Types used by decoder
171#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
172pub enum Reference<'de, 'a> {
173    /// Reference to a byte sequence that survives at least as long as the de
174    Borrowed(&'de [u8]),
175    /// Reference to a byte sequence that may be free soon
176    Copied(&'a [u8]),
177}
178
179impl Reference<'_, '_> {
180    /// Borrow the underlying bytes regardless of `Borrowed` or `Copied`.
181    pub fn as_bytes(&self) -> &[u8] {
182        match self {
183            Reference::Borrowed(b) => b,
184            Reference::Copied(b) => b,
185        }
186    }
187}
188
189/// decode input source
190pub trait IoRead<'de> {
191    /// Error type produced by the reader.
192    type Error: core::error::Error + 'static;
193    /// read exactly `len` bytes and consume
194    fn read_slice<'a>(&'a mut self, len: usize) -> Result<Reference<'de, 'a>, Self::Error>;
195}
196
197/// Simple reader that reads from a byte slice.
198pub struct SliceReader<'de> {
199    /// current buffer
200    cursor: &'de [u8],
201}
202impl<'de> SliceReader<'de> {
203    /// create a new reader
204    pub fn new(buf: &'de [u8]) -> Self {
205        Self { cursor: buf }
206    }
207
208    /// Get the remaining, committed bytes (peeked bytes are not subtracted
209    /// until `consume()` is called).
210    pub fn rest(&self) -> &'de [u8] {
211        self.cursor
212    }
213}
214
215/// `SliceReader` Error
216#[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)]
217pub enum RError {
218    /// buffer is empty
219    BufferEmpty,
220}
221
222impl core::fmt::Display for RError {
223    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
224        match self {
225            RError::BufferEmpty => write!(f, "Buffer is empty"),
226        }
227    }
228}
229
230impl core::error::Error for RError {}
231
232impl<'de> IoRead<'de> for SliceReader<'de> {
233    type Error = RError;
234
235    #[inline]
236    fn read_slice<'a>(&'a mut self, len: usize) -> Result<Reference<'de, 'a>, Self::Error> {
237        let (read, rest) = self
238            .cursor
239            .split_at_checked(len)
240            .ok_or(RError::BufferEmpty)?;
241        self.cursor = rest;
242        Ok(Reference::Borrowed(read))
243    }
244}
245
246#[cfg(feature = "alloc")]
247mod iter_reader {
248    use crate::io::RError;
249
250    use super::IoRead;
251
252    /// Reader that reads from a iterator
253    pub struct IterReader<I> {
254        it: I,
255        buf: alloc::vec::Vec<u8>,
256    }
257
258    impl<I> IterReader<I>
259    where
260        I: Iterator<Item = u8>,
261    {
262        /// create new reader
263        pub fn new(it: I) -> Self {
264            Self {
265                it: it.into_iter(),
266                buf: alloc::vec::Vec::new(),
267            }
268        }
269    }
270    impl<'de, I> IoRead<'de> for IterReader<I>
271    where
272        I: Iterator<Item = u8>,
273    {
274        type Error = RError;
275        fn read_slice<'a>(
276            &'a mut self,
277            len: usize,
278        ) -> Result<super::Reference<'de, 'a>, Self::Error> {
279            self.buf.clear();
280            if self.buf.capacity() < len {
281                self.buf.reserve(len - self.buf.capacity());
282            }
283
284            self.buf.extend(self.it.by_ref().take(len));
285            if self.buf.len() != len {
286                return Err(RError::BufferEmpty);
287            };
288
289            Ok(super::Reference::Copied(&self.buf[..len]))
290        }
291    }
292}
293#[cfg(feature = "alloc")]
294pub use iter_reader::IterReader;
295
296#[cfg(feature = "std")]
297mod std_reader {
298    use super::IoRead;
299
300    /// Simple reader that reads from a `std::io::Read`.
301    pub struct StdReader<R> {
302        reader: R,
303        buf: std::vec::Vec<u8>,
304    }
305
306    impl<R> StdReader<R>
307    where
308        R: std::io::Read,
309    {
310        /// create a new reader
311        pub fn new(reader: R) -> Self {
312            Self {
313                reader,
314                buf: std::vec::Vec::new(),
315            }
316        }
317    }
318
319    impl<'de, R> IoRead<'de> for StdReader<R>
320    where
321        R: std::io::Read,
322    {
323        type Error = std::io::Error;
324
325        fn read_slice<'a>(
326            &'a mut self,
327            len: usize,
328        ) -> Result<super::Reference<'de, 'a>, Self::Error> {
329            if self.buf.len() < len {
330                self.buf.resize(len, 0);
331            };
332            self.reader.read_exact(&mut self.buf[..len])?;
333
334            Ok(super::Reference::Copied(&self.buf[..len]))
335        }
336    }
337}
338#[cfg(feature = "std")]
339pub use std_reader::StdReader;
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    #[should_panic]
347    fn buffer_full() {
348        let buf: &mut [u8] = &mut [0u8];
349        let mut writer = SliceWriter::from_slice(buf);
350        writer.write(&[1, 2]).unwrap();
351    }
352
353    #[test]
354    fn slice_reader_reads_and_advances() {
355        // Arrange: make a reader over a fixed slice
356        let input: &[u8] = &[1, 2, 3, 4, 5];
357        let mut reader = SliceReader::new(input);
358
359        // Act: read exact 2 bytes, then 3 bytes
360        {
361            // Keep the first borrow in a narrower scope
362            let a = reader.read_slice(2).expect("read 2 bytes");
363            assert_eq!(a.as_bytes(), &[1, 2]);
364        }
365        let b = reader.read_slice(3).expect("read 3 bytes");
366        // Assert: returned slices match and rest is empty
367        assert_eq!(b.as_bytes(), &[3, 4, 5]);
368        assert_eq!(reader.rest(), &[]);
369    }
370
371    #[test]
372    fn slice_reader_returns_error_on_overshoot() {
373        // Arrange
374        let input: &[u8] = &[10, 20];
375        let mut reader = SliceReader::new(input);
376
377        // Act: first read consumes all bytes
378        let first = reader.read_slice(2).expect("read 2 bytes");
379        assert_eq!(first.as_bytes(), &[10, 20]);
380
381        // Assert: second read fails with BufferEmpty
382        assert!(matches!(reader.read_slice(1), Err(RError::BufferEmpty)));
383    }
384
385    #[cfg(feature = "alloc")]
386    #[test]
387    fn iter_reader_reads_exact_length() {
388        // Arrange: iterator with 4 items
389        let it = [7u8, 8, 9, 10].into_iter();
390        let mut reader = IterReader::new(it);
391
392        // Act: read 3 then 1
393        {
394            let part1 = reader.read_slice(3).expect("read 3 bytes");
395            assert_eq!(part1.as_bytes(), &[7, 8, 9]);
396        }
397        let part2 = reader.read_slice(1).expect("read 1 byte");
398
399        // Assert
400        assert_eq!(part2.as_bytes(), &[10]);
401    }
402
403    #[cfg(feature = "alloc")]
404    #[test]
405    fn iter_reader_returns_error_when_insufficient() {
406        // Arrange: iterator shorter than requested length
407        let it = [1u8, 2].into_iter();
408        let mut reader = IterReader::new(it);
409
410        // Act + Assert: request more than available -> error
411        assert!(matches!(reader.read_slice(3), Err(RError::BufferEmpty)));
412    }
413}