use anyhow::{Context, Result};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use tracing::{debug, info, warn};
use super::database::{Database, Tree};
#[derive(Debug, Clone)]
enum WalOp {
Put {
tree: String,
key: Vec<u8>,
value: Vec<u8>,
},
Delete {
tree: String,
key: Vec<u8>,
},
}
impl WalOp {
fn serialize(&self) -> Vec<u8> {
let mut buf = Vec::new();
match self {
WalOp::Put { tree, key, value } => {
buf.push(0x01); buf.extend_from_slice(&(tree.len() as u32).to_le_bytes());
buf.extend_from_slice(tree.as_bytes());
buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
buf.extend_from_slice(key);
buf.extend_from_slice(&(value.len() as u32).to_le_bytes());
buf.extend_from_slice(value);
}
WalOp::Delete { tree, key } => {
buf.push(0x02); buf.extend_from_slice(&(tree.len() as u32).to_le_bytes());
buf.extend_from_slice(tree.as_bytes());
buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
buf.extend_from_slice(key);
}
}
buf
}
fn deserialize(data: &[u8]) -> Result<(WalOp, usize)> {
if data.is_empty() {
anyhow::bail!("Empty WAL entry");
}
let mut pos = 0;
let op_type = data[pos];
pos += 1;
if pos + 4 > data.len() {
anyhow::bail!("Truncated WAL entry");
}
let tree_len =
u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
if pos + tree_len > data.len() {
anyhow::bail!("Truncated WAL entry");
}
let tree = String::from_utf8(data[pos..pos + tree_len].to_vec())?;
pos += tree_len;
if pos + 4 > data.len() {
anyhow::bail!("Truncated WAL entry");
}
let key_len =
u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
if pos + key_len > data.len() {
anyhow::bail!("Truncated WAL entry");
}
let key = data[pos..pos + key_len].to_vec();
pos += key_len;
match op_type {
0x01 => {
if pos + 4 > data.len() {
anyhow::bail!("Truncated WAL entry");
}
let value_len =
u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]])
as usize;
pos += 4;
if pos + value_len > data.len() {
anyhow::bail!("Truncated WAL entry");
}
let value = data[pos..pos + value_len].to_vec();
pos += value_len;
Ok((WalOp::Put { tree, key, value }, pos))
}
0x02 => {
Ok((WalOp::Delete { tree, key }, pos))
}
_ => anyhow::bail!("Unknown WAL operation type: {}", op_type),
}
}
}
#[derive(Debug, Clone)]
pub struct WalConfig {
pub flush_interval_blocks: u64,
pub max_buffered_ops: usize,
pub sync_wal: bool,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
flush_interval_blocks: 1000,
max_buffered_ops: 100_000,
sync_wal: false, }
}
}
pub struct WalBufferedDb {
inner: Arc<dyn Database>,
wal_path: PathBuf,
wal_file: Mutex<Option<BufWriter<File>>>,
buffer: RwLock<HashMap<String, HashMap<Vec<u8>, Option<Vec<u8>>>>>,
buffered_ops: Mutex<usize>,
current_height: Mutex<u64>,
last_flush_height: Mutex<u64>,
config: WalConfig,
}
impl WalBufferedDb {
pub fn new(inner: Arc<dyn Database>, wal_dir: &Path, config: WalConfig) -> Result<Self> {
std::fs::create_dir_all(wal_dir)?;
let wal_path = wal_dir.join("ibd.wal");
let db = Self {
inner,
wal_path: wal_path.clone(),
wal_file: Mutex::new(None),
buffer: RwLock::new(HashMap::new()),
buffered_ops: Mutex::new(0),
current_height: Mutex::new(0),
last_flush_height: Mutex::new(0),
config,
};
if wal_path.exists() {
info!("Found existing WAL file, replaying...");
db.replay_wal()?;
}
let wal_file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true) .open(&wal_path)
.context("Failed to open WAL file")?;
*db.wal_file.lock().unwrap() = Some(BufWriter::new(wal_file));
Ok(db)
}
fn replay_wal(&self) -> Result<()> {
let file = File::open(&self.wal_path)?;
let mut reader = BufReader::new(file);
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
if data.is_empty() {
return Ok(());
}
let mut pos = 0;
let mut ops_count = 0;
let mut tree_ops: HashMap<String, Vec<WalOp>> = HashMap::new();
while pos < data.len() {
match WalOp::deserialize(&data[pos..]) {
Ok((op, consumed)) => {
let tree_name = match &op {
WalOp::Put { tree, .. } => tree.clone(),
WalOp::Delete { tree, .. } => tree.clone(),
};
tree_ops.entry(tree_name).or_default().push(op);
pos += consumed;
ops_count += 1;
}
Err(e) => {
warn!("WAL replay stopped at position {} due to: {}", pos, e);
break;
}
}
}
info!(
"Replaying {} WAL operations across {} trees",
ops_count,
tree_ops.len()
);
for (tree_name, ops) in tree_ops {
let tree = self.inner.open_tree(&tree_name)?;
let mut batch = tree.batch()?;
for op in ops {
match op {
WalOp::Put { key, value, .. } => batch.put(&key, &value),
WalOp::Delete { key, .. } => batch.delete(&key),
}
}
batch.commit()?;
}
info!("WAL replay complete");
Ok(())
}
pub fn buffered_put(&self, tree_name: &str, key: &[u8], value: &[u8]) -> Result<()> {
let op = WalOp::Put {
tree: tree_name.to_string(),
key: key.to_vec(),
value: value.to_vec(),
};
self.write_wal(&op)?;
let mut buffer = self.buffer.write().unwrap();
buffer
.entry(tree_name.to_string())
.or_default()
.insert(key.to_vec(), Some(value.to_vec()));
*self.buffered_ops.lock().unwrap() += 1;
self.maybe_flush()?;
Ok(())
}
pub fn buffered_delete(&self, tree_name: &str, key: &[u8]) -> Result<()> {
let op = WalOp::Delete {
tree: tree_name.to_string(),
key: key.to_vec(),
};
self.write_wal(&op)?;
let mut buffer = self.buffer.write().unwrap();
buffer
.entry(tree_name.to_string())
.or_default()
.insert(key.to_vec(), None);
*self.buffered_ops.lock().unwrap() += 1;
self.maybe_flush()?;
Ok(())
}
pub fn buffered_get(&self, tree_name: &str, key: &[u8]) -> Result<Option<Vec<u8>>> {
let buffer = self.buffer.read().unwrap();
if let Some(tree_buffer) = buffer.get(tree_name) {
if let Some(value_opt) = tree_buffer.get(key) {
return Ok(value_opt.clone());
}
}
drop(buffer);
let tree = self.inner.open_tree(tree_name)?;
tree.get(key)
}
fn write_wal(&self, op: &WalOp) -> Result<()> {
let mut wal_file_guard = self.wal_file.lock().unwrap();
if let Some(ref mut wal_file) = *wal_file_guard {
let data = op.serialize();
wal_file.write_all(&data)?;
if self.config.sync_wal {
wal_file.flush()?;
}
}
Ok(())
}
pub fn block_processed(&self, height: u64) -> Result<()> {
*self.current_height.lock().unwrap() = height;
self.maybe_flush()
}
fn maybe_flush(&self) -> Result<()> {
let current = *self.current_height.lock().unwrap();
let last_flush = *self.last_flush_height.lock().unwrap();
let ops = *self.buffered_ops.lock().unwrap();
let should_flush = (current - last_flush >= self.config.flush_interval_blocks)
|| (ops >= self.config.max_buffered_ops);
if should_flush {
self.flush()?;
}
Ok(())
}
pub fn flush(&self) -> Result<()> {
let mut buffer = self.buffer.write().unwrap();
let ops_count = *self.buffered_ops.lock().unwrap();
if ops_count == 0 {
return Ok(());
}
let start = std::time::Instant::now();
debug!("Flushing {} buffered operations to database", ops_count);
for (tree_name, tree_buffer) in buffer.drain() {
if tree_buffer.is_empty() {
continue;
}
let tree = self.inner.open_tree(&tree_name)?;
let mut batch = tree.batch()?;
for (key, value_opt) in tree_buffer {
match value_opt {
Some(value) => batch.put(&key, &value),
None => batch.delete(&key),
}
}
batch.commit()?;
}
self.inner.flush()?;
{
let mut wal_file_guard = self.wal_file.lock().unwrap();
if let Some(ref mut wal_file) = *wal_file_guard {
wal_file.flush()?;
}
let new_file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&self.wal_path)?;
*wal_file_guard = Some(BufWriter::new(new_file));
}
*self.buffered_ops.lock().unwrap() = 0;
*self.last_flush_height.lock().unwrap() = *self.current_height.lock().unwrap();
let elapsed = start.elapsed();
info!(
"Flushed {} operations in {:?} ({:.0} ops/sec)",
ops_count,
elapsed,
ops_count as f64 / elapsed.as_secs_f64()
);
Ok(())
}
pub fn inner(&self) -> &Arc<dyn Database> {
&self.inner
}
pub fn stats(&self) -> (usize, u64, u64) {
let ops = *self.buffered_ops.lock().unwrap();
let current = *self.current_height.lock().unwrap();
let last_flush = *self.last_flush_height.lock().unwrap();
(ops, current, last_flush)
}
}
impl Drop for WalBufferedDb {
fn drop(&mut self) {
if let Err(e) = self.flush() {
warn!("Failed to flush WAL buffer on shutdown: {}", e);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
struct MockDb {
data: Mutex<HashMap<String, HashMap<Vec<u8>, Vec<u8>>>>,
}
impl MockDb {
fn new() -> Self {
Self {
data: Mutex::new(HashMap::new()),
}
}
}
impl Database for MockDb {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn open_tree(&self, name: &str) -> Result<Box<dyn Tree>> {
Ok(Box::new(MockTree {
name: name.to_string(),
db: Arc::new(Mutex::new(HashMap::new())),
}))
}
fn flush(&self) -> Result<()> {
Ok(())
}
}
struct MockTree {
name: String,
db: Arc<Mutex<HashMap<Vec<u8>, Vec<u8>>>>,
}
impl Tree for MockTree {
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
self.db.lock().unwrap().insert(key.to_vec(), value.to_vec());
Ok(())
}
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.db.lock().unwrap().get(key).cloned())
}
fn remove(&self, key: &[u8]) -> Result<()> {
self.db.lock().unwrap().remove(key);
Ok(())
}
fn contains_key(&self, key: &[u8]) -> Result<bool> {
Ok(self.db.lock().unwrap().contains_key(key))
}
fn clear(&self) -> Result<()> {
self.db.lock().unwrap().clear();
Ok(())
}
fn len(&self) -> Result<usize> {
Ok(self.db.lock().unwrap().len())
}
fn iter(&self) -> Box<dyn Iterator<Item = Result<(Vec<u8>, Vec<u8>)>> + '_> {
Box::new(std::iter::empty())
}
fn batch(&self) -> Result<Box<dyn super::super::database::BatchWriter + '_>> {
Ok(Box::new(MockBatch {
ops: Vec::new(),
db: self.db.clone(),
}))
}
}
struct MockBatch {
ops: Vec<(Vec<u8>, Option<Vec<u8>>)>,
db: Arc<Mutex<HashMap<Vec<u8>, Vec<u8>>>>,
}
impl super::super::database::BatchWriter for MockBatch {
fn put(&mut self, key: &[u8], value: &[u8]) {
self.ops.push((key.to_vec(), Some(value.to_vec())));
}
fn delete(&mut self, key: &[u8]) {
self.ops.push((key.to_vec(), None));
}
fn commit(self: Box<Self>) -> Result<()> {
let mut db = self.db.lock().unwrap();
for (key, value) in self.ops {
match value {
Some(v) => db.insert(key, v),
None => db.remove(&key),
};
}
Ok(())
}
fn len(&self) -> usize {
self.ops.len()
}
}
#[test]
fn test_wal_basic_operations() {
let dir = tempdir().unwrap();
let mock_db = Arc::new(MockDb::new());
let config = WalConfig {
flush_interval_blocks: 10,
max_buffered_ops: 100,
sync_wal: false,
};
let wal_db = WalBufferedDb::new(mock_db, dir.path(), config).unwrap();
wal_db.buffered_put("test", b"key1", b"value1").unwrap();
wal_db.buffered_put("test", b"key2", b"value2").unwrap();
assert_eq!(
wal_db.buffered_get("test", b"key1").unwrap(),
Some(b"value1".to_vec())
);
wal_db.flush().unwrap();
let (ops, _, _) = wal_db.stats();
assert_eq!(ops, 0);
}
#[test]
fn test_wal_serialization() {
let op = WalOp::Put {
tree: "test".to_string(),
key: vec![1, 2, 3],
value: vec![4, 5, 6],
};
let serialized = op.serialize();
let (deserialized, _) = WalOp::deserialize(&serialized).unwrap();
match deserialized {
WalOp::Put { tree, key, value } => {
assert_eq!(tree, "test");
assert_eq!(key, vec![1, 2, 3]);
assert_eq!(value, vec![4, 5, 6]);
}
_ => panic!("Wrong operation type"),
}
}
}