#![deny(unsafe_code)]
#[cfg(any(test, feature = "fault-injection"))]
pub mod fault;
pub mod lock;
pub use crate::platform::lock::{ReaderLock, WriterLock};
use std::fs::{File, OpenOptions};
use std::io;
use std::path::Path;
#[cfg(unix)]
use std::os::unix::fs::FileExt as _;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt as _;
#[cfg(windows)]
use std::os::windows::fs::FileExt as _;
#[cfg(unix)]
const OWNER_ONLY_MODE: u32 = 0o600;
use crate::error::{Error, Result};
#[cfg(windows)]
const WIN_TRANSIENT_RETRY_LIMIT: u32 = 10;
pub trait FileBackend: Sized {
fn len(&self) -> Result<u64>;
fn is_empty(&self) -> Result<bool> {
Ok(self.len()? == 0)
}
fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> Result<()>;
fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<()>;
fn set_len(&self, new_len: u64) -> Result<()>;
fn sync_data(&self, mode: SyncMode) -> Result<()>;
fn sync_all(&self) -> Result<()>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum SyncMode {
#[default]
Full,
Normal,
Off,
}
#[derive(Debug)]
pub struct FileHandle {
file: File,
}
impl FileHandle {
pub fn open_or_create<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut opts = OpenOptions::new();
opts.read(true).write(true).create(true).truncate(false);
#[cfg(unix)]
opts.mode(OWNER_ONLY_MODE);
let file = opts.open(path)?;
Ok(Self { file })
}
pub fn create_new<P: AsRef<Path>>(path: P) -> Result<Self> {
let mut opts = OpenOptions::new();
opts.read(true).write(true).create_new(true);
#[cfg(unix)]
opts.mode(OWNER_ONLY_MODE);
let file = opts.open(path)?;
Ok(Self { file })
}
pub fn len(&self) -> Result<u64> {
let meta = self.file.metadata()?;
Ok(meta.len())
}
pub fn is_empty(&self) -> Result<bool> {
Ok(self.len()? == 0)
}
pub fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> Result<()> {
read_exact_at_impl(&self.file, buf, offset).map_err(Error::from)
}
pub fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<()> {
write_all_at_impl(&self.file, buf, offset).map_err(Error::from)
}
pub fn set_len(&self, new_len: u64) -> Result<()> {
#[cfg(windows)]
{
set_len_with_retry(&self.file, new_len).map_err(Error::from)
}
#[cfg(not(windows))]
{
self.file.set_len(new_len).map_err(Error::from)
}
}
pub fn sync_all(&self) -> Result<()> {
self.file.sync_all().map_err(Error::from)
}
pub fn sync_data(&self, mode: SyncMode) -> Result<()> {
match mode {
SyncMode::Off => Ok(()),
SyncMode::Normal => sync_data_normal(&self.file),
SyncMode::Full => sync_data_full(&self.file),
}
}
}
impl FileBackend for FileHandle {
fn len(&self) -> Result<u64> {
FileHandle::len(self)
}
fn read_exact_at(&self, buf: &mut [u8], offset: u64) -> Result<()> {
FileHandle::read_exact_at(self, buf, offset)
}
fn write_all_at(&self, buf: &[u8], offset: u64) -> Result<()> {
FileHandle::write_all_at(self, buf, offset)
}
fn set_len(&self, new_len: u64) -> Result<()> {
FileHandle::set_len(self, new_len)
}
fn sync_data(&self, mode: SyncMode) -> Result<()> {
FileHandle::sync_data(self, mode)
}
fn sync_all(&self) -> Result<()> {
FileHandle::sync_all(self)
}
}
#[cfg(unix)]
fn read_exact_at_impl(file: &File, buf: &mut [u8], offset: u64) -> io::Result<()> {
file.read_exact_at(buf, offset)
}
#[cfg(unix)]
fn write_all_at_impl(file: &File, buf: &[u8], offset: u64) -> io::Result<()> {
file.write_all_at(buf, offset)
}
#[cfg(windows)]
fn read_exact_at_impl(file: &File, mut buf: &mut [u8], mut offset: u64) -> io::Result<()> {
let mut attempt: u32 = 0;
while !buf.is_empty() {
match file.seek_read(buf, offset) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
));
}
Ok(n) => {
attempt = 0;
let tmp = buf;
buf = &mut tmp[n..];
offset += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) if is_transient_io_error(&e) => {
if attempt >= WIN_TRANSIENT_RETRY_LIMIT - 1 {
return Err(e);
}
windows_io_backoff(attempt + 1);
attempt += 1;
}
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(windows)]
fn write_all_at_impl(file: &File, mut buf: &[u8], mut offset: u64) -> io::Result<()> {
let mut attempt: u32 = 0;
while !buf.is_empty() {
match file.seek_write(buf, offset) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
Ok(n) => {
attempt = 0;
buf = &buf[n..];
offset += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) if is_transient_io_error(&e) => {
if attempt >= WIN_TRANSIENT_RETRY_LIMIT - 1 {
return Err(e);
}
windows_io_backoff(attempt + 1);
attempt += 1;
}
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(windows)]
fn set_len_with_retry(file: &File, new_len: u64) -> io::Result<()> {
retry_transient_io(|| file.set_len(new_len))
}
#[cfg(windows)]
fn retry_transient_io<F>(mut op: F) -> io::Result<()>
where
F: FnMut() -> io::Result<()>,
{
let mut attempt: u32 = 0;
loop {
match op() {
Ok(()) => return Ok(()),
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) if is_transient_io_error(&e) => {
if attempt >= WIN_TRANSIENT_RETRY_LIMIT - 1 {
return Err(e);
}
windows_io_backoff(attempt + 1);
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
#[cfg(windows)]
fn is_transient_io_error(e: &io::Error) -> bool {
use windows_sys::Win32::Foundation::{ERROR_LOCK_VIOLATION, ERROR_SHARING_VIOLATION};
match e.raw_os_error() {
Some(code) => {
code == ERROR_LOCK_VIOLATION.cast_signed()
|| code == ERROR_SHARING_VIOLATION.cast_signed()
}
None => false,
}
}
#[cfg(windows)]
fn windows_io_backoff(attempt: u32) {
const STEP_MS: u64 = 25;
const CAP_MS: u64 = 250;
let delay = u64::from(attempt).saturating_mul(STEP_MS).min(CAP_MS);
std::thread::sleep(std::time::Duration::from_millis(delay));
}
fn sync_data_normal(file: &File) -> Result<()> {
file.sync_all().map_err(Error::from)
}
#[cfg(target_vendor = "apple")]
fn sync_data_full(file: &File) -> Result<()> {
rustix::fs::fcntl_fullfsync(file).map_err(|e| Error::Io(io::Error::from(e)))
}
#[cfg(all(unix, not(target_vendor = "apple")))]
fn sync_data_full(file: &File) -> Result<()> {
rustix::fs::fdatasync(file).map_err(|e| Error::Io(io::Error::from(e)))
}
#[cfg(windows)]
fn sync_data_full(file: &File) -> Result<()> {
file.sync_all().map_err(Error::from)
}
pub fn remove_file_if_exists<P: AsRef<Path>>(path: P) -> Result<()> {
match std::fs::remove_file(path) {
Ok(()) => Ok(()),
Err(e) if e.kind() == io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(Error::Io(e)),
}
}
impl From<io::ErrorKind> for Error {
fn from(kind: io::ErrorKind) -> Self {
Error::Io(io::Error::from(kind))
}
}
#[cfg(test)]
mod tests {
use super::{FileHandle, SyncMode};
use tempfile::TempDir;
fn write_and_sync(mode: SyncMode) {
let dir = TempDir::new().expect("tempdir");
let path = dir.path().join("sync.bin");
let h = FileHandle::open_or_create(&path).expect("open");
h.set_len(4096).expect("set_len");
h.write_all_at(&[0xABu8; 4096], 0).expect("write");
h.sync_data(mode).expect("sync_data must succeed");
}
#[test]
fn sync_data_full_returns_ok() {
write_and_sync(SyncMode::Full);
}
#[test]
fn sync_data_normal_returns_ok() {
write_and_sync(SyncMode::Normal);
}
#[test]
fn sync_data_off_is_noop() {
write_and_sync(SyncMode::Off);
}
#[test]
fn default_is_full() {
assert_eq!(SyncMode::default(), SyncMode::Full);
}
#[cfg(windows)]
#[test]
fn is_transient_io_error_matches_lock_and_sharing_codes() {
use super::is_transient_io_error;
use std::io;
use windows_sys::Win32::Foundation::{
ERROR_ACCESS_DENIED, ERROR_LOCK_VIOLATION, ERROR_SHARING_VIOLATION,
};
let lock = io::Error::from_raw_os_error(ERROR_LOCK_VIOLATION.cast_signed());
let share = io::Error::from_raw_os_error(ERROR_SHARING_VIOLATION.cast_signed());
assert!(is_transient_io_error(&lock));
assert!(is_transient_io_error(&share));
let denied = io::Error::from_raw_os_error(ERROR_ACCESS_DENIED.cast_signed());
assert!(!is_transient_io_error(&denied));
let not_found = io::Error::from(io::ErrorKind::NotFound);
assert!(!is_transient_io_error(¬_found));
}
#[cfg(windows)]
#[test]
fn windows_io_backoff_is_bounded() {
use super::windows_io_backoff;
use std::time::Instant;
let start = Instant::now();
windows_io_backoff(1);
assert!(start.elapsed() < std::time::Duration::from_secs(1));
let start = Instant::now();
windows_io_backoff(100);
assert!(start.elapsed() < std::time::Duration::from_secs(1));
}
#[cfg(windows)]
#[test]
fn retry_transient_io_returns_first_success() {
use super::retry_transient_io;
use std::cell::Cell;
use std::io;
use windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION;
let calls = Cell::new(0u32);
let result = retry_transient_io(|| {
calls.set(calls.get() + 1);
if calls.get() < 3 {
Err(io::Error::from_raw_os_error(
ERROR_LOCK_VIOLATION.cast_signed(),
))
} else {
Ok(())
}
});
result.expect("must recover after transient sequence");
assert_eq!(calls.get(), 3);
}
#[cfg(windows)]
#[test]
fn retry_transient_io_exhausts_and_surfaces_last_error() {
use super::{retry_transient_io, WIN_TRANSIENT_RETRY_LIMIT};
use std::cell::Cell;
use std::io;
use windows_sys::Win32::Foundation::ERROR_LOCK_VIOLATION;
let calls = Cell::new(0u32);
let err = retry_transient_io(|| -> io::Result<()> {
calls.set(calls.get() + 1);
Err(io::Error::from_raw_os_error(
ERROR_LOCK_VIOLATION.cast_signed(),
))
})
.expect_err("must exhaust");
assert_eq!(calls.get(), WIN_TRANSIENT_RETRY_LIMIT);
assert_eq!(err.raw_os_error(), Some(ERROR_LOCK_VIOLATION.cast_signed()));
}
#[cfg(windows)]
#[test]
fn retry_transient_io_returns_non_transient_immediately() {
use super::retry_transient_io;
use std::cell::Cell;
use std::io;
use windows_sys::Win32::Foundation::ERROR_ACCESS_DENIED;
let calls = Cell::new(0u32);
let err = retry_transient_io(|| -> io::Result<()> {
calls.set(calls.get() + 1);
Err(io::Error::from_raw_os_error(
ERROR_ACCESS_DENIED.cast_signed(),
))
})
.expect_err("must fail");
assert_eq!(calls.get(), 1);
assert_eq!(err.raw_os_error(), Some(ERROR_ACCESS_DENIED.cast_signed()));
}
}