slice_cell/
io.rs

1use crate::SliceCell;
2use std::{
3    io::{self, Read, Seek, SeekFrom, Write},
4    vec::Vec,
5};
6#[cfg(feature = "tokio")]
7use std::{
8    pin::Pin,
9    task::{Context, Poll},
10};
11#[cfg(feature = "tokio")]
12use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
13
14// mostly copied from stdlib: e.g. library/std/src/io/cursor.rs:281
15// note that these *cannot* implement BufRead, since fill_buf returns a `&[u8]` which could be
16// shared with another thread.
17
18impl Write for &SliceCell<u8> {
19    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
20        let write_len = std::cmp::min(self.len(), buf.len());
21        if write_len > 0 {
22            let dst;
23            (dst, *self) = self.split_at(write_len);
24            dst.copy_from_slice(buf);
25        }
26        Ok(write_len)
27    }
28
29    fn flush(&mut self) -> io::Result<()> {
30        Ok(())
31    }
32}
33
34impl Read for &SliceCell<u8> {
35    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
36        let read_len = std::cmp::min(self.len(), buf.len());
37        if read_len > 0 {
38            let src;
39            (src, *self) = self.split_at(read_len);
40            src.copy_into_slice(&mut buf[..read_len]);
41        }
42        Ok(read_len)
43    }
44    fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
45        let read_len = self.len();
46        if read_len == 0 {
47            return Ok(0);
48        }
49
50        buf.reserve(read_len);
51        let write_into = buf.spare_capacity_mut();
52        debug_assert!(write_into.len() >= read_len);
53
54        let src;
55        (src, *self) = self.split_at(read_len);
56        // SAFETY: cannot use `Vec::extend_from_slice`, since it could reallocate, which could be arbitrary user code,
57        // e.g. a custom global allocator could access *src through a `thread_local`.
58        unsafe {
59            // SAFETY: *src does not overlap with write_into,
60            // and no other code can access *src concurrently,
61            // since copy_nonoverlapping is just a memcpy.
62            std::ptr::copy_nonoverlapping(
63                src.as_ptr().cast::<u8>(),
64                write_into.as_mut_ptr().cast(),
65                read_len,
66            );
67            // SAFETY: We just wrote `read_len` bytes into the spare capacity.
68            buf.set_len(buf.len() + read_len);
69        }
70        Ok(read_len)
71    }
72}
73
74pub struct Cursor<T> {
75    inner: T,
76    pos: u64,
77}
78
79// SAFETY: We do not allow access to the inner `T` by (pinned) reference.
80impl<T> Unpin for Cursor<T> {}
81
82impl<T> Cursor<T> {
83    pub const fn new(inner: T) -> Self {
84        Self { inner, pos: 0 }
85    }
86
87    pub fn into_inner(self) -> T {
88        self.inner
89    }
90
91    pub fn position(&self) -> u64 {
92        self.pos
93    }
94
95    pub fn set_position(&mut self, pos: u64) {
96        self.pos = pos;
97    }
98}
99
100impl<T: AsRef<SliceCell<u8>>> Cursor<T> {
101    pub fn remaining_slice(&self) -> &SliceCell<u8> {
102        let len = self.pos.min(self.inner.as_ref().len() as u64);
103        &self.inner.as_ref()[(len as usize)..]
104    }
105}
106
107impl<T: AsRef<SliceCell<u8>>> Write for Cursor<T> {
108    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
109        let slice: &SliceCell<u8> = self.inner.as_ref();
110        let pos = std::cmp::min(self.pos, slice.len() as u64);
111        let amt = (&slice[(pos as usize)..]).write(buf)?;
112        self.pos += amt as u64;
113        Ok(amt)
114    }
115
116    fn flush(&mut self) -> io::Result<()> {
117        Ok(())
118    }
119}
120
121impl<T: AsRef<SliceCell<u8>>> Read for Cursor<T> {
122    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
123        let n = Read::read(&mut self.remaining_slice(), buf)?;
124        self.pos += n as u64;
125        Ok(n)
126    }
127    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
128        let n = buf.len();
129        Read::read_exact(&mut self.remaining_slice(), buf)?;
130        self.pos += n as u64;
131        Ok(())
132    }
133}
134
135impl<T: AsRef<SliceCell<u8>>> Seek for Cursor<T> {
136    fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
137        let (base_pos, offset) = match style {
138            SeekFrom::Start(n) => {
139                self.pos = n;
140                return Ok(n);
141            }
142            SeekFrom::End(n) => (self.inner.as_ref().len() as u64, n),
143            SeekFrom::Current(n) => (self.pos, n),
144        };
145        match base_pos.checked_add_signed(offset) {
146            Some(n) => {
147                self.pos = n;
148                Ok(self.pos)
149            }
150            None => Err(io::Error::new(
151                io::ErrorKind::InvalidInput,
152                "invalid seek to a negative or overflowing position",
153            )),
154        }
155    }
156}
157
158#[cfg(feature = "tokio")]
159impl AsyncRead for &SliceCell<u8> {
160    fn poll_read(
161        mut self: Pin<&mut Self>,
162        _cx: &mut Context<'_>,
163        buf: &mut ReadBuf<'_>,
164    ) -> Poll<io::Result<()>> {
165        let read_len = std::cmp::min(buf.remaining(), self.len());
166        if read_len > 0 {
167            let src;
168            (src, *self) = self.split_at(read_len);
169            if cfg!(feature = "tokio_assumptions") {
170                // SAFETY: Assumes that `ReadBuf::put_slice` does not perform a
171                // context switch, and that it only accesses itself and the slice
172                // passed to it. (i.e. it does not access *self except through
173                // the reference we pass to it)
174                buf.put_slice(unsafe { &*src.as_ptr() });
175            } else {
176                // SAFETY: we do not de-initialize any bytes of this buffer
177                let unfilled = unsafe { buf.unfilled_mut() };
178                debug_assert!(
179                    read_len <= unfilled.len(),
180                    "unfilled.len() should be == buf.remaining()"
181                );
182                // SAFETY: we are copying read_len bytes from `src` into `unfilled`.
183                // We know they do not overlap, since `unfilled` is `&mut [u8]` and
184                // `src` is a cell type.
185                // We know we do not go off the end of either, since `read_len` is
186                // the minimum of their lengths.
187                unsafe {
188                    std::ptr::copy_nonoverlapping(
189                        src.as_ptr() as *const u8,
190                        unfilled.as_mut_ptr().cast(),
191                        read_len,
192                    );
193                }
194                // SAFETY: we just wrote read_len bytes
195                unsafe {
196                    buf.assume_init(read_len);
197                }
198                buf.advance(read_len);
199            }
200        }
201        Poll::Ready(Ok(()))
202    }
203}
204
205#[cfg(feature = "tokio")]
206impl AsyncWrite for &SliceCell<u8> {
207    fn poll_write(
208        mut self: Pin<&mut Self>,
209        _: &mut Context<'_>,
210        buf: &[u8],
211    ) -> Poll<Result<usize, io::Error>> {
212        Poll::Ready(self.write(buf))
213    }
214
215    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
216        Poll::Ready(Ok(()))
217    }
218
219    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
220        Poll::Ready(Ok(()))
221    }
222}
223
224#[cfg(feature = "tokio")]
225impl<T: AsRef<SliceCell<u8>>> AsyncRead for Cursor<T> {
226    fn poll_read(
227        mut self: Pin<&mut Self>,
228        cx: &mut Context<'_>,
229        buf: &mut ReadBuf<'_>,
230    ) -> Poll<io::Result<()>> {
231        let old_len = buf.filled().len();
232        std::task::ready!(AsyncRead::poll_read(
233            Pin::new(&mut self.remaining_slice()),
234            cx,
235            buf
236        ))?;
237        let new_len = buf.filled().len();
238        self.pos += (new_len - old_len) as u64;
239        Poll::Ready(Ok(()))
240    }
241}
242
243#[cfg(feature = "tokio")]
244impl<T: AsRef<SliceCell<u8>>> AsyncWrite for Cursor<T> {
245    fn poll_write(
246        mut self: Pin<&mut Self>,
247        _: &mut Context<'_>,
248        buf: &[u8],
249    ) -> Poll<Result<usize, io::Error>> {
250        Poll::Ready(self.write(buf))
251    }
252
253    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
254        Poll::Ready(Ok(()))
255    }
256
257    fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
258        Poll::Ready(Ok(()))
259    }
260}
261
262#[cfg(feature = "tokio")]
263impl<T: AsRef<SliceCell<u8>>> AsyncSeek for Cursor<T> {
264    fn start_seek(mut self: Pin<&mut Self>, style: SeekFrom) -> io::Result<()> {
265        self.seek(style)?;
266        Ok(())
267    }
268
269    fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> {
270        Poll::Ready(Ok(self.pos))
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use crate::{io::Cursor, SliceCell};
277    use alloc::boxed::Box;
278    use std::io::{Read, Seek, Write};
279
280    #[test]
281    fn concurrent() {
282        let data: Box<SliceCell<u8>> =
283            SliceCell::new_boxed(std::vec![0u8; 2048].into_boxed_slice());
284        let mut writer: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
285        let mut reader: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
286        let mut buf = [0u8; 14];
287
288        writer.write(b"Hello, world!!").unwrap();
289
290        reader.read(&mut buf).unwrap();
291        assert_eq!(buf, *b"Hello, world!!");
292
293        reader.read(&mut buf).unwrap();
294        assert_eq!(buf, [0u8; 14]);
295
296        writer.write(b"Wonderful day!").unwrap();
297        writer.write(b"wow, much cell").unwrap();
298
299        reader.read(&mut buf).unwrap();
300        assert_eq!(buf, *b"wow, much cell");
301
302        reader.seek(std::io::SeekFrom::Start(0)).unwrap();
303
304        reader.read(&mut buf).unwrap();
305        assert_eq!(buf, *b"Hello, world!!");
306        reader.read(&mut buf).unwrap();
307        assert_eq!(buf, *b"Wonderful day!");
308        reader.read(&mut buf).unwrap();
309        assert_eq!(buf, *b"wow, much cell");
310    }
311
312    #[test]
313    fn rc() {
314        let data = SliceCell::try_new_rc(std::vec![0u8; 2048].into()).unwrap();
315        let mut writer = Cursor::new(data.clone());
316        let mut reader = Cursor::new(data.clone());
317        drop(data);
318        let mut buf = [0u8; 14];
319
320        writer.write(b"Hello, world!!").unwrap();
321
322        reader.read(&mut buf).unwrap();
323        assert_eq!(buf, *b"Hello, world!!");
324
325        reader.read(&mut buf).unwrap();
326        assert_eq!(buf, [0u8; 14]);
327
328        writer.write(b"Wonderful day!").unwrap();
329        writer.write(b"wow, much cell").unwrap();
330
331        reader.read(&mut buf).unwrap();
332        assert_eq!(buf, *b"wow, much cell");
333
334        reader.seek(std::io::SeekFrom::Start(0)).unwrap();
335
336        reader.read(&mut buf).unwrap();
337        assert_eq!(buf, *b"Hello, world!!");
338        reader.read(&mut buf).unwrap();
339        assert_eq!(buf, *b"Wonderful day!");
340        reader.read(&mut buf).unwrap();
341        assert_eq!(buf, *b"wow, much cell");
342    }
343}