use std::{
fs::{File, OpenOptions},
io,
path::{Path, PathBuf},
sync::Mutex,
};
use crate::error::{Result, WalError};
pub trait WalStore: Send + Sync {
fn write_at(&self, offset: u64, bytes: &[u8]) -> Result<()>;
fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<usize>;
fn truncate(&self, len: u64) -> Result<()>;
fn sync(&self) -> Result<()>;
fn len(&self) -> Result<u64>;
fn is_empty(&self) -> Result<bool> {
Ok(self.len()? == 0)
}
}
#[derive(Debug)]
pub struct FileStore {
file: File,
path: PathBuf,
}
impl FileStore {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)
.map_err(|e| WalError::io("opening the log file", e))?;
Ok(FileStore { file, path })
}
#[must_use]
pub fn path(&self) -> &Path {
&self.path
}
}
impl WalStore for FileStore {
fn write_at(&self, offset: u64, bytes: &[u8]) -> Result<()> {
pwrite_all(&self.file, offset, bytes).map_err(|e| WalError::io("writing a record", e))
}
fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<usize> {
pread_fill(&self.file, offset, buf).map_err(|e| WalError::io("reading from the log", e))
}
fn truncate(&self, len: u64) -> Result<()> {
self.file
.set_len(len)
.map_err(|e| WalError::io("truncating the log", e))
}
fn sync(&self) -> Result<()> {
durable_sync(&self.file).map_err(|e| WalError::io("flushing to stable storage", e))
}
fn len(&self) -> Result<u64> {
Ok(self
.file
.metadata()
.map_err(|e| WalError::io("reading log file metadata", e))?
.len())
}
}
#[derive(Debug, Default)]
pub struct MemStore {
data: Mutex<Vec<u8>>,
}
impl MemStore {
#[must_use]
pub fn new() -> Self {
MemStore {
data: Mutex::new(Vec::new()),
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
MemStore {
data: Mutex::new(Vec::with_capacity(capacity)),
}
}
#[must_use]
pub fn from_bytes(bytes: Vec<u8>) -> Self {
MemStore {
data: Mutex::new(bytes),
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, Vec<u8>> {
self.data
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
#[cfg(test)]
pub(crate) fn snapshot(&self) -> Vec<u8> {
self.lock().clone()
}
}
impl Clone for MemStore {
fn clone(&self) -> Self {
MemStore {
data: Mutex::new(self.lock().clone()),
}
}
}
impl WalStore for MemStore {
fn write_at(&self, offset: u64, bytes: &[u8]) -> Result<()> {
let start = usize::try_from(offset).map_err(|_| {
WalError::io(
"writing to memory",
io::Error::other("offset exceeds usize"),
)
})?;
let end = start.checked_add(bytes.len()).ok_or_else(|| {
WalError::io(
"writing to memory",
io::Error::other("write overflows usize"),
)
})?;
let mut data = self.lock();
if data.len() < end {
data.resize(end, 0); }
data[start..end].copy_from_slice(bytes);
Ok(())
}
fn read_at(&self, offset: u64, buf: &mut [u8]) -> Result<usize> {
let data = self.lock();
let start = match usize::try_from(offset) {
Ok(start) if start < data.len() => start,
_ => return Ok(0),
};
let available = &data[start..];
let n = available.len().min(buf.len());
buf[..n].copy_from_slice(&available[..n]);
Ok(n)
}
fn truncate(&self, len: u64) -> Result<()> {
let len = usize::try_from(len).unwrap_or(usize::MAX);
self.lock().truncate(len);
Ok(())
}
fn sync(&self) -> Result<()> {
Ok(())
}
fn len(&self) -> Result<u64> {
Ok(self.lock().len() as u64)
}
}
#[cfg(target_os = "macos")]
pub(crate) fn durable_sync(file: &File) -> io::Result<()> {
use std::os::unix::io::AsRawFd;
let fd = file.as_raw_fd();
let ret = unsafe { libc::fcntl(fd, libc::F_FULLFSYNC) };
if ret == -1 {
return Err(io::Error::last_os_error());
}
Ok(())
}
#[cfg(not(target_os = "macos"))]
pub(crate) fn durable_sync(file: &File) -> io::Result<()> {
file.sync_data()
}
#[cfg(unix)]
pub(crate) fn pwrite_all(file: &File, mut offset: u64, mut buf: &[u8]) -> io::Result<()> {
use std::os::unix::fs::FileExt;
while !buf.is_empty() {
match file.write_at(buf, offset) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"the store accepted zero bytes mid-record",
));
}
Ok(n) => {
buf = &buf[n..];
offset += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(windows)]
pub(crate) fn pwrite_all(file: &File, mut offset: u64, mut buf: &[u8]) -> io::Result<()> {
use std::os::windows::fs::FileExt;
while !buf.is_empty() {
match file.seek_write(buf, offset) {
Ok(0) => {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"the store accepted zero bytes mid-record",
));
}
Ok(n) => {
buf = &buf[n..];
offset += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(())
}
#[cfg(unix)]
pub(crate) fn pread_fill(file: &File, mut offset: u64, buf: &mut [u8]) -> io::Result<usize> {
use std::os::unix::fs::FileExt;
let mut total = 0;
while total < buf.len() {
match file.read_at(&mut buf[total..], offset) {
Ok(0) => break,
Ok(n) => {
total += n;
offset += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(total)
}
#[cfg(windows)]
pub(crate) fn pread_fill(file: &File, mut offset: u64, buf: &mut [u8]) -> io::Result<usize> {
use std::os::windows::fs::FileExt;
let mut total = 0;
while total < buf.len() {
match file.seek_read(&mut buf[total..], offset) {
Ok(0) => break,
Ok(n) => {
total += n;
offset += n as u64;
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
Ok(total)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_memstore_write_at_advances_len() {
let store = MemStore::new();
assert_eq!(store.len().unwrap(), 0);
store.write_at(0, b"abc").unwrap();
assert_eq!(store.len().unwrap(), 3);
store.write_at(3, b"de").unwrap();
assert_eq!(store.len().unwrap(), 5);
}
#[test]
fn test_memstore_write_past_end_zero_fills_gap() {
let store = MemStore::new();
store.write_at(4, b"XY").unwrap();
assert_eq!(store.len().unwrap(), 6);
let mut buf = [0xFFu8; 6];
assert_eq!(store.read_at(0, &mut buf).unwrap(), 6);
assert_eq!(&buf, &[0, 0, 0, 0, b'X', b'Y']);
}
#[test]
fn test_memstore_read_past_end_is_short() {
let store = MemStore::new();
store.write_at(0, b"abc").unwrap();
let mut buf = [0u8; 8];
assert_eq!(store.read_at(1, &mut buf).unwrap(), 2);
assert_eq!(&buf[..2], b"bc");
}
#[test]
fn test_memstore_truncate_shrinks() {
let store = MemStore::new();
store.write_at(0, b"0123456789").unwrap();
store.truncate(4).unwrap();
assert_eq!(store.len().unwrap(), 4);
}
#[test]
fn test_filestore_roundtrip_through_disk() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("store.bin");
{
let store = FileStore::open(&path).unwrap();
store.write_at(0, b"hello world").unwrap();
store.sync().unwrap();
assert_eq!(store.len().unwrap(), 11);
}
let store = FileStore::open(&path).unwrap();
assert_eq!(store.len().unwrap(), 11);
let mut buf = [0u8; 5];
assert_eq!(store.read_at(6, &mut buf).unwrap(), 5);
assert_eq!(&buf, b"world");
}
#[test]
fn test_filestore_concurrent_disjoint_writes() {
use std::sync::Arc;
use std::thread;
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("concurrent.bin");
let store = Arc::new(FileStore::open(&path).unwrap());
let mut handles = Vec::new();
for i in 0..8u64 {
let store = Arc::clone(&store);
handles.push(thread::spawn(move || {
let byte = b'A' + i as u8;
store.write_at(i * 4, &[byte; 4]).unwrap();
}));
}
for h in handles {
h.join().unwrap();
}
store.sync().unwrap();
let mut buf = [0u8; 32];
assert_eq!(store.read_at(0, &mut buf).unwrap(), 32);
for i in 0..8 {
let expected = b'A' + i as u8;
assert_eq!(&buf[i * 4..i * 4 + 4], &[expected; 4]);
}
}
#[test]
fn test_filestore_sync_durable_across_reopen() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("durable.bin");
{
let store = FileStore::open(&path).unwrap();
store.write_at(0, b"persisted").unwrap();
store.sync().unwrap();
}
let store = FileStore::open(&path).unwrap();
let mut buf = [0u8; 9];
assert_eq!(store.read_at(0, &mut buf).unwrap(), 9);
assert_eq!(&buf, b"persisted");
}
}