1use crate::SliceCell;
2use std::{
3 io::{self, Read, Seek, SeekFrom, Write},
4 vec::Vec,
5};
6#[cfg(feature = "tokio")]
7use std::{
8 pin::Pin,
9 task::{Context, Poll},
10};
11#[cfg(feature = "tokio")]
12use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
13
14impl Write for &SliceCell<u8> {
19 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
20 let write_len = std::cmp::min(self.len(), buf.len());
21 if write_len > 0 {
22 let dst;
23 (dst, *self) = self.split_at(write_len);
24 dst.copy_from_slice(buf);
25 }
26 Ok(write_len)
27 }
28
29 fn flush(&mut self) -> io::Result<()> {
30 Ok(())
31 }
32}
33
34impl Read for &SliceCell<u8> {
35 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
36 let read_len = std::cmp::min(self.len(), buf.len());
37 if read_len > 0 {
38 let src;
39 (src, *self) = self.split_at(read_len);
40 src.copy_into_slice(&mut buf[..read_len]);
41 }
42 Ok(read_len)
43 }
44 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
45 let read_len = self.len();
46 if read_len == 0 {
47 return Ok(0);
48 }
49
50 buf.reserve(read_len);
51 let write_into = buf.spare_capacity_mut();
52 debug_assert!(write_into.len() >= read_len);
53
54 let src;
55 (src, *self) = self.split_at(read_len);
56 unsafe {
59 std::ptr::copy_nonoverlapping(
63 src.as_ptr().cast::<u8>(),
64 write_into.as_mut_ptr().cast(),
65 read_len,
66 );
67 buf.set_len(buf.len() + read_len);
69 }
70 Ok(read_len)
71 }
72}
73
74pub struct Cursor<T> {
75 inner: T,
76 pos: u64,
77}
78
79impl<T> Unpin for Cursor<T> {}
81
82impl<T> Cursor<T> {
83 pub const fn new(inner: T) -> Self {
84 Self { inner, pos: 0 }
85 }
86
87 pub fn into_inner(self) -> T {
88 self.inner
89 }
90
91 pub fn position(&self) -> u64 {
92 self.pos
93 }
94
95 pub fn set_position(&mut self, pos: u64) {
96 self.pos = pos;
97 }
98}
99
100impl<T: AsRef<SliceCell<u8>>> Cursor<T> {
101 pub fn remaining_slice(&self) -> &SliceCell<u8> {
102 let len = self.pos.min(self.inner.as_ref().len() as u64);
103 &self.inner.as_ref()[(len as usize)..]
104 }
105}
106
107impl<T: AsRef<SliceCell<u8>>> Write for Cursor<T> {
108 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
109 let slice: &SliceCell<u8> = self.inner.as_ref();
110 let pos = std::cmp::min(self.pos, slice.len() as u64);
111 let amt = (&slice[(pos as usize)..]).write(buf)?;
112 self.pos += amt as u64;
113 Ok(amt)
114 }
115
116 fn flush(&mut self) -> io::Result<()> {
117 Ok(())
118 }
119}
120
121impl<T: AsRef<SliceCell<u8>>> Read for Cursor<T> {
122 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
123 let n = Read::read(&mut self.remaining_slice(), buf)?;
124 self.pos += n as u64;
125 Ok(n)
126 }
127 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
128 let n = buf.len();
129 Read::read_exact(&mut self.remaining_slice(), buf)?;
130 self.pos += n as u64;
131 Ok(())
132 }
133}
134
135impl<T: AsRef<SliceCell<u8>>> Seek for Cursor<T> {
136 fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
137 let (base_pos, offset) = match style {
138 SeekFrom::Start(n) => {
139 self.pos = n;
140 return Ok(n);
141 }
142 SeekFrom::End(n) => (self.inner.as_ref().len() as u64, n),
143 SeekFrom::Current(n) => (self.pos, n),
144 };
145 match base_pos.checked_add_signed(offset) {
146 Some(n) => {
147 self.pos = n;
148 Ok(self.pos)
149 }
150 None => Err(io::Error::new(
151 io::ErrorKind::InvalidInput,
152 "invalid seek to a negative or overflowing position",
153 )),
154 }
155 }
156}
157
158#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
159#[cfg(feature = "tokio")]
160impl AsyncRead for &SliceCell<u8> {
161 fn poll_read(
162 mut self: Pin<&mut Self>,
163 _cx: &mut Context<'_>,
164 buf: &mut ReadBuf<'_>,
165 ) -> Poll<io::Result<()>> {
166 let read_len = std::cmp::min(buf.remaining(), self.len());
167 if read_len > 0 {
168 let src;
169 (src, *self) = self.split_at(read_len);
170 if cfg!(feature = "tokio_assumptions") {
171 buf.put_slice(unsafe { &*src.as_ptr() });
176 } else {
177 let unfilled = unsafe { buf.unfilled_mut() };
179 debug_assert!(
180 read_len <= unfilled.len(),
181 "unfilled.len() should be == buf.remaining()"
182 );
183 unsafe {
189 std::ptr::copy_nonoverlapping(
190 src.as_ptr() as *const u8,
191 unfilled.as_mut_ptr().cast(),
192 read_len,
193 );
194 }
195 unsafe {
197 buf.assume_init(read_len);
198 }
199 buf.advance(read_len);
200 }
201 }
202 Poll::Ready(Ok(()))
203 }
204}
205
206#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
207#[cfg(feature = "tokio")]
208impl AsyncWrite for &SliceCell<u8> {
209 fn poll_write(
210 mut self: Pin<&mut Self>,
211 _: &mut Context<'_>,
212 buf: &[u8],
213 ) -> Poll<Result<usize, io::Error>> {
214 Poll::Ready(self.write(buf))
215 }
216
217 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
218 Poll::Ready(Ok(()))
219 }
220
221 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
222 Poll::Ready(Ok(()))
223 }
224}
225
226#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
227#[cfg(feature = "tokio")]
228impl<T: AsRef<SliceCell<u8>>> AsyncRead for Cursor<T> {
229 fn poll_read(
230 mut self: Pin<&mut Self>,
231 cx: &mut Context<'_>,
232 buf: &mut ReadBuf<'_>,
233 ) -> Poll<io::Result<()>> {
234 let old_len = buf.filled().len();
235 std::task::ready!(AsyncRead::poll_read(
236 Pin::new(&mut self.remaining_slice()),
237 cx,
238 buf
239 ))?;
240 let new_len = buf.filled().len();
241 self.pos += (new_len - old_len) as u64;
242 Poll::Ready(Ok(()))
243 }
244}
245
246#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
247#[cfg(feature = "tokio")]
248impl<T: AsRef<SliceCell<u8>>> AsyncWrite for Cursor<T> {
249 fn poll_write(
250 mut self: Pin<&mut Self>,
251 _: &mut Context<'_>,
252 buf: &[u8],
253 ) -> Poll<Result<usize, io::Error>> {
254 Poll::Ready(self.write(buf))
255 }
256
257 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
258 Poll::Ready(Ok(()))
259 }
260
261 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
262 Poll::Ready(Ok(()))
263 }
264}
265
266#[cfg_attr(feature = "nightly_docs", doc(cfg(feature = "tokio")))]
267#[cfg(feature = "tokio")]
268impl<T: AsRef<SliceCell<u8>>> AsyncSeek for Cursor<T> {
269 fn start_seek(mut self: Pin<&mut Self>, style: SeekFrom) -> io::Result<()> {
270 self.seek(style)?;
271 Ok(())
272 }
273
274 fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> {
275 Poll::Ready(Ok(self.pos))
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use crate::{io::Cursor, SliceCell};
282 use alloc::boxed::Box;
283 use std::io::{Read, Seek, Write};
284
285 #[test]
286 fn concurrent() {
287 let data: Box<SliceCell<u8>> =
288 SliceCell::new_boxed(std::vec![0u8; 2048].into_boxed_slice());
289 let mut writer: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
290 let mut reader: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
291 let mut buf = [0u8; 14];
292
293 writer.write(b"Hello, world!!").unwrap();
294
295 reader.read(&mut buf).unwrap();
296 assert_eq!(buf, *b"Hello, world!!");
297
298 reader.read(&mut buf).unwrap();
299 assert_eq!(buf, [0u8; 14]);
300
301 writer.write(b"Wonderful day!").unwrap();
302 writer.write(b"wow, much cell").unwrap();
303
304 reader.read(&mut buf).unwrap();
305 assert_eq!(buf, *b"wow, much cell");
306
307 reader.seek(std::io::SeekFrom::Start(0)).unwrap();
308
309 reader.read(&mut buf).unwrap();
310 assert_eq!(buf, *b"Hello, world!!");
311 reader.read(&mut buf).unwrap();
312 assert_eq!(buf, *b"Wonderful day!");
313 reader.read(&mut buf).unwrap();
314 assert_eq!(buf, *b"wow, much cell");
315 }
316
317 #[test]
318 fn rc() {
319 let data = SliceCell::try_new_rc(std::vec![0u8; 2048].into()).unwrap();
320 let mut writer = Cursor::new(data.clone());
321 let mut reader = Cursor::new(data.clone());
322 drop(data);
323 let mut buf = [0u8; 14];
324
325 writer.write(b"Hello, world!!").unwrap();
326
327 reader.read(&mut buf).unwrap();
328 assert_eq!(buf, *b"Hello, world!!");
329
330 reader.read(&mut buf).unwrap();
331 assert_eq!(buf, [0u8; 14]);
332
333 writer.write(b"Wonderful day!").unwrap();
334 writer.write(b"wow, much cell").unwrap();
335
336 reader.read(&mut buf).unwrap();
337 assert_eq!(buf, *b"wow, much cell");
338
339 reader.seek(std::io::SeekFrom::Start(0)).unwrap();
340
341 reader.read(&mut buf).unwrap();
342 assert_eq!(buf, *b"Hello, world!!");
343 reader.read(&mut buf).unwrap();
344 assert_eq!(buf, *b"Wonderful day!");
345 reader.read(&mut buf).unwrap();
346 assert_eq!(buf, *b"wow, much cell");
347 }
348}