use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use super::config::IngestionSafetyLevel;
use super::write_buffer::{CommitRequest, WriteOp};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WalRecord {
Commit {
txn_id: u64,
timestamp: u64,
operations: Vec<WalOp>,
},
Prepare {
txn_id: u64,
partition: u16,
operations: Vec<WalOp>,
},
CommitPrepared {
txn_id: u64,
},
RollbackPrepared {
txn_id: u64,
},
Checkpoint {
timestamp: u64,
row_id_state: Vec<(String, u64)>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WalOp {
Insert {
table: String,
row_id: u64,
data: Vec<u8>,
},
Update {
table: String,
row_id: u64,
data: Vec<u8>,
},
Delete {
table: String,
row_id: u64,
},
}
impl From<&WriteOp> for WalOp {
fn from(op: &WriteOp) -> Self {
match op {
WriteOp::Insert { table, row_id, data } => WalOp::Insert {
table: table.clone(),
row_id: *row_id,
data: data.clone(),
},
WriteOp::Update { table, row_id, data } => WalOp::Update {
table: table.clone(),
row_id: *row_id,
data: data.clone(),
},
WriteOp::Delete { table, row_id } => WalOp::Delete {
table: table.clone(),
row_id: *row_id,
},
}
}
}
pub struct WalPartition {
id: u16,
path: PathBuf,
writer: Mutex<BufWriter<File>>,
lsn: AtomicU64,
unflushed_bytes: AtomicU64,
unflushed_records: AtomicU64,
}
impl WalPartition {
pub fn open(path: impl AsRef<Path>, id: u16) -> std::io::Result<Self> {
let path = path.as_ref().to_path_buf();
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&path)?;
let lsn = Self::recover_lsn(&path)?;
Ok(Self {
id,
path,
writer: Mutex::new(BufWriter::with_capacity(64 * 1024, file)),
lsn: AtomicU64::new(lsn),
unflushed_bytes: AtomicU64::new(0),
unflushed_records: AtomicU64::new(0),
})
}
fn recover_lsn(path: &Path) -> std::io::Result<u64> {
if path.exists() {
let metadata = std::fs::metadata(path)?;
if metadata.len() > 0 {
return Ok(metadata.len());
}
}
Ok(0)
}
pub fn append(&self, record: &WalRecord) -> std::io::Result<u64> {
let serialized = bincode::serialize(record)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let record_len = serialized.len() as u32;
let mut writer = self.writer.lock();
writer.write_all(&record_len.to_le_bytes())?;
writer.write_all(&serialized)?;
let lsn = self.lsn.fetch_add(1, Ordering::SeqCst) + 1;
self.unflushed_bytes.fetch_add(serialized.len() as u64 + 4, Ordering::Relaxed);
self.unflushed_records.fetch_add(1, Ordering::Relaxed);
Ok(lsn)
}
pub fn sync(&self) -> std::io::Result<()> {
let mut writer = self.writer.lock();
writer.flush()?;
writer.get_ref().sync_all()?;
self.unflushed_bytes.store(0, Ordering::Relaxed);
self.unflushed_records.store(0, Ordering::Relaxed);
Ok(())
}
pub fn unflushed_stats(&self) -> (u64, u64) {
(
self.unflushed_bytes.load(Ordering::Relaxed),
self.unflushed_records.load(Ordering::Relaxed),
)
}
pub fn lsn(&self) -> u64 {
self.lsn.load(Ordering::Acquire)
}
}
pub struct PartitionedWalManager {
partitions: Vec<Arc<WalPartition>>,
base_path: PathBuf,
safety_level: IngestionSafetyLevel,
global_timestamp: AtomicU64,
pending_2pc: RwLock<HashMap<u64, TwoPcState>>,
enabled: AtomicBool,
}
#[derive(Debug)]
struct TwoPcState {
partitions: HashSet<u16>,
prepared: HashSet<u16>,
operations: HashMap<u16, Vec<WalOp>>,
}
impl PartitionedWalManager {
pub fn new(
base_path: impl AsRef<Path>,
partition_count: usize,
safety_level: IngestionSafetyLevel,
) -> std::io::Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
std::fs::create_dir_all(&base_path)?;
let enabled = safety_level.use_wal();
let mut partitions = Vec::with_capacity(partition_count);
if enabled {
for i in 0..partition_count {
let path = base_path.join(format!("wal_{:04}.log", i));
let partition = WalPartition::open(&path, i as u16)?;
partitions.push(Arc::new(partition));
}
}
Ok(Self {
partitions,
base_path,
safety_level,
global_timestamp: AtomicU64::new(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64,
),
pending_2pc: RwLock::new(HashMap::new()),
enabled: AtomicBool::new(enabled),
})
}
pub fn is_enabled(&self) -> bool {
self.enabled.load(Ordering::Acquire)
}
fn partition_for(&self, table: &str, row_id: u64) -> u16 {
if self.partitions.is_empty() {
return 0;
}
use std::hash::{Hash, Hasher};
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
table.hash(&mut hasher);
row_id.hash(&mut hasher);
let hash = hasher.finish();
(hash % self.partitions.len() as u64) as u16
}
pub fn next_timestamp(&self) -> u64 {
self.global_timestamp.fetch_add(1, Ordering::AcqRel) + 1
}
pub fn commit_single(
&self,
txn_id: u64,
operations: &[WriteOp],
sync: bool,
) -> std::io::Result<u64> {
if !self.is_enabled() || self.partitions.is_empty() {
return Ok(self.next_timestamp());
}
let partition_id = if let Some(first) = operations.first() {
self.partition_for(first.table(), first.row_id())
} else {
return Ok(self.next_timestamp());
};
let timestamp = self.next_timestamp();
let wal_ops: Vec<WalOp> = operations.iter().map(|op| op.into()).collect();
let record = WalRecord::Commit {
txn_id,
timestamp,
operations: wal_ops,
};
let partition = self.partitions.get(partition_id as usize)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, format!("Invalid partition ID: {}", partition_id)))?;
partition.append(&record)?;
if sync {
partition.sync()?;
}
Ok(timestamp)
}
pub fn write_commit(
&self,
txn_id: u64,
timestamp: u64,
operations: Vec<WalOp>,
) -> std::io::Result<()> {
if !self.is_enabled() || self.partitions.is_empty() {
return Ok(());
}
let partition_id = if let Some(first) = operations.first() {
let (table, row_id) = match first {
WalOp::Insert { table, row_id, .. } => (table.as_str(), *row_id),
WalOp::Update { table, row_id, .. } => (table.as_str(), *row_id),
WalOp::Delete { table, row_id } => (table.as_str(), *row_id),
};
self.partition_for(table, row_id)
} else {
return Ok(());
};
let record = WalRecord::Commit {
txn_id,
timestamp,
operations,
};
let partition = self.partitions.get(partition_id as usize)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, format!("Invalid partition ID: {}", partition_id)))?;
partition.append(&record)?;
Ok(())
}
pub fn prepare_2pc(
&self,
txn_id: u64,
operations: &[WriteOp],
) -> std::io::Result<()> {
if !self.is_enabled() || self.partitions.is_empty() {
return Ok(());
}
let mut partition_ops: HashMap<u16, Vec<WalOp>> = HashMap::new();
for op in operations {
let partition_id = self.partition_for(op.table(), op.row_id());
partition_ops
.entry(partition_id)
.or_insert_with(Vec::new)
.push(op.into());
}
let partitions: HashSet<u16> = partition_ops.keys().copied().collect();
let state = TwoPcState {
partitions: partitions.clone(),
prepared: HashSet::new(),
operations: partition_ops.clone(),
};
self.pending_2pc.write().insert(txn_id, state);
for (partition_id, ops) in partition_ops {
let record = WalRecord::Prepare {
txn_id,
partition: partition_id,
operations: ops,
};
let partition = self.partitions.get(partition_id as usize)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, format!("Invalid partition ID: {}", partition_id)))?;
partition.append(&record)?;
partition.sync()?;
if let Some(state) = self.pending_2pc.write().get_mut(&txn_id) {
state.prepared.insert(partition_id);
}
}
Ok(())
}
pub fn commit_2pc(&self, txn_id: u64) -> std::io::Result<u64> {
let timestamp = self.next_timestamp();
if !self.is_enabled() || self.partitions.is_empty() {
self.pending_2pc.write().remove(&txn_id);
return Ok(timestamp);
}
let partitions = {
let pending = self.pending_2pc.read();
pending.get(&txn_id).map(|s| s.partitions.clone())
};
if let Some(partitions) = partitions {
let record = WalRecord::CommitPrepared { txn_id };
for partition_id in partitions {
let partition = self.partitions.get(partition_id as usize)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, format!("Invalid partition ID: {}", partition_id)))?;
partition.append(&record)?;
}
for partition in &self.partitions {
partition.sync()?;
}
}
self.pending_2pc.write().remove(&txn_id);
Ok(timestamp)
}
pub fn rollback_2pc(&self, txn_id: u64) -> std::io::Result<()> {
if !self.is_enabled() || self.partitions.is_empty() {
self.pending_2pc.write().remove(&txn_id);
return Ok(());
}
let partitions = {
let pending = self.pending_2pc.read();
pending.get(&txn_id).map(|s| s.prepared.clone())
};
if let Some(partitions) = partitions {
let record = WalRecord::RollbackPrepared { txn_id };
for partition_id in partitions {
let partition = self.partitions.get(partition_id as usize)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, format!("Invalid partition ID: {}", partition_id)))?;
partition.append(&record)?;
}
}
self.pending_2pc.write().remove(&txn_id);
Ok(())
}
pub fn commit_batch(
&self,
requests: &[CommitRequest],
sync: bool,
) -> std::io::Result<u64> {
if !self.is_enabled() || self.partitions.is_empty() {
return Ok(self.next_timestamp());
}
let timestamp = self.next_timestamp();
let mut partition_records: HashMap<u16, Vec<WalRecord>> = HashMap::new();
for request in requests {
let mut partitions_used = HashSet::new();
for op in &request.operations {
let part = self.partition_for(op.table(), op.row_id());
partitions_used.insert(part);
}
if partitions_used.len() == 1 {
let partition_id = match partitions_used.iter().next() {
Some(&id) => id,
None => return Err(std::io::Error::new(std::io::ErrorKind::Other, "No partition found")),
};
let wal_ops: Vec<WalOp> = request.operations.iter().map(|op| op.into()).collect();
let record = WalRecord::Commit {
txn_id: request.txn_id,
timestamp,
operations: wal_ops,
};
partition_records
.entry(partition_id)
.or_insert_with(Vec::new)
.push(record);
} else {
let wal_ops: Vec<WalOp> = request.operations.iter().map(|op| op.into()).collect();
for &partition_id in &partitions_used {
let part_ops: Vec<WalOp> = request
.operations
.iter()
.filter(|op| self.partition_for(op.table(), op.row_id()) == partition_id)
.map(|op| op.into())
.collect();
let prepare = WalRecord::Prepare {
txn_id: request.txn_id,
partition: partition_id,
operations: part_ops,
};
partition_records
.entry(partition_id)
.or_insert_with(Vec::new)
.push(prepare);
}
for &partition_id in &partitions_used {
let commit = WalRecord::CommitPrepared { txn_id: request.txn_id };
partition_records
.entry(partition_id)
.or_insert_with(Vec::new)
.push(commit);
}
}
}
for (partition_id, records) in partition_records {
let partition = self.partitions.get(partition_id as usize)
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, format!("Invalid partition ID: {}", partition_id)))?;
for record in records {
partition.append(&record)?;
}
}
if sync {
for partition in &self.partitions {
partition.sync()?;
}
}
Ok(timestamp)
}
pub fn checkpoint(&self, row_id_state: Vec<(String, u64)>) -> std::io::Result<()> {
if !self.is_enabled() || self.partitions.is_empty() {
return Ok(());
}
let record = WalRecord::Checkpoint {
timestamp: self.global_timestamp.load(Ordering::Acquire),
row_id_state,
};
for partition in &self.partitions {
partition.append(&record)?;
partition.sync()?;
}
Ok(())
}
pub fn sync_all(&self) -> std::io::Result<()> {
for partition in &self.partitions {
partition.sync()?;
}
Ok(())
}
pub fn unflushed_bytes(&self) -> u64 {
self.partitions
.iter()
.map(|p| p.unflushed_stats().0)
.sum()
}
}
pub struct WalRecovery {
base_path: PathBuf,
}
impl WalRecovery {
pub fn new(base_path: impl AsRef<Path>) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
}
}
pub fn recover_row_ids(&self) -> std::io::Result<HashMap<String, u64>> {
let mut row_ids: HashMap<String, u64> = HashMap::new();
let entries = std::fs::read_dir(&self.base_path)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().map(|e| e == "log").unwrap_or(false) {
self.scan_wal_file(&path, &mut row_ids)?;
}
}
Ok(row_ids)
}
fn scan_wal_file(
&self,
path: &Path,
row_ids: &mut HashMap<String, u64>,
) -> std::io::Result<()> {
use std::io::Read;
let mut file = File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let mut pos = 0;
while pos + 4 <= buffer.len() {
let len_bytes: [u8; 4] = match buffer.get(pos..pos + 4) {
Some(slice) => match slice.try_into() {
Ok(arr) => arr,
Err(_) => break,
},
None => break,
};
let len = u32::from_le_bytes(len_bytes) as usize;
pos += 4;
if pos + len > buffer.len() {
break; }
if let Some(record_slice) = buffer.get(pos..pos + len) {
if let Ok(record) = bincode::deserialize::<WalRecord>(record_slice) {
self.process_record_for_recovery(record, row_ids);
}
}
pos += len;
}
Ok(())
}
fn process_record_for_recovery(
&self,
record: WalRecord,
row_ids: &mut HashMap<String, u64>,
) {
match record {
WalRecord::Commit { operations, .. } => {
for op in operations {
self.update_row_id_from_op(&op, row_ids);
}
}
WalRecord::Prepare { operations, .. } => {
for op in operations {
self.update_row_id_from_op(&op, row_ids);
}
}
WalRecord::Checkpoint { row_id_state, .. } => {
for (table, max_id) in row_id_state {
let entry = row_ids.entry(table).or_insert(0);
*entry = (*entry).max(max_id);
}
}
_ => {}
}
}
fn update_row_id_from_op(&self, op: &WalOp, row_ids: &mut HashMap<String, u64>) {
let (table, row_id) = match op {
WalOp::Insert { table, row_id, .. } => (table, *row_id),
WalOp::Update { table, row_id, .. } => (table, *row_id),
WalOp::Delete { table, row_id } => (table, *row_id),
};
let entry = row_ids.entry(table.clone()).or_insert(0);
*entry = (*entry).max(row_id);
}
pub fn recover_2pc(&self) -> std::io::Result<(Vec<u64>, Vec<u64>)> {
let mut prepared: HashMap<u64, HashSet<u16>> = HashMap::new();
let mut committed: HashSet<u64> = HashSet::new();
let mut rolled_back: HashSet<u64> = HashSet::new();
let entries = std::fs::read_dir(&self.base_path)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.extension().map(|e| e == "log").unwrap_or(false) {
self.scan_2pc_state(&path, &mut prepared, &mut committed, &mut rolled_back)?;
}
}
let to_rollback: Vec<u64> = prepared
.keys()
.filter(|txn_id| !committed.contains(txn_id) && !rolled_back.contains(txn_id))
.copied()
.collect();
let to_complete: Vec<u64> = prepared
.keys()
.filter(|txn_id| committed.contains(txn_id))
.copied()
.collect();
Ok((to_complete, to_rollback))
}
fn scan_2pc_state(
&self,
path: &Path,
prepared: &mut HashMap<u64, HashSet<u16>>,
committed: &mut HashSet<u64>,
rolled_back: &mut HashSet<u64>,
) -> std::io::Result<()> {
use std::io::Read;
let mut file = File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let mut pos = 0;
while pos + 4 <= buffer.len() {
let len_bytes: [u8; 4] = match buffer.get(pos..pos + 4) {
Some(slice) => match slice.try_into() {
Ok(arr) => arr,
Err(_) => break,
},
None => break,
};
let len = u32::from_le_bytes(len_bytes) as usize;
pos += 4;
if pos + len > buffer.len() {
break;
}
if let Some(record_slice) = buffer.get(pos..pos + len) {
if let Ok(record) = bincode::deserialize::<WalRecord>(record_slice) {
match record {
WalRecord::Prepare { txn_id, partition, .. } => {
prepared
.entry(txn_id)
.or_insert_with(HashSet::new)
.insert(partition);
}
WalRecord::CommitPrepared { txn_id } => {
committed.insert(txn_id);
}
WalRecord::RollbackPrepared { txn_id } => {
rolled_back.insert(txn_id);
}
_ => {}
}
}
}
pos += len;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_single_partition_commit() {
let dir = tempdir().unwrap();
let manager = PartitionedWalManager::new(
dir.path(),
1,
IngestionSafetyLevel::Full,
).unwrap();
let ops = vec![
WriteOp::Insert {
table: "test".to_string(),
row_id: 1,
data: vec![1, 2, 3],
},
];
let ts = manager.commit_single(1, &ops, true).unwrap();
assert!(ts > 0);
}
#[test]
fn test_disabled_wal() {
let dir = tempdir().unwrap();
let manager = PartitionedWalManager::new(
dir.path(),
1,
IngestionSafetyLevel::Unsafe {
disable_wal: true,
checkpoint_interval_secs: 0,
},
).unwrap();
assert!(!manager.is_enabled());
}
}