use std::{
fs::{self, File},
io::{Read, Write},
path::{Path, PathBuf},
};
use serde::{Deserialize, Serialize};
use zerocopy::{
FromBytes, Immutable, IntoBytes, KnownLayout,
byteorder::{LE, U64},
};
use crate::{CommitSeq, DbError, TransactionId, state::DatabaseState};
const STORE_FILE: &str = "store.oxgdb";
const TEMP_STORE_FILE: &str = "store.oxgdb.tmp";
const STORE_MAGIC: [u8; 8] = *b"OXGDB02\0";
const STORE_VERSION: u64 = 2;
const HEADER_LEN: usize = core::mem::size_of::<RawStoreHeader>();
#[derive(Clone, Copy, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
#[repr(C)]
struct RawStoreHeader {
magic: [u8; 8],
version: U64<LE>,
commit_seq: U64<LE>,
transaction_id: U64<LE>,
payload_len: U64<LE>,
payload_checksum: U64<LE>,
reserved: [u8; 16],
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub(crate) struct StoredDatabase {
pub(crate) commit_seq: CommitSeq,
pub(crate) transaction_id: TransactionId,
pub(crate) state: DatabaseState,
}
impl StoredDatabase {
#[must_use]
pub(crate) const fn empty() -> Self {
Self {
commit_seq: CommitSeq::new(0),
transaction_id: TransactionId::new(0),
state: DatabaseState::empty(),
}
}
}
#[must_use]
pub(crate) fn store_path(root: &Path) -> PathBuf {
root.join(STORE_FILE)
}
pub(crate) fn write_store(root: &Path, stored: &StoredDatabase) -> Result<(), DbError> {
fs::create_dir_all(root).map_err(|error| DbError::io("create database directory", error))?;
let payload = serde_json::to_vec(stored)?;
let header = RawStoreHeader {
magic: STORE_MAGIC,
version: U64::new(STORE_VERSION),
commit_seq: U64::new(stored.commit_seq.get()),
transaction_id: U64::new(stored.transaction_id.get()),
payload_len: U64::new(u64::try_from(payload.len()).map_err(|_error| DbError::IdOverflow)?),
payload_checksum: U64::new(checksum_bytes(&payload)),
reserved: [0; 16],
};
let temp_path = root.join(TEMP_STORE_FILE);
let mut file = File::create(&temp_path).map_err(|error| DbError::io("create store", error))?;
file.write_all(header.as_bytes())
.map_err(|error| DbError::io("write store header", error))?;
file.write_all(&payload)
.map_err(|error| DbError::io("write store payload", error))?;
file.flush()
.map_err(|error| DbError::io("flush store", error))?;
file.sync_all()
.map_err(|error| DbError::io("sync store", error))?;
fs::rename(temp_path, store_path(root)).map_err(|error| DbError::io("publish store", error))?;
sync_directory(root)?;
Ok(())
}
pub(crate) fn read_store(root: &Path) -> Result<StoredDatabase, DbError> {
let mut file = File::open(store_path(root)).map_err(|error| match error.kind() {
std::io::ErrorKind::NotFound => DbError::NotFound,
_kind => DbError::io("open store", error),
})?;
let header = read_header(&mut file)?;
let payload_len = usize::try_from(header.payload_len.get())
.map_err(|_error| DbError::invalid_store("payload length does not fit usize"))?;
let mut payload = vec![0_u8; payload_len];
file.read_exact(&mut payload)
.map_err(|error| DbError::io("read store payload", error))?;
reject_trailing_bytes(&mut file)?;
if checksum_bytes(&payload) != header.payload_checksum.get() {
return Err(DbError::invalid_store("store payload checksum mismatch"));
}
let stored: StoredDatabase = serde_json::from_slice(&payload)?;
if stored.commit_seq.get() != header.commit_seq.get()
|| stored.transaction_id.get() != header.transaction_id.get()
{
return Err(DbError::invalid_store(
"store header does not match payload",
));
}
stored.state.validate()?;
Ok(stored)
}
pub(crate) fn validate_store(root: &Path) -> Result<(), DbError> {
read_store(root).map(|_stored| ())
}
fn read_header(file: &mut File) -> Result<RawStoreHeader, DbError> {
let mut bytes = [0_u8; HEADER_LEN];
file.read_exact(&mut bytes)
.map_err(|error| DbError::io("read store header", error))?;
let header = RawStoreHeader::read_from_bytes(bytes.as_slice())
.map_err(|_error| DbError::invalid_store("store header layout mismatch"))?;
if header.magic != STORE_MAGIC || header.version.get() != STORE_VERSION {
return Err(DbError::invalid_store("store magic or version mismatch"));
}
if header.reserved != [0; 16] {
return Err(DbError::invalid_store("store reserved bytes are non-zero"));
}
Ok(header)
}
fn reject_trailing_bytes(file: &mut File) -> Result<(), DbError> {
let mut extra = [0_u8; 1];
match file
.read(&mut extra)
.map_err(|error| DbError::io("read store trailer", error))?
{
0 => Ok(()),
_count => Err(DbError::invalid_store("store has trailing bytes")),
}
}
#[cfg(unix)]
fn sync_directory(path: &Path) -> Result<(), DbError> {
let directory =
File::open(path).map_err(|error| DbError::io("open database directory", error))?;
directory
.sync_all()
.map_err(|error| DbError::io("sync database directory", error))
}
#[cfg(not(unix))]
fn sync_directory(_path: &Path) -> Result<(), DbError> {
Ok(())
}
fn checksum_bytes(bytes: &[u8]) -> u64 {
bytes.iter().fold(0xcbf2_9ce4_8422_2325, |hash, byte| {
(hash ^ u64::from(*byte)).wrapping_mul(0x0000_0100_0000_01b3)
})
}