Skip to main content

ax_io/utils/
cursor.rs

1#[cfg(feature = "alloc")]
2use alloc::{boxed::Box, string::String, vec::Vec};
3use core::{cmp, io::BorrowedCursor};
4
5use crate::{BufRead, Error, IoBuf, IoBufMut, Read, Result, Seek, SeekFrom, Write};
6
7/// A `Cursor` wraps an in-memory buffer and provides it with a
8/// [`Seek`] implementation.
9///
10/// `Cursor`s are used with in-memory buffers, anything implementing
11/// <code>[AsRef]<\[u8]></code>, to allow them to implement [`Read`] and/or [`Write`],
12/// allowing these buffers to be used anywhere you might use a reader or writer
13/// that does actual I/O.
14#[derive(Debug, Default, Eq, PartialEq)]
15pub struct Cursor<T> {
16    inner: T,
17    pos: u64,
18}
19
20impl<T> Cursor<T> {
21    /// Creates a new cursor wrapping the provided underlying in-memory buffer.
22    ///
23    /// Cursor initial position is `0` even if underlying buffer (e.g., [`Vec`])
24    /// is not empty. So writing to cursor starts with overwriting [`Vec`]
25    /// content, not with appending to it.
26    pub const fn new(inner: T) -> Cursor<T> {
27        Cursor { pos: 0, inner }
28    }
29
30    /// Consumes this cursor, returning the underlying value.
31    pub fn into_inner(self) -> T {
32        self.inner
33    }
34
35    /// Gets a reference to the underlying value in this cursor.
36    pub const fn get_ref(&self) -> &T {
37        &self.inner
38    }
39
40    /// Gets a mutable reference to the underlying value in this cursor.
41    ///
42    /// Care should be taken to avoid modifying the internal I/O state of the
43    /// underlying value as it may corrupt this cursor's position.
44    pub const fn get_mut(&mut self) -> &mut T {
45        &mut self.inner
46    }
47
48    /// Returns the current position of this cursor.
49    pub const fn position(&self) -> u64 {
50        self.pos
51    }
52
53    /// Sets the position of this cursor.
54    pub const fn set_position(&mut self, pos: u64) {
55        self.pos = pos;
56    }
57}
58
59impl<T> Cursor<T>
60where
61    T: AsRef<[u8]>,
62{
63    /// Splits the underlying slice at the cursor position and returns them.
64    pub fn split(&self) -> (&[u8], &[u8]) {
65        let slice = self.inner.as_ref();
66        let pos = self.pos.min(slice.len() as u64);
67        slice.split_at(pos as usize)
68    }
69}
70
71impl<T> Cursor<T>
72where
73    T: AsMut<[u8]>,
74{
75    /// Splits the underlying slice at the cursor position and returns them
76    /// mutably.
77    pub fn split_mut(&mut self) -> (&mut [u8], &mut [u8]) {
78        let slice = self.inner.as_mut();
79        let pos = self.pos.min(slice.len() as u64);
80        slice.split_at_mut(pos as usize)
81    }
82}
83
84impl<T> Clone for Cursor<T>
85where
86    T: Clone,
87{
88    #[inline]
89    fn clone(&self) -> Self {
90        Cursor {
91            inner: self.inner.clone(),
92            pos: self.pos,
93        }
94    }
95
96    #[inline]
97    fn clone_from(&mut self, other: &Self) {
98        self.inner.clone_from(&other.inner);
99        self.pos = other.pos;
100    }
101}
102
103impl<T> Seek for Cursor<T>
104where
105    T: AsRef<[u8]>,
106{
107    fn seek(&mut self, style: SeekFrom) -> Result<u64> {
108        let (base_pos, offset) = match style {
109            SeekFrom::Start(n) => {
110                self.pos = n;
111                return Ok(n);
112            }
113            SeekFrom::End(n) => (self.inner.as_ref().len() as u64, n),
114            SeekFrom::Current(n) => (self.pos, n),
115        };
116        match base_pos.checked_add_signed(offset) {
117            Some(n) => {
118                self.pos = n;
119                Ok(self.pos)
120            }
121            None => Err(Error::InvalidInput),
122        }
123    }
124
125    fn stream_len(&mut self) -> Result<u64> {
126        Ok(self.inner.as_ref().len() as u64)
127    }
128
129    fn stream_position(&mut self) -> Result<u64> {
130        Ok(self.pos)
131    }
132}
133
134impl<T> Read for Cursor<T>
135where
136    T: AsRef<[u8]>,
137{
138    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
139        let n = Read::read(&mut Cursor::split(self).1, buf)?;
140        self.pos += n as u64;
141        Ok(n)
142    }
143
144    fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
145        let result = Read::read_exact(&mut Cursor::split(self).1, buf);
146
147        match result {
148            Ok(_) => self.pos += buf.len() as u64,
149            // The only possible error condition is EOF, so place the cursor at "EOF"
150            Err(_) => self.pos = self.inner.as_ref().len() as u64,
151        }
152
153        result
154    }
155
156    fn read_buf(&mut self, mut cursor: BorrowedCursor<'_>) -> Result<()> {
157        let prev_written = cursor.written();
158
159        Read::read_buf(&mut Cursor::split(self).1, cursor.reborrow())?;
160
161        self.pos += (cursor.written() - prev_written) as u64;
162
163        Ok(())
164    }
165
166    fn read_buf_exact(&mut self, mut cursor: BorrowedCursor<'_>) -> Result<()> {
167        let prev_written = cursor.written();
168
169        let result = Read::read_buf_exact(&mut Cursor::split(self).1, cursor.reborrow());
170        self.pos += (cursor.written() - prev_written) as u64;
171
172        result
173    }
174
175    #[cfg(feature = "alloc")]
176    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> Result<usize> {
177        let content = Cursor::split(self).1;
178        let len = content.len();
179        buf.try_reserve(len).map_err(|_| Error::NoMemory)?;
180        buf.extend_from_slice(content);
181        self.pos += len as u64;
182
183        Ok(len)
184    }
185
186    #[cfg(feature = "alloc")]
187    fn read_to_string(&mut self, buf: &mut String) -> Result<usize> {
188        let content = str::from_utf8(Cursor::split(self).1).map_err(|_| Error::IllegalBytes)?;
189        let len = content.len();
190        buf.try_reserve(len).map_err(|_| Error::NoMemory)?;
191        buf.push_str(content);
192        self.pos += len as u64;
193
194        Ok(len)
195    }
196}
197
198impl<T> BufRead for Cursor<T>
199where
200    T: AsRef<[u8]>,
201{
202    fn fill_buf(&mut self) -> Result<&[u8]> {
203        Ok(Cursor::split(self).1)
204    }
205
206    fn consume(&mut self, amt: usize) {
207        self.pos += amt as u64;
208    }
209}
210
211fn slice_write(pos_mut: &mut u64, slice: &mut [u8], buf: &[u8]) -> Result<usize> {
212    let pos = cmp::min(*pos_mut, slice.len() as u64);
213    let amt = (&mut slice[(pos as usize)..]).write(buf)?;
214    *pos_mut += amt as u64;
215    Ok(amt)
216}
217
218#[inline]
219fn slice_write_all(pos_mut: &mut u64, slice: &mut [u8], buf: &[u8]) -> Result<()> {
220    let n = slice_write(pos_mut, slice, buf)?;
221    if n < buf.len() {
222        Err(Error::WriteZero)
223    } else {
224        Ok(())
225    }
226}
227
228/// Reserves the required space, and pads the vec with 0s if necessary.
229#[cfg(feature = "alloc")]
230fn reserve_and_pad(pos_mut: &mut u64, vec: &mut Vec<u8>, buf_len: usize) -> Result<usize> {
231    let pos: usize = (*pos_mut).try_into().map_err(|_| Error::InvalidInput)?;
232
233    // For safety reasons, we don't want these numbers to overflow
234    // otherwise our allocation won't be enough
235    let desired_cap = pos.saturating_add(buf_len);
236    if desired_cap > vec.capacity() {
237        // We want our vec's total capacity
238        // to have room for (pos+buf_len) bytes. Reserve allocates
239        // based on additional elements from the length, so we need to
240        // reserve the difference
241        vec.reserve(desired_cap - vec.len());
242    }
243    // Pad if pos is above the current len.
244    if pos > vec.len() {
245        let diff = pos - vec.len();
246        // Unfortunately, `resize()` would suffice but the optimiser does not
247        // realise the `reserve` it does can be eliminated. So we do it manually
248        // to eliminate that extra branch
249        let spare = vec.spare_capacity_mut();
250        debug_assert!(spare.len() >= diff);
251        // Safety: we have allocated enough capacity for this.
252        // And we are only writing, not reading
253        unsafe {
254            spare
255                .get_unchecked_mut(..diff)
256                .fill(core::mem::MaybeUninit::new(0));
257            vec.set_len(pos);
258        }
259    }
260
261    Ok(pos)
262}
263
264/// Writes the slice to the vec without allocating.
265///
266/// # Safety
267///
268/// `vec` must have `buf.len()` spare capacity.
269#[cfg(feature = "alloc")]
270unsafe fn vec_write_all_unchecked(pos: usize, vec: &mut Vec<u8>, buf: &[u8]) -> usize {
271    debug_assert!(vec.capacity() >= pos + buf.len());
272    unsafe { vec.as_mut_ptr().add(pos).copy_from(buf.as_ptr(), buf.len()) };
273    pos + buf.len()
274}
275
276/// Resizing `write_all` implementation for [`Cursor`].
277///
278/// Cursor is allowed to have a pre-allocated and initialised
279/// vector body, but with a position of 0. This means the [`Write`]
280/// will overwrite the contents of the vec.
281///
282/// This also allows for the vec body to be empty, but with a position of N.
283/// This means that [`Write`] will pad the vec with 0 initially,
284/// before writing anything from that point
285#[cfg(feature = "alloc")]
286fn vec_write_all(pos_mut: &mut u64, vec: &mut Vec<u8>, buf: &[u8]) -> Result<usize> {
287    let buf_len = buf.len();
288    let mut pos = reserve_and_pad(pos_mut, vec, buf_len)?;
289
290    // Write the buf then progress the vec forward if necessary
291    // Safety: we have ensured that the capacity is available
292    // and that all bytes get written up to pos
293    unsafe {
294        pos = vec_write_all_unchecked(pos, vec, buf);
295        if pos > vec.len() {
296            vec.set_len(pos);
297        }
298    };
299
300    // Bump us forward
301    *pos_mut += buf_len as u64;
302    Ok(buf_len)
303}
304
305impl Write for Cursor<&mut [u8]> {
306    #[inline]
307    fn write(&mut self, buf: &[u8]) -> Result<usize> {
308        slice_write(&mut self.pos, self.inner, buf)
309    }
310
311    #[inline]
312    fn write_all(&mut self, buf: &[u8]) -> Result<()> {
313        slice_write_all(&mut self.pos, self.inner, buf)
314    }
315
316    #[inline]
317    fn flush(&mut self) -> Result<()> {
318        Ok(())
319    }
320}
321
322#[cfg(feature = "alloc")]
323impl Write for Cursor<&mut Vec<u8>> {
324    fn write(&mut self, buf: &[u8]) -> Result<usize> {
325        vec_write_all(&mut self.pos, self.inner, buf)
326    }
327
328    fn write_all(&mut self, buf: &[u8]) -> Result<()> {
329        vec_write_all(&mut self.pos, self.inner, buf)?;
330        Ok(())
331    }
332
333    #[inline]
334    fn flush(&mut self) -> Result<()> {
335        Ok(())
336    }
337}
338
339#[cfg(feature = "alloc")]
340impl Write for Cursor<Vec<u8>> {
341    fn write(&mut self, buf: &[u8]) -> Result<usize> {
342        vec_write_all(&mut self.pos, &mut self.inner, buf)
343    }
344
345    fn write_all(&mut self, buf: &[u8]) -> Result<()> {
346        vec_write_all(&mut self.pos, &mut self.inner, buf)?;
347        Ok(())
348    }
349
350    #[inline]
351    fn flush(&mut self) -> Result<()> {
352        Ok(())
353    }
354}
355
356#[cfg(feature = "alloc")]
357impl Write for Cursor<Box<[u8]>> {
358    #[inline]
359    fn write(&mut self, buf: &[u8]) -> Result<usize> {
360        slice_write(&mut self.pos, &mut self.inner, buf)
361    }
362
363    #[inline]
364    fn write_all(&mut self, buf: &[u8]) -> Result<()> {
365        slice_write_all(&mut self.pos, &mut self.inner, buf)
366    }
367
368    #[inline]
369    fn flush(&mut self) -> Result<()> {
370        Ok(())
371    }
372}
373
374impl<const N: usize> Write for Cursor<[u8; N]> {
375    #[inline]
376    fn write(&mut self, buf: &[u8]) -> Result<usize> {
377        slice_write(&mut self.pos, &mut self.inner, buf)
378    }
379
380    #[inline]
381    fn write_all(&mut self, buf: &[u8]) -> Result<()> {
382        slice_write_all(&mut self.pos, &mut self.inner, buf)
383    }
384
385    #[inline]
386    fn flush(&mut self) -> Result<()> {
387        Ok(())
388    }
389}
390
391impl<T: IoBuf> IoBuf for Cursor<T> {
392    #[inline]
393    fn remaining(&self) -> usize {
394        self.inner.remaining() - (self.pos as usize)
395    }
396}
397
398impl<T: IoBufMut> IoBufMut for Cursor<T> {
399    #[inline]
400    fn remaining_mut(&self) -> usize {
401        self.inner.remaining_mut() - (self.pos as usize)
402    }
403}