Skip to main content

slice_codec/buffer/
slice.rs

1// Copyright (c) ZeroC, Inc.
2
3//! TODO maybe write a comment explaining this module?
4
5use super::*;
6use crate::{ErrorKind, Result};
7use core::borrow::Borrow;
8use core::{debug_assert, debug_assert_eq};
9
10/// A wrapper around a `&[u8]` that implements [`InputSource`].
11#[derive(Debug)]
12pub struct SliceInputSource<'a> {
13    /// The underlying buffer that this type wraps.
14    buffer: &'a [u8],
15    /// Tracks the current position in the buffer that is being read from.
16    pos: usize,
17}
18
19impl<'a> SliceInputSource<'a> {
20    /// Checks whether there are at least `requested` unread bytes left in the buffer.
21    /// If there are, this returns `Ok`, and if there aren't this returns an [`ErrorKind::UnexpectedEob`] error.
22    ///
23    /// This function is only used internally to ensure a particular read operation is safe to attempt.
24    fn does_buffer_have_at_least(&self, requested: usize) -> Result<()> {
25        let remaining = self.remaining();
26        if remaining < requested {
27            let error = ErrorKind::UnexpectedEob { requested, remaining };
28            Err(error.into())
29        } else {
30            Ok(())
31        }
32    }
33
34    /// The implementation used by `peek_bytes_exact` and `read_bytes_exact`.
35    /// It's implemented as a separate function so we can return a different lifetime than what the trait demands.
36    ///
37    /// The trait function requires we return a lifetime bound to `self`, whereas this function returns a lifetime
38    /// bound to the underlying buffer (`'a`). Returning a narrower lifetime lets us mutate other fields of `self`.
39    fn peek_bytes_exact_impl<const N: usize>(&self) -> Result<&'a [u8; N]> {
40        let bytes = self.peek_byte_slice_exact_impl(N)?;
41
42        // SAFETY: `peek_byte_slice_exact_impl` is guaranteed to return exactly 'N' bytes, which means it's safe to
43        // convert, since `&[u8; N]` has the same layout as an `&[u8]` over 'N' bytes.
44        unsafe {
45            debug_assert_eq!(bytes.len(), N);
46            Ok(bytes.try_into().unwrap_unchecked())
47        }
48    }
49
50    /// The implementation used by `peek_byte_slice_exact` and `read_byte_slice_exact`.
51    /// It's implemented as a separate function so we can return a different lifetime than what the trait demands.
52    ///
53    /// The trait function requires we return a lifetime bound to `self`, whereas this function returns a lifetime
54    /// bound to the underlying buffer (`'a`). Returning a narrower lifetime lets us mutate other fields of `self`.
55    fn peek_byte_slice_exact_impl(&self, count: usize) -> Result<&'a [u8]> {
56        self.does_buffer_have_at_least(count)?;
57
58        // SAFETY: the necessary bounds checking is performed by the above function call.
59        unsafe {
60            let end = self.pos + count;
61            debug_assert!(self.buffer.get(self.pos..end).is_some());
62            Ok(self.buffer.get_unchecked(self.pos..end))
63        }
64    }
65}
66
67impl InputSource for SliceInputSource<'_> {
68    fn remaining(&self) -> usize {
69        self.buffer.len() - self.pos
70    }
71
72    fn peek_byte(&mut self) -> Result<u8> {
73        self.does_buffer_have_at_least(1)?;
74
75        // SAFETY: the necessary bounds checking is performed by the above function call.
76        unsafe {
77            debug_assert!(self.buffer.get(self.pos).is_some());
78            Ok(*self.buffer.get_unchecked(self.pos))
79        }
80    }
81
82    fn read_byte(&mut self) -> Result<u8> {
83        let byte = self.peek_byte()?;
84        self.pos += 1;
85        Ok(byte)
86    }
87
88    fn peek_bytes_exact<const N: usize>(&mut self) -> Result<&[u8; N]> {
89        self.peek_bytes_exact_impl()
90    }
91
92    fn read_bytes_exact<const N: usize>(&mut self) -> Result<&[u8; N]> {
93        let bytes = self.peek_bytes_exact_impl()?;
94        self.pos += N;
95        Ok(bytes)
96    }
97
98    fn peek_byte_slice_exact(&mut self, count: usize) -> Result<&[u8]> {
99        self.peek_byte_slice_exact_impl(count)
100    }
101
102    fn read_byte_slice_exact(&mut self, count: usize) -> Result<&[u8]> {
103        let byte_slice = self.peek_byte_slice_exact_impl(count)?;
104        self.pos += count;
105        Ok(byte_slice)
106    }
107
108    fn read_bytes_into_exact(&mut self, dst: &mut [u8]) -> Result<()> {
109        let src = self.read_byte_slice_exact(dst.len())?;
110
111        // SAFETY: `read_byte_slice_exact` is guaranteed to return exactly `dst.len()` bytes, so there is enough space
112        // in `dst` to write these bytes, and we know the slices cannot overlap because `dst` is mutably borrowed,
113        // which guarantees exclusive access.
114        unsafe {
115            debug_assert_eq!(src.len(), dst.len());
116            core::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), dst.len());
117            Ok(())
118        }
119    }
120}
121
122impl<'a, T> From<&'a T> for SliceInputSource<'a>
123where
124    T: Borrow<[u8]> + ?Sized,
125{
126    /// Creates a new [`SliceInputSource`] that wraps the provided buffer.
127    fn from(value: &'a T) -> Self {
128        Self {
129            buffer: value.borrow(),
130            pos: 0,
131        }
132    }
133}
134
135// Allows users to create a [`Decoder`] directly from a slice,
136// without needing to construct an intermediate [`SliceInputSource`].
137impl<'a, T> From<T> for crate::decoder::Decoder<SliceInputSource<'a>>
138where
139    T: Into<SliceInputSource<'a>>,
140{
141    fn from(value: T) -> Self {
142        crate::decoder::Decoder::new(value.into())
143    }
144}
145
146/// A wrapper around a `&mut [u8]` that implements [`OutputTarget`].
147#[derive(Debug)]
148pub struct SliceOutputTarget<'a> {
149    /// The underlying buffer that this type wraps.
150    buffer: &'a mut [u8],
151    /// Tracks the current position in the buffer that is being written to.
152    pos: usize,
153}
154
155impl<'a> SliceOutputTarget<'a> {
156    /// Checks whether there are at least `requested` unwritten bytes left in the buffer.
157    /// If there are, this returns `Ok`, and if there aren't this returns an [`ErrorKind::UnexpectedEob`] error.
158    ///
159    /// This function is only used internally to ensure a particular write operation is safe to attempt.
160    fn does_buffer_have_at_least(&self, requested: usize) -> Result<()> {
161        let remaining = self.remaining();
162        if remaining < requested {
163            let error = ErrorKind::UnexpectedEob { requested, remaining };
164            Err(error.into())
165        } else {
166            Ok(())
167        }
168    }
169}
170
171impl OutputTarget for SliceOutputTarget<'_> {
172    fn remaining(&self) -> usize {
173        self.buffer.len() - self.pos
174    }
175
176    fn write_byte(&mut self, byte: u8) -> Result<()> {
177        self.does_buffer_have_at_least(1)?;
178
179        // SAFETY: the above function call guarantees there's enough space in `self.buffer` to write a single byte.
180        unsafe {
181            debug_assert!(self.buffer.get_mut(self.pos).is_some());
182            *self.buffer.get_unchecked_mut(self.pos) = byte;
183            self.pos += 1;
184            Ok(())
185        }
186    }
187
188    fn write_bytes_exact(&mut self, bytes: &[u8]) -> Result<()> {
189        let count = bytes.len();
190        self.does_buffer_have_at_least(count)?;
191
192        // SAFETY: the above function call guarantees there's enough space in `self.buffer` to write `bytes`,
193        // and we know the slices cannot overlap because the mutable borrow of `self` guarantees exclusive access.
194        unsafe {
195            let end = self.pos + count;
196            debug_assert!(self.buffer.get_mut(self.pos..end).is_some());
197            let target_slice = self.buffer.get_unchecked_mut(self.pos..end);
198            debug_assert_eq!(target_slice.len(), count);
199
200            core::ptr::copy_nonoverlapping(bytes.as_ptr(), target_slice.as_mut_ptr(), count);
201            self.pos = end;
202            Ok(())
203        }
204    }
205
206    fn write_bytes_into_reserved_exact(&mut self, reservation: &mut Reservation, bytes: &[u8]) -> Result<()> {
207        // Get a mutable slice of the buffer - one that corresponds to the reserved range.
208        let Some(reserved_slice) = self.buffer.get_mut(reservation.range()) else {
209            let error = ErrorKind::InvalidReservation {
210                buffer_len: self.buffer.len(),
211                reserved_range: reservation.range(),
212            };
213            return Err(error.into());
214        };
215
216        // Ensure there's enough space remaining in the reservation.
217        if reserved_slice.len() < bytes.len() {
218            let error = ErrorKind::UnexpectedEob {
219                requested: bytes.len(),
220                remaining: reserved_slice.len(),
221            };
222            return Err(error.into());
223        }
224
225        // SAFETY: we just checked that there's enough space in `reserved_slice` to write `bytes`,
226        // and we know the slices cannot overlap because the mutable borrow of `self` guarantees exclusive access.
227        unsafe {
228            core::ptr::copy_nonoverlapping(bytes.as_ptr(), reserved_slice.as_mut_ptr(), bytes.len());
229            reservation.0.start += bytes.len();
230            Ok(())
231        }
232    }
233
234    fn reserve_space(&mut self, count: usize) -> Result<Reservation> {
235        self.does_buffer_have_at_least(count)?;
236
237        self.pos += count;
238        Ok(Reservation((self.pos - count)..self.pos))
239    }
240}
241
242impl<'a> From<&'a mut [u8]> for SliceOutputTarget<'a> {
243    /// Creates a new [`SliceOutputTarget`] that wraps the provided buffer.
244    fn from(value: &'a mut [u8]) -> Self {
245        Self { buffer: value, pos: 0 }
246    }
247}
248
249impl<'a, const N: usize> From<&'a mut [u8; N]> for SliceOutputTarget<'a> {
250    /// Creates a new [`SliceOutputTarget`] that wraps the provided array.
251    fn from(value: &'a mut [u8; N]) -> Self {
252        Self {
253            buffer: value.as_mut_slice(),
254            pos: 0,
255        }
256    }
257}
258
259// Allows users to create an [`Encoder`] directly from a slice,
260// without needing to construct an intermediate [`SliceOutputTarget`].
261impl<'a, T> From<T> for crate::encoder::Encoder<SliceOutputTarget<'a>>
262where
263    T: Into<SliceOutputTarget<'a>>,
264{
265    fn from(value: T) -> Self {
266        crate::encoder::Encoder::new(value.into())
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    mod slice_input_source {
275        use super::*;
276
277        /// Verifies that [`does_buffer_have_at_least`] returns the correct number of remaining bytes in the buffer
278        /// when the remaining bytes number are greater than or equal to the number of requested bytes.
279        #[test]
280        fn does_buffer_has_at_least_returns_ok() {
281            // Arrange
282            let buffer = [115, 108, 105, 99, 101];
283            let source = SliceInputSource::from(&buffer);
284
285            // Act
286            let result = source.does_buffer_have_at_least(5);
287
288            // Assert
289            assert!(result.is_ok());
290        }
291
292        /// Verifies that [`does_buffer_have_at_least`] returns an error when the remaining bytes number are less than
293        /// the number of requested bytes.
294        #[test]
295        fn does_buffer_have_at_least_returns_error() {
296            // Arrange
297            let source = SliceInputSource::from(&[115, 108, 105, 99, 101]);
298
299            // Act
300            let result = source.does_buffer_have_at_least(6);
301
302            // Assert
303            assert!(result.is_err());
304            assert!(matches!(result.unwrap_err().kind(), ErrorKind::UnexpectedEob {
305                requested: 6,
306                remaining: 5
307            }));
308        }
309
310        /// Verifies that [`peek_byte`] returns the correct byte from the buffer without consuming it.
311        #[test]
312        fn peek_byte_returns_correct_byte() {
313            // Arrange
314            let mut source = SliceInputSource::from(&[115, 108, 105, 99, 101]);
315
316            // Act
317            let result = source.peek_byte();
318
319            // Assert
320            assert!(result.is_ok());
321            assert_eq!(result.unwrap(), 115);
322            assert_eq!(source.pos, 0);
323            assert_eq!(source.remaining(), 5);
324        }
325
326        /// Verifies that [`read_byte`] returns the correct byte from the buffer and consumes it.
327        #[test]
328        fn read_byte_returns_correct_byte() {
329            // Arrange
330            let mut source = SliceInputSource::from(&[115, 108, 105, 99, 101]);
331
332            // Act
333            let result = source.read_byte();
334
335            // Assert
336            assert!(result.is_ok());
337            assert_eq!(result.unwrap(), 115);
338            assert_eq!(source.pos, 1);
339            assert_eq!(source.remaining(), 4);
340        }
341
342        /// Verifies that [`peek_bytes_exact`] returns the correct number of bytes from the buffer without consuming
343        /// them.
344        #[test]
345        fn peek_bytes_exact_returns_correct_bytes() {
346            // Arrange
347            let mut source = SliceInputSource::from(&[115, 108, 105, 99, 101]);
348
349            // Act
350            let result = source.peek_bytes_exact::<3>();
351
352            // Assert
353            assert!(result.is_ok());
354            assert_eq!(result.unwrap(), &[115, 108, 105]);
355            assert_eq!(source.pos, 0);
356            assert_eq!(source.remaining(), 5);
357        }
358
359        /// Verifies that [`read_bytes_exact`] returns the correct number of bytes from the buffer and consumes them.
360        #[test]
361        fn read_bytes_exact_returns_correct_bytes() {
362            // Arrange
363            let mut source = SliceInputSource::from(&[115, 108, 105, 99, 101]);
364
365            // Act
366            let result = source.read_bytes_exact::<3>();
367
368            // Assert
369            assert!(result.is_ok());
370            assert_eq!(result.unwrap(), &[115, 108, 105]);
371            assert_eq!(source.pos, 3);
372            assert_eq!(source.remaining(), 2);
373        }
374    }
375
376    mod slice_output_target {
377
378        use super::*;
379
380        /// Verifies that [`does_buffer_have_at_least`] returns the correct number of remaining bytes in the buffer
381        /// when the remaining bytes number are greater than or equal to the number of requested bytes.
382        #[test]
383        fn does_buffer_has_at_least_returns_ok() {
384            // Arrange
385            let mut buffer = [115, 108, 105, 99, 101];
386            let target = SliceOutputTarget::from(buffer.as_mut_slice());
387
388            // Act
389            let result = target.does_buffer_have_at_least(5);
390
391            // Assert
392            assert!(result.is_ok());
393        }
394
395        /// Verifies that [`does_buffer_have_at_least`] returns an error when the remaining bytes number are less than
396        /// the number of requested bytes.
397        #[test]
398        fn does_buffer_have_at_least_returns_error() {
399            // Arrange
400            let mut buffer = [115, 108, 105, 99, 101];
401            let target = SliceOutputTarget::from(buffer.as_mut_slice());
402
403            // Act
404            let result = target.does_buffer_have_at_least(6);
405
406            // Assert
407            assert!(result.is_err());
408            assert!(matches!(result.unwrap_err().kind(), ErrorKind::UnexpectedEob {
409                requested: 6,
410                remaining: 5
411            }));
412        }
413
414        /// Verifies that [`write_byte`] writes the correct byte to the buffer and advances the position.
415        #[test]
416        fn write_byte_writes_correct_byte() {
417            // Arrange
418            let mut buffer = [0; 5];
419            let mut target = SliceOutputTarget::from(buffer.as_mut_slice());
420
421            // Act
422            let result = target.write_byte(115);
423
424            // Assert
425            assert!(result.is_ok());
426            assert_eq!(target.buffer, [115, 0, 0, 0, 0]);
427            assert_eq!(target.pos, 1);
428            assert_eq!(target.remaining(), 4);
429        }
430
431        /// Verifies that [`write_bytes_exact`] writes the correct bytes to the buffer and advances the position.
432        #[test]
433        fn write_bytes_exact_writes_correct_bytes() {
434            // Arrange
435            let mut buffer = [0; 5];
436            let mut target = SliceOutputTarget::from(buffer.as_mut_slice());
437
438            // Act
439            let result = target.write_bytes_exact(&[115, 108, 105, 99, 101]);
440
441            // Assert
442            assert!(result.is_ok());
443            assert_eq!(target.buffer, [115, 108, 105, 99, 101]);
444            assert_eq!(target.pos, 5);
445            assert_eq!(target.remaining(), 0);
446        }
447
448        /// Verifies that [`reserve_space`] reserves the correct number of bytes in the buffer and advances the
449        /// position past the reserved space so that the next write operation will not write into the reserved space.
450        #[test]
451        fn reserve_space_reserves_correct_space() {
452            // Arrange
453            let mut buffer = [0; 5];
454            let mut target = SliceOutputTarget::from(buffer.as_mut_slice());
455
456            // Act
457            let reserve_result = target.reserve_space(3);
458            let write_result = target.write_byte(99);
459
460            // Assert
461            assert!(reserve_result.is_ok());
462            assert!(write_result.is_ok());
463
464            assert_eq!(reserve_result.unwrap().0, 0..3);
465            assert_eq!(target.pos, 4);
466            assert_eq!(target.remaining(), 1);
467            assert_eq!(target.buffer, [0, 0, 0, 99, 0]);
468        }
469
470        /// Verifies that [`write_bytes_into_reserved_exact`] writes the correct bytes to the reserved space in the
471        /// buffer and does not advance the position past the reserved space.
472        #[test]
473        fn write_bytes_into_reserved_exact_writes_correct_bytes() {
474            // Arrange
475            let mut buffer = [0; 5];
476            let mut target = SliceOutputTarget::from(buffer.as_mut_slice());
477
478            // Should advance the position to 3.
479            let mut reservation = target.reserve_space(3).unwrap();
480
481            // Write a byte to ensure the position is advanced.
482            let _ = target.write_bytes_exact(&[99]);
483
484            // Act
485            let result = target.write_bytes_into_reserved_exact(&mut reservation, &[115, 108, 105]);
486
487            // Write a byte to ensure the position was not advanced.
488            let _ = target.write_byte(101);
489
490            // Assert
491            assert!(result.is_ok());
492            assert_eq!(target.buffer, [115, 108, 105, 99, 101]);
493            assert_eq!(target.pos, 5);
494            assert_eq!(target.remaining(), 0);
495        }
496    }
497}