use std::marker::PhantomData;
use std::{io, slice};
use std::fs::{OpenOptions, File};
use std::path::Path;
use std::mem;
use std::io::{Read, Write};
use memmap::MmapMut;
use fs2::FileExt;
const PERSISTENCE_FORMAT_VERSION: [u8; 3] = [0, 0, 5];
#[repr(C, packed)]
struct FileHeader<T>
{
magic_bytes: [u8; 8],
endianness: u16,
persistence_format_version: [u8; 3],
data_contained_version: [u8; 3],
default_data: T,
number_of_padding_bytes_after_header: u16,
}
pub struct MmapedVec<T>
{
file: File,
mm: MmapMut,
_marker: PhantomData<T>,
}
impl<T: Sized + Default> MmapedVec<T>
{
pub fn try_new (path: &Path, magic_bytes: [u8; 8], data_contained_version: [u8; 3]) -> io::Result<Self>
{
let mut file = OpenOptions::new().read(true).write(true).create(true).open(path)?;
file.try_lock_exclusive()?;
let fhs = mem::size_of::<FileHeader<T>>();
let number_of_padding_bytes_after_header = match fhs % 4096
{
0 => 0,
_ => (4096 - fhs % 4096) as u16,
};
let fh = FileHeader
{
magic_bytes,
endianness: 0x1234,
persistence_format_version: PERSISTENCE_FORMAT_VERSION,
data_contained_version,
default_data: T::default(),
number_of_padding_bytes_after_header,
};
let flen = file.metadata().unwrap().len();
let len_fh_and_padding = fhs as u64 + number_of_padding_bytes_after_header as u64;
if flen == 0
{
let buf = unsafe
{
slice::from_raw_parts(
&fh as *const FileHeader<T> as *const u8,
mem::size_of::<FileHeader<T>>())
};
file.write(buf)?;
file.set_len(len_fh_and_padding)?;
}
else if flen < fhs as u64
{
return Err(io::Error::new(io::ErrorKind::InvalidData,
format!("File `{:?}` has non-zero size ({} bytes), but it is shorter than \
the expected header size ({} bytes).", path, flen, fhs)));
}
else
{
let mut fh_handle = file.try_clone()?.take(fhs as u64);
let mut fh_buf = vec![0u8; fhs];
fh_handle.read(fh_buf.as_mut_slice()).unwrap();
let fh_file = unsafe { std::ptr::read(fh_buf.as_ptr() as *const FileHeader<T>) };
if fh_file.magic_bytes != fh.magic_bytes
{
return Err(io::Error::new(io::ErrorKind::InvalidData,
format!("File `{:?}`: Magic bytes mismatch.", path)));
}
if fh_file.endianness != fh.endianness
{
if (fh_file.endianness << 8 | fh_file.endianness >> 8) != fh.endianness
{
return Err(io::Error::new(io::ErrorKind::InvalidData,
format!("File `{:?}`: Endianness-marker invalid.", path)));
}
else
{
return Err(io::Error::new(io::ErrorKind::InvalidData,
format!("File `{:?}`: Wrong endianness.", path)));
}
}
}
if flen > 0 && flen < len_fh_and_padding
{
}
if flen > len_fh_and_padding && ((flen - len_fh_and_padding) % mem::size_of::<T>() as u64 != 0)
{
return Err(io::Error::new(io::ErrorKind::InvalidData,
format!("File `{:?}` has non-zero size, but file size minus header size and padding \
bytes is not an integer multiple of the size of the data type that the file supposedly \
contains. This indicates that the file might be corrupt, incorrectly versioned or \
malformed.", path)));
}
let mut mm = unsafe { MmapMut::map_mut(&file)? };
Ok(Self
{
file,
mm,
_marker: PhantomData,
})
}
}
#[cfg(test)]
mod tests
{
use super::*;
use std::error::Error;
use std::path::PathBuf;
use std::process::{Command, ExitStatus, Stdio};
use std::io::{Seek, SeekFrom};
use tempfile::TempDir;
use memoffset::offset_of;
#[repr(C, packed)]
struct Example
{
hello: u8,
world: u8,
}
impl Default for Example
{
fn default () -> Self
{
Self
{
hello: 1,
world: 2,
}
}
}
const EXAMPLE_MAGIC_BYTES: [u8; 8] = [b'T', b'E', b'S', b'T', b'F', b'I', b'L', b'E'];
const EXAMPLE_CORRUPT_MAGIC_BYTES: [u8; 8] = [b'X', b'Y', b'Z', b'T', b'F', 0, 0, 0];
const EXAMPLE_DATA_CONTAINED_VERSION: [u8; 3] = [0, 1, 0];
type ExampleFileHeader = FileHeader<Example>;
fn tempdir_and_tempfile () -> io::Result<(TempDir, PathBuf)>
{
let dir = tempfile::tempdir()?;
let pathbuf = dir.path().join("file.bin");
Ok((dir, pathbuf))
}
fn new_mmaped_vec_of_example_persisting_in_tempdir () -> io::Result<(TempDir, PathBuf, MmapedVec<Example>)>
{
let (dir, pathbuf) = tempdir_and_tempfile()?;
let mv = MmapedVec::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION)?;
Ok((dir, pathbuf, mv))
}
fn python3_try_lock_exclusive (path: &Path) -> io::Result<ExitStatus>
{
let mut child = Command::new("python3").arg("-").arg(path)
.stdin(Stdio::piped()).stdout(Stdio::inherit()).stderr(Stdio::inherit())
.spawn()?;
let child_stdin = child.stdin.as_mut().unwrap();
child_stdin.write_all(include_bytes!("../scripts/try_lock_exclusive.py"))?;
child.wait()
}
#[test]
pub fn test_create_mmaped_vec_onto_tempfile () -> Result<(), io::Error>
{
new_mmaped_vec_of_example_persisting_in_tempdir()?;
Ok(())
}
#[test]
pub fn test_file_is_locked_while_fd_is_held () -> Result<(), io::Error>
{
let (_dir, pathbuf, _mv) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
assert_eq!(python3_try_lock_exclusive(pathbuf.as_path())?.code(), Some(35));
Ok(())
}
#[test]
pub fn test_existing_file_is_locked_while_fd_is_held () -> Result<(), io::Error>
{
let (_dir, pathbuf, _) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
let _mv = MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION)?;
assert_eq!(python3_try_lock_exclusive(pathbuf.as_path())?.code(), Some(35));
Ok(())
}
#[test]
pub fn test_file_is_unlocked_after_drop () -> Result<(), io::Error>
{
let (_dir, pathbuf, _) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
assert_eq!(python3_try_lock_exclusive(pathbuf.as_path())?.code(), Some(0));
Ok(())
}
#[test]
pub fn test_detect_header_corrupt_magic_bytes () -> Result<(), io::Error>
{
let (_dir, pathbuf) = tempdir_and_tempfile()?;
MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_CORRUPT_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION)?;
let mv_err = MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION).err().unwrap();
assert!(mv_err.description().ends_with("Magic bytes mismatch."));
Ok(())
}
#[test]
pub fn test_detect_file_corrupt_truncated_to_under_end_of_header () -> Result<(), io::Error>
{
let (_dir, pathbuf, _) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
let file = OpenOptions::new().read(true).write(true).open(pathbuf.as_path())?;
let fhs = mem::size_of::<FileHeader<Example>>();
file.set_len((fhs - 1) as u64).unwrap();
let mv_err = MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION).err().unwrap();
assert!(mv_err.description().contains("shorter than the expected header size"));
Ok(())
}
#[test]
pub fn test_detect_file_corrupt_body_not_integer_multiple_of_data_type () -> Result<(), io::Error>
{
let (_dir, pathbuf, _) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
let file = OpenOptions::new().read(true).write(true).open(pathbuf.as_path())?;
let flen = file.metadata().unwrap().len();
file.set_len(flen + 1).unwrap();
let mv_err = MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION).err().unwrap();
assert!(mv_err.description().contains("not an integer multiple of the size of the data type"));
Ok(())
}
#[test]
pub fn test_detect_endianness_marker_invalid () -> Result<(), io::Error>
{
let (_dir, pathbuf, _) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
let mut file = OpenOptions::new().read(true).write(true).open(pathbuf.as_path())?;
let offs = SeekFrom::Start(offset_of!(ExampleFileHeader, endianness) as u64);
file.seek(offs).unwrap();
file.write(&[0u8, 0]).unwrap();
let mv_err = MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION).err().unwrap();
assert!(mv_err.description().ends_with("Endianness-marker invalid."));
Ok(())
}
#[test]
pub fn test_detect_wrong_endianness () -> Result<(), io::Error>
{
let (_dir, pathbuf, _) = new_mmaped_vec_of_example_persisting_in_tempdir()?;
let mut file = OpenOptions::new().read(true).write(true).open(pathbuf.as_path())?;
let offs = SeekFrom::Start(offset_of!(ExampleFileHeader, endianness) as u64);
file.seek(offs).unwrap();
let mut buf = [0u8, 0];
file.read_exact(&mut buf).unwrap();
buf.reverse();
file.seek(offs).unwrap();
file.write(&buf).unwrap();
let mv_err = MmapedVec::<Example>::try_new(pathbuf.as_path(),
EXAMPLE_MAGIC_BYTES, EXAMPLE_DATA_CONTAINED_VERSION).err().unwrap();
assert!(mv_err.description().ends_with("Wrong endianness."));
Ok(())
}
}