use alloc::collections::BTreeMap;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use super::types::{
RecoveryStats, TXN_LOG_MAGIC, TXN_LOG_VERSION, TxnError, TxnId, TxnLogEntry, TxnLogType,
TxnOperation, TxnResultType, TxnState,
};
#[derive(Debug, Clone)]
pub struct TxnLogHeader {
pub magic: u64,
pub version: u32,
pub log_id: u64,
pub created_at: u64,
pub last_lsn: u64,
pub checkpoint_lsn: u64,
pub checksum: u64,
}
impl TxnLogHeader {
pub fn new(log_id: u64, timestamp: u64) -> Self {
let mut header = Self {
magic: TXN_LOG_MAGIC,
version: TXN_LOG_VERSION,
log_id,
created_at: timestamp,
last_lsn: 0,
checkpoint_lsn: 0,
checksum: 0,
};
header.update_checksum();
header
}
pub fn update_checksum(&mut self) {
self.checksum = self.magic
^ self.version as u64
^ self.log_id
^ self.created_at
^ self.last_lsn
^ self.checkpoint_lsn;
}
pub fn verify(&self) -> bool {
let expected = self.magic
^ self.version as u64
^ self.log_id
^ self.created_at
^ self.last_lsn
^ self.checkpoint_lsn;
expected == self.checksum
}
}
#[derive(Debug)]
pub struct TxnLog {
dataset: String,
header: TxnLogHeader,
entries: Vec<TxnLogEntry>,
next_lsn: u64,
active_txns: BTreeMap<TxnId, TxnState>,
pending_ops: BTreeMap<TxnId, Vec<TxnOperation>>,
max_buffer_entries: usize,
sync_on_write: bool,
}
impl TxnLog {
pub fn new(dataset: &str, log_id: u64, timestamp: u64) -> Self {
Self {
dataset: dataset.to_string(),
header: TxnLogHeader::new(log_id, timestamp),
entries: Vec::new(),
next_lsn: 1,
active_txns: BTreeMap::new(),
pending_ops: BTreeMap::new(),
max_buffer_entries: 1000,
sync_on_write: true,
}
}
pub fn dataset(&self) -> &str {
&self.dataset
}
pub fn log_id(&self) -> u64 {
self.header.log_id
}
pub fn current_lsn(&self) -> u64 {
self.next_lsn.saturating_sub(1)
}
pub fn next_lsn(&self) -> u64 {
self.next_lsn
}
pub fn checkpoint_lsn(&self) -> u64 {
self.header.checkpoint_lsn
}
pub fn buffered_entries(&self) -> usize {
self.entries.len()
}
pub fn active_txn_count(&self) -> usize {
self.active_txns.len()
}
pub fn has_txn(&self, txn_id: TxnId) -> bool {
self.active_txns.contains_key(&txn_id)
}
pub fn txn_state(&self, txn_id: TxnId) -> Option<TxnState> {
self.active_txns.get(&txn_id).copied()
}
pub fn set_sync_on_write(&mut self, sync: bool) {
self.sync_on_write = sync;
}
pub fn set_max_buffer_entries(&mut self, max: usize) {
self.max_buffer_entries = max;
}
pub fn begin(&mut self, txn_id: TxnId, timestamp: u64) -> TxnResultType<()> {
if self.active_txns.contains_key(&txn_id) {
return Err(TxnError::AlreadyExists(txn_id));
}
let lsn = self.allocate_lsn();
let mut entry = TxnLogEntry::begin(txn_id, lsn, timestamp);
entry.calculate_checksum();
self.entries.push(entry);
self.active_txns.insert(txn_id, TxnState::Active);
self.pending_ops.insert(txn_id, Vec::new());
Ok(())
}
pub fn log_operation(
&mut self,
txn_id: TxnId,
op: TxnOperation,
timestamp: u64,
) -> TxnResultType<u32> {
let state = self
.active_txns
.get(&txn_id)
.copied()
.ok_or(TxnError::NotFound(txn_id))?;
if state != TxnState::Active {
return Err(TxnError::InvalidState {
txn_id,
current: state,
expected: &[TxnState::Active],
});
}
let op_index = self
.pending_ops
.get(&txn_id)
.map(|ops| ops.len() as u32)
.unwrap_or(0);
let lsn = self.allocate_lsn();
let mut entry = TxnLogEntry::operation(txn_id, op_index, lsn, timestamp, op.clone());
entry.calculate_checksum();
self.entries.push(entry);
if let Some(ops) = self.pending_ops.get_mut(&txn_id) {
ops.push(op);
}
Ok(op_index)
}
pub fn prepare(&mut self, txn_id: TxnId, timestamp: u64) -> TxnResultType<()> {
let state = self
.active_txns
.get(&txn_id)
.copied()
.ok_or(TxnError::NotFound(txn_id))?;
if state != TxnState::Active {
return Err(TxnError::InvalidState {
txn_id,
current: state,
expected: &[TxnState::Active],
});
}
let lsn = self.allocate_lsn();
let mut entry = TxnLogEntry::prepare(txn_id, lsn, timestamp);
entry.calculate_checksum();
self.entries.push(entry);
self.active_txns.insert(txn_id, TxnState::Prepared);
Ok(())
}
pub fn commit(&mut self, txn_id: TxnId, timestamp: u64) -> TxnResultType<Vec<TxnOperation>> {
let state = self
.active_txns
.get(&txn_id)
.copied()
.ok_or(TxnError::NotFound(txn_id))?;
if !matches!(state, TxnState::Active | TxnState::Prepared) {
return Err(TxnError::InvalidState {
txn_id,
current: state,
expected: &[TxnState::Active, TxnState::Prepared],
});
}
let lsn = self.allocate_lsn();
let mut entry = TxnLogEntry::commit(txn_id, lsn, timestamp);
entry.calculate_checksum();
self.entries.push(entry);
self.active_txns.insert(txn_id, TxnState::Committed);
let ops = self.pending_ops.remove(&txn_id).unwrap_or_default();
self.active_txns.remove(&txn_id);
Ok(ops)
}
pub fn abort(&mut self, txn_id: TxnId, timestamp: u64) -> TxnResultType<Vec<TxnOperation>> {
let state = self
.active_txns
.get(&txn_id)
.copied()
.ok_or(TxnError::NotFound(txn_id))?;
if state == TxnState::Committed {
return Err(TxnError::InvalidState {
txn_id,
current: state,
expected: &[TxnState::Active, TxnState::Prepared],
});
}
let lsn = self.allocate_lsn();
let mut entry = TxnLogEntry::abort(txn_id, lsn, timestamp);
entry.calculate_checksum();
self.entries.push(entry);
self.active_txns.insert(txn_id, TxnState::Aborted);
let ops = self.pending_ops.remove(&txn_id).unwrap_or_default();
self.active_txns.remove(&txn_id);
Ok(ops)
}
pub fn log_rollback(
&mut self,
txn_id: TxnId,
op_index: u32,
op: TxnOperation,
timestamp: u64,
) -> TxnResultType<()> {
let lsn = self.allocate_lsn();
let mut entry = TxnLogEntry::rollback(txn_id, op_index, lsn, timestamp, op);
entry.calculate_checksum();
self.entries.push(entry);
Ok(())
}
pub fn checkpoint(&mut self, timestamp: u64) -> TxnResultType<u64> {
if !self.active_txns.is_empty() {
return Ok(self.header.checkpoint_lsn);
}
let lsn = self.current_lsn();
self.header.checkpoint_lsn = lsn;
self.header.last_lsn = lsn;
self.header.update_checksum();
self.entries.retain(|e| e.lsn > lsn);
Ok(lsn)
}
pub fn flush(&mut self) -> TxnResultType<usize> {
let count = self.entries.len();
self.header.last_lsn = self.current_lsn();
self.header.update_checksum();
Ok(count)
}
pub fn get_recovery_txns(&self) -> Vec<(TxnId, TxnState, Vec<TxnOperation>)> {
let mut result = Vec::new();
for (&txn_id, &state) in &self.active_txns {
if state.is_recoverable() {
let ops = self.pending_ops.get(&txn_id).cloned().unwrap_or_default();
result.push((txn_id, state, ops));
}
}
result
}
pub fn replay_from(&self, from_lsn: u64) -> impl Iterator<Item = &TxnLogEntry> {
self.entries.iter().filter(move |e| e.lsn >= from_lsn)
}
pub fn load_entries(&mut self, entries: Vec<TxnLogEntry>) -> TxnResultType<RecoveryStats> {
let mut stats = RecoveryStats::new();
for entry in entries {
if !entry.verify_checksum() {
stats.errors += 1;
continue;
}
stats.log_entries_processed += 1;
match entry.entry_type {
TxnLogType::Begin => {
self.active_txns.insert(entry.txn_id, TxnState::Active);
self.pending_ops.insert(entry.txn_id, Vec::new());
}
TxnLogType::Operation => {
if let Some(op) = entry.operation {
if let Some(ops) = self.pending_ops.get_mut(&entry.txn_id) {
ops.push(op);
}
}
}
TxnLogType::Prepare => {
if let Some(state) = self.active_txns.get_mut(&entry.txn_id) {
*state = TxnState::Prepared;
}
}
TxnLogType::Commit => {
self.active_txns.remove(&entry.txn_id);
self.pending_ops.remove(&entry.txn_id);
stats.txns_recovered += 1;
}
TxnLogType::Abort => {
self.active_txns.remove(&entry.txn_id);
self.pending_ops.remove(&entry.txn_id);
stats.txns_rolled_back += 1;
}
TxnLogType::Rollback => {
stats.ops_undone += 1;
}
TxnLogType::Checkpoint => {
self.header.checkpoint_lsn = entry.lsn;
}
}
if entry.lsn >= self.next_lsn {
self.next_lsn = entry.lsn + 1;
}
}
Ok(stats)
}
fn allocate_lsn(&mut self) -> u64 {
let lsn = self.next_lsn;
self.next_lsn += 1;
lsn
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
fn test_log() -> TxnLog {
TxnLog::new("test/pool", 1, 1000)
}
#[test]
fn test_log_header() {
let header = TxnLogHeader::new(1, 1000);
assert_eq!(header.magic, TXN_LOG_MAGIC);
assert_eq!(header.version, TXN_LOG_VERSION);
assert!(header.verify());
}
#[test]
fn test_begin_transaction() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
assert!(log.has_txn(txn_id));
assert_eq!(log.txn_state(txn_id), Some(TxnState::Active));
}
#[test]
fn test_begin_duplicate() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
assert!(log.begin(txn_id, 1001).is_err());
}
#[test]
fn test_log_operation() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
let op = TxnOperation::Create {
path: "/test.txt".into(),
content: vec![1, 2, 3],
mode: 0o644,
};
let op_index = log.log_operation(txn_id, op, 1001).unwrap();
assert_eq!(op_index, 0);
}
#[test]
fn test_prepare() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
log.prepare(txn_id, 1001).unwrap();
assert_eq!(log.txn_state(txn_id), Some(TxnState::Prepared));
}
#[test]
fn test_commit() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
let op = TxnOperation::Create {
path: "/test.txt".into(),
content: vec![1, 2, 3],
mode: 0o644,
};
log.log_operation(txn_id, op, 1001).unwrap();
let ops = log.commit(txn_id, 1002).unwrap();
assert_eq!(ops.len(), 1);
assert!(!log.has_txn(txn_id));
}
#[test]
fn test_abort() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
let op = TxnOperation::Create {
path: "/test.txt".into(),
content: vec![1, 2, 3],
mode: 0o644,
};
log.log_operation(txn_id, op, 1001).unwrap();
let ops = log.abort(txn_id, 1002).unwrap();
assert_eq!(ops.len(), 1);
assert!(!log.has_txn(txn_id));
}
#[test]
fn test_lsn_allocation() {
let mut log = test_log();
assert_eq!(log.next_lsn(), 1);
log.begin(TxnId::new(1), 1000).unwrap();
assert_eq!(log.next_lsn(), 2);
log.log_operation(
TxnId::new(1),
TxnOperation::Mkdir {
path: "/dir".into(),
mode: 0o755,
},
1001,
)
.unwrap();
assert_eq!(log.next_lsn(), 3);
}
#[test]
fn test_checkpoint() {
let mut log = test_log();
let txn_id = TxnId::new(1);
log.begin(txn_id, 1000).unwrap();
log.commit(txn_id, 1001).unwrap();
let cp_lsn = log.checkpoint(1002).unwrap();
assert!(cp_lsn > 0);
assert_eq!(log.checkpoint_lsn(), cp_lsn);
}
#[test]
fn test_flush() {
let mut log = test_log();
log.begin(TxnId::new(1), 1000).unwrap();
log.log_operation(
TxnId::new(1),
TxnOperation::Mkdir {
path: "/dir".into(),
mode: 0o755,
},
1001,
)
.unwrap();
let count = log.flush().unwrap();
assert_eq!(count, 2);
}
#[test]
fn test_recovery_txns() {
let mut log = test_log();
log.begin(TxnId::new(1), 1000).unwrap();
log.log_operation(
TxnId::new(1),
TxnOperation::Mkdir {
path: "/dir".into(),
mode: 0o755,
},
1001,
)
.unwrap();
log.begin(TxnId::new(2), 1002).unwrap();
log.prepare(TxnId::new(2), 1003).unwrap();
let recovery = log.get_recovery_txns();
assert_eq!(recovery.len(), 2);
}
#[test]
fn test_load_entries() {
let mut log = test_log();
let entries = vec![
{
let mut e = TxnLogEntry::begin(TxnId::new(1), 1, 1000);
e.calculate_checksum();
e
},
{
let mut e = TxnLogEntry::operation(
TxnId::new(1),
0,
2,
1001,
TxnOperation::Mkdir {
path: "/dir".into(),
mode: 0o755,
},
);
e.calculate_checksum();
e
},
{
let mut e = TxnLogEntry::commit(TxnId::new(1), 3, 1002);
e.calculate_checksum();
e
},
];
let stats = log.load_entries(entries).unwrap();
assert_eq!(stats.log_entries_processed, 3);
assert_eq!(stats.txns_recovered, 1);
}
#[test]
fn test_replay_from() {
let mut log = test_log();
log.begin(TxnId::new(1), 1000).unwrap();
log.log_operation(
TxnId::new(1),
TxnOperation::Mkdir {
path: "/dir".into(),
mode: 0o755,
},
1001,
)
.unwrap();
log.commit(TxnId::new(1), 1002).unwrap();
let entries: Vec<_> = log.replay_from(2).collect();
assert_eq!(entries.len(), 2); }
}