bitar/archive_reader/
io_reader.rs1use async_trait::async_trait;
2use bytes::{Bytes, BytesMut};
3use core::pin::Pin;
4use core::task::{Context, Poll};
5use futures_util::{ready, stream::Stream};
6use std::io;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt, ReadBuf};
8
9use crate::archive_reader::ArchiveReader;
10use crate::ChunkOffset;
11
12pub struct IoReader<T>(T);
15
16impl<T> IoReader<T> {
17 pub fn new(inner: T) -> Self {
18 Self(inner)
19 }
20}
21
22impl<T> From<T> for IoReader<T> {
23 fn from(inner: T) -> Self {
24 Self(inner)
25 }
26}
27
28#[async_trait]
29impl<T> ArchiveReader for IoReader<T>
30where
31 T: AsyncRead + AsyncSeek + Unpin + Send,
32{
33 type Error = io::Error;
34
35 async fn read_at(&mut self, offset: u64, size: usize) -> Result<Bytes, io::Error> {
36 self.0.seek(io::SeekFrom::Start(offset)).await?;
37 let mut buf = BytesMut::with_capacity(size);
38 while buf.len() < size {
39 if self.0.read_buf(&mut buf).await? == 0 {
40 return Err(io::ErrorKind::UnexpectedEof.into());
41 }
42 }
43 Ok(buf.freeze())
44 }
45
46 fn read_chunks<'a>(
47 &'a mut self,
48 chunks: Vec<ChunkOffset>,
49 ) -> Pin<Box<dyn Stream<Item = Result<Bytes, io::Error>> + Send + 'a>> {
50 Box::pin(IoChunkReader::new(&mut self.0, chunks))
51 }
52}
53
54enum IoChunkReaderState {
55 Seek,
56 PollSeek,
57 Read,
58}
59
60struct IoChunkReader<'a, R>
61where
62 R: AsyncRead + AsyncSeek + Unpin + Send + ?Sized,
63{
64 state: IoChunkReaderState,
65 chunks: Vec<ChunkOffset>,
66 chunk_index: usize,
67 buf: BytesMut,
68 buf_offset: usize,
69 reader: &'a mut R,
70}
71
72impl<'a, R> IoChunkReader<'a, R>
73where
74 R: AsyncRead + AsyncSeekExt + Unpin + Send + ?Sized,
75{
76 fn new(reader: &'a mut R, chunks: Vec<ChunkOffset>) -> Self {
77 let first = chunks
78 .first()
79 .cloned()
80 .unwrap_or(ChunkOffset { offset: 0, size: 0 });
81 Self {
82 reader,
83 state: IoChunkReaderState::Seek,
84 chunk_index: 0,
85 buf: BytesMut::with_capacity(first.size),
86 chunks,
87 buf_offset: 0,
88 }
89 }
90
91 fn poll_chunk(&mut self, cx: &mut Context) -> Poll<Option<Result<Bytes, io::Error>>>
92 where
93 R: AsyncSeek + AsyncRead + Send + Unpin,
94 Self: Unpin + Send,
95 {
96 while self.chunk_index < self.chunks.len() {
97 let read_at = &self.chunks[self.chunk_index];
98 if self.buf_offset >= read_at.size {
99 self.buf_offset = 0;
100 self.chunk_index += 1;
101 self.state = IoChunkReaderState::Seek;
102 let chunk = self.buf.clone();
103 return Poll::Ready(Some(Ok(chunk.freeze())));
104 }
105 match self.state {
106 IoChunkReaderState::Seek => {
107 match Pin::new(&mut self.reader).start_seek(io::SeekFrom::Start(read_at.offset))
108 {
109 Ok(()) => self.state = IoChunkReaderState::PollSeek,
110 Err(err) => return Poll::Ready(Some(Err(err))),
111 }
112 }
113 IoChunkReaderState::PollSeek => {
114 match ready!(Pin::new(&mut self.reader).poll_complete(cx)) {
115 Ok(_rc) => self.state = IoChunkReaderState::Read,
116 Err(err) => return Poll::Ready(Some(Err(err))),
117 }
118 }
119 IoChunkReaderState::Read => {
120 if self.buf.len() != read_at.size {
121 self.buf.resize(read_at.size, 0);
122 }
123 let mut buf = ReadBuf::new(&mut self.buf[self.buf_offset..]);
124 match ready!(Pin::new(&mut self.reader).poll_read(cx, &mut buf)) {
125 Ok(()) if buf.filled().is_empty() => {
126 return Poll::Ready(Some(Err(io::Error::new(
127 io::ErrorKind::UnexpectedEof,
128 "archive ended unexpectedly",
129 ))));
130 }
131 Ok(()) => self.buf_offset += buf.filled().len(),
132 Err(err) => return Poll::Ready(Some(Err(err))),
133 }
134 }
135 }
136 }
137 Poll::Ready(None)
138 }
139}
140
141impl<R> Stream for IoChunkReader<'_, R>
142where
143 R: AsyncRead + AsyncSeek + Unpin + Send,
144{
145 type Item = Result<Bytes, io::Error>;
146
147 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
148 self.poll_chunk(cx)
149 }
150
151 fn size_hint(&self) -> (usize, Option<usize>) {
152 let chunks_left = self.chunks.len() - self.chunk_index;
153 (chunks_left, Some(chunks_left))
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160 use futures_util::{pin_mut, StreamExt};
161 use std::io::Write;
162 use tempfile::NamedTempFile;
163 use tokio::fs::File;
164
165 #[tokio::test]
166 async fn local_read_single_small() {
167 let mut file = NamedTempFile::new().unwrap();
168 let expected: Vec<u8> = b"hello file".to_vec();
169 file.write_all(&expected).unwrap();
170 let reader = IoReader(File::open(&file.path()).await.unwrap());
171 pin_mut!(reader);
172 let read_back = reader.read_at(0, expected.len()).await.unwrap();
173 assert_eq!(read_back, expected);
174 }
175
176 #[tokio::test]
177 async fn local_read_single_big() {
178 let mut file = NamedTempFile::new().unwrap();
179 let expected: Vec<u8> = (0..10 * 1024 * 1024).map(|v| v as u8).collect();
180 file.write_all(&expected).unwrap();
181 let reader = IoReader(File::open(&file.path()).await.unwrap());
182 pin_mut!(reader);
183 let read_back = reader.read_at(0, expected.len()).await.unwrap();
184 assert_eq!(read_back, expected);
185 }
186
187 #[tokio::test]
188 async fn local_read_chunks() {
189 let mut file = NamedTempFile::new().unwrap();
190 let expected: Vec<u8> = (0..10 * 1024 * 1024).map(|v| v as u8).collect();
191 let chunks = vec![
192 ChunkOffset::new(0, 10),
193 ChunkOffset::new(10, 20),
194 ChunkOffset::new(30, 30),
195 ChunkOffset::new(60, 100),
196 ChunkOffset::new(160, 200),
197 ChunkOffset::new(360, 400),
198 ChunkOffset::new(760, 8 * 1024 * 1024),
199 ];
200 file.write_all(&expected).unwrap();
201 let mut reader = IoReader(File::open(&file.path()).await.unwrap());
202 let stream = reader.read_chunks(chunks.clone());
203 {
204 pin_mut!(stream);
205 let mut chunk_offset = 0;
206 let mut chunk_count = 0;
207 while let Some(chunk) = stream.next().await {
208 let chunk = chunk.unwrap();
209 let chunk_size = chunks[chunk_count].size;
210 assert_eq!(chunk, &expected[chunk_offset..chunk_offset + chunk_size]);
211 chunk_offset += chunk_size;
212 chunk_count += 1;
213 }
214 assert_eq!(chunk_count, chunks.len());
215 }
216 }
217}