buf_list/cursor/
mod.rs

1// Copyright (c) The buf-list Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "futures03")]
5mod futures_imp;
6#[cfg(test)]
7mod tests;
8#[cfg(feature = "tokio1")]
9mod tokio_imp;
10
11use crate::{BufList, errors::ReadExactError};
12use bytes::Bytes;
13use std::{
14    cmp::Ordering,
15    io::{self, IoSliceMut, SeekFrom},
16};
17
18/// A `Cursor` wraps an in-memory `BufList` and provides it with a [`Seek`] implementation.
19///
20/// `Cursor`s allow `BufList`s to implement [`Read`] and [`BufRead`], allowing a `BufList` to be
21/// used anywhere you might use a reader or writer that does actual I/O.
22///
23/// The cursor may either own or borrow a `BufList`: both `Cursor<BufList>` and `Cursor<&BufList>`
24/// are supported.
25///
26/// # Optional features
27///
28/// * `tokio1`: With this feature enabled, [`Cursor`] implements the `tokio` crate's
29///   [`AsyncSeek`](tokio::io::AsyncSeek), [`AsyncRead`](tokio::io::AsyncRead) and
30///   [`AsyncBufRead`](tokio::io::AsyncBufRead).
31/// * `futures03`: With this feature enabled, [`Cursor`] implements the `futures` crate's
32///   [`AsyncSeek`](futures_io_03::AsyncSeek), [`AsyncRead`](futures_io_03::AsyncRead) and
33///   [`AsyncBufRead`](futures_io_03::AsyncBufRead).
34///
35/// [`Read`]: std::io::Read
36/// [`BufRead`]: std::io::BufRead
37/// [`Seek`]: std::io::Seek
38pub struct Cursor<T> {
39    inner: T,
40
41    /// Data associated with the cursor.
42    data: CursorData,
43}
44
45impl<T: AsRef<BufList>> Cursor<T> {
46    /// Creates a new cursor wrapping the provided `BufList`.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use buf_list::{BufList, Cursor};
52    ///
53    /// let cursor = Cursor::new(BufList::new());
54    /// ```
55    pub fn new(inner: T) -> Cursor<T> {
56        let data = CursorData::new();
57        Cursor { inner, data }
58    }
59
60    /// Consumes this cursor, returning the underlying value.
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use buf_list::{BufList, Cursor};
66    ///
67    /// let cursor = Cursor::new(BufList::new());
68    ///
69    /// let vec = cursor.into_inner();
70    /// ```
71    pub fn into_inner(self) -> T {
72        self.inner
73    }
74
75    /// Gets a reference to the underlying value in this cursor.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use buf_list::{BufList, Cursor};
81    ///
82    /// let cursor = Cursor::new(BufList::new());
83    ///
84    /// let reference = cursor.get_ref();
85    /// ```
86    pub const fn get_ref(&self) -> &T {
87        &self.inner
88    }
89
90    /// Returns the current position of this cursor.
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// use buf_list::{BufList, Cursor};
96    /// use std::io::prelude::*;
97    /// use std::io::SeekFrom;
98    ///
99    /// let mut cursor = Cursor::new(BufList::from(&[1, 2, 3, 4, 5][..]));
100    ///
101    /// assert_eq!(cursor.position(), 0);
102    ///
103    /// cursor.seek(SeekFrom::Current(2)).unwrap();
104    /// assert_eq!(cursor.position(), 2);
105    ///
106    /// cursor.seek(SeekFrom::Current(-1)).unwrap();
107    /// assert_eq!(cursor.position(), 1);
108    /// ```
109    pub const fn position(&self) -> u64 {
110        self.data.pos
111    }
112
113    /// Sets the position of this cursor.
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use buf_list::{BufList, Cursor};
119    ///
120    /// let mut cursor = Cursor::new(BufList::from(&[1, 2, 3, 4, 5][..]));
121    ///
122    /// assert_eq!(cursor.position(), 0);
123    ///
124    /// cursor.set_position(2);
125    /// assert_eq!(cursor.position(), 2);
126    ///
127    /// cursor.set_position(4);
128    /// assert_eq!(cursor.position(), 4);
129    /// ```
130    pub fn set_position(&mut self, pos: u64) {
131        self.data.set_pos(self.inner.as_ref(), pos);
132    }
133
134    // ---
135    // Helper methods
136    // ---
137    #[cfg(test)]
138    fn assert_invariants(&self) -> anyhow::Result<()> {
139        self.data.assert_invariants(self.inner.as_ref())
140    }
141}
142
143impl<T> Clone for Cursor<T>
144where
145    T: Clone,
146{
147    #[inline]
148    fn clone(&self) -> Self {
149        Cursor {
150            inner: self.inner.clone(),
151            data: self.data.clone(),
152        }
153    }
154
155    #[inline]
156    fn clone_from(&mut self, other: &Self) {
157        self.inner.clone_from(&other.inner);
158        self.data = other.data.clone();
159    }
160}
161
162impl<T: AsRef<BufList>> io::Seek for Cursor<T> {
163    fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
164        self.data.seek_impl(self.inner.as_ref(), style)
165    }
166
167    fn stream_position(&mut self) -> io::Result<u64> {
168        Ok(self.data.pos)
169    }
170}
171
172impl<T: AsRef<BufList>> io::Read for Cursor<T> {
173    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
174        Ok(self.data.read_impl(self.inner.as_ref(), buf))
175    }
176
177    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
178        Ok(self.data.read_vectored_impl(self.inner.as_ref(), bufs))
179    }
180
181    // TODO: is_read_vectored once that's available on stable Rust.
182
183    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
184        self.data.read_exact_impl(self.inner.as_ref(), buf)
185    }
186}
187
188impl<T: AsRef<BufList>> io::BufRead for Cursor<T> {
189    fn fill_buf(&mut self) -> io::Result<&[u8]> {
190        Ok(self.data.fill_buf_impl(self.inner.as_ref()))
191    }
192
193    fn consume(&mut self, amt: usize) {
194        self.data.consume_impl(self.inner.as_ref(), amt);
195    }
196}
197
198#[derive(Clone, Debug)]
199struct CursorData {
200    /// The chunk number the cursor is pointing to. Kept in sync with pos.
201    ///
202    /// This is within the range [0, self.start_pos.len()). It is self.start_pos.len() - 1 iff pos
203    /// is greater than list.num_bytes().
204    chunk: usize,
205
206    /// The overall position in the stream. Kept in sync with chunk.
207    pos: u64,
208}
209
210impl CursorData {
211    fn new() -> Self {
212        Self { chunk: 0, pos: 0 }
213    }
214
215    #[cfg(test)]
216    fn assert_invariants(&self, list: &BufList) -> anyhow::Result<()> {
217        use anyhow::ensure;
218
219        ensure!(
220            self.pos >= list.get_start_pos()[self.chunk],
221            "invariant failed: current position {} >= start position {} (chunk = {})",
222            self.pos,
223            list.get_start_pos()[self.chunk],
224            self.chunk
225        );
226
227        let next_pos = list.get_start_pos().get(self.chunk + 1).copied().into();
228        ensure!(
229            Offset::Value(self.pos) < next_pos,
230            "invariant failed: next start position {:?} > current position {} (chunk = {})",
231            next_pos,
232            self.pos,
233            self.chunk
234        );
235
236        Ok(())
237    }
238
239    fn seek_impl(&mut self, list: &BufList, style: SeekFrom) -> io::Result<u64> {
240        let (base_pos, offset) = match style {
241            SeekFrom::Start(n) => {
242                self.set_pos(list, n);
243                return Ok(n);
244            }
245            SeekFrom::End(n) => (self.num_bytes(list), n),
246            SeekFrom::Current(n) => (self.pos, n),
247        };
248        // Can't use checked_add_signed since it was only stabilized in Rust 1.66. This is adapted
249        // from
250        // https://github.com/rust-lang/rust/blame/ed937594d3/library/std/src/io/cursor.rs#L295-L299.
251        let new_pos = if offset >= 0 {
252            base_pos.checked_add(offset as u64)
253        } else {
254            base_pos.checked_sub(offset.wrapping_neg() as u64)
255        };
256        match new_pos {
257            Some(n) => {
258                self.set_pos(list, n);
259                Ok(self.pos)
260            }
261            None => Err(io::Error::new(
262                io::ErrorKind::InvalidInput,
263                "invalid seek to a negative or overflowing position",
264            )),
265        }
266    }
267
268    fn read_impl(&mut self, list: &BufList, buf: &mut [u8]) -> usize {
269        // Read as much as possible until we fill up the buffer.
270        let mut buf_pos = 0;
271        while buf_pos < buf.len() {
272            let (chunk, chunk_pos) = match self.get_chunk_and_pos(list) {
273                Some(value) => value,
274                None => break,
275            };
276            // The number of bytes to copy is the smaller of the two:
277            // - the length of the chunk - the position in it.
278            // - the number of bytes remaining, which is buf.len() - buf_pos.
279            let n_to_copy = (chunk.len() - chunk_pos).min(buf.len() - buf_pos);
280            let chunk_bytes = chunk.as_ref();
281
282            let bytes_to_copy = &chunk_bytes[chunk_pos..(chunk_pos + n_to_copy)];
283            let dest = &mut buf[buf_pos..(buf_pos + n_to_copy)];
284            dest.copy_from_slice(bytes_to_copy);
285            buf_pos += n_to_copy;
286
287            // Increment the position.
288            self.pos += n_to_copy as u64;
289            // If we've finished reading through the chunk, move to the next chunk.
290            if n_to_copy == chunk.len() - chunk_pos {
291                self.chunk += 1;
292            }
293        }
294
295        buf_pos
296    }
297
298    fn read_vectored_impl(&mut self, list: &BufList, bufs: &mut [IoSliceMut<'_>]) -> usize {
299        let mut nread = 0;
300        for buf in bufs {
301            // Copy data from the buffer until we run out of bytes to copy.
302            let n = self.read_impl(list, buf);
303            nread += n;
304            if n < buf.len() {
305                break;
306            }
307        }
308        nread
309    }
310
311    fn read_exact_impl(&mut self, list: &BufList, buf: &mut [u8]) -> io::Result<()> {
312        // This is the same as read_impl as long as there's enough space.
313        let total = self.num_bytes(list);
314        let remaining = total.saturating_sub(self.pos);
315        let buf_len = buf.len();
316        if remaining < buf_len as u64 {
317            // Rust 1.80 and above will cause the position to be set to the end
318            // of the buffer, due to (apparently)
319            // https://github.com/rust-lang/rust/pull/125404. Follow that
320            // behavior.
321            self.set_pos(list, total);
322            return Err(io::Error::new(
323                io::ErrorKind::UnexpectedEof,
324                ReadExactError { remaining, buf_len },
325            ));
326        }
327
328        self.read_impl(list, buf);
329        Ok(())
330    }
331
332    fn fill_buf_impl<'a>(&'a self, list: &'a BufList) -> &'a [u8] {
333        const EMPTY_SLICE: &[u8] = &[];
334        match self.get_chunk_and_pos(list) {
335            Some((chunk, chunk_pos)) => &chunk.as_ref()[chunk_pos..],
336            // An empty return value means the end of the buffer has been reached.
337            None => EMPTY_SLICE,
338        }
339    }
340
341    fn consume_impl(&mut self, list: &BufList, amt: usize) {
342        self.set_pos(list, self.pos + amt as u64);
343    }
344
345    fn set_pos(&mut self, list: &BufList, new_pos: u64) {
346        match new_pos.cmp(&self.pos) {
347            Ordering::Greater => {
348                let start_pos = list.get_start_pos();
349                let next_start = start_pos.get(self.chunk + 1).copied().into();
350                if Offset::Value(new_pos) < next_start {
351                    // Within the same chunk.
352                } else {
353                    // The above check ensures that we're not currently pointing to the last index
354                    // (since it would have returned Eof, which is greater than Offset(n) for any
355                    // n).
356                    //
357                    // Do a binary search for this element.
358                    match start_pos[self.chunk + 1..].binary_search(&new_pos) {
359                        // We're starting the search from self.chunk + 1, which means that the value
360                        // returned from binary_search is 1 less than the actual delta.
361                        Ok(delta_minus_one) => {
362                            // Exactly at the start point of a chunk.
363                            self.chunk += 1 + delta_minus_one;
364                        }
365                        // The value returned in the error case (not at the start point of a chunk)
366                        // is (delta - 1) + 1, so just delta.
367                        Err(delta) => {
368                            debug_assert!(
369                                delta > 0,
370                                "delta must be at least 1 since we already \
371                                checked the same chunk (self.chunk = {})",
372                                self.chunk,
373                            );
374                            self.chunk += delta;
375                        }
376                    }
377                }
378            }
379            Ordering::Equal => {}
380            Ordering::Less => {
381                let start_pos = list.get_start_pos();
382                if start_pos.get(self.chunk).copied() <= Some(new_pos) {
383                    // Within the same chunk.
384                } else {
385                    match start_pos[..self.chunk].binary_search(&new_pos) {
386                        Ok(chunk) => {
387                            // Exactly at the start point of a chunk.
388                            self.chunk = chunk;
389                        }
390                        Err(chunk_plus_1) => {
391                            debug_assert!(
392                                chunk_plus_1 > 0,
393                                "chunk_plus_1 must be at least 1 since self.start_pos[0] is 0 \
394                                 (self.chunk = {})",
395                                self.chunk,
396                            );
397                            self.chunk = chunk_plus_1 - 1;
398                        }
399                    }
400                }
401            }
402        }
403        self.pos = new_pos;
404    }
405
406    #[inline]
407    fn get_chunk_and_pos<'b>(&self, list: &'b BufList) -> Option<(&'b Bytes, usize)> {
408        match list.get_chunk(self.chunk) {
409            Some(chunk) => {
410                // This guarantees that pos is not past the end of the list.
411                debug_assert!(
412                    self.pos < self.num_bytes(list),
413                    "self.pos ({}) is less than num_bytes ({})",
414                    self.pos,
415                    self.num_bytes(list)
416                );
417                Some((
418                    chunk,
419                    (self.pos - list.get_start_pos()[self.chunk]) as usize,
420                ))
421            }
422            None => {
423                // pos is past the end of the list.
424                None
425            }
426        }
427    }
428
429    fn num_bytes(&self, list: &BufList) -> u64 {
430        *list
431            .get_start_pos()
432            .last()
433            .expect("start_pos always has at least one element")
434    }
435}
436
437/// This is the same as Option<T> except Offset and Eof are reversed in ordering, i.e. Eof >
438/// Offset(T) for any T.
439#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
440enum Offset<T> {
441    Value(T),
442    Eof,
443}
444
445impl<T> From<Option<T>> for Offset<T> {
446    fn from(value: Option<T>) -> Self {
447        match value {
448            Some(v) => Self::Value(v),
449            None => Self::Eof,
450        }
451    }
452}