use std::fs::{create_dir_all, File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use super::state::{OperationType, TransactionId};
const WAL_MAGIC: u32 = 0x53594E57;
const WAL_VERSION: u16 = 1;
const MAX_WAL_FILE_SIZE: u64 = 64 * 1024 * 1024;
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WALEntryType {
TransactionBegin = 1,
TransactionOperation = 2,
TransactionCommit = 3,
TransactionRollback = 4,
}
#[derive(Debug, Clone)]
pub struct WALEntry {
pub entry_type: WALEntryType,
pub transaction_id: TransactionId,
pub global_sequence: u64,
pub txn_sequence: u64,
pub timestamp: SystemTime,
pub operation_type: Option<OperationType>,
pub description: String,
}
#[derive(Debug)]
pub struct PersistentWAL {
wal_dir: PathBuf,
current_writer: Arc<Mutex<Option<BufWriter<File>>>>,
current_file_number: Arc<Mutex<u64>>,
global_sequence: Arc<Mutex<u64>>,
current_file_path: Arc<Mutex<Option<PathBuf>>>,
current_file_size: Arc<Mutex<u64>>,
catalog_wal: Option<Arc<CatalogWAL>>,
}
#[derive(Debug)]
pub struct CatalogWAL {
catalog_wal_dir: PathBuf,
writer: Arc<Mutex<Option<BufWriter<File>>>>,
file_number: Arc<Mutex<u64>>,
}
impl WALEntry {
pub fn new(
entry_type: WALEntryType,
transaction_id: TransactionId,
global_sequence: u64,
txn_sequence: u64,
operation_type: Option<OperationType>,
description: String,
) -> Self {
Self {
entry_type,
transaction_id,
global_sequence,
txn_sequence,
timestamp: SystemTime::now(),
operation_type,
description,
}
}
pub fn affects_catalog(&self) -> bool {
match self.operation_type {
Some(OperationType::CreateTable)
| Some(OperationType::CreateGraph)
| Some(OperationType::DropTable)
| Some(OperationType::DropGraph) => true,
_ => {
self.description.contains("SCHEMA")
|| self.description.contains("INDEX")
|| self.description.contains("CONSTRAINT")
|| self.description.contains("VIEW")
}
}
}
pub fn serialize(&self) -> Vec<u8> {
let mut buffer = Vec::with_capacity(256);
buffer.extend_from_slice(&WAL_MAGIC.to_le_bytes());
buffer.push(self.entry_type as u8);
buffer.extend_from_slice(&self.transaction_id.id().to_le_bytes());
buffer.extend_from_slice(&self.global_sequence.to_le_bytes());
buffer.extend_from_slice(&self.txn_sequence.to_le_bytes());
let timestamp_nanos = self
.timestamp
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
buffer.extend_from_slice(×tamp_nanos.to_le_bytes());
let op_type_byte = match &self.operation_type {
Some(OperationType::Select) => 0,
Some(OperationType::Match) => 1,
Some(OperationType::Insert) => 10,
Some(OperationType::Update) => 11,
Some(OperationType::Set) => 12,
Some(OperationType::Delete) => 13,
Some(OperationType::Remove) => 14,
Some(OperationType::CreateTable) => 20,
Some(OperationType::CreateGraph) => 21,
Some(OperationType::AlterTable) => 22,
Some(OperationType::DropTable) => 23,
Some(OperationType::DropGraph) => 24,
Some(OperationType::CreateUser) => 25,
Some(OperationType::DropUser) => 26,
Some(OperationType::CreateRole) => 27,
Some(OperationType::DropRole) => 28,
Some(OperationType::GrantRole) => 29,
Some(OperationType::RevokeRole) => 30,
Some(OperationType::Begin) => 31,
Some(OperationType::Commit) => 32,
Some(OperationType::Rollback) => 33,
Some(OperationType::Other) => 99,
None => 255,
};
buffer.push(op_type_byte);
let desc_bytes = self.description.as_bytes();
buffer.extend_from_slice(&(desc_bytes.len() as u32).to_le_bytes());
buffer.extend_from_slice(desc_bytes);
let checksum = crc32fast::hash(&buffer);
buffer.extend_from_slice(&checksum.to_le_bytes());
buffer
}
pub fn deserialize(data: &[u8]) -> Result<Self, WALError> {
if data.len() < 50 {
return Err(WALError::CorruptedEntry("Entry too small".to_string()));
}
let mut offset = 0;
let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
if magic != WAL_MAGIC {
return Err(WALError::CorruptedEntry("Invalid magic number".to_string()));
}
offset += 4;
let entry_type = match data[offset] {
1 => WALEntryType::TransactionBegin,
2 => WALEntryType::TransactionOperation,
3 => WALEntryType::TransactionCommit,
4 => WALEntryType::TransactionRollback,
_ => return Err(WALError::CorruptedEntry("Invalid entry type".to_string())),
};
offset += 1;
let txn_id_bytes = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
];
let transaction_id = TransactionId::from_u64(u64::from_le_bytes(txn_id_bytes));
offset += 8;
let global_seq_bytes = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
];
let global_sequence = u64::from_le_bytes(global_seq_bytes);
offset += 8;
let txn_seq_bytes = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
];
let txn_sequence = u64::from_le_bytes(txn_seq_bytes);
offset += 8;
let timestamp_bytes = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
data[offset + 5],
data[offset + 6],
data[offset + 7],
];
let timestamp_nanos = u64::from_le_bytes(timestamp_bytes);
let timestamp = UNIX_EPOCH + std::time::Duration::from_nanos(timestamp_nanos);
offset += 8;
let operation_type = match data[offset] {
0 => Some(OperationType::Select),
1 => Some(OperationType::Match),
10 => Some(OperationType::Insert),
11 => Some(OperationType::Update),
12 => Some(OperationType::Set),
13 => Some(OperationType::Delete),
14 => Some(OperationType::Remove),
20 => Some(OperationType::CreateTable),
21 => Some(OperationType::CreateGraph),
22 => Some(OperationType::AlterTable),
23 => Some(OperationType::DropTable),
24 => Some(OperationType::DropGraph),
25 => Some(OperationType::CreateUser),
26 => Some(OperationType::DropUser),
27 => Some(OperationType::CreateRole),
28 => Some(OperationType::DropRole),
29 => Some(OperationType::GrantRole),
30 => Some(OperationType::RevokeRole),
31 => Some(OperationType::Begin),
32 => Some(OperationType::Commit),
33 => Some(OperationType::Rollback),
99 => Some(OperationType::Other),
255 => None,
_ => {
return Err(WALError::CorruptedEntry(
"Invalid operation type".to_string(),
))
}
};
offset += 1;
if offset + 4 > data.len() {
return Err(WALError::CorruptedEntry(
"Truncated description length".to_string(),
));
}
let desc_len_bytes = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
];
let desc_len = u32::from_le_bytes(desc_len_bytes) as usize;
offset += 4;
if offset + desc_len + 4 > data.len() {
return Err(WALError::CorruptedEntry(
"Truncated description or checksum".to_string(),
));
}
let description = String::from_utf8(data[offset..offset + desc_len].to_vec())
.map_err(|_| WALError::CorruptedEntry("Invalid UTF-8 in description".to_string()))?;
offset += desc_len;
let expected_checksum_bytes = [
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
];
let expected_checksum = u32::from_le_bytes(expected_checksum_bytes);
let actual_checksum = crc32fast::hash(&data[..offset]);
if expected_checksum != actual_checksum {
return Err(WALError::CorruptedEntry("Checksum mismatch".to_string()));
}
Ok(Self {
entry_type,
transaction_id,
global_sequence,
txn_sequence,
timestamp,
operation_type,
description,
})
}
}
impl PersistentWAL {
pub fn new(db_path: PathBuf) -> Result<Self, WALError> {
Self::new_with_path(db_path)
}
pub fn new_with_path(db_path: PathBuf) -> Result<Self, WALError> {
let wal_dir = db_path.join("wal");
create_dir_all(&wal_dir)
.map_err(|e| WALError::IOError(format!("Failed to create WAL directory: {}", e)))?;
let catalog_wal_dir = wal_dir.join("catalog");
create_dir_all(&catalog_wal_dir).map_err(|e| {
WALError::IOError(format!("Failed to create catalog WAL directory: {}", e))
})?;
let catalog_wal = CatalogWAL {
catalog_wal_dir,
writer: Arc::new(Mutex::new(None)),
file_number: Arc::new(Mutex::new(0)),
};
let mut wal = Self {
wal_dir,
current_writer: Arc::new(Mutex::new(None)),
current_file_number: Arc::new(Mutex::new(0)),
global_sequence: Arc::new(Mutex::new(0)),
current_file_path: Arc::new(Mutex::new(None)),
current_file_size: Arc::new(Mutex::new(0)),
catalog_wal: Some(Arc::new(catalog_wal)),
};
wal.initialize()?;
Ok(wal)
}
fn initialize(&mut self) -> Result<(), WALError> {
let mut max_file_number = 0u64;
let mut max_global_sequence = 0u64;
if let Ok(entries) = std::fs::read_dir(&self.wal_dir) {
for entry in entries {
if let Ok(entry) = entry {
if let Some(filename) = entry.file_name().to_str() {
if filename.starts_with("wal_") && filename.ends_with(".log") {
if let Some(number_str) = filename
.strip_prefix("wal_")
.and_then(|s| s.strip_suffix(".log"))
{
if let Ok(file_number) = number_str.parse::<u64>() {
max_file_number = max_file_number.max(file_number);
if let Ok(entries) = self.read_wal_file(file_number) {
for entry in entries {
max_global_sequence =
max_global_sequence.max(entry.global_sequence);
}
}
}
}
}
}
}
}
}
*self.current_file_number.lock().unwrap() = max_file_number;
*self.global_sequence.lock().unwrap() = max_global_sequence;
self.rotate_wal_file()?;
Ok(())
}
pub fn write_entry(&self, entry: WALEntry) -> Result<(), WALError> {
let serialized = entry.serialize();
{
let current_size = *self.current_file_size.lock().unwrap();
if current_size + serialized.len() as u64 > MAX_WAL_FILE_SIZE {
self.rotate_wal_file()?;
}
}
{
let mut writer_guard = self.current_writer.lock().unwrap();
if let Some(writer) = writer_guard.as_mut() {
writer
.write_all(&serialized)
.map_err(|e| WALError::IOError(format!("Failed to write WAL entry: {}", e)))?;
writer
.flush()
.map_err(|e| WALError::IOError(format!("Failed to flush WAL: {}", e)))?;
writer
.get_mut()
.sync_data()
.map_err(|e| WALError::IOError(format!("Failed to sync WAL: {}", e)))?;
*self.current_file_size.lock().unwrap() += serialized.len() as u64;
} else {
return Err(WALError::IOError("No active WAL file".to_string()));
}
}
if entry.affects_catalog() {
if let Some(catalog_wal) = &self.catalog_wal {
catalog_wal.write_entry(&serialized)?;
}
}
Ok(())
}
#[allow(dead_code)] pub fn mark_committed(&self, transaction_id: TransactionId) -> Result<(), WALError> {
let seq = self.next_global_sequence();
let entry = WALEntry::new(
WALEntryType::TransactionCommit,
transaction_id,
seq,
0, Some(OperationType::Commit),
format!("Transaction {} committed", transaction_id.id()),
);
self.write_entry(entry)
}
#[allow(dead_code)] pub fn mark_rolled_back(&self, transaction_id: TransactionId) -> Result<(), WALError> {
let seq = self.next_global_sequence();
let entry = WALEntry::new(
WALEntryType::TransactionRollback,
transaction_id,
seq,
0, Some(OperationType::Rollback),
format!("Transaction {} rolled back", transaction_id.id()),
);
self.write_entry(entry)
}
fn rotate_wal_file(&self) -> Result<(), WALError> {
{
let mut writer_guard = self.current_writer.lock().unwrap();
if let Some(mut old_writer) = writer_guard.take() {
old_writer
.flush()
.map_err(|e| WALError::IOError(format!("Failed to flush old WAL: {}", e)))?;
old_writer
.get_mut()
.sync_all()
.map_err(|e| WALError::IOError(format!("Failed to sync old WAL: {}", e)))?;
}
}
let mut file_number = self.current_file_number.lock().unwrap();
*file_number += 1;
let filename = format!("wal_{:06}.log", *file_number);
let file_path = self.wal_dir.join(&filename);
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(false)
.append(true)
.open(&file_path)
.map_err(|e| WALError::IOError(format!("Failed to create WAL file: {}", e)))?;
let mut writer = BufWriter::new(file);
let header = self.create_file_header()?;
writer
.write_all(&header)
.map_err(|e| WALError::IOError(format!("Failed to write WAL header: {}", e)))?;
writer
.flush()
.map_err(|e| WALError::IOError(format!("Failed to flush WAL header: {}", e)))?;
*self.current_writer.lock().unwrap() = Some(writer);
*self.current_file_path.lock().unwrap() = Some(file_path);
*self.current_file_size.lock().unwrap() = header.len() as u64;
Ok(())
}
fn create_file_header(&self) -> Result<Vec<u8>, WALError> {
let mut header = Vec::with_capacity(64);
header.extend_from_slice(&WAL_MAGIC.to_le_bytes());
header.extend_from_slice(&WAL_VERSION.to_le_bytes());
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
header.extend_from_slice(×tamp.to_le_bytes());
header.extend_from_slice(&[0u8; 50]);
Ok(header)
}
pub fn next_global_sequence(&self) -> u64 {
let mut seq = self.global_sequence.lock().unwrap();
*seq += 1;
*seq
}
#[allow(dead_code)] pub fn current_file_number(&self) -> u64 {
*self.current_file_number.lock().unwrap()
}
#[allow(dead_code)] pub fn flush(&self) -> Result<(), WALError> {
let mut writer_guard = self.current_writer.lock().unwrap();
if let Some(writer) = writer_guard.as_mut() {
writer
.flush()
.map_err(|e| WALError::IOError(format!("Failed to flush WAL: {}", e)))?;
writer
.get_mut()
.sync_all()
.map_err(|e| WALError::IOError(format!("Failed to sync WAL: {}", e)))?;
}
Ok(())
}
pub fn read_wal_file(&self, file_number: u64) -> Result<Vec<WALEntry>, WALError> {
let filename = format!("wal_{:06}.log", file_number);
let file_path = self.wal_dir.join(&filename);
if !file_path.exists() {
return Err(WALError::IOError(format!(
"WAL file not found: {}",
file_path.display()
)));
}
let mut file = File::open(&file_path)
.map_err(|e| WALError::IOError(format!("Failed to open WAL file: {}", e)))?;
file.seek(SeekFrom::Start(64))
.map_err(|e| WALError::IOError(format!("Failed to seek in WAL file: {}", e)))?;
let mut reader = BufReader::new(file);
let mut entries = Vec::new();
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| WALError::IOError(format!("Failed to read WAL file: {}", e)))?;
let mut offset = 0;
while offset < buffer.len() {
if offset + 4 <= buffer.len() {
let magic = u32::from_le_bytes([
buffer[offset],
buffer[offset + 1],
buffer[offset + 2],
buffer[offset + 3],
]);
if magic == WAL_MAGIC {
if offset + 50 <= buffer.len() {
match WALEntry::deserialize(&buffer[offset..]) {
Ok(entry) => {
let entry_bytes = entry.serialize();
let entry_size = entry_bytes.len();
entries.push(entry);
offset += entry_size;
continue;
}
Err(_) => {
offset += 1;
}
}
} else {
break;
}
} else {
offset += 1;
}
} else {
break;
}
}
Ok(entries)
}
}
#[derive(Debug)]
#[allow(dead_code)] pub enum WALError {
IOError(String),
CorruptedEntry(String),
ConfigError(String),
}
impl std::fmt::Display for WALError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WALError::IOError(msg) => write!(f, "WAL IO Error: {}", msg),
WALError::CorruptedEntry(msg) => write!(f, "WAL Corrupted Entry: {}", msg),
WALError::ConfigError(msg) => write!(f, "WAL Config Error: {}", msg),
}
}
}
impl std::error::Error for WALError {}
impl CatalogWAL {
pub fn write_entry(&self, serialized: &[u8]) -> Result<(), WALError> {
{
let writer_guard = self.writer.lock().unwrap();
if writer_guard.is_none() {
drop(writer_guard); self.rotate_catalog_file()?;
}
}
let mut writer_guard = self.writer.lock().unwrap();
if let Some(writer) = writer_guard.as_mut() {
writer.write_all(serialized).map_err(|e| {
WALError::IOError(format!("Failed to write catalog WAL entry: {}", e))
})?;
writer
.flush()
.map_err(|e| WALError::IOError(format!("Failed to flush catalog WAL: {}", e)))?;
writer
.get_mut()
.sync_all()
.map_err(|e| WALError::IOError(format!("Failed to sync catalog WAL: {}", e)))?;
}
Ok(())
}
fn rotate_catalog_file(&self) -> Result<(), WALError> {
let mut file_number = self.file_number.lock().unwrap();
*file_number += 1;
let filename = format!("catalog_{:06}.log", *file_number);
let file_path = self.catalog_wal_dir.join(&filename);
let file = OpenOptions::new()
.create(true)
.append(true)
.open(&file_path)
.map_err(|e| WALError::IOError(format!("Failed to create catalog WAL file: {}", e)))?;
let writer = BufWriter::new(file);
*self.writer.lock().unwrap() = Some(writer);
Ok(())
}
}