#![deny(unsafe_code)]
use std::fs::File;
use std::io::{self, Read};
use std::thread::{self, JoinHandle};
use crossbeam_channel::{Receiver, Sender, TryRecvError, bounded};
pub const DEFAULT_CHUNK_SIZE: usize = 4 * 1024 * 1024;
pub const DEFAULT_PREFETCH_DEPTH: usize = 4;
const WILLNEED_LOOKAHEAD: i64 = 128 * 1024 * 1024;
type Item = io::Result<Vec<u8>>;
#[derive(Debug)]
pub struct PrefetchReader {
current: Option<(Vec<u8>, usize)>,
rx: Option<Receiver<Item>>,
handle: Option<JoinHandle<()>>,
consumer_stalls: u64,
bytes_consumed: u64,
}
impl PrefetchReader {
#[must_use]
pub fn new<R: Read + Send + 'static>(inner: R) -> Self {
Self::with_config(inner, DEFAULT_CHUNK_SIZE, DEFAULT_PREFETCH_DEPTH)
}
#[must_use]
pub fn with_config<R: Read + Send + 'static>(
inner: R,
chunk_size: usize,
prefetch_depth: usize,
) -> Self {
Self::build(inner, chunk_size, prefetch_depth, None)
}
#[must_use]
pub fn from_file(file: File) -> Self {
Self::from_file_with_config(file, DEFAULT_CHUNK_SIZE, DEFAULT_PREFETCH_DEPTH)
}
#[must_use]
pub fn from_file_with_config(file: File, chunk_size: usize, prefetch_depth: usize) -> Self {
let hint_fd = crate::os_hints::hint_fd(&file);
Self::build(file, chunk_size, prefetch_depth, hint_fd)
}
fn build<R: Read + Send + 'static>(
inner: R,
chunk_size: usize,
prefetch_depth: usize,
hint_fd: Option<i32>,
) -> Self {
assert!(chunk_size > 0, "PrefetchReader chunk_size must be > 0");
assert!(prefetch_depth > 0, "PrefetchReader prefetch_depth must be > 0");
let (tx, rx) = bounded::<Item>(prefetch_depth);
let handle = thread::Builder::new()
.name("fgumi-prefetch".to_string())
.spawn(move || producer_main(inner, chunk_size, hint_fd, &tx))
.expect("failed to spawn fgumi-prefetch thread");
Self {
current: None,
rx: Some(rx),
handle: Some(handle),
consumer_stalls: 0,
bytes_consumed: 0,
}
}
#[must_use]
pub fn bytes_consumed(&self) -> u64 {
self.bytes_consumed
}
#[must_use]
pub fn consumer_stalls(&self) -> u64 {
self.consumer_stalls
}
}
fn producer_main<R: Read>(inner: R, chunk_size: usize, hint_fd: Option<i32>, tx: &Sender<Item>) {
let tx_for_panic = tx.clone();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
producer_loop(inner, chunk_size, hint_fd, tx);
}));
if let Err(payload) = result {
let msg = match payload.downcast_ref::<&'static str>() {
Some(s) => (*s).to_string(),
None => match payload.downcast_ref::<String>() {
Some(s) => s.clone(),
None => "fgumi-prefetch producer thread panicked".to_string(),
},
};
let _ = tx_for_panic.send(Err(io::Error::other(msg)));
}
}
fn producer_loop<R: Read>(
mut inner: R,
chunk_size: usize,
hint_fd: Option<i32>,
tx: &Sender<Item>,
) {
let mut position: i64 = 0;
loop {
let mut buf = vec![0u8; chunk_size];
let mut filled: usize = 0;
let mut eof = false;
loop {
match inner.read(&mut buf[filled..]) {
Ok(0) => {
eof = true;
break;
}
Ok(n) => {
filled += n;
break;
}
Err(e) if e.kind() == io::ErrorKind::Interrupted => (),
Err(e) => {
if filled > 0 {
buf.truncate(filled);
let _ = tx.send(Ok(buf));
}
let _ = tx.send(Err(e));
return;
}
}
}
if filled == 0 && eof {
return;
}
position = position.saturating_add(i64::try_from(filled).unwrap_or(i64::MAX));
if let Some(fd) = hint_fd {
crate::os_hints::advise_willneed_raw(fd, position, WILLNEED_LOOKAHEAD);
}
buf.truncate(filled);
if tx.send(Ok(buf)).is_err() {
return;
}
if eof {
return;
}
}
}
impl Read for PrefetchReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
loop {
if let Some((data, pos)) = self.current.as_mut() {
if *pos < data.len() {
let n = std::cmp::min(buf.len(), data.len() - *pos);
buf[..n].copy_from_slice(&data[*pos..*pos + n]);
*pos += n;
self.bytes_consumed += n as u64;
return Ok(n);
}
self.current = None;
}
let Some(rx) = self.rx.as_ref() else {
return Ok(0);
};
let item = match rx.try_recv() {
Ok(item) => item,
Err(TryRecvError::Disconnected) => return Ok(0),
Err(TryRecvError::Empty) => {
self.consumer_stalls += 1;
match rx.recv() {
Ok(item) => item,
Err(_) => return Ok(0),
}
}
};
match item {
Ok(data) if !data.is_empty() => self.current = Some((data, 0)),
Ok(_) => {} Err(e) => return Err(e),
}
}
}
}
impl Drop for PrefetchReader {
fn drop(&mut self) {
self.rx = None;
self.current = None;
if let Some(handle) = self.handle.take() {
if handle.join().is_err() {
log::debug!("fgumi-prefetch producer thread panicked during shutdown");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::io::Cursor;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
fn sample_bytes(len: usize) -> Vec<u8> {
(0..len).map(|i| u8::try_from(i % 251).expect("mod 251 fits in u8")).collect()
}
#[test]
fn empty_input_returns_zero_immediately() {
let mut reader = PrefetchReader::new(Cursor::new(Vec::<u8>::new()));
let mut buf = [0u8; 16];
assert_eq!(reader.read(&mut buf).unwrap(), 0);
assert_eq!(reader.bytes_consumed(), 0);
}
#[test]
fn from_file_reads_correctly() {
use std::io::Write;
let data = sample_bytes(50_000);
let mut tmp = tempfile::NamedTempFile::new().expect("create temp file");
tmp.write_all(&data).expect("write temp file");
let file = File::open(tmp.path()).expect("reopen temp file");
let mut reader = PrefetchReader::from_file(file);
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
assert_eq!(out, data);
assert_eq!(reader.bytes_consumed(), data.len() as u64);
}
#[test]
fn read_to_end_small_matches_input() {
let data = b"hello, fgumi prefetch".to_vec();
let mut reader = PrefetchReader::new(Cursor::new(data.clone()));
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
assert_eq!(out, data);
assert_eq!(reader.bytes_consumed(), data.len() as u64);
}
#[test]
fn read_to_end_large_matches_input() {
let data = sample_bytes(1_000_003);
let mut reader = PrefetchReader::with_config(Cursor::new(data.clone()), 8 * 1024, 2);
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
assert_eq!(out, data);
assert_eq!(reader.bytes_consumed(), data.len() as u64);
}
#[test]
fn tiny_chunk_size_and_many_small_reads() {
let data = sample_bytes(5_000);
let mut reader = PrefetchReader::with_config(Cursor::new(data.clone()), 17, 2);
let mut out = Vec::new();
let mut tmp = [0u8; 7];
loop {
let n = reader.read(&mut tmp).unwrap();
if n == 0 {
break;
}
out.extend_from_slice(&tmp[..n]);
}
assert_eq!(out, data);
}
#[test]
fn repeated_read_after_eof_returns_zero_forever() {
let mut reader = PrefetchReader::new(Cursor::new(b"abc".to_vec()));
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
assert_eq!(out, b"abc");
let mut tmp = [0u8; 8];
for _ in 0..10 {
assert_eq!(reader.read(&mut tmp).unwrap(), 0);
}
}
#[test]
fn drop_before_consuming_does_not_hang() {
let data = vec![0u8; 1_000_000];
let _reader = PrefetchReader::with_config(Cursor::new(data), 4 * 1024, 2);
}
#[test]
fn partial_read_then_drop_does_not_hang() {
let data = sample_bytes(500_000);
let mut reader = PrefetchReader::with_config(Cursor::new(data), 4 * 1024, 2);
let mut tmp = [0u8; 32];
let n = reader.read(&mut tmp).unwrap();
assert!(n > 0);
}
#[test]
fn error_from_inner_reader_propagates_once() {
struct AlwaysErr;
impl Read for AlwaysErr {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
Err(io::Error::new(io::ErrorKind::PermissionDenied, "nope"))
}
}
let mut reader = PrefetchReader::new(AlwaysErr);
let mut buf = [0u8; 16];
let err = reader.read(&mut buf).expect_err("first read should error");
assert_eq!(err.kind(), io::ErrorKind::PermissionDenied);
assert_eq!(reader.read(&mut buf).unwrap(), 0);
}
#[test]
fn error_after_some_data_delivers_data_then_error() {
struct DataThenErr {
sent: bool,
}
impl Read for DataThenErr {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.sent {
Err(io::Error::other("subsequent error"))
} else {
self.sent = true;
let n = buf.len().min(128);
for (i, b) in buf.iter_mut().take(n).enumerate() {
*b = u8::try_from(i).expect("n <= 128 so i fits in u8");
}
Ok(n)
}
}
}
let mut reader = PrefetchReader::with_config(DataThenErr { sent: false }, 1024, 2);
let mut out = Vec::new();
let mut tmp = [0u8; 256];
let n = reader.read(&mut tmp).unwrap();
assert_eq!(n, 128);
out.extend_from_slice(&tmp[..n]);
assert_eq!(out.len(), 128);
let err = reader.read(&mut tmp).expect_err("second read should error");
assert!(matches!(err.kind(), io::ErrorKind::Other | io::ErrorKind::UnexpectedEof));
}
#[test]
fn interrupted_errors_are_retried_transparently() {
struct FlakyThenEof {
call: usize,
}
impl Read for FlakyThenEof {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.call += 1;
if self.call <= 3 {
return Err(io::Error::new(io::ErrorKind::Interrupted, "try again"));
}
if self.call == 4 {
let n = buf.len().min(10);
for (i, b) in buf.iter_mut().take(n).enumerate() {
*b = u8::try_from(i + 1).expect("n <= 10 so i+1 fits in u8");
}
return Ok(n);
}
Ok(0)
}
}
let mut reader = PrefetchReader::with_config(FlakyThenEof { call: 0 }, 64, 2);
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
assert_eq!(out, (1..=10).collect::<Vec<u8>>());
}
#[test]
fn drop_joins_producer_thread() {
struct Tracked {
flag: Arc<AtomicBool>,
data: Vec<u8>,
pos: usize,
}
impl Read for Tracked {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.pos >= self.data.len() {
return Ok(0);
}
let n = buf.len().min(self.data.len() - self.pos);
buf[..n].copy_from_slice(&self.data[self.pos..self.pos + n]);
self.pos += n;
Ok(n)
}
}
impl Drop for Tracked {
fn drop(&mut self) {
self.flag.store(true, Ordering::SeqCst);
}
}
let flag = Arc::new(AtomicBool::new(false));
let inner = Tracked { flag: Arc::clone(&flag), data: sample_bytes(1024), pos: 0 };
{
let mut reader = PrefetchReader::with_config(inner, 64, 2);
let mut out = Vec::new();
reader.read_to_end(&mut out).unwrap();
assert_eq!(out.len(), 1024);
}
assert!(
flag.load(Ordering::SeqCst),
"producer thread should have dropped the inner reader"
);
}
proptest! {
#[test]
fn prop_byte_identical_to_cursor(
data in prop::collection::vec(any::<u8>(), 0..8_192),
chunk_size in 1usize..2_048,
depth in 1usize..6,
read_size in 1usize..256,
) {
let expected = data.clone();
let mut reader = PrefetchReader::with_config(
Cursor::new(data),
chunk_size,
depth,
);
let mut out = Vec::with_capacity(expected.len());
let mut tmp = vec![0u8; read_size];
loop {
let n = reader.read(&mut tmp).expect("read should not fail on Cursor");
if n == 0 {
break;
}
out.extend_from_slice(&tmp[..n]);
}
prop_assert_eq!(out, expected);
}
}
}