slice_cell/
io.rs

1use crate::SliceCell;
2use std::{
3    io::{self, Read, Seek, SeekFrom, Write},
4    vec::Vec,
5};
6#[cfg(feature = "tokio")]
7use std::{
8    pin::Pin,
9    task::{Context, Poll},
10};
11#[cfg(feature = "tokio")]
12use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
13
14// mostly copied from stdlib: e.g. library/std/src/io/cursor.rs:281
15// note that these *cannot* implement BufRead, since fill_buf returns a `&[u8]` which could be
16// shared with another thread.
17
18impl Write for &SliceCell<u8> {
19    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
20        let write_len = std::cmp::min(self.len(), buf.len());
21        if write_len > 0 {
22            let dst;
23            (dst, *self) = self.split_at(write_len);
24            dst.copy_from_slice(buf);
25        }
26        Ok(write_len)
27    }
28
29    fn flush(&mut self) -> io::Result<()> {
30        Ok(())
31    }
32}
33
34impl Read for &SliceCell<u8> {
35    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
36        let read_len = std::cmp::min(self.len(), buf.len());
37        if read_len > 0 {
38            let src;
39            (src, *self) = self.split_at(read_len);
40            src.copy_into_slice(&mut buf[..read_len]);
41        }
42        Ok(read_len)
43    }
44    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
45        let read_len = self.len();
46        if read_len == 0 {
47            return Ok(0);
48        }
49
50        buf.reserve(read_len);
51        let write_into = buf.spare_capacity_mut();
52        debug_assert!(write_into.len() >= read_len);
53
54        let src;
55        (src, *self) = self.split_at(read_len);
56        // SAFETY: cannot use `Vec::extend_from_slice`, since it could reallocate, which could be arbitrary user code,
57        // e.g. a custom global allocator could access *src through a `thread_local`.
58        unsafe {
59            // SAFETY: *src does not overlap with write_into,
60            // and no other code can access *src concurrently,
61            // since copy_nonoverlapping is just a memcpy.
62            std::ptr::copy_nonoverlapping(
63                src.as_ptr().cast::<u8>(),
64                write_into.as_mut_ptr().cast(),
65                read_len,
66            );
67            // SAFETY: We just wrote `read_len` bytes into the spare capacity.
68            buf.set_len(buf.len() + read_len);
69        }
70        Ok(read_len)
71    }
72}
73
74pub struct Cursor<T> {
75    inner: T,
76    pos: u64,
77}
78
79// SAFETY: We do not allow access to the inner `T` by (pinned) reference.
80impl<T> Unpin for Cursor<T> {}
81
82impl<T> Cursor<T> {
83    pub const fn new(inner: T) -> Self {
84        Self { inner, pos: 0 }
85    }
86
87    pub fn into_inner(self) -> T {
88        self.inner
89    }
90
91    pub fn position(&self) -> u64 {
92        self.pos
93    }
94
95    pub fn set_position(&mut self, pos: u64) {
96        self.pos = pos;
97    }
98}
99
100impl<T: AsRef<SliceCell<u8>>> Cursor<T> {
101    pub fn remaining_slice(&self) -> &SliceCell<u8> {
102        let len = self.pos.min(self.inner.as_ref().len() as u64);
103        &self.inner.as_ref()[(len as usize)..]
104    }
105}
106
107impl<T: AsRef<SliceCell<u8>>> Write for Cursor<T> {
108    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
109        let slice: &SliceCell<u8> = self.inner.as_ref();
110        let pos = std::cmp::min(self.pos, slice.len() as u64);
111        let amt = (&slice[(pos as usize)..]).write(buf)?;
112        self.pos += amt as u64;
113        Ok(amt)
114    }
115
116    fn flush(&mut self) -> io::Result<()> {
117        Ok(())
118    }
119}
120
121impl<T: AsRef<SliceCell<u8>>> Read for Cursor<T> {
122    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
123        let n = Read::read(&mut self.remaining_slice(), buf)?;
124        self.pos += n as u64;
125        Ok(n)
126    }
127    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
128        let n = buf.len();
129        Read::read_exact(&mut self.remaining_slice(), buf)?;
130        self.pos += n as u64;
131        Ok(())
132    }
133}
134
135impl<T: AsRef<SliceCell<u8>>> Seek for Cursor<T> {
136    fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
137        let (base_pos, offset) = match style {
138            SeekFrom::Start(n) => {
139                self.pos = n;
140                return Ok(n);
141            }
142            SeekFrom::End(n) => (self.inner.as_ref().len() as u64, n),
143            SeekFrom::Current(n) => (self.pos, n),
144        };
145        match base_pos.checked_add_signed(offset) {
146            Some(n) => {
147                self.pos = n;
148                Ok(self.pos)
149            }
150            None => Err(io::Error::new(
151                io::ErrorKind::InvalidInput,
152                "invalid seek to a negative or overflowing position",
153            )),
154        }
155    }
156}
157
158#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
159#[cfg(feature = "tokio")]
160impl AsyncRead for &SliceCell<u8> {
161    fn poll_read(
162        mut self: Pin<&mut Self>,
163        _cx: &mut Context<'_>,
164        buf: &mut ReadBuf<'_>,
165    ) -> Poll<io::Result<()>> {
166        let read_len = std::cmp::min(buf.remaining(), self.len());
167        if read_len > 0 {
168            let src;
169            (src, *self) = self.split_at(read_len);
170            if cfg!(feature = "tokio_assumptions") {
171                // SAFETY: Assumes that `ReadBuf::put_slice` does not perform a
172                // context switch, and that it only accesses itself and the slice
173                // passed to it. (i.e. it does not access *self except through
174                // the reference we pass to it)
175                buf.put_slice(unsafe { &*src.as_ptr() });
176            } else {
177                // SAFETY: we do not de-initialize any bytes of this buffer
178                let unfilled = unsafe { buf.unfilled_mut() };
179                debug_assert!(
180                    read_len <= unfilled.len(),
181                    "unfilled.len() should be == buf.remaining()"
182                );
183                // SAFETY: we are copying read_len bytes from `src` into `unfilled`.
184                // We know they do not overlap, since `unfilled` is `&mut [u8]` and
185                // `src` is a cell type.
186                // We know we do not go off the end of either, since `read_len` is
187                // the minimum of their lengths.
188                unsafe {
189                    std::ptr::copy_nonoverlapping(
190                        src.as_ptr() as *const u8,
191                        unfilled.as_mut_ptr().cast(),
192                        read_len,
193                    );
194                }
195                // SAFETY: we just wrote read_len bytes
196                unsafe {
197                    buf.assume_init(read_len);
198                }
199                buf.advance(read_len);
200            }
201        }
202        Poll::Ready(Ok(()))
203    }
204}
205
206#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
207#[cfg(feature = "tokio")]
208impl AsyncWrite for &SliceCell<u8> {
209    fn poll_write(
210        mut self: Pin<&mut Self>,
211        _: &mut Context<'_>,
212        buf: &[u8],
213    ) -> Poll<Result<usize, io::Error>> {
214        Poll::Ready(self.write(buf))
215    }
216
217    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
218        Poll::Ready(Ok(()))
219    }
220
221    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
222        Poll::Ready(Ok(()))
223    }
224}
225
226#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
227#[cfg(feature = "tokio")]
228impl<T: AsRef<SliceCell<u8>>> AsyncRead for Cursor<T> {
229    fn poll_read(
230        mut self: Pin<&mut Self>,
231        cx: &mut Context<'_>,
232        buf: &mut ReadBuf<'_>,
233    ) -> Poll<io::Result<()>> {
234        let old_len = buf.filled().len();
235        std::task::ready!(AsyncRead::poll_read(
236            Pin::new(&mut self.remaining_slice()),
237            cx,
238            buf
239        ))?;
240        let new_len = buf.filled().len();
241        self.pos += (new_len - old_len) as u64;
242        Poll::Ready(Ok(()))
243    }
244}
245
246#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
247#[cfg(feature = "tokio")]
248impl<T: AsRef<SliceCell<u8>>> AsyncWrite for Cursor<T> {
249    fn poll_write(
250        mut self: Pin<&mut Self>,
251        _: &mut Context<'_>,
252        buf: &[u8],
253    ) -> Poll<Result<usize, io::Error>> {
254        Poll::Ready(self.write(buf))
255    }
256
257    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
258        Poll::Ready(Ok(()))
259    }
260
261    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
262        Poll::Ready(Ok(()))
263    }
264}
265
266#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
267#[cfg(feature = "tokio")]
268impl<T: AsRef<SliceCell<u8>>> AsyncSeek for Cursor<T> {
269    fn start_seek(mut self: Pin<&mut Self>, style: SeekFrom) -> io::Result<()> {
270        self.seek(style)?;
271        Ok(())
272    }
273
274    fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> {
275        Poll::Ready(Ok(self.pos))
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::{io::Cursor, SliceCell};
282    use alloc::boxed::Box;
283    use std::io::{Read, Seek, Write};
284
285    #[test]
286    fn concurrent() {
287        let data: Box<SliceCell<u8>> =
288            SliceCell::new_boxed(std::vec![0u8; 2048].into_boxed_slice());
289        let mut writer: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
290        let mut reader: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
291        let mut buf = [0u8; 14];
292
293        writer.write(b"Hello, world!!").unwrap();
294
295        reader.read(&mut buf).unwrap();
296        assert_eq!(buf, *b"Hello, world!!");
297
298        reader.read(&mut buf).unwrap();
299        assert_eq!(buf, [0u8; 14]);
300
301        writer.write(b"Wonderful day!").unwrap();
302        writer.write(b"wow, much cell").unwrap();
303
304        reader.read(&mut buf).unwrap();
305        assert_eq!(buf, *b"wow, much cell");
306
307        reader.seek(std::io::SeekFrom::Start(0)).unwrap();
308
309        reader.read(&mut buf).unwrap();
310        assert_eq!(buf, *b"Hello, world!!");
311        reader.read(&mut buf).unwrap();
312        assert_eq!(buf, *b"Wonderful day!");
313        reader.read(&mut buf).unwrap();
314        assert_eq!(buf, *b"wow, much cell");
315    }
316
317    #[test]
318    fn rc() {
319        let data = SliceCell::try_new_rc(std::vec![0u8; 2048].into()).unwrap();
320        let mut writer = Cursor::new(data.clone());
321        let mut reader = Cursor::new(data.clone());
322        drop(data);
323        let mut buf = [0u8; 14];
324
325        writer.write(b"Hello, world!!").unwrap();
326
327        reader.read(&mut buf).unwrap();
328        assert_eq!(buf, *b"Hello, world!!");
329
330        reader.read(&mut buf).unwrap();
331        assert_eq!(buf, [0u8; 14]);
332
333        writer.write(b"Wonderful day!").unwrap();
334        writer.write(b"wow, much cell").unwrap();
335
336        reader.read(&mut buf).unwrap();
337        assert_eq!(buf, *b"wow, much cell");
338
339        reader.seek(std::io::SeekFrom::Start(0)).unwrap();
340
341        reader.read(&mut buf).unwrap();
342        assert_eq!(buf, *b"Hello, world!!");
343        reader.read(&mut buf).unwrap();
344        assert_eq!(buf, *b"Wonderful day!");
345        reader.read(&mut buf).unwrap();
346        assert_eq!(buf, *b"wow, much cell");
347    }
348}