#![warn(missing_docs)]
#![warn(clippy::undocumented_unsafe_blocks)]
use rustix::{
fs::{AtFlags, Mode, OFlags, copy_file_range, ftruncate, ioctl_ficlone, linkat, open, rename},
io::Errno,
mm::{MapFlags, MremapFlags, MsyncFlags, ProtFlags, mmap, mremap, msync, munmap},
};
use std::{
ffi::c_void,
fs::File,
io,
ops::{Deref, DerefMut},
os::fd::AsFd,
path::{Path, PathBuf},
};
fn ficlone(fd_out: impl AsFd, fd_in: impl AsFd, len: usize) -> io::Result<bool> {
match ioctl_ficlone(&fd_out, &fd_in) {
Ok(()) => Ok(false),
Err(Errno::OPNOTSUPP) => {
ftruncate(&fd_out, len as u64)?;
let mut off_in = 0;
let mut off_out = 0;
while off_in < len as u64 {
let rem = len - off_in as usize;
let n =
copy_file_range(&fd_in, Some(&mut off_in), &fd_out, Some(&mut off_out), rem)?;
assert_eq!(off_in, off_out);
assert!(
n <= rem,
"copy_file_range() copied more bytes than requested"
);
if n == 0 {
Err(io::ErrorKind::UnexpectedEof)?;
}
}
assert_eq!(off_out, len as u64);
Ok(true)
}
Err(e) => Err(e.into()),
}
}
pub struct Mmap {
ptr: *mut c_void, len: usize,
}
unsafe impl Send for Mmap {}
unsafe impl Sync for Mmap {}
impl Mmap {
pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref();
let original = File::open(path)?;
let len = original.metadata()?.len() as usize;
if len >= isize::MAX as usize {
return Err(io::ErrorKind::FileTooLarge.into());
}
let dir = path.parent().filter(|x| *x != "").unwrap_or(Path::new("."));
let private: File =
open(dir, OFlags::TMPFILE | OFlags::RDWR, Mode::RUSR | Mode::WUSR)?.into();
ficlone(&private, &original, len)?;
let ptr;
if len == 0 {
ptr = std::ptr::null_mut();
} else {
unsafe {
ptr = mmap(
std::ptr::null_mut(),
len,
ProtFlags::READ,
MapFlags::SHARED,
&private,
0,
)?;
}
};
assert!(ptr.is_null() == (len == 0));
Ok(Self { ptr, len })
}
}
impl Deref for Mmap {
type Target = [u8];
fn deref(&self) -> &[u8] {
if self.len == 0 {
&[]
} else {
unsafe { core::slice::from_raw_parts(self.ptr as *const u8, self.len) }
}
}
}
impl Drop for Mmap {
fn drop(&mut self) {
if self.len != 0 {
unsafe {
match munmap(self.ptr, self.len) {
Ok(()) => (),
Err(e) => eprintln!("munmap failed: {e}"),
}
}
}
}
}
pub struct MmapMut {
original: OriginalFile,
private: File, ptr: *mut c_void,
len: usize,
}
enum OriginalFile {
Fd(File),
Path(PathBuf),
}
unsafe impl Send for MmapMut {}
unsafe impl Sync for MmapMut {}
impl MmapMut {
pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref();
let original = File::options().read(true).write(true).open(path)?;
let len = original.metadata()?.len() as usize;
if len >= isize::MAX as usize {
return Err(io::ErrorKind::FileTooLarge.into());
}
let dir = path.parent().filter(|x| *x != "").unwrap_or(Path::new("."));
let private: File =
open(dir, OFlags::TMPFILE | OFlags::RDWR, Mode::RUSR | Mode::WUSR)?.into();
let fellback = ficlone(&private, &original, len)?;
let ptr;
if len == 0 {
ptr = std::ptr::null_mut();
} else {
unsafe {
ptr = mmap(
std::ptr::null_mut(),
len,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED,
&private,
0,
)?;
}
};
assert!(ptr.is_null() == (len == 0));
Ok(Self {
private,
ptr,
len,
original: if fellback {
OriginalFile::Path(path.to_owned())
} else {
OriginalFile::Fd(original)
},
})
}
pub fn create(path: impl AsRef<Path>) -> io::Result<Self> {
let path = path.as_ref();
let original = File::create(path)?;
let dir = path.parent().filter(|x| *x != "").unwrap_or(Path::new("."));
let private: File =
open(dir, OFlags::TMPFILE | OFlags::RDWR, Mode::RUSR | Mode::WUSR)?.into();
let ptr = std::ptr::null_mut();
let len = 0;
let fellback = matches!(ioctl_ficlone(&private, &original), Err(Errno::OPNOTSUPP));
Ok(Self {
private,
ptr,
len,
original: if fellback {
OriginalFile::Path(path.to_owned())
} else {
OriginalFile::Fd(original)
},
})
}
pub fn commit(&mut self) -> io::Result<()> {
self.sync()?;
match &self.original {
OriginalFile::Fd(original) => ioctl_ficlone(original, &self.private)?,
OriginalFile::Path(path) => {
let dir = path.parent().filter(|x| *x != "").unwrap_or(Path::new("."));
let private2: File =
open(dir, OFlags::TMPFILE | OFlags::RDWR, Mode::RUSR | Mode::WUSR)?.into();
ficlone(&private2, &self.private, self.len)?;
link(&private2, path)?;
}
}
Ok(())
}
pub fn commit_and_close(self) -> io::Result<()> {
self.sync()?;
match &self.original {
OriginalFile::Fd(original) => ioctl_ficlone(original, &self.private)?,
OriginalFile::Path(path) => {
link(&self.private, path)?;
}
}
Ok(())
}
pub fn link(self, path: impl AsRef<Path>) -> io::Result<()> {
self.sync()?;
link(&self.private, path.as_ref())?;
Ok(())
}
fn sync(&self) -> io::Result<()> {
if self.len != 0 {
unsafe {
msync(self.ptr, self.len, MsyncFlags::SYNC)?;
}
}
Ok(())
}
pub fn resize(&mut self, new_len: usize) -> io::Result<()> {
if new_len >= isize::MAX as usize {
return Err(io::ErrorKind::FileTooLarge.into());
}
if new_len == self.len {
return Ok(());
}
ftruncate(&self.private, new_len as u64)?;
if new_len == 0 {
unsafe {
munmap(self.ptr, self.len)?;
}
self.ptr = std::ptr::null_mut();
} else if self.len == 0 {
unsafe {
self.ptr = mmap(
std::ptr::null_mut(),
new_len,
ProtFlags::READ | ProtFlags::WRITE,
MapFlags::SHARED,
&self.private,
0,
)?;
}
} else {
unsafe {
self.ptr = mremap(self.ptr, self.len, new_len, MremapFlags::MAYMOVE)?;
}
}
self.len = new_len;
assert!(self.ptr.is_null() == (self.len == 0));
Ok(())
}
}
fn link(fd: &File, path: &Path) -> io::Result<()> {
let mut tmppath = path.with_added_extension(".tmp");
loop {
match linkat(fd, "", rustix::fs::CWD, &tmppath, AtFlags::EMPTY_PATH) {
Ok(()) => {
rename(tmppath, path)?;
break; }
Err(Errno::EXIST) => {
tmppath = tmppath.with_added_extension(".tmp");
}
Err(e) => Err(e)?,
}
}
Ok(())
}
impl Deref for MmapMut {
type Target = [u8];
fn deref(&self) -> &[u8] {
if self.len == 0 {
&[]
} else {
unsafe { core::slice::from_raw_parts(self.ptr as *const u8, self.len) }
}
}
}
impl DerefMut for MmapMut {
fn deref_mut(&mut self) -> &mut [u8] {
if self.len == 0 {
&mut []
} else {
unsafe { core::slice::from_raw_parts_mut(self.ptr as *mut u8, self.len) }
}
}
}
impl Drop for MmapMut {
fn drop(&mut self) {
if self.len != 0 {
unsafe {
match munmap(self.ptr, self.len) {
Ok(()) => (),
Err(e) => eprintln!("munmap failed: {e}"),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn paths(name: &str) -> impl Iterator<Item = PathBuf> {
["/tmp", "/var/tmp"].into_iter().map(move |d| {
let d = Path::new(d).join("mmap-snapshot");
std::fs::create_dir_all(&d).unwrap();
d.join(name)
})
}
#[test]
fn mmap() -> std::io::Result<()> {
for p in paths("mmap") {
std::fs::write(&p, b"Hello world!")?;
let f = Mmap::open(&p)?;
std::fs::write(&p, b"Goodbye world!")?;
assert_eq!(&*f, b"Hello world!");
std::fs::remove_file(&p)?;
assert_eq!(&*f, b"Hello world!");
}
Ok(())
}
#[test]
fn mmap_mut() -> std::io::Result<()> {
for p in paths("mmap_mut") {
std::fs::write(&p, b"Hello world!")?;
let mut f = MmapMut::open(&p)?;
assert_eq!(&*f, b"Hello world!");
f[6..11].copy_from_slice(b"sekai");
assert_eq!(&*f, b"Hello sekai!");
assert_eq!(std::fs::read_to_string(&p)?, "Hello world!");
f.commit()?;
std::mem::drop(f);
assert_eq!(std::fs::read_to_string(&p)?, "Hello sekai!");
std::fs::remove_file(&p)?;
}
Ok(())
}
#[test]
fn zero_len() -> std::io::Result<()> {
for p in paths("zero_len") {
File::create(&p)?;
let f = Mmap::open(&p)?;
assert_eq!(&*f, b"");
std::fs::remove_file(&p)?;
assert_eq!(&*f, b"");
}
Ok(())
}
#[test]
fn zero_len_mut() -> std::io::Result<()> {
for p in paths("zero_len_mut") {
File::create(&p)?;
let mut f = MmapMut::open(&p)?;
assert_eq!(&*f, b"");
f.resize(12)?;
f.copy_from_slice(b"Hello world!");
assert_eq!(std::fs::read_to_string(&p)?, "");
f.commit()?;
assert_eq!(std::fs::read_to_string(&p)?, "Hello world!");
f[6..11].copy_from_slice(b"sekai");
assert_eq!(&*f, b"Hello sekai!");
assert_eq!(std::fs::read_to_string(&p)?, "Hello world!");
f.commit()?;
std::mem::drop(f);
assert_eq!(std::fs::read_to_string(&p)?, "Hello sekai!");
std::fs::remove_file(&p)?;
}
Ok(())
}
}