bitar/archive_reader/
io_reader.rs

1use async_trait::async_trait;
2use bytes::{Bytes, BytesMut};
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use futures_util::{ready, stream::Stream};
6use std::io;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, ReadBuf};
8
9use crate::archive_reader::ArchiveReader;
10use crate::ChunkOffset;
11
12/// Wrapper which implements ArchiveReader for any type which implements
13/// tokio AsyncRead and AsyncSeek.
14pub struct IoReader<T>(T);
15
16impl<T> IoReader<T> {
17    pub fn new(inner: T) -> Self {
18        Self(inner)
19    }
20}
21
22impl<T> From<T> for IoReader<T> {
23    fn from(inner: T) -> Self {
24        Self(inner)
25    }
26}
27
28#[async_trait]
29impl<T> ArchiveReader for IoReader<T>
30where
31    T: AsyncRead + AsyncSeek + Unpin + Send,
32{
33    type Error = io::Error;
34
35    async fn read_at(&mut self, offset: u64, size: usize) -> Result<Bytes, io::Error> {
36        self.0.seek(io::SeekFrom::Start(offset)).await?;
37        let mut buf = BytesMut::with_capacity(size);
38        while buf.len() < size {
39            if self.0.read_buf(&mut buf).await? == 0 {
40                return Err(io::ErrorKind::UnexpectedEof.into());
41            }
42        }
43        Ok(buf.freeze())
44    }
45
46    fn read_chunks<'a>(
47        &'a mut self,
48        chunks: Vec<ChunkOffset>,
49    ) -> Pin<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + 'a>> {
50        Box::pin(IoChunkReader::new(&mut self.0, chunks))
51    }
52}
53
54enum IoChunkReaderState {
55    Seek,
56    PollSeek,
57    Read,
58}
59
60struct IoChunkReader<'a, R>
61where
62    R: AsyncRead + AsyncSeek + Unpin + Send + ?Sized,
63{
64    state: IoChunkReaderState,
65    chunks: Vec<ChunkOffset>,
66    chunk_index: usize,
67    buf: BytesMut,
68    buf_offset: usize,
69    reader: &'a mut R,
70}
71
72impl<'a, R> IoChunkReader<'a, R>
73where
74    R: AsyncRead + AsyncSeekExt + Unpin + Send + ?Sized,
75{
76    fn new(reader: &'a mut R, chunks: Vec<ChunkOffset>) -> Self {
77        let first = chunks
78            .first()
79            .cloned()
80            .unwrap_or(ChunkOffset { offset: 0, size: 0 });
81        Self {
82            reader,
83            state: IoChunkReaderState::Seek,
84            chunk_index: 0,
85            buf: BytesMut::with_capacity(first.size),
86            chunks,
87            buf_offset: 0,
88        }
89    }
90
91    fn poll_chunk(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, io::Error>>>
92    where
93        R: AsyncSeek + AsyncRead + Send + Unpin,
94        Self: Unpin + Send,
95    {
96        while self.chunk_index < self.chunks.len() {
97            let read_at = &self.chunks[self.chunk_index];
98            if self.buf_offset >= read_at.size {
99                self.buf_offset = 0;
100                self.chunk_index += 1;
101                self.state = IoChunkReaderState::Seek;
102                let chunk = self.buf.clone();
103                return Poll::Ready(Some(Ok(chunk.freeze())));
104            }
105            match self.state {
106                IoChunkReaderState::Seek => {
107                    match Pin::new(&mut self.reader).start_seek(io::SeekFrom::Start(read_at.offset))
108                    {
109                        Ok(()) => self.state = IoChunkReaderState::PollSeek,
110                        Err(err) => return Poll::Ready(Some(Err(err))),
111                    }
112                }
113                IoChunkReaderState::PollSeek => {
114                    match ready!(Pin::new(&mut self.reader).poll_complete(cx)) {
115                        Ok(_rc) => self.state = IoChunkReaderState::Read,
116                        Err(err) => return Poll::Ready(Some(Err(err))),
117                    }
118                }
119                IoChunkReaderState::Read => {
120                    if self.buf.len() != read_at.size {
121                        self.buf.resize(read_at.size, 0);
122                    }
123                    let mut buf = ReadBuf::new(&mut self.buf[self.buf_offset..]);
124                    match ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf)) {
125                        Ok(()) if buf.filled().is_empty() => {
126                            return Poll::Ready(Some(Err(io::Error::new(
127                                io::ErrorKind::UnexpectedEof,
128                                "archive ended unexpectedly",
129                            ))));
130                        }
131                        Ok(()) => self.buf_offset += buf.filled().len(),
132                        Err(err) => return Poll::Ready(Some(Err(err))),
133                    }
134                }
135            }
136        }
137        Poll::Ready(None)
138    }
139}
140
141impl<R> Stream for IoChunkReader<'_, R>
142where
143    R: AsyncRead + AsyncSeek + Unpin + Send,
144{
145    type Item = Result<Bytes, io::Error>;
146
147    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
148        self.poll_chunk(cx)
149    }
150
151    fn size_hint(&self) -> (usize, Option<usize>) {
152        let chunks_left = self.chunks.len() - self.chunk_index;
153        (chunks_left, Some(chunks_left))
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use futures_util::{pin_mut, StreamExt};
161    use std::io::Write;
162    use tempfile::NamedTempFile;
163    use tokio::fs::File;
164
165    #[tokio::test]
166    async fn local_read_single_small() {
167        let mut file = NamedTempFile::new().unwrap();
168        let expected: Vec<u8> = b"hello file".to_vec();
169        file.write_all(&expected).unwrap();
170        let reader = IoReader(File::open(&file.path()).await.unwrap());
171        pin_mut!(reader);
172        let read_back = reader.read_at(0, expected.len()).await.unwrap();
173        assert_eq!(read_back, expected);
174    }
175
176    #[tokio::test]
177    async fn local_read_single_big() {
178        let mut file = NamedTempFile::new().unwrap();
179        let expected: Vec<u8> = (0..10 * 1024 * 1024).map(|v| v as u8).collect();
180        file.write_all(&expected).unwrap();
181        let reader = IoReader(File::open(&file.path()).await.unwrap());
182        pin_mut!(reader);
183        let read_back = reader.read_at(0, expected.len()).await.unwrap();
184        assert_eq!(read_back, expected);
185    }
186
187    #[tokio::test]
188    async fn local_read_chunks() {
189        let mut file = NamedTempFile::new().unwrap();
190        let expected: Vec<u8> = (0..10 * 1024 * 1024).map(|v| v as u8).collect();
191        let chunks = vec![
192            ChunkOffset::new(0, 10),
193            ChunkOffset::new(10, 20),
194            ChunkOffset::new(30, 30),
195            ChunkOffset::new(60, 100),
196            ChunkOffset::new(160, 200),
197            ChunkOffset::new(360, 400),
198            ChunkOffset::new(760, 8 * 1024 * 1024),
199        ];
200        file.write_all(&expected).unwrap();
201        let mut reader = IoReader(File::open(&file.path()).await.unwrap());
202        let stream = reader.read_chunks(chunks.clone());
203        {
204            pin_mut!(stream);
205            let mut chunk_offset = 0;
206            let mut chunk_count = 0;
207            while let Some(chunk) = stream.next().await {
208                let chunk = chunk.unwrap();
209                let chunk_size = chunks[chunk_count].size;
210                assert_eq!(chunk, &expected[chunk_offset..chunk_offset + chunk_size]);
211                chunk_offset += chunk_size;
212                chunk_count += 1;
213            }
214            assert_eq!(chunk_count, chunks.len());
215        }
216    }
217}