use std::collections::VecDeque;
use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, MutexGuard, RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::time::{SystemTime, UNIX_EPOCH};
pub type TxnId = u64;
pub type Lsn = u64;
pub type Timestamp = u64;
fn io_lock_error(context: &'static str) -> io::Error {
io::Error::other(format!("{context} lock poisoned"))
}
fn io_read_guard<'a, T>(
lock: &'a RwLock<T>,
context: &'static str,
) -> io::Result<RwLockReadGuard<'a, T>> {
lock.read().map_err(|_| io_lock_error(context))
}
fn io_write_guard<'a, T>(
lock: &'a RwLock<T>,
context: &'static str,
) -> io::Result<RwLockWriteGuard<'a, T>> {
lock.write().map_err(|_| io_lock_error(context))
}
fn io_mutex_guard<'a, T>(
lock: &'a Mutex<T>,
context: &'static str,
) -> io::Result<MutexGuard<'a, T>> {
lock.lock().map_err(|_| io_lock_error(context))
}
fn recover_read_guard<'a, T>(lock: &'a RwLock<T>) -> RwLockReadGuard<'a, T> {
match lock.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
fn transaction_wal_frame_error(err: reddb_file::RdbFileError) -> io::Error {
let message = err.to_string();
let kind = if message.contains("missing")
|| message.contains("empty")
|| message.contains("truncated")
|| message.contains("too short")
{
io::ErrorKind::UnexpectedEof
} else {
io::ErrorKind::InvalidData
};
io::Error::new(kind, message)
}
fn transaction_wal_payload_error(err: reddb_file::RdbFileError) -> io::Error {
transaction_wal_frame_error(err)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LogEntryType {
Begin,
Commit,
Abort,
Insert { key: Vec<u8>, value: Vec<u8> },
Update {
key: Vec<u8>,
old_value: Vec<u8>,
new_value: Vec<u8>,
},
Delete { key: Vec<u8>, old_value: Vec<u8> },
Checkpoint { active_txns: Vec<TxnId> },
Savepoint { name: String },
RollbackToSavepoint { name: String },
Compensate { original_lsn: Lsn },
End,
}
impl LogEntryType {
pub fn is_commit(&self) -> bool {
matches!(self, LogEntryType::Commit)
}
pub fn is_abort(&self) -> bool {
matches!(self, LogEntryType::Abort)
}
pub fn is_data_modification(&self) -> bool {
matches!(
self,
LogEntryType::Insert { .. } | LogEntryType::Update { .. } | LogEntryType::Delete { .. }
)
}
pub fn to_bytes(&self) -> Vec<u8> {
reddb_file::encode_transaction_wal_entry_payload(&self.to_file_payload())
}
pub fn from_bytes(data: &[u8]) -> io::Result<(Self, usize)> {
let (payload, consumed) = reddb_file::decode_transaction_wal_entry_payload(data)
.map_err(transaction_wal_payload_error)?;
Ok((Self::from_file_payload(payload), consumed))
}
fn to_file_payload(&self) -> reddb_file::TransactionWalEntryPayload {
match self {
LogEntryType::Begin => reddb_file::TransactionWalEntryPayload::Begin,
LogEntryType::Commit => reddb_file::TransactionWalEntryPayload::Commit,
LogEntryType::Abort => reddb_file::TransactionWalEntryPayload::Abort,
LogEntryType::Insert { key, value } => reddb_file::TransactionWalEntryPayload::Insert {
key: key.clone(),
value: value.clone(),
},
LogEntryType::Update {
key,
old_value,
new_value,
} => reddb_file::TransactionWalEntryPayload::Update {
key: key.clone(),
old_value: old_value.clone(),
new_value: new_value.clone(),
},
LogEntryType::Delete { key, old_value } => {
reddb_file::TransactionWalEntryPayload::Delete {
key: key.clone(),
old_value: old_value.clone(),
}
}
LogEntryType::Checkpoint { active_txns } => {
reddb_file::TransactionWalEntryPayload::Checkpoint {
active_txns: active_txns.clone(),
}
}
LogEntryType::Savepoint { name } => {
reddb_file::TransactionWalEntryPayload::Savepoint { name: name.clone() }
}
LogEntryType::RollbackToSavepoint { name } => {
reddb_file::TransactionWalEntryPayload::RollbackToSavepoint { name: name.clone() }
}
LogEntryType::Compensate { original_lsn } => {
reddb_file::TransactionWalEntryPayload::Compensate {
original_lsn: *original_lsn,
}
}
LogEntryType::End => reddb_file::TransactionWalEntryPayload::End,
}
}
fn from_file_payload(payload: reddb_file::TransactionWalEntryPayload) -> Self {
match payload {
reddb_file::TransactionWalEntryPayload::Begin => LogEntryType::Begin,
reddb_file::TransactionWalEntryPayload::Commit => LogEntryType::Commit,
reddb_file::TransactionWalEntryPayload::Abort => LogEntryType::Abort,
reddb_file::TransactionWalEntryPayload::Insert { key, value } => {
LogEntryType::Insert { key, value }
}
reddb_file::TransactionWalEntryPayload::Update {
key,
old_value,
new_value,
} => LogEntryType::Update {
key,
old_value,
new_value,
},
reddb_file::TransactionWalEntryPayload::Delete { key, old_value } => {
LogEntryType::Delete { key, old_value }
}
reddb_file::TransactionWalEntryPayload::Checkpoint { active_txns } => {
LogEntryType::Checkpoint { active_txns }
}
reddb_file::TransactionWalEntryPayload::Savepoint { name } => {
LogEntryType::Savepoint { name }
}
reddb_file::TransactionWalEntryPayload::RollbackToSavepoint { name } => {
LogEntryType::RollbackToSavepoint { name }
}
reddb_file::TransactionWalEntryPayload::Compensate { original_lsn } => {
LogEntryType::Compensate { original_lsn }
}
reddb_file::TransactionWalEntryPayload::End => LogEntryType::End,
}
}
}
#[derive(Debug, Clone)]
pub struct LogEntry {
pub lsn: Lsn,
pub txn_id: TxnId,
pub prev_lsn: Option<Lsn>,
pub timestamp: Timestamp,
pub entry_type: LogEntryType,
}
impl LogEntry {
pub fn new(txn_id: TxnId, prev_lsn: Option<Lsn>, entry_type: LogEntryType) -> Self {
Self {
lsn: 0, txn_id,
prev_lsn,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as Timestamp,
entry_type,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
reddb_file::encode_transaction_wal_record_frame(&reddb_file::TransactionWalRecordFrame {
lsn: self.lsn,
txn_id: self.txn_id,
prev_lsn: self.prev_lsn,
timestamp: self.timestamp,
entry_type_payload: self.entry_type.to_bytes(),
})
}
pub fn from_bytes(data: &[u8]) -> io::Result<Self> {
let frame = reddb_file::decode_transaction_wal_record_frame(data)
.map_err(transaction_wal_frame_error)?;
let (entry_type, consumed) = LogEntryType::from_bytes(&frame.entry_type_payload)?;
if consumed != frame.entry_type_payload.len() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"WAL entry type length mismatch",
));
}
Ok(Self {
lsn: frame.lsn,
txn_id: frame.txn_id,
prev_lsn: frame.prev_lsn,
timestamp: frame.timestamp,
entry_type,
})
}
pub fn serialized_size(&self) -> usize {
reddb_file::transaction_wal_record_encoded_len(self.entry_type.to_bytes().len())
}
}
#[derive(Debug, Clone)]
pub struct WalConfig {
pub path: PathBuf,
pub sync_on_commit: bool,
pub buffer_size: usize,
pub max_file_size: u64,
pub checkpoint_interval: u64,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
path: reddb_file::layout::default_transaction_wal_path(),
sync_on_commit: true,
buffer_size: 64 * 1024, max_file_size: 100 * 1024 * 1024, checkpoint_interval: 1000,
}
}
}
impl WalConfig {
pub fn with_path<P: AsRef<Path>>(path: P) -> Self {
Self {
path: path.as_ref().to_path_buf(),
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct WalStats {
pub entries_written: u64,
pub bytes_written: u64,
pub syncs: u64,
pub checkpoints: u64,
pub file_size: u64,
}
pub struct TransactionLog {
config: WalConfig,
next_lsn: AtomicU64,
file: Option<Mutex<BufWriter<File>>>,
buffer: RwLock<VecDeque<LogEntry>>,
txn_prev_lsn: RwLock<std::collections::HashMap<TxnId, Lsn>>,
stats: RwLock<WalStats>,
last_checkpoint_lsn: AtomicU64,
}
impl TransactionLog {
pub fn new(config: WalConfig) -> io::Result<Self> {
let file = if config.path.as_os_str().is_empty() {
None
} else {
let f = OpenOptions::new()
.create(true)
.append(true)
.read(true)
.open(&config.path)?;
Some(Mutex::new(BufWriter::with_capacity(config.buffer_size, f)))
};
Ok(Self {
config,
next_lsn: AtomicU64::new(1),
file,
buffer: RwLock::new(VecDeque::new()),
txn_prev_lsn: RwLock::new(std::collections::HashMap::new()),
stats: RwLock::new(WalStats::default()),
last_checkpoint_lsn: AtomicU64::new(0),
})
}
pub fn in_memory() -> Self {
Self {
config: WalConfig {
path: PathBuf::new(),
..Default::default()
},
next_lsn: AtomicU64::new(1),
file: None,
buffer: RwLock::new(VecDeque::new()),
txn_prev_lsn: RwLock::new(std::collections::HashMap::new()),
stats: RwLock::new(WalStats::default()),
last_checkpoint_lsn: AtomicU64::new(0),
}
}
pub fn append(&self, mut entry: LogEntry) -> io::Result<Lsn> {
let lsn = self.next_lsn.fetch_add(1, Ordering::SeqCst);
entry.lsn = lsn;
{
let mut prev_lsns = io_write_guard(&self.txn_prev_lsn, "wal prev_lsn map")?;
entry.prev_lsn = prev_lsns.get(&entry.txn_id).copied();
prev_lsns.insert(entry.txn_id, lsn);
}
let bytes = entry.to_bytes();
if let Some(ref file) = self.file {
let mut writer = io_mutex_guard(file, "wal file")?;
writer.write_all(&(bytes.len() as u32).to_le_bytes())?;
writer.write_all(&bytes)?;
if self.config.sync_on_commit && entry.entry_type.is_commit() {
writer.flush()?;
writer.get_mut().sync_all()?;
let mut stats = io_write_guard(&self.stats, "wal stats")?;
stats.syncs += 1;
}
}
{
let mut buffer = io_write_guard(&self.buffer, "wal buffer")?;
buffer.push_back(entry);
while buffer.len() > 10000 {
buffer.pop_front();
}
}
{
let mut stats = io_write_guard(&self.stats, "wal stats")?;
stats.entries_written += 1;
stats.bytes_written += bytes.len() as u64 + 4;
stats.file_size += bytes.len() as u64 + 4;
}
Ok(lsn)
}
pub fn log_begin(&self, txn_id: TxnId) -> io::Result<Lsn> {
self.append(LogEntry::new(txn_id, None, LogEntryType::Begin))
}
pub fn log_commit(&self, txn_id: TxnId) -> io::Result<Lsn> {
let lsn = self.append(LogEntry::new(txn_id, None, LogEntryType::Commit))?;
{
let mut prev_lsns = io_write_guard(&self.txn_prev_lsn, "wal prev_lsn map")?;
prev_lsns.remove(&txn_id);
}
Ok(lsn)
}
pub fn log_abort(&self, txn_id: TxnId) -> io::Result<Lsn> {
let lsn = self.append(LogEntry::new(txn_id, None, LogEntryType::Abort))?;
{
let mut prev_lsns = io_write_guard(&self.txn_prev_lsn, "wal prev_lsn map")?;
prev_lsns.remove(&txn_id);
}
Ok(lsn)
}
pub fn log_insert(&self, txn_id: TxnId, key: Vec<u8>, value: Vec<u8>) -> io::Result<Lsn> {
self.append(LogEntry::new(
txn_id,
None,
LogEntryType::Insert { key, value },
))
}
pub fn log_update(
&self,
txn_id: TxnId,
key: Vec<u8>,
old_value: Vec<u8>,
new_value: Vec<u8>,
) -> io::Result<Lsn> {
self.append(LogEntry::new(
txn_id,
None,
LogEntryType::Update {
key,
old_value,
new_value,
},
))
}
pub fn log_delete(&self, txn_id: TxnId, key: Vec<u8>, old_value: Vec<u8>) -> io::Result<Lsn> {
self.append(LogEntry::new(
txn_id,
None,
LogEntryType::Delete { key, old_value },
))
}
pub fn log_savepoint(&self, txn_id: TxnId, name: String) -> io::Result<Lsn> {
self.append(LogEntry::new(
txn_id,
None,
LogEntryType::Savepoint { name },
))
}
pub fn checkpoint(&self, active_txns: Vec<TxnId>) -> io::Result<Lsn> {
let lsn = self.append(LogEntry::new(
0, None,
LogEntryType::Checkpoint { active_txns },
))?;
if let Some(ref file) = self.file {
let mut writer = io_mutex_guard(file, "wal file")?;
writer.flush()?;
writer.get_mut().sync_all()?;
}
self.last_checkpoint_lsn.store(lsn, Ordering::SeqCst);
{
let mut stats = io_write_guard(&self.stats, "wal stats")?;
stats.checkpoints += 1;
}
Ok(lsn)
}
pub fn flush(&self) -> io::Result<()> {
if let Some(ref file) = self.file {
let mut writer = io_mutex_guard(file, "wal file")?;
writer.flush()?;
writer.get_mut().sync_all()?;
}
Ok(())
}
pub fn get_txn_entries(&self, txn_id: TxnId) -> Vec<LogEntry> {
let buffer = recover_read_guard(&self.buffer);
buffer
.iter()
.filter(|e| e.txn_id == txn_id)
.cloned()
.collect()
}
pub fn get_entries_since(&self, lsn: Lsn) -> Vec<LogEntry> {
let buffer = recover_read_guard(&self.buffer);
buffer.iter().filter(|e| e.lsn >= lsn).cloned().collect()
}
pub fn current_lsn(&self) -> Lsn {
self.next_lsn.load(Ordering::SeqCst) - 1
}
pub fn last_checkpoint(&self) -> Lsn {
self.last_checkpoint_lsn.load(Ordering::SeqCst)
}
pub fn stats(&self) -> WalStats {
recover_read_guard(&self.stats).clone()
}
pub fn config(&self) -> &WalConfig {
&self.config
}
}
pub struct LogReader {
reader: BufReader<File>,
}
impl LogReader {
pub fn open<P: AsRef<Path>>(path: P) -> io::Result<Self> {
let file = File::open(path)?;
Ok(Self {
reader: BufReader::new(file),
})
}
pub fn read_all(&mut self) -> io::Result<Vec<LogEntry>> {
let mut entries = Vec::new();
loop {
match self.read_entry() {
Ok(entry) => entries.push(entry),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e),
}
}
Ok(entries)
}
pub fn read_entry(&mut self) -> io::Result<LogEntry> {
let mut len_buf = [0u8; 4];
self.reader.read_exact(&mut len_buf)?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut data = vec![0u8; len];
self.reader.read_exact(&mut data)?;
LogEntry::from_bytes(&data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_log_entry_serialize() {
let entry = LogEntry {
lsn: 42,
txn_id: 1,
prev_lsn: Some(40),
timestamp: 1234567890,
entry_type: LogEntryType::Insert {
key: b"key1".to_vec(),
value: b"value1".to_vec(),
},
};
let bytes = entry.to_bytes();
let recovered = LogEntry::from_bytes(&bytes).unwrap();
assert_eq!(recovered.lsn, entry.lsn);
assert_eq!(recovered.txn_id, entry.txn_id);
assert_eq!(recovered.prev_lsn, entry.prev_lsn);
}
#[test]
fn test_in_memory_log() {
let log = TransactionLog::in_memory();
let lsn1 = log.log_begin(1).unwrap();
let lsn2 = log
.log_insert(1, b"key".to_vec(), b"value".to_vec())
.unwrap();
let lsn3 = log.log_commit(1).unwrap();
assert_eq!(lsn1, 1);
assert_eq!(lsn2, 2);
assert_eq!(lsn3, 3);
let entries = log.get_txn_entries(1);
assert_eq!(entries.len(), 3);
}
#[test]
fn test_checkpoint() {
let log = TransactionLog::in_memory();
log.log_begin(1).unwrap();
log.log_begin(2).unwrap();
let cp_lsn = log.checkpoint(vec![1, 2]).unwrap();
assert_eq!(log.last_checkpoint(), cp_lsn);
}
#[test]
fn test_log_entry_types() {
let types = vec![
LogEntryType::Begin,
LogEntryType::Commit,
LogEntryType::Abort,
LogEntryType::Insert {
key: b"k".to_vec(),
value: b"v".to_vec(),
},
LogEntryType::Update {
key: b"k".to_vec(),
old_value: b"old".to_vec(),
new_value: b"new".to_vec(),
},
LogEntryType::Delete {
key: b"k".to_vec(),
old_value: b"v".to_vec(),
},
LogEntryType::Checkpoint {
active_txns: vec![1, 2, 3],
},
LogEntryType::Savepoint {
name: "sp1".to_string(),
},
LogEntryType::End,
];
for t in types {
let bytes = t.to_bytes();
let (recovered, _) = LogEntryType::from_bytes(&bytes).unwrap();
assert_eq!(recovered, t);
}
}
#[test]
fn test_prev_lsn_chain() {
let log = TransactionLog::in_memory();
log.log_begin(1).unwrap(); log.log_insert(1, b"k1".to_vec(), b"v1".to_vec()).unwrap(); log.log_insert(1, b"k2".to_vec(), b"v2".to_vec()).unwrap();
let entries = log.get_txn_entries(1);
assert_eq!(entries[0].prev_lsn, None);
assert_eq!(entries[1].prev_lsn, Some(1));
assert_eq!(entries[2].prev_lsn, Some(2));
}
#[test]
fn test_log_entry_type_rejects_truncated_insert() {
let err = LogEntryType::from_bytes(&[3, 4, 0, 0, 0, b'k'])
.expect_err("truncated insert should fail");
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
#[test]
fn test_log_entry_rejects_truncated_type_payload() {
let entry = LogEntry {
lsn: 7,
txn_id: 9,
prev_lsn: Some(3),
timestamp: 42,
entry_type: LogEntryType::Insert {
key: b"hello".to_vec(),
value: b"world".to_vec(),
},
};
let mut bytes = entry.to_bytes();
bytes.truncate(bytes.len() - 2);
let err = LogEntry::from_bytes(&bytes).expect_err("truncated entry should fail");
assert_eq!(err.kind(), io::ErrorKind::UnexpectedEof);
}
}