use crate::error::{DbxError, DbxResult};
use crate::wal::WalRecord;
use dashmap::DashMap;
use rayon::prelude::*;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
pub struct PartitionedWalWriter {
partitions: DashMap<String, Vec<WalRecord>>,
wal_dir: PathBuf,
sequence: AtomicU64,
flush_threshold: usize,
}
impl PartitionedWalWriter {
pub fn new(wal_dir: PathBuf, flush_threshold: usize) -> DbxResult<Self> {
if !wal_dir.exists() {
std::fs::create_dir_all(&wal_dir).map_err(|source| DbxError::Io { source })?;
}
Ok(Self {
partitions: DashMap::new(),
wal_dir,
sequence: AtomicU64::new(0),
flush_threshold,
})
}
pub fn with_defaults(wal_dir: PathBuf) -> DbxResult<Self> {
Self::new(wal_dir, 100)
}
pub fn append(&self, record: WalRecord) -> DbxResult<u64> {
let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
let table = Self::extract_table(&record);
let mut partition = self.partitions.entry(table).or_default();
partition.push(record);
if partition.len() >= self.flush_threshold {
let records = std::mem::take(&mut *partition);
drop(partition); self.flush_records(&Self::extract_table(&records[0]), &records)?;
}
Ok(seq)
}
pub fn append_batch(&self, records: Vec<WalRecord>) -> DbxResult<Vec<u64>> {
let sequences: Vec<u64> = records
.iter()
.map(|_| self.sequence.fetch_add(1, Ordering::SeqCst))
.collect();
let mut grouped: std::collections::HashMap<String, Vec<WalRecord>> =
std::collections::HashMap::new();
for record in records {
let table = Self::extract_table(&record);
grouped.entry(table).or_default().push(record);
}
let results: Vec<DbxResult<()>> = grouped
.into_par_iter()
.map(|(table, partition_records)| {
let mut partition = self.partitions.entry(table.clone()).or_default();
partition.extend(partition_records);
if partition.len() >= self.flush_threshold {
let records = std::mem::take(&mut *partition);
drop(partition);
self.flush_records(&table, &records)?;
}
Ok(())
})
.collect();
for result in results {
result?;
}
Ok(sequences)
}
pub fn flush_all(&self) -> DbxResult<usize> {
let tables: Vec<String> = self.partitions.iter().map(|e| e.key().clone()).collect();
let flushed: Vec<DbxResult<usize>> = tables
.par_iter()
.map(|table| {
if let Some(mut partition) = self.partitions.get_mut(table) {
if partition.is_empty() {
return Ok(0);
}
let records = std::mem::take(&mut *partition);
let count = records.len();
drop(partition);
self.flush_records(table, &records)?;
Ok(count)
} else {
Ok(0)
}
})
.collect();
let mut total = 0;
for result in flushed {
total += result?;
}
Ok(total)
}
pub fn partition_count(&self) -> usize {
self.partitions.len()
}
pub fn buffered_count(&self) -> usize {
self.partitions.iter().map(|e| e.value().len()).sum()
}
pub fn current_sequence(&self) -> u64 {
self.sequence.load(Ordering::SeqCst)
}
fn extract_table(record: &WalRecord) -> String {
match record {
WalRecord::Insert { table, .. } => table.clone(),
WalRecord::Delete { table, .. } => table.clone(),
WalRecord::Batch { table, .. } => table.clone(),
WalRecord::Checkpoint { .. } => "__checkpoint__".to_string(),
WalRecord::Commit { .. } => "__tx__".to_string(),
WalRecord::Rollback { .. } => "__tx__".to_string(),
}
}
fn flush_records(&self, table: &str, records: &[WalRecord]) -> DbxResult<()> {
let safe_name = table.replace(['/', '\\', ':'], "_");
let path = self.wal_dir.join(format!("{safe_name}.wal"));
let serialized: Vec<u8> = records
.iter()
.flat_map(|r| {
let mut buf = bincode::serialize(r).unwrap_or_default();
let len = buf.len() as u32;
let mut frame = len.to_le_bytes().to_vec();
frame.append(&mut buf);
frame
})
.collect();
use std::io::Write;
let mut file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(&path)
.map_err(|source| DbxError::Io { source })?;
file.write_all(&serialized)
.map_err(|source| DbxError::Io { source })?;
file.flush().map_err(|source| DbxError::Io { source })?;
Ok(())
}
}
pub struct ParallelCheckpointManager {
wal_dir: PathBuf,
}
impl ParallelCheckpointManager {
pub fn new(wal_dir: PathBuf) -> Self {
Self { wal_dir }
}
pub fn checkpoint_tables(&self, tables: &[String]) -> DbxResult<usize> {
let results: Vec<DbxResult<()>> = tables
.par_iter()
.map(|table| {
let safe_name = table.replace(['/', '\\', ':'], "_");
let wal_path = self.wal_dir.join(format!("{safe_name}.wal"));
let checkpoint_path = self.wal_dir.join(format!("{safe_name}.checkpoint"));
if wal_path.exists() {
std::fs::rename(&wal_path, &checkpoint_path)
.map_err(|source| DbxError::Io { source })?;
}
Ok(())
})
.collect();
let mut count = 0;
for result in results {
result?;
count += 1;
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn insert_record(table: &str, key: &[u8], value: &[u8]) -> WalRecord {
WalRecord::Insert {
table: table.to_string(),
key: key.to_vec(),
value: value.to_vec(),
ts: 0,
}
}
#[test]
fn test_partitioned_wal_basic() {
let dir = tempdir().unwrap();
let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
let seq = wal.append(insert_record("users", b"k1", b"v1")).unwrap();
assert_eq!(seq, 0);
let seq2 = wal.append(insert_record("orders", b"k2", b"v2")).unwrap();
assert_eq!(seq2, 1);
assert_eq!(wal.partition_count(), 2);
assert_eq!(wal.buffered_count(), 2);
}
#[test]
fn test_partitioned_wal_batch() {
let dir = tempdir().unwrap();
let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
let records = vec![
insert_record("users", b"k1", b"v1"),
insert_record("users", b"k2", b"v2"),
insert_record("orders", b"k3", b"v3"),
];
let seqs = wal.append_batch(records).unwrap();
assert_eq!(seqs.len(), 3);
assert_eq!(wal.partition_count(), 2);
}
#[test]
fn test_partitioned_wal_flush() {
let dir = tempdir().unwrap();
let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
for i in 0..10 {
wal.append(insert_record("users", format!("k{i}").as_bytes(), b"v"))
.unwrap();
}
let flushed = wal.flush_all().unwrap();
assert_eq!(flushed, 10);
assert_eq!(wal.buffered_count(), 0);
assert!(dir.path().join("users.wal").exists());
}
#[test]
fn test_partitioned_wal_auto_flush() {
let dir = tempdir().unwrap();
let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 5).unwrap();
for i in 0..5 {
wal.append(insert_record("users", format!("k{i}").as_bytes(), b"v"))
.unwrap();
}
assert_eq!(wal.buffered_count(), 0);
assert!(dir.path().join("users.wal").exists());
}
#[test]
fn test_parallel_checkpoint() {
let dir = tempdir().unwrap();
let wal = PartitionedWalWriter::new(dir.path().to_path_buf(), 100).unwrap();
wal.append(insert_record("users", b"k1", b"v1")).unwrap();
wal.append(insert_record("orders", b"k2", b"v2")).unwrap();
wal.flush_all().unwrap();
let checkpoint_mgr = ParallelCheckpointManager::new(dir.path().to_path_buf());
let count = checkpoint_mgr
.checkpoint_tables(&["users".to_string(), "orders".to_string()])
.unwrap();
assert_eq!(count, 2);
assert!(dir.path().join("users.checkpoint").exists());
assert!(dir.path().join("orders.checkpoint").exists());
}
}