buf_list/cursor/
mod.rs

1// Copyright (c) The buf-list Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "futures03")]
5mod futures_imp;
6#[cfg(test)]
7mod tests;
8#[cfg(feature = "tokio1")]
9mod tokio_imp;
10
11use crate::{BufList, errors::ReadExactError};
12use bytes::{Buf, Bytes};
13use std::{
14    cmp::Ordering,
15    io::{self, IoSlice, IoSliceMut, SeekFrom},
16};
17
18/// A `Cursor` wraps an in-memory `BufList` and provides it with a [`Seek`] implementation.
19///
20/// `Cursor`s allow `BufList`s to implement [`Read`] and [`BufRead`], allowing a `BufList` to be
21/// used anywhere you might use a reader or writer that does actual I/O.
22///
23/// The cursor may either own or borrow a `BufList`: both `Cursor<BufList>` and `Cursor<&BufList>`
24/// are supported.
25///
26/// # Optional features
27///
28/// * `tokio1`: With this feature enabled, [`Cursor`] implements the `tokio` crate's
29///   [`AsyncSeek`](tokio::io::AsyncSeek), [`AsyncRead`](tokio::io::AsyncRead) and
30///   [`AsyncBufRead`](tokio::io::AsyncBufRead).
31/// * `futures03`: With this feature enabled, [`Cursor`] implements the `futures` crate's
32///   [`AsyncSeek`](futures_io_03::AsyncSeek), [`AsyncRead`](futures_io_03::AsyncRead) and
33///   [`AsyncBufRead`](futures_io_03::AsyncBufRead).
34///
35/// [`Read`]: std::io::Read
36/// [`BufRead`]: std::io::BufRead
37/// [`Seek`]: std::io::Seek
38pub struct Cursor<T> {
39    inner: T,
40
41    /// Data associated with the cursor.
42    data: CursorData,
43}
44
45impl<T: AsRef<BufList>> Cursor<T> {
46    /// Creates a new cursor wrapping the provided `BufList`.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use buf_list::{BufList, Cursor};
52    ///
53    /// let cursor = Cursor::new(BufList::new());
54    /// ```
55    pub fn new(inner: T) -> Cursor<T> {
56        let data = CursorData::new();
57        Cursor { inner, data }
58    }
59
60    /// Consumes this cursor, returning the underlying value.
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use buf_list::{BufList, Cursor};
66    ///
67    /// let cursor = Cursor::new(BufList::new());
68    ///
69    /// let vec = cursor.into_inner();
70    /// ```
71    pub fn into_inner(self) -> T {
72        self.inner
73    }
74
75    /// Gets a reference to the underlying value in this cursor.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use buf_list::{BufList, Cursor};
81    ///
82    /// let cursor = Cursor::new(BufList::new());
83    ///
84    /// let reference = cursor.get_ref();
85    /// ```
86    pub const fn get_ref(&self) -> &T {
87        &self.inner
88    }
89
90    /// Returns the current position of this cursor.
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// use buf_list::{BufList, Cursor};
96    /// use std::io::prelude::*;
97    /// use std::io::SeekFrom;
98    ///
99    /// let mut cursor = Cursor::new(BufList::from(&[1, 2, 3, 4, 5][..]));
100    ///
101    /// assert_eq!(cursor.position(), 0);
102    ///
103    /// cursor.seek(SeekFrom::Current(2)).unwrap();
104    /// assert_eq!(cursor.position(), 2);
105    ///
106    /// cursor.seek(SeekFrom::Current(-1)).unwrap();
107    /// assert_eq!(cursor.position(), 1);
108    /// ```
109    pub const fn position(&self) -> u64 {
110        self.data.pos
111    }
112
113    /// Sets the position of this cursor.
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use buf_list::{BufList, Cursor};
119    ///
120    /// let mut cursor = Cursor::new(BufList::from(&[1, 2, 3, 4, 5][..]));
121    ///
122    /// assert_eq!(cursor.position(), 0);
123    ///
124    /// cursor.set_position(2);
125    /// assert_eq!(cursor.position(), 2);
126    ///
127    /// cursor.set_position(4);
128    /// assert_eq!(cursor.position(), 4);
129    /// ```
130    pub fn set_position(&mut self, pos: u64) {
131        self.data.set_pos(self.inner.as_ref(), pos);
132    }
133
134    // ---
135    // Helper methods
136    // ---
137    #[cfg(test)]
138    fn assert_invariants(&self) -> anyhow::Result<()> {
139        self.data.assert_invariants(self.inner.as_ref())
140    }
141}
142
143impl<T> Clone for Cursor<T>
144where
145    T: Clone,
146{
147    #[inline]
148    fn clone(&self) -> Self {
149        Cursor {
150            inner: self.inner.clone(),
151            data: self.data.clone(),
152        }
153    }
154
155    #[inline]
156    fn clone_from(&mut self, other: &Self) {
157        self.inner.clone_from(&other.inner);
158        self.data = other.data.clone();
159    }
160}
161
162impl<T: AsRef<BufList>> io::Seek for Cursor<T> {
163    fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
164        self.data.seek_impl(self.inner.as_ref(), style)
165    }
166
167    fn stream_position(&mut self) -> io::Result<u64> {
168        Ok(self.data.pos)
169    }
170}
171
172impl<T: AsRef<BufList>> io::Read for Cursor<T> {
173    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
174        Ok(self.data.read_impl(self.inner.as_ref(), buf))
175    }
176
177    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
178        Ok(self.data.read_vectored_impl(self.inner.as_ref(), bufs))
179    }
180
181    // TODO: is_read_vectored once that's available on stable Rust.
182
183    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
184        self.data.read_exact_impl(self.inner.as_ref(), buf)
185    }
186}
187
188impl<T: AsRef<BufList>> io::BufRead for Cursor<T> {
189    fn fill_buf(&mut self) -> io::Result<&[u8]> {
190        Ok(self.data.fill_buf_impl(self.inner.as_ref()))
191    }
192
193    fn consume(&mut self, amt: usize) {
194        self.data.consume_impl(self.inner.as_ref(), amt);
195    }
196}
197
198impl<T: AsRef<BufList>> Buf for Cursor<T> {
199    fn remaining(&self) -> usize {
200        let total = self.data.num_bytes(self.inner.as_ref());
201        total.saturating_sub(self.data.pos) as usize
202    }
203
204    fn chunk(&self) -> &[u8] {
205        self.data.fill_buf_impl(self.inner.as_ref())
206    }
207
208    fn advance(&mut self, amt: usize) {
209        self.data.consume_impl(self.inner.as_ref(), amt);
210    }
211
212    fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize {
213        if iovs.is_empty() {
214            return 0;
215        }
216
217        let list = self.inner.as_ref();
218        let mut filled = 0;
219        let mut current_chunk = self.data.chunk;
220        let mut current_pos = self.data.pos;
221
222        // Iterate through chunks starting from the current position
223        while filled < iovs.len() && current_chunk < list.num_chunks() {
224            if let Some(chunk) = list.get_chunk(current_chunk) {
225                let chunk_start_pos = list.get_start_pos()[current_chunk];
226                let offset_in_chunk = (current_pos - chunk_start_pos) as usize;
227
228                if offset_in_chunk < chunk.len() {
229                    let chunk_slice = &chunk.as_ref()[offset_in_chunk..];
230                    iovs[filled] = IoSlice::new(chunk_slice);
231                    filled += 1;
232                }
233
234                current_chunk += 1;
235                // Move to the start of the next chunk
236                if let Some(&next_start_pos) = list.get_start_pos().get(current_chunk) {
237                    current_pos = next_start_pos;
238                } else {
239                    break;
240                }
241            } else {
242                break;
243            }
244        }
245
246        filled
247    }
248}
249
250#[derive(Clone, Debug)]
251struct CursorData {
252    /// The chunk number the cursor is pointing to. Kept in sync with pos.
253    ///
254    /// This is within the range [0, self.start_pos.len()). It is self.start_pos.len() - 1 iff pos
255    /// is greater than list.num_bytes().
256    chunk: usize,
257
258    /// The overall position in the stream. Kept in sync with chunk.
259    pos: u64,
260}
261
262impl CursorData {
263    fn new() -> Self {
264        Self { chunk: 0, pos: 0 }
265    }
266
267    #[cfg(test)]
268    fn assert_invariants(&self, list: &BufList) -> anyhow::Result<()> {
269        use anyhow::ensure;
270
271        ensure!(
272            self.pos >= list.get_start_pos()[self.chunk],
273            "invariant failed: current position {} >= start position {} (chunk = {})",
274            self.pos,
275            list.get_start_pos()[self.chunk],
276            self.chunk
277        );
278
279        let next_pos = list.get_start_pos().get(self.chunk + 1).copied().into();
280        ensure!(
281            Offset::Value(self.pos) < next_pos,
282            "invariant failed: next start position {:?} > current position {} (chunk = {})",
283            next_pos,
284            self.pos,
285            self.chunk
286        );
287
288        Ok(())
289    }
290
291    fn seek_impl(&mut self, list: &BufList, style: SeekFrom) -> io::Result<u64> {
292        let (base_pos, offset) = match style {
293            SeekFrom::Start(n) => {
294                self.set_pos(list, n);
295                return Ok(n);
296            }
297            SeekFrom::End(n) => (self.num_bytes(list), n),
298            SeekFrom::Current(n) => (self.pos, n),
299        };
300        // Can't use checked_add_signed since it was only stabilized in Rust 1.66. This is adapted
301        // from
302        // https://github.com/rust-lang/rust/blame/ed937594d3/library/std/src/io/cursor.rs#L295-L299.
303        let new_pos = if offset >= 0 {
304            base_pos.checked_add(offset as u64)
305        } else {
306            base_pos.checked_sub(offset.wrapping_neg() as u64)
307        };
308        match new_pos {
309            Some(n) => {
310                self.set_pos(list, n);
311                Ok(self.pos)
312            }
313            None => Err(io::Error::new(
314                io::ErrorKind::InvalidInput,
315                "invalid seek to a negative or overflowing position",
316            )),
317        }
318    }
319
320    fn read_impl(&mut self, list: &BufList, buf: &mut [u8]) -> usize {
321        // Read as much as possible until we fill up the buffer.
322        let mut buf_pos = 0;
323        while buf_pos < buf.len() {
324            let (chunk, chunk_pos) = match self.get_chunk_and_pos(list) {
325                Some(value) => value,
326                None => break,
327            };
328            // The number of bytes to copy is the smaller of the two:
329            // - the length of the chunk - the position in it.
330            // - the number of bytes remaining, which is buf.len() - buf_pos.
331            let n_to_copy = (chunk.len() - chunk_pos).min(buf.len() - buf_pos);
332            let chunk_bytes = chunk.as_ref();
333
334            let bytes_to_copy = &chunk_bytes[chunk_pos..(chunk_pos + n_to_copy)];
335            let dest = &mut buf[buf_pos..(buf_pos + n_to_copy)];
336            dest.copy_from_slice(bytes_to_copy);
337            buf_pos += n_to_copy;
338
339            // Increment the position.
340            self.pos += n_to_copy as u64;
341            // If we've finished reading through the chunk, move to the next chunk.
342            if n_to_copy == chunk.len() - chunk_pos {
343                self.chunk += 1;
344            }
345        }
346
347        buf_pos
348    }
349
350    fn read_vectored_impl(&mut self, list: &BufList, bufs: &mut [IoSliceMut<'_>]) -> usize {
351        let mut nread = 0;
352        for buf in bufs {
353            // Copy data from the buffer until we run out of bytes to copy.
354            let n = self.read_impl(list, buf);
355            nread += n;
356            if n < buf.len() {
357                break;
358            }
359        }
360        nread
361    }
362
363    fn read_exact_impl(&mut self, list: &BufList, buf: &mut [u8]) -> io::Result<()> {
364        // This is the same as read_impl as long as there's enough space.
365        let total = self.num_bytes(list);
366        let remaining = total.saturating_sub(self.pos);
367        let buf_len = buf.len();
368        if remaining < buf_len as u64 {
369            // Rust 1.80 and above will cause the position to be set to the end
370            // of the buffer, due to (apparently)
371            // https://github.com/rust-lang/rust/pull/125404. Follow that
372            // behavior.
373            self.set_pos(list, total);
374            return Err(io::Error::new(
375                io::ErrorKind::UnexpectedEof,
376                ReadExactError { remaining, buf_len },
377            ));
378        }
379
380        self.read_impl(list, buf);
381        Ok(())
382    }
383
384    fn fill_buf_impl<'a>(&'a self, list: &'a BufList) -> &'a [u8] {
385        const EMPTY_SLICE: &[u8] = &[];
386        match self.get_chunk_and_pos(list) {
387            Some((chunk, chunk_pos)) => &chunk.as_ref()[chunk_pos..],
388            // An empty return value means the end of the buffer has been reached.
389            None => EMPTY_SLICE,
390        }
391    }
392
393    fn consume_impl(&mut self, list: &BufList, amt: usize) {
394        self.set_pos(list, self.pos + amt as u64);
395    }
396
397    fn set_pos(&mut self, list: &BufList, new_pos: u64) {
398        match new_pos.cmp(&self.pos) {
399            Ordering::Greater => {
400                let start_pos = list.get_start_pos();
401                let next_start = start_pos.get(self.chunk + 1).copied().into();
402                if Offset::Value(new_pos) < next_start {
403                    // Within the same chunk.
404                } else {
405                    // The above check ensures that we're not currently pointing to the last index
406                    // (since it would have returned Eof, which is greater than Offset(n) for any
407                    // n).
408                    //
409                    // Do a binary search for this element.
410                    match start_pos[self.chunk + 1..].binary_search(&new_pos) {
411                        // We're starting the search from self.chunk + 1, which means that the value
412                        // returned from binary_search is 1 less than the actual delta.
413                        Ok(delta_minus_one) => {
414                            // Exactly at the start point of a chunk.
415                            self.chunk += 1 + delta_minus_one;
416                        }
417                        // The value returned in the error case (not at the start point of a chunk)
418                        // is (delta - 1) + 1, so just delta.
419                        Err(delta) => {
420                            debug_assert!(
421                                delta > 0,
422                                "delta must be at least 1 since we already \
423                                checked the same chunk (self.chunk = {})",
424                                self.chunk,
425                            );
426                            self.chunk += delta;
427                        }
428                    }
429                }
430            }
431            Ordering::Equal => {}
432            Ordering::Less => {
433                let start_pos = list.get_start_pos();
434                if start_pos.get(self.chunk).copied() <= Some(new_pos) {
435                    // Within the same chunk.
436                } else {
437                    match start_pos[..self.chunk].binary_search(&new_pos) {
438                        Ok(chunk) => {
439                            // Exactly at the start point of a chunk.
440                            self.chunk = chunk;
441                        }
442                        Err(chunk_plus_1) => {
443                            debug_assert!(
444                                chunk_plus_1 > 0,
445                                "chunk_plus_1 must be at least 1 since self.start_pos[0] is 0 \
446                                 (self.chunk = {})",
447                                self.chunk,
448                            );
449                            self.chunk = chunk_plus_1 - 1;
450                        }
451                    }
452                }
453            }
454        }
455        self.pos = new_pos;
456    }
457
458    #[inline]
459    fn get_chunk_and_pos<'b>(&self, list: &'b BufList) -> Option<(&'b Bytes, usize)> {
460        match list.get_chunk(self.chunk) {
461            Some(chunk) => {
462                // This guarantees that pos is not past the end of the list.
463                debug_assert!(
464                    self.pos < self.num_bytes(list),
465                    "self.pos ({}) is less than num_bytes ({})",
466                    self.pos,
467                    self.num_bytes(list)
468                );
469                Some((
470                    chunk,
471                    (self.pos - list.get_start_pos()[self.chunk]) as usize,
472                ))
473            }
474            None => {
475                // pos is past the end of the list.
476                None
477            }
478        }
479    }
480
481    fn num_bytes(&self, list: &BufList) -> u64 {
482        *list
483            .get_start_pos()
484            .last()
485            .expect("start_pos always has at least one element")
486    }
487}
488
489/// This is the same as Option<T> except Offset and Eof are reversed in ordering, i.e. Eof >
490/// Offset(T) for any T.
491#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
492enum Offset<T> {
493    Value(T),
494    Eof,
495}
496
497impl<T> From<Option<T>> for Offset<T> {
498    fn from(value: Option<T>) -> Self {
499        match value {
500            Some(v) => Self::Value(v),
501            None => Self::Eof,
502        }
503    }
504}