use crate::encryption::EncryptionEngine;
use byteorder::{LittleEndian, ReadBytesExt};
use parking_lot::Mutex;
use sochdb_core::{Result, SochDBError, WalRecordType};
use std::cell::Cell;
use std::collections::HashSet;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
const CACHE_VALIDITY_NS: u64 = 1_000_000;
thread_local! {
static TS_CACHE: Cell<(Instant, u64)> = Cell::new((Instant::now(), 0));
}
#[inline(always)]
pub fn cached_timestamp_us() -> u64 {
TS_CACHE.with(|cache| {
let (instant, ts) = cache.get();
let elapsed_ns = instant.elapsed().as_nanos() as u64;
if elapsed_ns < CACHE_VALIDITY_NS {
ts + elapsed_ns / 1000
} else {
let new_ts = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock set before UNIX epoch (1970-01-01)")
.as_micros() as u64;
cache.set((Instant::now(), new_ts));
new_ts
}
})
}
const RECORD_HEADER_SIZE: usize = 4 + 1 + 8 + 8 + 4 + 4;
const CHECKSUM_SIZE: usize = 4;
const DEFAULT_TXN_BUFFER_CAPACITY: usize = 32 * 1024;
fn encode_record_body(
record_type: WalRecordType,
txn_id: u64,
timestamp_us: u64,
key: &[u8],
value: &[u8],
) -> Vec<u8> {
let body_len = (RECORD_HEADER_SIZE - 4) + key.len() + value.len() + CHECKSUM_SIZE;
let mut body = Vec::with_capacity(body_len);
let mut hasher = crc32fast::Hasher::new();
let rt = record_type as u8;
body.push(rt);
hasher.update(&[rt]);
let t = txn_id.to_le_bytes();
body.extend_from_slice(&t);
hasher.update(&t);
let ts = timestamp_us.to_le_bytes();
body.extend_from_slice(&ts);
hasher.update(&ts);
let kl = (key.len() as u32).to_le_bytes();
body.extend_from_slice(&kl);
hasher.update(&kl);
let vl = (value.len() as u32).to_le_bytes();
body.extend_from_slice(&vl);
hasher.update(&vl);
body.extend_from_slice(key);
hasher.update(key);
body.extend_from_slice(value);
hasher.update(value);
body.extend_from_slice(&hasher.finalize().to_le_bytes());
body
}
const WAL_AAD_VERSION: u8 = 1;
const WAL_AAD_LEN: usize = 1 + 16 + 4 + 8;
const MAX_WAL_FRAME_LEN: u32 = 512 * 1024 * 1024;
#[derive(Debug)]
pub struct TxnWalBuffer {
txn_id: u64,
buffer: Vec<u8>,
entry_count: usize,
}
impl TxnWalBuffer {
#[inline]
pub fn new(txn_id: u64) -> Self {
Self {
txn_id,
buffer: Vec::with_capacity(DEFAULT_TXN_BUFFER_CAPACITY),
entry_count: 0,
}
}
#[inline]
pub fn with_capacity(txn_id: u64, capacity: usize) -> Self {
Self {
txn_id,
buffer: Vec::with_capacity(capacity),
entry_count: 0,
}
}
#[inline]
pub fn append(&mut self, key: &[u8], value: &[u8]) {
let timestamp_us = cached_timestamp_us();
let total_len = RECORD_HEADER_SIZE + key.len() + value.len() + CHECKSUM_SIZE;
let entry_start = self.buffer.len();
self.buffer.extend_from_slice(&[0u8; 4]);
let mut hasher = crc32fast::Hasher::new();
let record_type_byte = WalRecordType::Data as u8;
self.buffer.push(record_type_byte);
hasher.update(&[record_type_byte]);
let txn_bytes = self.txn_id.to_le_bytes();
self.buffer.extend_from_slice(&txn_bytes);
hasher.update(&txn_bytes);
let ts_bytes = timestamp_us.to_le_bytes();
self.buffer.extend_from_slice(&ts_bytes);
hasher.update(&ts_bytes);
let key_len_bytes = (key.len() as u32).to_le_bytes();
self.buffer.extend_from_slice(&key_len_bytes);
hasher.update(&key_len_bytes);
let val_len_bytes = (value.len() as u32).to_le_bytes();
self.buffer.extend_from_slice(&val_len_bytes);
hasher.update(&val_len_bytes);
self.buffer.extend_from_slice(key);
hasher.update(key);
self.buffer.extend_from_slice(value);
hasher.update(value);
self.buffer
.extend_from_slice(&hasher.finalize().to_le_bytes());
let content_len = (total_len - 4) as u32;
self.buffer[entry_start..entry_start + 4].copy_from_slice(&content_len.to_le_bytes());
self.entry_count += 1;
}
#[inline]
pub fn flush_to_wal(&self, wal: &TxnWal) -> Result<u64> {
wal.flush_buffer(self)
}
#[inline]
pub fn clear(&mut self) {
self.buffer.clear();
self.entry_count = 0;
}
#[inline]
pub fn entry_count(&self) -> usize {
self.entry_count
}
#[inline]
pub fn bytes_buffered(&self) -> usize {
self.buffer.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct TxnWalEntry {
pub record_type: WalRecordType,
pub txn_id: u64,
pub timestamp_us: u64,
pub key: Vec<u8>,
pub value: Vec<u8>,
}
impl TxnWalEntry {
pub fn data(txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Self {
Self {
record_type: WalRecordType::Data,
txn_id,
timestamp_us: Self::now_us(),
key,
value,
}
}
pub fn txn_begin(txn_id: u64) -> Self {
Self {
record_type: WalRecordType::TxnBegin,
txn_id,
timestamp_us: Self::now_us(),
key: Vec::new(),
value: Vec::new(),
}
}
pub fn txn_commit(txn_id: u64) -> Self {
Self {
record_type: WalRecordType::TxnCommit,
txn_id,
timestamp_us: Self::now_us(),
key: Vec::new(),
value: Vec::new(),
}
}
pub fn txn_abort(txn_id: u64) -> Self {
Self {
record_type: WalRecordType::TxnAbort,
txn_id,
timestamp_us: Self::now_us(),
key: Vec::new(),
value: Vec::new(),
}
}
pub fn checkpoint(txn_id: u64) -> Self {
Self {
record_type: WalRecordType::Checkpoint,
txn_id,
timestamp_us: Self::now_us(),
key: Vec::new(),
value: Vec::new(),
}
}
pub fn schema_change(txn_id: u64, schema_data: Vec<u8>) -> Self {
Self {
record_type: WalRecordType::SchemaChange,
txn_id,
timestamp_us: Self::now_us(),
key: Vec::new(),
value: schema_data,
}
}
#[inline]
fn now_us() -> u64 {
cached_timestamp_us()
}
pub fn checksum(&self) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(&[self.record_type as u8]);
hasher.update(&self.txn_id.to_le_bytes());
hasher.update(&self.timestamp_us.to_le_bytes());
hasher.update(&(self.key.len() as u32).to_le_bytes());
hasher.update(&(self.value.len() as u32).to_le_bytes());
hasher.update(&self.key);
hasher.update(&self.value);
hasher.finalize()
}
pub fn to_bytes(&self) -> Vec<u8> {
let total_len = RECORD_HEADER_SIZE + self.key.len() + self.value.len() + CHECKSUM_SIZE;
let mut buf = Vec::with_capacity(total_len);
let mut hasher = crc32fast::Hasher::new();
let content_len = (total_len - 4) as u32;
buf.extend_from_slice(&content_len.to_le_bytes());
let record_type_byte = self.record_type as u8;
buf.push(record_type_byte);
hasher.update(&[record_type_byte]);
let txn_bytes = self.txn_id.to_le_bytes();
buf.extend_from_slice(&txn_bytes);
hasher.update(&txn_bytes);
let ts_bytes = self.timestamp_us.to_le_bytes();
buf.extend_from_slice(&ts_bytes);
hasher.update(&ts_bytes);
let key_len_bytes = (self.key.len() as u32).to_le_bytes();
buf.extend_from_slice(&key_len_bytes);
hasher.update(&key_len_bytes);
let val_len_bytes = (self.value.len() as u32).to_le_bytes();
buf.extend_from_slice(&val_len_bytes);
hasher.update(&val_len_bytes);
buf.extend_from_slice(&self.key);
hasher.update(&self.key);
buf.extend_from_slice(&self.value);
hasher.update(&self.value);
buf.extend_from_slice(&hasher.finalize().to_le_bytes());
buf
}
pub fn from_reader<R: Read>(reader: &mut R) -> Result<Self> {
let content_len = reader.read_u32::<LittleEndian>()?;
if content_len < (RECORD_HEADER_SIZE - 4 + CHECKSUM_SIZE) as u32 {
return Err(SochDBError::Corruption("WAL entry too short".into()));
}
if content_len > MAX_WAL_FRAME_LEN {
return Err(SochDBError::Corruption(format!(
"WAL entry length {content_len} exceeds maximum {MAX_WAL_FRAME_LEN}"
)));
}
let mut body = vec![0u8; content_len as usize];
reader.read_exact(&mut body)?;
Self::parse_body(&body)
}
pub fn parse_body(body: &[u8]) -> Result<Self> {
let mut cur = std::io::Cursor::new(body);
let record_type_byte = cur.read_u8()?;
let record_type = WalRecordType::try_from(record_type_byte).map_err(|_| {
SochDBError::Corruption(format!("Invalid record type: {}", record_type_byte))
})?;
let txn_id = cur.read_u64::<LittleEndian>()?;
let timestamp_us = cur.read_u64::<LittleEndian>()?;
let key_len = cur.read_u32::<LittleEndian>()? as usize;
let value_len = cur.read_u32::<LittleEndian>()? as usize;
let mut key = vec![0u8; key_len];
cur.read_exact(&mut key)?;
let mut value = vec![0u8; value_len];
cur.read_exact(&mut value)?;
let stored_checksum = cur.read_u32::<LittleEndian>()?;
let entry = Self {
record_type,
txn_id,
timestamp_us,
key,
value,
};
if entry.checksum() != stored_checksum {
return Err(SochDBError::Corruption(format!(
"WAL checksum mismatch for txn_id {}: expected {}, got {}",
txn_id,
entry.checksum(),
stored_checksum
)));
}
Ok(entry)
}
fn body_bytes(&self) -> Vec<u8> {
let full = self.to_bytes();
full[4..].to_vec()
}
}
pub struct TxnWal {
path: PathBuf,
writer: Mutex<BufWriter<File>>,
next_txn_id: AtomicU64,
sequence: AtomicU64,
bytes_since_sync: AtomicU64,
cached_timestamp_us: AtomicU64,
encryption: Arc<EncryptionEngine>,
db_uuid: [u8; 16],
dek_epoch: u32,
records_in_file: AtomicU64,
}
impl TxnWal {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
Self::new_with_encryption(path, Arc::new(EncryptionEngine::disabled()), [0u8; 16], 0)
}
pub fn new_with_encryption<P: AsRef<Path>>(
path: P,
encryption: Arc<EncryptionEngine>,
db_uuid: [u8; 16],
dek_epoch: u32,
) -> Result<Self> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let file = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&path)?;
let now_us = cached_timestamp_us();
let wal = Self {
path,
writer: Mutex::new(BufWriter::with_capacity(256 * 1024, file)),
next_txn_id: AtomicU64::new(1),
sequence: AtomicU64::new(0),
bytes_since_sync: AtomicU64::new(0),
cached_timestamp_us: AtomicU64::new(now_us),
encryption,
db_uuid,
dek_epoch,
records_in_file: AtomicU64::new(0),
};
wal.recover_state()?;
Ok(wal)
}
#[inline]
fn record_aad(&self, ordinal: u64) -> [u8; WAL_AAD_LEN] {
let mut aad = [0u8; WAL_AAD_LEN];
aad[0] = WAL_AAD_VERSION;
aad[1..17].copy_from_slice(&self.db_uuid);
aad[17..21].copy_from_slice(&self.dek_epoch.to_le_bytes());
aad[21..29].copy_from_slice(&ordinal.to_le_bytes());
aad
}
#[inline]
fn encrypt_frame(&self, body: &[u8], ordinal: u64) -> Result<Vec<u8>> {
let env = self
.encryption
.encrypt_with_aad(body, &self.record_aad(ordinal))?;
let mut out = Vec::with_capacity(4 + env.len());
out.extend_from_slice(&(env.len() as u32).to_le_bytes());
out.extend_from_slice(&env);
Ok(out)
}
fn read_record<R: Read>(&self, reader: &mut R, ordinal: &mut u64) -> Result<TxnWalEntry> {
if !self.encryption.is_enabled() {
let entry = TxnWalEntry::from_reader(reader)?;
*ordinal += 1;
return Ok(entry);
}
let outer_len = reader.read_u32::<LittleEndian>()?; if outer_len > MAX_WAL_FRAME_LEN {
return Err(SochDBError::Corruption(format!(
"encrypted WAL frame length {outer_len} exceeds maximum {MAX_WAL_FRAME_LEN}"
)));
}
let mut env = vec![0u8; outer_len as usize];
reader.read_exact(&mut env)?; let body = self
.encryption
.decrypt_with_aad(&env, &self.record_aad(*ordinal))?;
let entry = TxnWalEntry::parse_body(&body)?;
*ordinal += 1;
Ok(entry)
}
fn recover_state(&self) -> Result<()> {
let file = File::open(&self.path)?;
let mut reader = BufReader::new(file);
let mut count: u64 = 0;
let our_pid = std::process::id() as u64;
let pid_base = our_pid << 32;
let mut max_our_counter: u64 = 0;
let mut ordinal: u64 = 0;
loop {
match self.read_record(&mut reader, &mut ordinal) {
Ok(entry) => {
count += 1;
let entry_pid = entry.txn_id >> 32;
if entry_pid == our_pid {
let entry_counter = entry.txn_id & 0xFFFF_FFFF;
if entry_counter > max_our_counter {
max_our_counter = entry_counter;
}
}
}
Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => {
if self.encryption.is_enabled() {
return Err(e);
}
break;
}
}
}
let next_id = pid_base + max_our_counter + 1;
self.next_txn_id.store(next_id, Ordering::SeqCst);
self.sequence.store(count, Ordering::SeqCst);
self.records_in_file.store(count, Ordering::SeqCst);
Ok(())
}
#[inline]
fn get_cached_timestamp(&self) -> u64 {
let cached = self.cached_timestamp_us.load(Ordering::Relaxed);
let seq = self.sequence.load(Ordering::Relaxed);
if seq & 0x3FF == 0 {
let now_us = cached_timestamp_us();
self.cached_timestamp_us.store(now_us, Ordering::Relaxed);
return now_us;
}
cached
}
pub fn append(&self, entry: &TxnWalEntry) -> Result<u64> {
let mut writer = self.writer.lock();
let bytes = self.frame_under_lock(entry)?;
writer.write_all(&bytes)?;
let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
self.bytes_since_sync
.fetch_add(bytes.len() as u64, Ordering::Relaxed);
Ok(seq)
}
#[inline]
fn frame_under_lock(&self, entry: &TxnWalEntry) -> Result<Vec<u8>> {
if self.encryption.is_enabled() {
let ord = self.records_in_file.fetch_add(1, Ordering::SeqCst);
self.encrypt_frame(&entry.body_bytes(), ord)
} else {
Ok(entry.to_bytes())
}
}
#[inline]
pub fn append_no_flush(&self, entry: &TxnWalEntry) -> Result<u64> {
let mut writer = self.writer.lock();
let bytes = self.frame_under_lock(entry)?;
writer.write_all(&bytes)?;
let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
self.bytes_since_sync
.fetch_add(bytes.len() as u64, Ordering::Relaxed);
Ok(seq)
}
#[inline]
pub fn write_no_flush(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<u64> {
let entry = TxnWalEntry::data(txn_id, key, value);
self.append_no_flush(&entry)
}
#[inline]
pub fn write_no_flush_refs(&self, txn_id: u64, key: &[u8], value: &[u8]) -> Result<u64> {
let timestamp_us = self.get_cached_timestamp();
let total_len = RECORD_HEADER_SIZE + key.len() + value.len() + CHECKSUM_SIZE;
let mut writer = self.writer.lock();
if self.encryption.is_enabled() {
let body = encode_record_body(WalRecordType::Data, txn_id, timestamp_us, key, value);
let ord = self.records_in_file.fetch_add(1, Ordering::SeqCst);
let frame = self.encrypt_frame(&body, ord)?;
writer.write_all(&frame)?;
let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
self.bytes_since_sync
.fetch_add(frame.len() as u64, Ordering::Relaxed);
return Ok(seq);
}
let mut hasher = crc32fast::Hasher::new();
let content_len = (total_len - 4) as u32;
writer.write_all(&content_len.to_le_bytes())?;
let record_type_byte = WalRecordType::Data as u8;
writer.write_all(&[record_type_byte])?;
hasher.update(&[record_type_byte]);
let txn_bytes = txn_id.to_le_bytes();
writer.write_all(&txn_bytes)?;
hasher.update(&txn_bytes);
let ts_bytes = timestamp_us.to_le_bytes();
writer.write_all(&ts_bytes)?;
hasher.update(&ts_bytes);
let key_len_bytes = (key.len() as u32).to_le_bytes();
writer.write_all(&key_len_bytes)?;
hasher.update(&key_len_bytes);
let val_len_bytes = (value.len() as u32).to_le_bytes();
writer.write_all(&val_len_bytes)?;
hasher.update(&val_len_bytes);
writer.write_all(key)?;
hasher.update(key);
writer.write_all(value)?;
hasher.update(value);
writer.write_all(&hasher.finalize().to_le_bytes())?;
let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
self.bytes_since_sync
.fetch_add(total_len as u64, Ordering::Relaxed);
Ok(seq)
}
pub fn flush(&self) -> Result<()> {
let mut writer = self.writer.lock();
writer.flush()?;
Ok(())
}
pub fn append_sync(&self, entry: &TxnWalEntry) -> Result<u64> {
let seq = self.append(entry)?;
self.sync()?;
Ok(seq)
}
pub fn sync(&self) -> Result<()> {
let mut writer = self.writer.lock();
writer.flush()?;
writer.get_ref().sync_all()?;
self.bytes_since_sync.store(0, Ordering::Relaxed);
Ok(())
}
#[inline]
pub fn flush_buffer(&self, buffer: &TxnWalBuffer) -> Result<u64> {
if buffer.is_empty() {
return Ok(0);
}
let mut writer = self.writer.lock();
if self.encryption.is_enabled() {
let buf = &buffer.buffer;
let mut pos = 0usize;
let mut total_written = 0u64;
while pos + 4 <= buf.len() {
let content_len =
u32::from_le_bytes([buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]])
as usize;
pos += 4;
if pos + content_len > buf.len() {
return Err(SochDBError::Corruption(
"txn buffer truncated mid-record during encrypted flush".into(),
));
}
let body = &buf[pos..pos + content_len];
pos += content_len;
let ord = self.records_in_file.fetch_add(1, Ordering::SeqCst);
let frame = self.encrypt_frame(body, ord)?;
writer.write_all(&frame)?;
total_written += frame.len() as u64;
}
let seq = self
.sequence
.fetch_add(buffer.entry_count as u64, Ordering::SeqCst);
self.bytes_since_sync
.fetch_add(total_written, Ordering::Relaxed);
return Ok(seq);
}
writer.write_all(&buffer.buffer)?;
let seq = self
.sequence
.fetch_add(buffer.entry_count as u64, Ordering::SeqCst);
self.bytes_since_sync
.fetch_add(buffer.buffer.len() as u64, Ordering::Relaxed);
Ok(seq)
}
pub fn size_bytes(&self) -> u64 {
std::fs::metadata(&self.path).map(|m| m.len()).unwrap_or(0)
}
pub fn alloc_txn_id(&self) -> u64 {
self.next_txn_id.fetch_add(1, Ordering::SeqCst)
}
pub fn begin_transaction(&self) -> Result<u64> {
let txn_id = self.alloc_txn_id();
let entry = TxnWalEntry::txn_begin(txn_id);
self.append(&entry)?;
Ok(txn_id)
}
pub fn commit_transaction(&self, txn_id: u64) -> Result<()> {
self.flush()?;
let entry = TxnWalEntry::txn_commit(txn_id);
self.append_sync(&entry)?;
Ok(())
}
pub fn commit_durable_batch(&self, txn_ids: &[u64]) -> Result<()> {
for &txn_id in txn_ids {
let entry = TxnWalEntry::txn_commit(txn_id);
self.append_no_flush(&entry)?;
}
self.flush()?;
self.sync()?;
Ok(())
}
pub fn abort_transaction(&self, txn_id: u64) -> Result<()> {
let entry = TxnWalEntry::txn_abort(txn_id);
self.append(&entry)?;
Ok(())
}
pub fn write(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<u64> {
let entry = TxnWalEntry::data(txn_id, key, value);
self.append(&entry)
}
#[allow(clippy::type_complexity)]
pub fn replay_for_recovery(&self) -> Result<(Vec<(Vec<u8>, Vec<u8>)>, usize)> {
let file = File::open(&self.path)?;
let mut reader = BufReader::new(file);
let mut pending_writes: std::collections::HashMap<u64, Vec<(Vec<u8>, Vec<u8>)>> =
std::collections::HashMap::new();
let mut result = Vec::new();
let mut txn_count = 0;
let mut ordinal: u64 = 0;
loop {
match self.read_record(&mut reader, &mut ordinal) {
Ok(entry) => match entry.record_type {
WalRecordType::TxnBegin => {
pending_writes.insert(entry.txn_id, Vec::new());
}
WalRecordType::Data => {
pending_writes
.entry(entry.txn_id)
.or_insert_with(Vec::new)
.push((entry.key, entry.value));
}
WalRecordType::TxnCommit => {
if let Some(writes) = pending_writes.remove(&entry.txn_id) {
result.extend(writes);
txn_count += 1;
}
}
WalRecordType::TxnAbort => {
pending_writes.remove(&entry.txn_id);
}
_ => {}
},
Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => {
if self.encryption.is_enabled() {
return Err(e);
}
break;
}
}
}
Ok((result, txn_count))
}
pub fn replay<F>(&self, mut callback: F) -> Result<u64>
where
F: FnMut(TxnWalEntry) -> Result<()>,
{
let file = File::open(&self.path)?;
let mut reader = BufReader::new(file);
let mut count = 0u64;
let mut ordinal: u64 = 0;
loop {
match self.read_record(&mut reader, &mut ordinal) {
Ok(entry) => {
callback(entry)?;
count += 1;
}
Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => {
if self.encryption.is_enabled() {
return Err(e);
}
eprintln!("WAL replay warning: {:?}", e);
break;
}
}
}
Ok(count)
}
pub fn truncate(&self) -> Result<()> {
let mut writer = self.writer.lock();
writer.flush()?;
let file = writer.get_ref();
file.set_len(0)?;
file.sync_all()?;
self.sequence.store(0, Ordering::SeqCst);
self.bytes_since_sync.store(0, Ordering::Relaxed);
self.records_in_file.store(0, Ordering::SeqCst);
Ok(())
}
pub fn write_checkpoint(&self) -> Result<u64> {
let entry = TxnWalEntry::checkpoint(0);
self.append_sync(&entry)
}
pub fn append_clr(
&self,
txn_id: u64,
_original_lsn: u64,
undo_next_lsn: Option<u64>,
undo_data: &[u8],
) -> Result<u64> {
let key = undo_next_lsn.unwrap_or(0).to_le_bytes().to_vec();
let entry = TxnWalEntry {
record_type: WalRecordType::CompensationLogRecord,
txn_id,
timestamp_us: TxnWalEntry::now_us(),
key, value: undo_data.to_vec(),
};
self.append(&entry)
}
pub fn write_checkpoint_with_data(&self, checkpoint_data: &[u8]) -> Result<u64> {
let entry = TxnWalEntry {
record_type: WalRecordType::Checkpoint,
txn_id: 0,
timestamp_us: TxnWalEntry::now_us(),
key: Vec::new(),
value: checkpoint_data.to_vec(),
};
self.append_sync(&entry)
}
pub fn write_checkpoint_end(&self, checkpoint_data: &[u8]) -> Result<u64> {
let entry = TxnWalEntry {
record_type: WalRecordType::CheckpointEnd,
txn_id: 0,
timestamp_us: TxnWalEntry::now_us(),
key: Vec::new(),
value: checkpoint_data.to_vec(),
};
self.append_sync(&entry)
}
pub fn sequence(&self) -> u64 {
self.sequence.load(Ordering::SeqCst)
}
pub fn bytes_since_sync(&self) -> u64 {
self.bytes_since_sync.load(Ordering::Relaxed)
}
pub fn path(&self) -> &Path {
&self.path
}
}
#[derive(Debug, Clone, Default)]
pub struct TxnWalStats {
pub entries_written: u64,
pub bytes_since_sync: u64,
pub next_txn_id: u64,
}
#[allow(dead_code)]
pub struct ShardedWal {
shards: Vec<parking_lot::Mutex<WalShard>>,
num_shards: usize,
central_writer: parking_lot::Mutex<BufWriter<File>>,
next_txn_id: AtomicU64,
sequence: AtomicU64,
path: PathBuf,
}
struct WalShard {
buffer: Vec<u8>,
entry_count: usize,
}
impl WalShard {
fn new() -> Self {
Self {
buffer: Vec::with_capacity(64 * 1024), entry_count: 0,
}
}
fn append(&mut self, entry: &TxnWalEntry) {
let bytes = entry.to_bytes();
self.buffer.extend_from_slice(&bytes);
self.entry_count += 1;
}
fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
fn drain(&mut self) -> Vec<u8> {
self.entry_count = 0;
std::mem::take(&mut self.buffer)
}
}
impl ShardedWal {
pub fn new<P: AsRef<Path>>(path: P, num_shards: usize) -> Result<Self> {
let path = path.as_ref().to_path_buf();
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&path)?;
let num_shards = num_shards.next_power_of_two();
let shards: Vec<_> = (0..num_shards)
.map(|_| parking_lot::Mutex::new(WalShard::new()))
.collect();
Ok(Self {
shards,
num_shards,
central_writer: parking_lot::Mutex::new(BufWriter::with_capacity(256 * 1024, file)),
next_txn_id: AtomicU64::new(1),
sequence: AtomicU64::new(0),
path,
})
}
#[inline]
fn shard_idx(&self, txn_id: u64) -> usize {
(txn_id as usize) & (self.num_shards - 1)
}
pub fn append(&self, entry: &TxnWalEntry) -> u64 {
let shard_idx = self.shard_idx(entry.txn_id);
let mut shard = self.shards[shard_idx].lock();
shard.append(entry);
self.sequence.fetch_add(1, Ordering::SeqCst)
}
pub fn alloc_txn_id(&self) -> u64 {
self.next_txn_id.fetch_add(1, Ordering::SeqCst)
}
pub fn flush(&self) -> Result<()> {
let mut central = self.central_writer.lock();
for shard in &self.shards {
let mut shard_guard = shard.lock();
if !shard_guard.is_empty() {
let data = shard_guard.drain();
central.write_all(&data)?;
}
}
central.flush()?;
Ok(())
}
pub fn sync(&self) -> Result<()> {
self.flush()?;
let central = self.central_writer.lock();
central.get_ref().sync_all()?;
Ok(())
}
pub fn begin_transaction(&self) -> Result<u64> {
let txn_id = self.alloc_txn_id();
let entry = TxnWalEntry::txn_begin(txn_id);
self.append(&entry);
Ok(txn_id)
}
pub fn write(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<u64> {
let entry = TxnWalEntry::data(txn_id, key, value);
Ok(self.append(&entry))
}
pub fn commit_transaction(&self, txn_id: u64) -> Result<u64> {
let entry = TxnWalEntry::txn_commit(txn_id);
let seq = self.append(&entry);
self.sync()?; Ok(seq)
}
pub fn stats(&self) -> ShardedWalStats {
let mut shard_entry_counts = Vec::with_capacity(self.num_shards);
for shard in &self.shards {
shard_entry_counts.push(shard.lock().entry_count);
}
ShardedWalStats {
num_shards: self.num_shards,
total_entries: self.sequence.load(Ordering::SeqCst),
next_txn_id: self.next_txn_id.load(Ordering::SeqCst),
shard_entry_counts,
}
}
}
#[derive(Debug, Clone)]
pub struct ShardedWalStats {
pub num_shards: usize,
pub total_entries: u64,
pub next_txn_id: u64,
pub shard_entry_counts: Vec<usize>,
}
#[derive(Debug, Clone, Default)]
pub struct CrashRecoveryStats {
pub total_records: u64,
pub committed_txns: u64,
pub rolled_back_txns: u64,
pub aborted_txns: u64,
pub recovered_writes: u64,
pub torn_records: u64,
pub bytes_read: u64,
pub recovery_duration_us: u64,
pub max_txn_id: u64,
}
impl TxnWal {
pub fn stats(&self) -> TxnWalStats {
TxnWalStats {
entries_written: self.sequence.load(Ordering::SeqCst),
bytes_since_sync: self.bytes_since_sync.load(Ordering::Relaxed),
next_txn_id: self.next_txn_id.load(Ordering::SeqCst),
}
}
#[allow(clippy::type_complexity)]
pub fn crash_recovery(&self) -> Result<(Vec<(Vec<u8>, Vec<u8>)>, CrashRecoveryStats)> {
let start_time = std::time::Instant::now();
let file = File::open(&self.path)?;
let file_size = file.metadata()?.len();
let mut reader = BufReader::new(file);
let mut stats = CrashRecoveryStats {
bytes_read: file_size,
..Default::default()
};
let mut committed_txns: HashSet<u64> = HashSet::new();
let mut aborted_txns: HashSet<u64> = HashSet::new();
let mut pending_writes: std::collections::HashMap<u64, Vec<(Vec<u8>, Vec<u8>)>> =
std::collections::HashMap::new();
let mut all_txns: HashSet<u64> = HashSet::new();
let mut ordinal: u64 = 0;
loop {
match self.read_record(&mut reader, &mut ordinal) {
Ok(entry) => {
stats.total_records += 1;
if entry.txn_id > stats.max_txn_id {
stats.max_txn_id = entry.txn_id;
}
match entry.record_type {
WalRecordType::TxnBegin => {
pending_writes.insert(entry.txn_id, Vec::new());
all_txns.insert(entry.txn_id);
}
WalRecordType::Data => {
if let Some(writes) = pending_writes.get_mut(&entry.txn_id) {
writes.push((entry.key, entry.value));
}
}
WalRecordType::TxnCommit => {
committed_txns.insert(entry.txn_id);
}
WalRecordType::TxnAbort => {
pending_writes.remove(&entry.txn_id);
aborted_txns.insert(entry.txn_id);
}
_ => {}
}
}
Err(SochDBError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => {
if self.encryption.is_enabled() {
return Err(e);
}
stats.torn_records += 1;
break;
}
}
}
let mut result = Vec::new();
for (txn_id, writes) in &pending_writes {
if committed_txns.contains(txn_id) {
stats.committed_txns += 1;
stats.recovered_writes += writes.len() as u64;
result.extend(writes.clone());
}
}
stats.aborted_txns = aborted_txns.len() as u64;
stats.rolled_back_txns = all_txns.len() as u64 - stats.committed_txns - stats.aborted_txns;
stats.recovery_duration_us = start_time.elapsed().as_micros() as u64;
Ok((result, stats))
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_wal_entry_roundtrip() {
let entry = TxnWalEntry::data(42, b"key".to_vec(), b"value".to_vec());
let bytes = entry.to_bytes();
let mut cursor = std::io::Cursor::new(bytes);
let recovered = TxnWalEntry::from_reader(&mut cursor).unwrap();
assert_eq!(recovered.record_type, WalRecordType::Data);
assert_eq!(recovered.txn_id, 42);
assert_eq!(recovered.key, b"key");
assert_eq!(recovered.value, b"value");
}
#[test]
fn test_wal_append_and_replay() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
{
let wal = TxnWal::new(&wal_path).unwrap();
let txn_id = wal.begin_transaction().unwrap();
wal.write(txn_id, b"k1".to_vec(), b"v1".to_vec()).unwrap();
wal.write(txn_id, b"k2".to_vec(), b"v2".to_vec()).unwrap();
wal.commit_transaction(txn_id).unwrap();
}
{
let wal = TxnWal::new(&wal_path).unwrap();
let (writes, txn_count) = wal.replay_for_recovery().unwrap();
assert_eq!(txn_count, 1);
assert_eq!(writes.len(), 2);
assert_eq!(writes[0], (b"k1".to_vec(), b"v1".to_vec()));
assert_eq!(writes[1], (b"k2".to_vec(), b"v2".to_vec()));
}
}
#[test]
fn test_uncommitted_transaction_rollback() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
{
let wal = TxnWal::new(&wal_path).unwrap();
let txn1 = wal.begin_transaction().unwrap();
wal.write(txn1, b"committed".to_vec(), b"yes".to_vec())
.unwrap();
wal.commit_transaction(txn1).unwrap();
let txn2 = wal.begin_transaction().unwrap();
wal.write(txn2, b"uncommitted".to_vec(), b"no".to_vec())
.unwrap();
}
{
let wal = TxnWal::new(&wal_path).unwrap();
let (writes, txn_count) = wal.replay_for_recovery().unwrap();
assert_eq!(txn_count, 1); assert_eq!(writes.len(), 1);
assert_eq!(writes[0], (b"committed".to_vec(), b"yes".to_vec()));
}
}
#[test]
fn test_aborted_transaction() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
{
let wal = TxnWal::new(&wal_path).unwrap();
let txn = wal.begin_transaction().unwrap();
wal.write(txn, b"aborted".to_vec(), b"data".to_vec())
.unwrap();
wal.abort_transaction(txn).unwrap();
}
{
let wal = TxnWal::new(&wal_path).unwrap();
let (writes, txn_count) = wal.replay_for_recovery().unwrap();
assert_eq!(txn_count, 0);
assert!(writes.is_empty());
}
}
#[test]
fn test_checksum_validation() {
let entry = TxnWalEntry::data(1, b"key".to_vec(), b"value".to_vec());
let mut bytes = entry.to_bytes();
let len = bytes.len();
bytes[len - 1] ^= 0xFF;
let mut cursor = std::io::Cursor::new(bytes);
let result = TxnWalEntry::from_reader(&mut cursor);
assert!(result.is_err());
}
#[test]
fn test_crash_recovery_with_stats() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
{
let wal = TxnWal::new(&wal_path).unwrap();
let txn1 = wal.begin_transaction().unwrap();
wal.write(txn1, b"k1".to_vec(), b"v1".to_vec()).unwrap();
wal.write(txn1, b"k2".to_vec(), b"v2".to_vec()).unwrap();
wal.commit_transaction(txn1).unwrap();
let txn2 = wal.begin_transaction().unwrap();
wal.write(txn2, b"aborted_key".to_vec(), b"aborted_val".to_vec())
.unwrap();
wal.abort_transaction(txn2).unwrap();
let txn3 = wal.begin_transaction().unwrap();
wal.write(txn3, b"k3".to_vec(), b"v3".to_vec()).unwrap();
wal.commit_transaction(txn3).unwrap();
let txn4 = wal.begin_transaction().unwrap();
wal.write(txn4, b"uncommitted".to_vec(), b"data".to_vec())
.unwrap();
}
{
let wal = TxnWal::new(&wal_path).unwrap();
let (writes, stats) = wal.crash_recovery().unwrap();
assert_eq!(writes.len(), 3);
assert_eq!(stats.committed_txns, 2);
assert_eq!(stats.aborted_txns, 1);
assert_eq!(stats.rolled_back_txns, 1); assert_eq!(stats.recovered_writes, 3);
assert!(stats.recovery_duration_us > 0);
}
}
#[test]
fn test_torn_write_detection() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
{
let wal = TxnWal::new(&wal_path).unwrap();
let txn = wal.begin_transaction().unwrap();
wal.write(txn, b"key".to_vec(), b"value".to_vec()).unwrap();
wal.commit_transaction(txn).unwrap();
}
{
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.append(true)
.open(&wal_path)
.unwrap();
file.write_all(&[0x10, 0x00, 0x00, 0x00, 0xFF, 0xFF])
.unwrap();
}
{
let wal = TxnWal::new(&wal_path).unwrap();
let (writes, stats) = wal.crash_recovery().unwrap();
assert_eq!(writes.len(), 1);
assert_eq!(stats.committed_txns, 1);
assert_eq!(stats.torn_records, 1);
}
}
#[test]
fn test_crc32_determinism() {
let mut entry1 = TxnWalEntry::data(42, b"key".to_vec(), b"value".to_vec());
entry1.timestamp_us = 12345;
let mut entry2 = TxnWalEntry::data(42, b"key".to_vec(), b"value".to_vec());
entry2.timestamp_us = 12345;
assert_eq!(entry1.checksum(), entry2.checksum());
let mut entry3 = TxnWalEntry::data(42, b"key".to_vec(), b"different".to_vec());
entry3.timestamp_us = 12345;
assert_ne!(entry1.checksum(), entry3.checksum());
let bytes = entry1.to_bytes();
let mut cursor = std::io::Cursor::new(bytes);
let recovered = TxnWalEntry::from_reader(&mut cursor).unwrap();
assert_eq!(recovered.checksum(), entry1.checksum());
}
}
#[cfg(test)]
mod encryption_wal_tests {
use super::*;
use crate::encryption::EncryptionEngine;
use tempfile::tempdir;
fn enc(key: u8) -> Arc<EncryptionEngine> {
Arc::new(EncryptionEngine::new(&[key; 32]))
}
const UUID: [u8; 16] = [9u8; 16];
#[test]
fn encrypted_write_then_recover_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("enc.wal");
{
let wal = TxnWal::new_with_encryption(&path, enc(7), UUID, 0).unwrap();
let t1 = wal.begin_transaction().unwrap();
wal.write(t1, b"alpha".to_vec(), b"one".to_vec()).unwrap();
wal.write_no_flush_refs(t1, b"beta", b"two").unwrap();
let mut buf = TxnWalBuffer::new(t1);
buf.append(b"gamma", b"three");
wal.flush_buffer(&buf).unwrap();
wal.commit_transaction(t1).unwrap();
let t2 = wal.begin_transaction().unwrap();
wal.write(t2, b"ghost".to_vec(), b"x".to_vec()).unwrap();
wal.abort_transaction(t2).unwrap();
wal.sync().unwrap();
}
let raw = std::fs::read(&path).unwrap();
assert!(!contains(&raw, b"alpha"));
assert!(!contains(&raw, b"three"));
let wal = TxnWal::new_with_encryption(&path, enc(7), UUID, 0).unwrap();
let (writes, stats) = wal.crash_recovery().unwrap();
let keys: Vec<_> = writes.iter().map(|(k, _)| k.clone()).collect();
assert!(keys.contains(&b"alpha".to_vec()));
assert!(keys.contains(&b"beta".to_vec()));
assert!(keys.contains(&b"gamma".to_vec()));
assert!(!keys.contains(&b"ghost".to_vec()), "aborted txn leaked");
assert_eq!(stats.committed_txns, 1);
}
#[test]
fn wrong_key_fails_loud_not_empty() {
let dir = tempdir().unwrap();
let path = dir.path().join("enc.wal");
{
let wal = TxnWal::new_with_encryption(&path, enc(1), UUID, 0).unwrap();
let t = wal.begin_transaction().unwrap();
wal.write(t, b"k".to_vec(), b"v".to_vec()).unwrap();
wal.commit_transaction(t).unwrap();
wal.sync().unwrap();
}
let opened = TxnWal::new_with_encryption(&path, enc(2), UUID, 0);
assert!(opened.is_err(), "wrong key opened silently (data-loss bug)");
match opened.err().unwrap() {
SochDBError::Encryption(_) => {}
other => panic!("expected Encryption error, got {other:?}"),
}
}
#[test]
fn tamper_midstream_fails_loud() {
let dir = tempdir().unwrap();
let path = dir.path().join("enc.wal");
{
let wal = TxnWal::new_with_encryption(&path, enc(5), UUID, 0).unwrap();
let t = wal.begin_transaction().unwrap();
wal.write(t, b"k1".to_vec(), b"v1".to_vec()).unwrap();
wal.commit_transaction(t).unwrap();
wal.sync().unwrap();
}
let mut raw = std::fs::read(&path).unwrap();
let mid = raw.len() / 2;
raw[mid] ^= 0xFF;
std::fs::write(&path, &raw).unwrap();
let opened = TxnWal::new_with_encryption(&path, enc(5), UUID, 0);
let failed = opened.is_err()
|| opened
.ok()
.map(|w| w.crash_recovery().is_err())
.unwrap_or(true);
assert!(failed, "tampered encrypted WAL was silently accepted");
}
#[test]
fn disabled_engine_is_byte_identical_to_plaintext() {
let dir = tempdir().unwrap();
let entry = TxnWalEntry::data(42, b"k1".to_vec(), b"v1".to_vec());
let golden = entry.to_bytes();
let p_dis = dir.path().join("disabled.wal");
{
let wal = TxnWal::new_with_encryption(
&p_dis,
Arc::new(EncryptionEngine::disabled()),
[0u8; 16],
0,
)
.unwrap();
wal.append(&entry).unwrap();
wal.sync().unwrap();
}
assert_eq!(
std::fs::read(&p_dis).unwrap(),
golden,
"disabled-engine append diverged from legacy plaintext frame"
);
let p_enc = dir.path().join("enc.wal");
{
let wal = TxnWal::new_with_encryption(&p_enc, enc(3), UUID, 0).unwrap();
wal.append(&entry).unwrap();
wal.sync().unwrap();
}
let enc_bytes = std::fs::read(&p_enc).unwrap();
assert_ne!(enc_bytes, golden);
let mut cur = std::io::Cursor::new(&enc_bytes);
assert!(
TxnWalEntry::from_reader(&mut cur).is_err()
|| cur.position() as usize != enc_bytes.len(),
"ciphertext frame must not parse cleanly as a plaintext record"
);
}
fn contains(haystack: &[u8], needle: &[u8]) -> bool {
haystack.windows(needle.len()).any(|w| w == needle)
}
}