#![allow(unsafe_code)]
use std::fs::File;
use std::future::Future;
use std::io;
use std::mem::MaybeUninit;
use std::os::fd::AsRawFd;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;
use tokio::io::AsyncRead;
use tokio::task;
use crate::{BlobRange, ReadStream};
pub(crate) fn file_reader(file: impl Into<Arc<File>>, range: BlobRange) -> ReadStream<'static> {
let range = range.unwrap_or((0, u64::MAX));
let buf_size = (range.1 - range.0).min(1 << 21) as usize;
Box::pin(FileReader {
file: file.into(),
offset: range.0,
end: range.1,
state: FileReaderState::Idle(new_uninit(buf_size)),
})
}
type UnsafeBuf = Box<[MaybeUninit<u8>]>;
fn new_uninit(n: usize) -> UnsafeBuf {
use std::alloc::{alloc, handle_alloc_error, Layout};
let layout = Layout::array::<MaybeUninit<u8>>(n)
.expect("failed to create allocation layout for reading buf");
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
handle_alloc_error(layout);
}
let slice = std::ptr::slice_from_raw_parts_mut(ptr.cast(), n);
unsafe { Box::from_raw(slice) }
}
unsafe fn slice_assume_init_ref(slice: &[MaybeUninit<u8>]) -> &[u8] {
unsafe { &*(slice as *const [MaybeUninit<u8>] as *const [u8]) }
}
struct FileReader {
file: Arc<File>,
offset: u64,
end: u64,
state: FileReaderState,
}
#[derive(Default)]
enum FileReaderState {
Pending(task::JoinHandle<io::Result<(UnsafeBuf, usize)>>),
Queued(UnsafeBuf, usize, usize),
Idle(UnsafeBuf),
#[default]
None,
}
impl AsyncRead for FileReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
dst: &mut tokio::io::ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
loop {
match std::mem::take(&mut this.state) {
FileReaderState::Queued(buf, start, end) => {
assert!(start < end);
let size = std::cmp::min(dst.remaining(), end - start);
dst.put_slice(unsafe { slice_assume_init_ref(&buf[start..start + size]) });
if start + size == end {
this.state = FileReaderState::Idle(buf);
} else {
this.state = FileReaderState::Queued(buf, start + size, end);
}
return Poll::Ready(Ok(()));
}
FileReaderState::Pending(mut handle) => {
let (buf, n) = match Pin::new(&mut handle).poll(cx) {
Poll::Ready(result) => result??,
Poll::Pending => {
this.state = FileReaderState::Pending(handle);
return Poll::Pending;
}
};
if n == 0 {
this.end = this.offset; this.state = FileReaderState::Idle(buf);
} else {
this.offset += n as u64;
this.state = FileReaderState::Queued(buf, 0, n);
}
}
FileReaderState::Idle(mut buf) => {
let offset = this.offset;
if offset >= this.end {
this.state = FileReaderState::Idle(buf);
return Poll::Ready(Ok(())); }
let file = Arc::clone(&this.file);
let read_len = (this.end - offset).min(buf.len() as u64) as usize;
this.state = FileReaderState::Pending(task::spawn_blocking(move || {
let result = unsafe {
libc::pread(
file.as_raw_fd(),
buf.as_mut_ptr().cast(),
read_len,
offset as i64,
)
};
if result < 0 {
Err(io::Error::last_os_error())
} else {
Ok::<_, io::Error>((buf, result as usize))
}
}));
}
FileReaderState::None => unreachable!(),
}
}
}
}
#[cfg(test)]
mod tests {
use std::{io::Write, sync::Arc};
use anyhow::Result;
use tokio::task;
use super::file_reader;
use crate::read_to_vec;
#[tokio::test]
async fn test_file_reader() -> Result<()> {
let file = Arc::new(tempfile::tempfile()?);
let file2 = Arc::clone(&file);
task::spawn_blocking(move || (&mut &*file2).write_all(b"hello world")).await??;
let reader = file_reader(file, None);
assert_eq!(read_to_vec(reader).await?, b"hello world");
Ok(())
}
}