#![deny(missing_docs)]
use std::{
error::Error,
fs::{File, OpenOptions},
io::{Seek, SeekFrom, Write},
marker::PhantomData,
ops::{Deref, DerefMut},
path::Path,
};
use bytemuck::{try_cast_slice, try_cast_slice_mut, Pod};
use fs2::FileExt;
use memmap2::MmapOptions;
pub enum Buffer<T: Pod> {
Disk(BackedBuffer<T>),
Memory(Vec<T>),
}
impl<T: Pod> Buffer<T> {
pub fn new_on_disk(capacity: usize, path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
Ok(Self::Disk(BackedBuffer::new(capacity, path)?))
}
pub fn new_in_memory(capacity: usize) -> Self {
Self::Memory(vec![T::zeroed(); capacity])
}
pub fn load_from_disk(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
Ok(Self::Disk(BackedBuffer::load(path)?))
}
pub fn from_vec_in_memory(data: Vec<T>) -> Self {
Self::Memory(data)
}
pub fn from_slice_on_disk(data: &[T], path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
Ok(Self::Disk(BackedBuffer::copy_from_slice(data, path)?))
}
pub fn shrink(&mut self, new_len: usize) {
assert!(
new_len <= self.len(),
"`new_len` must be less than current length!"
);
match self {
Self::Disk(buffer) => buffer.shrink(new_len),
Self::Memory(buffer) => buffer.resize(new_len, T::zeroed()),
}
}
}
pub struct BackedBuffer<T: Pod> {
mmap: memmap2::MmapMut,
len: usize,
file: Option<File>,
_ph: PhantomData<T>,
}
impl<T: Pod> BackedBuffer<T> {
pub fn new(capacity: usize, path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
let mut file = OpenOptions::new()
.read(true)
.write(true)
.truncate(true)
.create(true)
.open(path)?;
let capacity_bytes = capacity * std::mem::size_of::<T>();
file.seek(SeekFrom::Start(0))?;
file.allocate(capacity_bytes as u64)?;
const BLOCK_SIZE: usize = 4096;
const BLOCK: [u8; BLOCK_SIZE] = [0; BLOCK_SIZE];
let mut size = capacity_bytes;
while size > 0 {
let block = usize::min(size, BLOCK_SIZE);
file.write_all(&BLOCK[..block])?;
size = size.checked_sub(block).unwrap();
}
unsafe { Self::from_file(file) }
}
pub fn load(path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
let file = OpenOptions::new().read(true).write(true).open(path)?;
unsafe { Self::from_file(file) }
}
pub fn copy_from_slice(slice: &[T], path: impl AsRef<Path>) -> Result<Self, Box<dyn Error>> {
let mut buf = Self::new(slice.len(), path)?;
buf.copy_from_slice(slice);
Ok(buf)
}
pub fn shrink(&mut self, new_len: usize) {
assert!(
new_len <= self.len(),
"`new_len` must be less than current length!"
);
self.len = new_len;
}
unsafe fn from_file(file: File) -> Result<Self, Box<dyn Error>> {
file.try_lock_exclusive()?;
let mmap = unsafe { MmapOptions::new().populate().map_mut(&file)? };
let len = try_cast_slice::<u8, T>(&mmap[..])?.len();
Ok(Self {
mmap,
file: Some(file),
len,
_ph: PhantomData,
})
}
}
impl<T: Pod> AsRef<[T]> for BackedBuffer<T> {
fn as_ref(&self) -> &[T] {
self.deref()
}
}
impl<T: Pod> AsMut<[T]> for BackedBuffer<T> {
fn as_mut(&mut self) -> &mut [T] {
self.deref_mut()
}
}
impl<T: Pod> Deref for BackedBuffer<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &Self::Target {
&try_cast_slice(&self.mmap[..]).unwrap()[..self.len]
}
}
impl<T: Pod> DerefMut for BackedBuffer<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut try_cast_slice_mut(&mut self.mmap[..]).unwrap()[..self.len]
}
}
impl<T: Pod> Deref for Buffer<T> {
type Target = [T];
#[inline]
fn deref(&self) -> &Self::Target {
match self {
Self::Disk(backed_buffer) => backed_buffer.deref(),
Self::Memory(vector) => vector.deref(),
}
}
}
impl<T: Pod> DerefMut for Buffer<T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
Self::Disk(backed_buffer) => backed_buffer.deref_mut(),
Self::Memory(vector) => vector.deref_mut(),
}
}
}
impl<T: Pod> AsRef<[T]> for Buffer<T> {
fn as_ref(&self) -> &[T] {
match self {
Self::Disk(data) => data.deref(),
Self::Memory(data) => data.deref(),
}
}
}
impl<T: Pod> AsMut<[T]> for Buffer<T> {
fn as_mut(&mut self) -> &mut [T] {
match self {
Self::Disk(data) => data.deref_mut(),
Self::Memory(data) => data.deref_mut(),
}
}
}
impl<T: Pod> Drop for BackedBuffer<T> {
fn drop(&mut self) {
if let Some(file) = self.file.take() {
file.unlock().unwrap_or(());
}
}
}
#[cfg(test)]
impl<T: Pod> std::fmt::Debug for BackedBuffer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!("{:?} of length {}", self.file, self.len()))
}
}
#[cfg(test)]
mod tests {
use super::BackedBuffer;
use std::{error::Error, fs::File, io::Write, path::Path};
#[test]
fn read() -> Result<(), Box<dyn Error>> {
let tempdir = tempfile::tempdir().unwrap();
let file_path = Path::join(tempdir.path(), "test");
File::create(file_path.clone())
.unwrap()
.write("hello, world!".as_bytes())?;
let mmap = BackedBuffer::<u8>::load(file_path).expect("");
assert_eq!(&mmap[..], "hello, world!".as_bytes());
Ok(())
}
#[test]
fn write() -> Result<(), Box<dyn Error>> {
let tempdir = tempfile::tempdir().unwrap();
let file_path = Path::join(tempdir.path(), "test");
File::create(file_path.clone())
.unwrap()
.write("hello, world!".as_bytes())?;
let mut mmap = BackedBuffer::<u8>::load(file_path).expect("");
mmap.copy_from_slice("halle, werld!".as_bytes());
assert_eq!(&mmap[..], "halle, werld!".as_bytes());
Ok(())
}
#[test]
fn locking() -> Result<(), Box<dyn Error>> {
let tempdir = tempfile::tempdir().unwrap();
let file_path = Path::join(tempdir.path(), "test");
File::create(file_path.clone()).unwrap();
let _mmap_1 = BackedBuffer::<u8>::load(file_path.clone()).expect("");
let _mmap_2 = BackedBuffer::<u8>::load(file_path.clone()).expect_err("");
drop(_mmap_1);
let _mmap_2 = BackedBuffer::<u8>::load(file_path.clone()).expect("");
Ok(())
}
}