use std::{io, path::Path};
use super::Progress;
#[cfg(all(any(unix, windows), feature = "async"))]
fn handle_blocking_io_task_result<T>(
result: Result<io::Result<T>, tokio::task::JoinError>,
) -> io::Result<T> {
match result {
Ok(inner_result) => inner_result,
Err(e) => {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
Err(io::Error::new(
io::ErrorKind::Interrupted,
format!("blocking task was cancelled: {e}"),
))
}
}
}
}
#[cfg(feature = "async")]
pub async fn async_read_byte_range(
path: impl AsRef<Path>,
offset: u64,
len: usize,
) -> io::Result<Vec<u8>> {
if len == 0 {
return Ok(Vec::new());
}
#[cfg(any(unix, windows))]
{
let path_buf = path.as_ref().to_path_buf();
let result = tokio::task::spawn_blocking(move || {
read_at_internal(path_buf, offset, len as u64, None::<&dyn Progress>)
})
.await;
handle_blocking_io_task_result(result)
}
#[cfg(not(any(unix, windows)))]
{
seek_read_async_internal(path, offset, len as u64, None::<&dyn Progress>).await
}
}
#[cfg(feature = "async")]
pub async fn async_read_byte_range_with_progress(
path: impl AsRef<Path>,
offset: u64,
len: u64,
pb: impl Progress + Send + 'static,
) -> io::Result<Vec<u8>> {
if len == 0 {
return Ok(Vec::new());
}
#[cfg(any(unix, windows))]
{
let path_buf = path.as_ref().to_path_buf();
let result =
tokio::task::spawn_blocking(move || read_at_internal(path_buf, offset, len, Some(&pb)))
.await;
handle_blocking_io_task_result(result)
}
#[cfg(not(any(unix, windows)))]
{
seek_read_async_internal(path, offset, len, Some(&pb)).await
}
}
pub fn read_byte_range(path: impl AsRef<Path>, offset: u64, len: usize) -> io::Result<Vec<u8>> {
#[cfg(any(unix, windows))]
{
read_at_internal(path, offset, len as u64, None::<&dyn Progress>)
}
#[cfg(not(any(unix, windows)))]
{
seek_read_blocking_internal(path, offset, len as u64, None::<&dyn Progress>)
}
}
pub fn read_byte_range_with_progress(
path: impl AsRef<Path>,
offset: u64,
len: u64,
pb: &impl Progress,
) -> io::Result<Vec<u8>> {
#[cfg(any(unix, windows))]
{
read_at_internal(path, offset, len, Some(pb))
}
#[cfg(not(any(unix, windows)))]
{
seek_read_blocking_internal(path, offset, len, Some(pb))
}
}
#[inline]
fn validate_len_for_buffer(len: u64) -> io::Result<usize> {
len.try_into().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"length does not fit in memory buffer (usize)",
)
})
}
#[cfg(any(unix, windows))]
fn read_at_internal(
path: impl AsRef<Path>,
offset: u64,
len: u64,
pb: Option<&(impl Progress + ?Sized)>,
) -> io::Result<Vec<u8>> {
#[inline]
fn read_at_compat(file: &std::fs::File, buf: &mut [u8], offset: u64) -> io::Result<usize> {
#[cfg(unix)]
{
use std::os::unix::fs::FileExt as _;
file.read_at(buf, offset)
}
#[cfg(windows)]
{
use std::os::windows::fs::FileExt as _;
file.seek_read(buf, offset)
}
}
if len == 0 {
if let Some(pb) = pb {
pb.finish();
}
return Ok(Vec::new());
}
let file = std::fs::File::open(path.as_ref())?;
if offset.checked_add(len).is_none() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"read range offset + length overflows a u64",
));
}
let capacity = validate_len_for_buffer(len)?;
let mut buffer = vec![0; capacity];
let mut total_bytes_read = 0;
let read_result = loop {
if total_bytes_read >= capacity {
break Ok(());
}
let current_slice = &mut buffer[total_bytes_read..];
let current_offset = offset + total_bytes_read as u64;
match read_at_compat(&file, current_slice, current_offset) {
Ok(0) => break Ok(()), Ok(bytes_read) => {
total_bytes_read += bytes_read;
if let Some(pb) = pb {
pb.inc(bytes_read as u64);
}
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => break Err(e),
}
};
if let Some(pb) = pb {
pb.finish();
}
read_result?;
buffer.truncate(total_bytes_read);
Ok(buffer)
}
#[cfg(not(any(unix, windows)))]
macro_rules! define_seek_read_internal {
(
$vis:vis,
$name:ident,
$doc:expr,
$($async:ident)?,
$( @$await:tt )?,
$file:ty,
$read_trait:path,
$seek_trait:path
) => {
#[doc = $doc]
$vis $($async)? fn $name(
path: impl AsRef<Path>,
offset: u64,
len: u64,
pb: Option<&(impl Progress + ?Sized)>,
) -> io::Result<Vec<u8>> {
use $read_trait;
use $seek_trait;
if len == 0 {
if let Some(pb) = pb {
pb.finish();
}
return Ok(Vec::new());
}
let capacity = validate_len_for_buffer(len)?;
let mut file = <$file>::open(path)$(.$await)? ?;
file.seek(io::SeekFrom::Start(offset))$(.$await)? ?;
if let Some(pb) = pb {
let mut reader = file.take(len);
let mut buffer = Vec::with_capacity(capacity);
let mut read_buf = vec![0; READ_CHUNK.min(capacity)];
let result = loop {
match reader.read(&mut read_buf)$(.$await)? {
Ok(0) => break Ok(buffer), Ok(n) => {
buffer.extend_from_slice(&read_buf[..n]);
pb.inc(n as u64);
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {},
Err(e) => break Err(e),
}
};
pb.finish();
result
} else {
let mut reader = file.take(len);
let mut buffer = Vec::with_capacity(capacity);
loop {
match reader.read_to_end(&mut buffer)$(.$await)? {
Ok(_) => break Ok(buffer),
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {},
Err(e) => break Err(e),
}
}
}
}
};
}
#[cfg(not(any(unix, windows)))]
const READ_CHUNK: usize = 64 * 1024;
#[cfg(all(not(any(unix, windows)), feature = "async"))]
define_seek_read_internal!(
, seek_read_async_internal,
"Internal async implementation using `seek` and `read` for other platforms.",
async,
@await,
tokio::fs::File,
tokio::io::AsyncReadExt,
tokio::io::AsyncSeekExt
);
#[cfg(not(any(unix, windows)))]
define_seek_read_internal!(
pub, seek_read_blocking_internal,
"Internal blocking implementation using `seek` and `read` for other platforms.",
, , std::fs::File,
std::io::Read,
std::io::Seek
);
#[cfg(test)]
mod tests {
use std::{
io::Write as _,
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
};
use tempfile::NamedTempFile;
use super::*;
struct MockProgress {
total: Arc<AtomicU64>,
finished: Arc<AtomicBool>,
}
impl Progress for MockProgress {
fn inc(&self, delta: u64) {
self.total.fetch_add(delta, Ordering::SeqCst);
}
fn finish(&self) {
self.finished.store(true, Ordering::SeqCst);
}
}
fn setup_test_file(content: &[u8]) -> (NamedTempFile, std::path::PathBuf, Vec<u8>) {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(content).unwrap();
let path = temp_file.path().to_path_buf();
(temp_file, path, content.to_vec())
}
#[test]
fn test_sync_read_middle() {
let (_file, path, content) = setup_test_file(b"abcdefghijklmnopqrstuvwxyz");
let result = read_byte_range(&path, 5, 10).unwrap();
assert_eq!(result, &content[5..15]);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_read_middle() {
let (_file, path, content) = setup_test_file(b"abcdefghijklmnopqrstuvwxyz");
let result = async_read_byte_range(&path, 5, 10).await.unwrap();
assert_eq!(result, &content[5..15]);
}
#[test]
fn test_read_at_start() {
let (_file, path, content) = setup_test_file(b"abcdefghijklmnopqrstuvwxyz");
let result = read_byte_range(&path, 0, 5).unwrap();
assert_eq!(result, &content[0..5]);
}
#[test]
fn test_read_at_end() {
let (_file, path, content) = setup_test_file(b"abcdefghijklmnopqrstuvwxyz");
let result = read_byte_range(&path, 21, 5).unwrap();
assert_eq!(result, &content[21..26]);
}
#[test]
fn test_read_full_file() {
let (_file, path, content) = setup_test_file(b"abcdefghijklmnopqrstuvwxyz");
let result = read_byte_range(&path, 0, content.len()).unwrap();
assert_eq!(result, content);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_read_full_file() {
let (_file, path, content) = setup_test_file(b"abcdefghijklmnopqrstuvwxyz");
let result = async_read_byte_range(&path, 0, content.len())
.await
.unwrap();
assert_eq!(result, content);
}
#[test]
fn test_read_past_eof() {
let (_file, path, content) = setup_test_file(b"short file");
let result = read_byte_range(&path, 0, 20).unwrap();
assert_eq!(result, content);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_read_past_eof() {
let (_file, path, content) = setup_test_file(b"short file");
let result = async_read_byte_range(&path, 5, 100).await.unwrap();
assert_eq!(result, &content[5..]);
}
#[test]
fn test_zero_length_read() {
let (_file, path, _) = setup_test_file(b"some data");
let result = read_byte_range(&path, 5, 0).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_zero_length_read_with_progress() {
let (_file, path, _) = setup_test_file(b"abc");
let progress = MockProgress {
total: Arc::new(AtomicU64::new(0)),
finished: Arc::new(AtomicBool::new(false)),
};
let result = read_byte_range_with_progress(&path, 0, 0, &progress).unwrap();
assert!(result.is_empty());
assert!(progress.finished.load(Ordering::SeqCst));
assert_eq!(progress.total.load(Ordering::SeqCst), 0);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_zero_length_with_progress() {
let (_file, path, _) = setup_test_file(b"abc");
let progress = MockProgress {
total: Arc::new(AtomicU64::new(0)),
finished: Arc::new(AtomicBool::new(false)),
};
let result = async_read_byte_range_with_progress(path, 0, 0, progress)
.await
.unwrap();
assert!(result.is_empty());
}
#[test]
fn test_sync_with_progress() {
let (_file, path, _) = setup_test_file(&[0u8; 1000]);
let progress = MockProgress {
total: Arc::new(AtomicU64::new(0)),
finished: Arc::new(AtomicBool::new(false)),
};
let result = read_byte_range_with_progress(&path, 100, 500, &progress).unwrap();
assert_eq!(result.len(), 500);
assert_eq!(progress.total.load(Ordering::SeqCst), 500);
assert!(progress.finished.load(Ordering::SeqCst));
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_with_progress() {
let (_file, path, _) = setup_test_file(&[0u8; 1000]);
let progress = MockProgress {
total: Arc::new(AtomicU64::new(0)),
finished: Arc::new(AtomicBool::new(false)),
};
let result = async_read_byte_range_with_progress(&path, 100, 500, progress)
.await
.unwrap();
assert_eq!(result.len(), 500);
let final_total = result.len() as u64; assert_eq!(final_total, 500);
}
#[test]
fn test_file_not_found() {
let path = Path::new("a/file/that/does/not/exist.txt");
let result = read_byte_range(path, 0, 10);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::NotFound);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_file_not_found() {
let path = Path::new("a/file/that/does/not/exist.txt");
let result = async_read_byte_range(path, 0, 10).await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::NotFound);
}
#[cfg(any(unix, windows))]
#[test]
fn test_offset_overflow() {
let (_file, path, _) = setup_test_file(b"data");
let offset = u64::MAX - 5;
let len = 10;
let result = read_byte_range(&path, offset, len);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidInput);
}
#[cfg(feature = "async")]
#[tokio::test]
async fn test_async_cancellation_is_not_panic() {
let large_content = vec![0u8; 1024 * 1024];
let (_file, path, _) = setup_test_file(&large_content);
let task = tokio::spawn(async move { async_read_byte_range(path, 0, 1024 * 1024).await });
task.abort();
let result = task.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_cancelled());
}
}