1use crate::SliceCell;
2use std::{
3 io::{self, Read, Seek, SeekFrom, Write},
4 vec::Vec,
5};
6#[cfg(feature = "tokio")]
7use std::{
8 pin::Pin,
9 task::{Context, Poll},
10};
11#[cfg(feature = "tokio")]
12use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite, ReadBuf};
13
14impl Write for &SliceCell<u8> {
19 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
20 let write_len = std::cmp::min(self.len(), buf.len());
21 if write_len > 0 {
22 let dst;
23 (dst, *self) = self.split_at(write_len);
24 dst.copy_from_slice(buf);
25 }
26 Ok(write_len)
27 }
28
29 fn flush(&mut self) -> io::Result<()> {
30 Ok(())
31 }
32}
33
34impl Read for &SliceCell<u8> {
35 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
36 let read_len = std::cmp::min(self.len(), buf.len());
37 if read_len > 0 {
38 let src;
39 (src, *self) = self.split_at(read_len);
40 src.copy_into_slice(&mut buf[..read_len]);
41 }
42 Ok(read_len)
43 }
44 fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
45 let read_len = self.len();
46 if read_len == 0 {
47 return Ok(0);
48 }
49
50 buf.reserve(read_len);
51 let write_into = buf.spare_capacity_mut();
52 debug_assert!(write_into.len() >= read_len);
53
54 let src;
55 (src, *self) = self.split_at(read_len);
56 unsafe {
59 std::ptr::copy_nonoverlapping(
63 src.as_ptr().cast::<u8>(),
64 write_into.as_mut_ptr().cast(),
65 read_len,
66 );
67 buf.set_len(buf.len() + read_len);
69 }
70 Ok(read_len)
71 }
72}
73
74pub struct Cursor<T> {
75 inner: T,
76 pos: u64,
77}
78
79impl<T> Unpin for Cursor<T> {}
81
82impl<T> Cursor<T> {
83 pub const fn new(inner: T) -> Self {
84 Self { inner, pos: 0 }
85 }
86
87 pub fn into_inner(self) -> T {
88 self.inner
89 }
90
91 pub fn position(&self) -> u64 {
92 self.pos
93 }
94
95 pub fn set_position(&mut self, pos: u64) {
96 self.pos = pos;
97 }
98}
99
100impl<T: AsRef<SliceCell<u8>>> Cursor<T> {
101 pub fn remaining_slice(&self) -> &SliceCell<u8> {
102 let len = self.pos.min(self.inner.as_ref().len() as u64);
103 &self.inner.as_ref()[(len as usize)..]
104 }
105}
106
107impl<T: AsRef<SliceCell<u8>>> Write for Cursor<T> {
108 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
109 let slice: &SliceCell<u8> = self.inner.as_ref();
110 let pos = std::cmp::min(self.pos, slice.len() as u64);
111 let amt = (&slice[(pos as usize)..]).write(buf)?;
112 self.pos += amt as u64;
113 Ok(amt)
114 }
115
116 fn flush(&mut self) -> io::Result<()> {
117 Ok(())
118 }
119}
120
121impl<T: AsRef<SliceCell<u8>>> Read for Cursor<T> {
122 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
123 let n = Read::read(&mut self.remaining_slice(), buf)?;
124 self.pos += n as u64;
125 Ok(n)
126 }
127 fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
128 let n = buf.len();
129 Read::read_exact(&mut self.remaining_slice(), buf)?;
130 self.pos += n as u64;
131 Ok(())
132 }
133}
134
135impl<T: AsRef<SliceCell<u8>>> Seek for Cursor<T> {
136 fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
137 let (base_pos, offset) = match style {
138 SeekFrom::Start(n) => {
139 self.pos = n;
140 return Ok(n);
141 }
142 SeekFrom::End(n) => (self.inner.as_ref().len() as u64, n),
143 SeekFrom::Current(n) => (self.pos, n),
144 };
145 match base_pos.checked_add_signed(offset) {
146 Some(n) => {
147 self.pos = n;
148 Ok(self.pos)
149 }
150 None => Err(io::Error::new(
151 io::ErrorKind::InvalidInput,
152 "invalid seek to a negative or overflowing position",
153 )),
154 }
155 }
156}
157
158#[cfg(feature = "tokio")]
159impl AsyncRead for &SliceCell<u8> {
160 fn poll_read(
161 mut self: Pin<&mut Self>,
162 _cx: &mut Context<'_>,
163 buf: &mut ReadBuf<'_>,
164 ) -> Poll<io::Result<()>> {
165 let read_len = std::cmp::min(buf.remaining(), self.len());
166 if read_len > 0 {
167 let src;
168 (src, *self) = self.split_at(read_len);
169 if cfg!(feature = "tokio_assumptions") {
170 buf.put_slice(unsafe { &*src.as_ptr() });
175 } else {
176 let unfilled = unsafe { buf.unfilled_mut() };
178 debug_assert!(
179 read_len <= unfilled.len(),
180 "unfilled.len() should be == buf.remaining()"
181 );
182 unsafe {
188 std::ptr::copy_nonoverlapping(
189 src.as_ptr() as *const u8,
190 unfilled.as_mut_ptr().cast(),
191 read_len,
192 );
193 }
194 unsafe {
196 buf.assume_init(read_len);
197 }
198 buf.advance(read_len);
199 }
200 }
201 Poll::Ready(Ok(()))
202 }
203}
204
205#[cfg(feature = "tokio")]
206impl AsyncWrite for &SliceCell<u8> {
207 fn poll_write(
208 mut self: Pin<&mut Self>,
209 _: &mut Context<'_>,
210 buf: &[u8],
211 ) -> Poll<Result<usize, io::Error>> {
212 Poll::Ready(self.write(buf))
213 }
214
215 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
216 Poll::Ready(Ok(()))
217 }
218
219 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
220 Poll::Ready(Ok(()))
221 }
222}
223
224#[cfg(feature = "tokio")]
225impl<T: AsRef<SliceCell<u8>>> AsyncRead for Cursor<T> {
226 fn poll_read(
227 mut self: Pin<&mut Self>,
228 cx: &mut Context<'_>,
229 buf: &mut ReadBuf<'_>,
230 ) -> Poll<io::Result<()>> {
231 let old_len = buf.filled().len();
232 std::task::ready!(AsyncRead::poll_read(
233 Pin::new(&mut self.remaining_slice()),
234 cx,
235 buf
236 ))?;
237 let new_len = buf.filled().len();
238 self.pos += (new_len - old_len) as u64;
239 Poll::Ready(Ok(()))
240 }
241}
242
243#[cfg(feature = "tokio")]
244impl<T: AsRef<SliceCell<u8>>> AsyncWrite for Cursor<T> {
245 fn poll_write(
246 mut self: Pin<&mut Self>,
247 _: &mut Context<'_>,
248 buf: &[u8],
249 ) -> Poll<Result<usize, io::Error>> {
250 Poll::Ready(self.write(buf))
251 }
252
253 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
254 Poll::Ready(Ok(()))
255 }
256
257 fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
258 Poll::Ready(Ok(()))
259 }
260}
261
262#[cfg(feature = "tokio")]
263impl<T: AsRef<SliceCell<u8>>> AsyncSeek for Cursor<T> {
264 fn start_seek(mut self: Pin<&mut Self>, style: SeekFrom) -> io::Result<()> {
265 self.seek(style)?;
266 Ok(())
267 }
268
269 fn poll_complete(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<u64>> {
270 Poll::Ready(Ok(self.pos))
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use crate::{io::Cursor, SliceCell};
277 use alloc::boxed::Box;
278 use std::io::{Read, Seek, Write};
279
280 #[test]
281 fn concurrent() {
282 let data: Box<SliceCell<u8>> =
283 SliceCell::new_boxed(std::vec![0u8; 2048].into_boxed_slice());
284 let mut writer: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
285 let mut reader: Cursor<&SliceCell<u8>> = Cursor::new(&*data);
286 let mut buf = [0u8; 14];
287
288 writer.write(b"Hello, world!!").unwrap();
289
290 reader.read(&mut buf).unwrap();
291 assert_eq!(buf, *b"Hello, world!!");
292
293 reader.read(&mut buf).unwrap();
294 assert_eq!(buf, [0u8; 14]);
295
296 writer.write(b"Wonderful day!").unwrap();
297 writer.write(b"wow, much cell").unwrap();
298
299 reader.read(&mut buf).unwrap();
300 assert_eq!(buf, *b"wow, much cell");
301
302 reader.seek(std::io::SeekFrom::Start(0)).unwrap();
303
304 reader.read(&mut buf).unwrap();
305 assert_eq!(buf, *b"Hello, world!!");
306 reader.read(&mut buf).unwrap();
307 assert_eq!(buf, *b"Wonderful day!");
308 reader.read(&mut buf).unwrap();
309 assert_eq!(buf, *b"wow, much cell");
310 }
311
312 #[test]
313 fn rc() {
314 let data = SliceCell::try_new_rc(std::vec![0u8; 2048].into()).unwrap();
315 let mut writer = Cursor::new(data.clone());
316 let mut reader = Cursor::new(data.clone());
317 drop(data);
318 let mut buf = [0u8; 14];
319
320 writer.write(b"Hello, world!!").unwrap();
321
322 reader.read(&mut buf).unwrap();
323 assert_eq!(buf, *b"Hello, world!!");
324
325 reader.read(&mut buf).unwrap();
326 assert_eq!(buf, [0u8; 14]);
327
328 writer.write(b"Wonderful day!").unwrap();
329 writer.write(b"wow, much cell").unwrap();
330
331 reader.read(&mut buf).unwrap();
332 assert_eq!(buf, *b"wow, much cell");
333
334 reader.seek(std::io::SeekFrom::Start(0)).unwrap();
335
336 reader.read(&mut buf).unwrap();
337 assert_eq!(buf, *b"Hello, world!!");
338 reader.read(&mut buf).unwrap();
339 assert_eq!(buf, *b"Wonderful day!");
340 reader.read(&mut buf).unwrap();
341 assert_eq!(buf, *b"wow, much cell");
342 }
343}