use super::{new_error, FFErr};
use crate::fe::FRes;
use libc::{
c_int, c_void, close, fdatasync, fstat, ftruncate, iovec, off_t, open, pread, preadv, pwrite, pwritev, size_t,
stat, unlink, EACCES, EBADF, EDQUOT, EFAULT, EINVAL, EIO, EISDIR, EMSGSIZE, ENOSPC, EPERM, EROFS, ESPIPE,
O_CLOEXEC, O_CREAT, O_NOATIME, O_RDWR, O_TRUNC, S_IRUSR, S_IWUSR,
};
use std::{
ffi::CString,
io,
os::unix::ffi::OsStrExt,
path::PathBuf,
sync::atomic::{AtomicI32, Ordering},
};
const CLOSED_FD: i32 = -1;
pub(crate) struct File(AtomicI32);
unsafe impl Send for File {}
unsafe impl Sync for File {}
impl File {
pub(crate) unsafe fn new(path: &PathBuf, is_new: bool, mid: u8) -> FRes<Self> {
let fd = open_with_flags(path, prep_flags(is_new), mid)?;
Ok(Self(AtomicI32::new(fd)))
}
#[inline]
pub(crate) fn fd(&self) -> i32 {
self.0.load(Ordering::Acquire)
}
#[inline]
pub(super) unsafe fn sync(&self, mid: u8) -> FRes<()> {
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
const MAX_RETRIES: usize = 4;
let mut retries = 0;
loop {
if fdatasync(self.fd() as c_int) != 0 {
let error = last_os_error();
let error_raw = error.raw_os_error();
if error.kind() == io::ErrorKind::Interrupted {
continue;
}
if error_raw == Some(EINVAL) || error_raw == Some(EBADF) {
return new_error(mid, FFErr::Hcf, error);
}
if error_raw == Some(EROFS) {
return new_error(mid, FFErr::Wrt, error);
}
if error_raw == Some(EIO) {
if retries < MAX_RETRIES {
retries += 1;
std::thread::yield_now();
continue;
}
return new_error(mid, FFErr::Syn, error);
}
return new_error(mid, FFErr::Unk, error);
}
return Ok(());
}
}
#[inline(always)]
pub(crate) unsafe fn close(&self, mid: u8) -> FRes<()> {
let fd = self.0.swap(CLOSED_FD, Ordering::AcqRel);
if fd == CLOSED_FD {
return Ok(());
}
if close(fd) != 0 {
let error = last_os_error();
let error_raw = error.raw_os_error();
if error_raw == Some(EIO) {
return new_error(mid, FFErr::Syn, error);
}
return new_error(mid, FFErr::Unk, error);
}
Ok(())
}
#[inline]
pub(super) unsafe fn unlink(&self, path: &PathBuf, mid: u8) -> FRes<()> {
let cpath = path_to_cstring(path, mid)?;
if unlink(cpath.as_ptr()) != 0 {
let error = last_os_error();
return new_error(mid, FFErr::Unk, error);
}
Ok(())
}
#[inline]
pub(crate) unsafe fn length(&self, mid: u8) -> FRes<u64> {
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
let mut st = std::mem::zeroed::<stat>();
let res = fstat(self.fd(), &mut st);
if res != 0 {
let error = last_os_error();
let error_raw = error.raw_os_error();
if error_raw == Some(EBADF) || error_raw == Some(EFAULT) {
return new_error(mid, FFErr::Hcf, error);
}
return new_error(mid, FFErr::Unk, error);
}
Ok(st.st_size as u64)
}
pub(crate) unsafe fn resize(&self, new_len: u64, mid: u8) -> FRes<()> {
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
if ftruncate(self.fd(), new_len as off_t) != 0 {
let error = last_os_error();
let error_raw = error.raw_os_error();
if error_raw == Some(EINVAL) || error_raw == Some(EBADF) {
return new_error(mid, FFErr::Hcf, error);
}
if error_raw == Some(EROFS) {
return new_error(mid, FFErr::Wrt, error);
}
if error_raw == Some(ENOSPC) {
return new_error(mid, FFErr::Nsp, error);
}
return new_error(mid, FFErr::Unk, error);
}
Ok(())
}
#[inline(always)]
pub(super) unsafe fn pread(&self, buf_ptr: *mut u8, offset: usize, len_to_read: usize, mid: u8) -> FRes<()> {
debug_assert_ne!(len_to_read, 0, "invalid length");
debug_assert!(!buf_ptr.is_null(), "invalid buffer pointer");
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
let mut read = 0usize;
while read < len_to_read {
let res = pread(
self.fd(),
buf_ptr.add(read) as *mut c_void,
(len_to_read - read) as size_t,
(offset + read) as i64,
);
if res <= 0 {
let error = std::io::Error::last_os_error();
let error_raw = error.raw_os_error();
if error.kind() == io::ErrorKind::Interrupted {
continue;
}
if res == 0 {
return new_error(mid, FFErr::Eof, error);
}
if error_raw == Some(EACCES) || error_raw == Some(EPERM) {
return new_error(mid, FFErr::Red, error);
}
if error_raw == Some(EINVAL)
|| error_raw == Some(EBADF)
|| error_raw == Some(EFAULT)
|| error_raw == Some(ESPIPE)
{
return new_error(mid, FFErr::Hcf, error);
}
return new_error(mid, FFErr::Unk, error);
}
read += res as usize;
}
Ok(())
}
#[inline(always)]
pub(super) unsafe fn preadv(&self, buf_ptrs: &[*mut u8], offset: usize, buffer_size: usize, mid: u8) -> FRes<()> {
#[cfg(debug_assertions)]
{
debug_assert_ne!(buffer_size, 0, "invalid buffer length");
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
}
let mut consumed = 0usize;
let mut iovecs: Vec<iovec> = buf_ptrs
.iter()
.map(|ptr| iovec {
iov_base: *ptr as *mut c_void,
iov_len: buffer_size,
})
.collect();
while !iovecs.is_empty() {
let res = preadv(
self.fd(),
iovecs.as_ptr(),
iovecs.len() as c_int,
offset as off_t + consumed as off_t,
);
if res <= 0 {
let error = io::Error::last_os_error();
let error_raw = error.raw_os_error();
if error.kind() == io::ErrorKind::Interrupted {
continue;
}
if res == 0 {
return new_error(mid, FFErr::Eof, error);
}
if error_raw == Some(EACCES) || error_raw == Some(EPERM) {
return new_error(mid, FFErr::Red, error);
}
if error_raw == Some(EINVAL)
|| error_raw == Some(EBADF)
|| error_raw == Some(EFAULT)
|| error_raw == Some(ESPIPE)
|| error_raw == Some(EMSGSIZE)
{
return new_error(mid, FFErr::Hcf, error);
}
return new_error(mid, FFErr::Unk, error);
}
let mut remaining = res as usize;
while remaining > 0 {
let iov = &mut iovecs[0];
if remaining >= iov.iov_len {
remaining -= iov.iov_len;
consumed += iov.iov_len;
iovecs.remove(0);
} else {
iov.iov_base = (iov.iov_base as *mut u8).add(remaining) as *mut c_void;
iov.iov_len -= remaining;
consumed += remaining;
remaining = 0;
}
}
}
Ok(())
}
#[inline(always)]
pub(super) unsafe fn pwrite(&self, buf_ptr: *const u8, offset: usize, len_to_write: usize, mid: u8) -> FRes<()> {
debug_assert_ne!(len_to_write, 0, "invalid length");
debug_assert!(!buf_ptr.is_null(), "invalid buffer pointer");
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
let mut written = 0usize;
while written < len_to_write {
let res = pwrite(
self.fd(),
buf_ptr.add(written) as *const c_void,
(len_to_write - written) as size_t,
(offset + written) as i64,
);
if res <= 0 {
let error = std::io::Error::last_os_error();
let error_raw = error.raw_os_error();
if error.kind() == std::io::ErrorKind::Interrupted {
continue;
}
if res == 0 {
return new_error(mid, FFErr::Eof, error);
}
if error_raw == Some(EROFS) {
return new_error(mid, FFErr::Wrt, error);
}
if error_raw == Some(EINVAL)
|| error_raw == Some(EBADF)
|| error_raw == Some(EFAULT)
|| error_raw == Some(ESPIPE)
{
return new_error(mid, FFErr::Hcf, error);
}
return new_error(mid, FFErr::Unk, error);
}
written += res as usize;
}
Ok(())
}
#[inline(always)]
pub(super) unsafe fn pwritev(
&self,
buf_ptrs: &[*const u8],
offset: usize,
buffer_size: usize,
mid: u8,
) -> FRes<()> {
#[cfg(debug_assertions)]
{
debug_assert_ne!(buffer_size, 0, "invalid buffer length");
debug_assert!(self.fd() != CLOSED_FD, "Invalid fd for LinuxFile");
}
let mut consumed = 0usize;
let mut iovecs: Vec<iovec> = buf_ptrs
.iter()
.map(|ptr| iovec {
iov_base: *ptr as *mut c_void,
iov_len: buffer_size,
})
.collect();
while !iovecs.is_empty() {
let res = pwritev(
self.fd(),
iovecs.as_ptr(),
iovecs.len() as c_int,
offset as off_t + consumed as off_t,
);
if res <= 0 {
let error = std::io::Error::last_os_error();
let error_raw = error.raw_os_error();
if error.kind() == std::io::ErrorKind::Interrupted {
continue;
}
if res == 0 {
return new_error(mid, FFErr::Eof, error);
}
if error_raw == Some(EROFS) {
return new_error(mid, FFErr::Wrt, error);
}
if error_raw == Some(ENOSPC) || error_raw == Some(EDQUOT) {
return new_error(mid, FFErr::Nsp, error);
}
if error_raw == Some(EINVAL)
|| error_raw == Some(EBADF)
|| error_raw == Some(EFAULT)
|| error_raw == Some(ESPIPE)
|| error_raw == Some(EMSGSIZE)
{
return new_error(mid, FFErr::Hcf, error);
}
return new_error(mid, FFErr::Unk, error);
}
let mut remaining = res as usize;
while remaining > 0 {
let iov = &mut iovecs[0];
if remaining >= iov.iov_len {
remaining -= iov.iov_len;
consumed += iov.iov_len;
iovecs.remove(0);
} else {
iov.iov_base = (iov.iov_base as *mut u8).add(remaining) as *mut c_void;
iov.iov_len -= remaining;
consumed += remaining;
remaining = 0;
}
}
}
Ok(())
}
}
unsafe fn open_with_flags(path: &PathBuf, mut flags: i32, mid: u8) -> FRes<i32> {
let cpath = path_to_cstring(path, mid)?;
let mut tried_noatime = false;
loop {
let fd = if flags & O_CREAT != 0 {
open(
cpath.as_ptr(),
flags,
S_IRUSR | S_IWUSR, )
} else {
open(cpath.as_ptr(), flags)
};
if fd < 0 {
let error = last_os_error();
let err_raw = error.raw_os_error();
if error.kind() == io::ErrorKind::Interrupted {
continue;
}
if err_raw == Some(EPERM) && (flags & O_NOATIME) != 0 && !tried_noatime {
flags &= !O_NOATIME;
tried_noatime = true;
continue;
}
if err_raw == Some(ENOSPC) {
return new_error(mid, FFErr::Nsp, error);
}
if err_raw == Some(EISDIR) {
return new_error(mid, FFErr::Hcf, error);
}
return new_error(mid, FFErr::Unk, error);
}
return Ok(fd);
}
}
const fn prep_flags(is_new: bool) -> i32 {
const BASE: i32 = O_RDWR | O_NOATIME | O_CLOEXEC;
const NEW: i32 = O_CREAT | O_TRUNC;
BASE | ((is_new as i32) * NEW)
}
fn path_to_cstring(path: &std::path::PathBuf, mid: u8) -> FRes<CString> {
match CString::new(path.as_os_str().as_bytes()) {
Ok(cs) => Ok(cs),
Err(e) => {
let error = io::Error::new(io::ErrorKind::Other, e.to_string());
new_error(mid, FFErr::Inv, error)
}
}
}
#[inline]
fn last_os_error() -> std::io::Error {
io::Error::last_os_error()
}
#[cfg(all(test, target_os = "linux"))]
mod tests {
use super::*;
use crate::fe::FECheckOk;
use std::path::PathBuf;
use tempfile::{tempdir, TempDir};
const MID: u8 = 0x00;
fn new_tmp() -> (TempDir, PathBuf, File) {
let dir = tempdir().expect("temp dir");
let tmp = dir.path().join("tmp_file");
let file = unsafe { File::new(&tmp, true, MID) }.expect("new LinuxFile");
(dir, tmp, file)
}
mod new_open {
use super::*;
#[test]
fn new_works() {
let (_dir, tmp, file) = new_tmp();
assert!(file.fd() >= 0);
assert!(tmp.exists());
assert!(unsafe { file.close(MID).check_ok() });
}
#[test]
fn open_works() {
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.fd() >= 0);
assert!(file.close(MID).check_ok());
match File::new(&tmp, false, MID) {
Ok(file) => {
assert!(file.fd() >= 0);
assert!(file.close(MID).check_ok());
}
Err(e) => panic!("failed to open file due to E: {:?}", e),
}
}
}
#[test]
fn open_fails_when_file_is_unlinked() {
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.fd() >= 0);
assert!(file.close(MID).check_ok());
assert!(file.unlink(&tmp, MID).check_ok());
let file = File::new(&tmp, false, MID);
assert!(file.is_err());
}
}
}
mod close {
use super::*;
#[test]
fn close_works() {
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.close(MID).check_ok());
assert!(tmp.exists());
}
}
#[test]
fn close_after_close_does_not_fail() {
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.close(MID).check_ok());
assert!(file.close(MID).check_ok());
assert!(file.close(MID).check_ok());
assert!(tmp.exists());
}
}
}
mod sync {
use super::*;
#[test]
fn sync_works() {
let (_dir, _tmp, file) = new_tmp();
unsafe {
assert!(file.sync(MID).check_ok());
assert!(file.close(MID).check_ok());
}
}
}
mod unlink {
use super::*;
#[test]
fn unlink_correctly_deletes_file() {
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.close(MID).check_ok());
assert!(file.unlink(&tmp, MID).check_ok());
assert!(!tmp.exists());
}
}
#[test]
fn unlink_fails_on_unlinked_file() {
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.close(MID).check_ok());
assert!(file.unlink(&tmp, MID).check_ok());
assert!(!tmp.exists());
assert!(file.unlink(&tmp, MID).is_err());
}
}
}
mod resize {
use super::*;
#[test]
fn extend_zero_extends_file() {
const NEW_LEN: u64 = 0x80;
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.resize(NEW_LEN, MID).check_ok());
let curr_len = file.length(MID).expect("fetch metadata");
assert_eq!(curr_len, NEW_LEN);
assert!(file.close(MID).check_ok());
}
let file_contents = std::fs::read(&tmp).expect("read from file");
assert_eq!(file_contents.len(), NEW_LEN as usize, "len mismatch for file");
assert!(
file_contents.iter().all(|b| *b == 0u8),
"file must be zero byte extended"
);
}
#[test]
fn open_preserves_existing_length() {
const NEW_LEN: u64 = 0x80;
let (_dir, tmp, file) = new_tmp();
unsafe {
assert!(file.resize(NEW_LEN, MID).check_ok());
let curr_len = file.length(MID).expect("fetch metadata");
assert_eq!(curr_len, NEW_LEN);
assert!(file.sync(MID).check_ok());
assert!(file.close(MID).check_ok());
match File::new(&tmp, false, MID) {
Err(e) => panic!("{:?}", e),
Ok(file) => {
let curr_len = file.length(MID).expect("fetch metadata");
assert_eq!(curr_len, NEW_LEN);
}
}
}
}
}
mod write_read {
use super::*;
#[test]
fn pwrite_pread_cycle() {
let (_dir, _tmp, file) = new_tmp();
const LEN: usize = 0x20;
const DATA: [u8; LEN] = [0x1A; LEN];
unsafe {
file.resize(LEN as u64, MID).expect("resize file");
assert!(file.pwrite(DATA.as_ptr(), 0, LEN, MID).check_ok());
let mut buf = vec![0u8; LEN];
assert!(file.pread(buf.as_mut_ptr(), 0, LEN, MID).check_ok());
assert_eq!(DATA.to_vec(), buf, "mismatch between read and write");
assert!(file.close(MID).check_ok());
}
}
#[test]
fn pwritev_pread_cycle() {
let (_dir, _tmp, file) = new_tmp();
const LEN: usize = 0x20;
const DATA: [u8; LEN] = [0x1A; LEN];
let ptrs = vec![DATA.as_ptr(); 0x10];
let total_len = ptrs.len() * LEN;
unsafe {
file.resize(total_len as u64, MID).expect("resize file");
assert!(file.pwritev(&ptrs, 0, LEN, MID).check_ok());
let mut buf = vec![0u8; total_len];
assert!(
file.pread(buf.as_mut_ptr(), 0, total_len, MID).check_ok(),
"pread failed"
);
assert_eq!(buf.len(), total_len, "mismatch between read and write");
for chunk in buf.chunks_exact(LEN) {
assert_eq!(chunk, DATA, "data mismatch in pwritev readback");
}
assert!(file.close(MID).check_ok());
}
}
#[test]
fn pwritev_preadv_cycle() {
let (_dir, _tmp, file) = new_tmp();
const LEN: usize = 0x20;
const DATA: [u8; LEN] = [0x1A; LEN];
let ptrs = vec![DATA.as_ptr(); 0x10];
let total_len = ptrs.len() * LEN;
unsafe {
file.resize(total_len as u64, MID).expect("resize file");
assert!(file.pwritev(&ptrs, 0, LEN, MID).check_ok());
let mut bufs: Vec<Vec<u8>> = (0..ptrs.len()).map(|_| vec![0u8; LEN]).collect();
let buf_ptrs: Vec<*mut u8> = bufs.iter_mut().map(|b| b.as_mut_ptr()).collect();
assert!(file.preadv(&buf_ptrs, 0, LEN, MID).check_ok(), "preadv failed");
for buf in bufs {
assert_eq!(buf, DATA, "data mismatch in pwritev/preadv cycle");
}
assert!(file.close(MID).check_ok());
}
}
#[test]
fn pwrite_pread_cycle_across_sessions() {
let (_dir, tmp, file) = new_tmp();
const LEN: usize = 0x20;
const DATA: [u8; LEN] = [0x1A; LEN];
unsafe {
assert!(file.resize(LEN as u64, MID).check_ok());
assert!(file.pwrite(DATA.as_ptr(), 0, LEN, MID).check_ok());
assert!(file.sync(MID).check_ok());
assert!(file.close(MID).check_ok());
}
unsafe {
let mut buf = vec![0u8; LEN];
let file = File::new(&tmp, false, MID).expect("open file");
assert!(file.pread(buf.as_mut_ptr(), 0, LEN, MID).check_ok());
assert_eq!(DATA.to_vec(), buf, "mismatch between read and write");
assert!(file.close(MID).check_ok());
}
}
}
mod concurrency {
use super::*;
#[test]
fn concurrent_writes_then_read() {
const THREADS: usize = 8;
const CHUNK: usize = 0x100;
let (_dir, _tmp, file) = new_tmp();
let file = std::sync::Arc::new(file);
unsafe { file.resize((THREADS * CHUNK) as u64, MID).expect("extend") };
let mut handles = Vec::new();
for i in 0..THREADS {
let f = file.clone();
handles.push(std::thread::spawn(move || {
let data = vec![i as u8; CHUNK];
unsafe { f.pwrite(data.as_ptr(), i * CHUNK, CHUNK, MID).expect("write") };
}));
}
for h in handles {
assert!(h
.join()
.map_err(|e| {
eprintln!("\n{:?}\n", e);
})
.is_ok());
}
let mut read_buf = vec![0u8; THREADS * CHUNK];
unsafe { assert!(file.pread(read_buf.as_mut_ptr(), 0, read_buf.len(), MID).check_ok()) };
for i in 0..THREADS {
let chunk = &read_buf[i * CHUNK..(i + 1) * CHUNK];
assert!(chunk.iter().all(|b| *b == i as u8));
}
}
}
}