use std::collections::HashMap;
use crate::constants::*;
use crate::core::pager::FilePager;
use crate::error::{KiteError, Result};
use crate::types::*;
use crate::util::binary::*;
use super::record::{parse_wal_record, ParsedWalRecord, WalRecord};
const PRIMARY_REGION_RATIO: f64 = 0.75;
pub struct WalBuffer {
base_offset: u64,
capacity: u64,
head: u64,
tail: u64,
page_size: usize,
pending_writes: HashMap<u64, Vec<u8>>,
primary_region_size: u64,
secondary_region_start: u64,
secondary_region_size: u64,
active_region: u8,
primary_head: u64,
secondary_head: u64,
}
impl WalBuffer {
pub fn new(base_offset: u64, capacity: u64, page_size: usize) -> Self {
let primary_region_size = (capacity as f64 * PRIMARY_REGION_RATIO) as u64;
let secondary_region_start = primary_region_size;
let secondary_region_size = capacity - primary_region_size;
Self {
base_offset,
capacity,
head: 0,
tail: 0,
page_size,
pending_writes: HashMap::new(),
primary_region_size,
secondary_region_start,
secondary_region_size,
active_region: 0,
primary_head: 0,
secondary_head: secondary_region_start,
}
}
pub fn from_header(header: &DbHeaderV1) -> Self {
let base_offset = header.wal_start_page * header.page_size as u64;
let capacity = header.wal_page_count * header.page_size as u64;
let primary_region_size = (capacity as f64 * PRIMARY_REGION_RATIO) as u64;
let secondary_region_start = primary_region_size;
let secondary_region_size = capacity - primary_region_size;
let active_region = header.active_wal_region;
let mut primary_head = header.wal_primary_head;
let mut secondary_head = header.wal_secondary_head;
if primary_head == 0 && header.wal_head > 0 {
primary_head = header.wal_head;
}
if secondary_head == 0 {
secondary_head = secondary_region_start;
}
if active_region == 1
&& secondary_head <= secondary_region_start
&& header.wal_head >= secondary_region_start
{
secondary_head = header.wal_head;
}
Self {
base_offset,
capacity,
head: header.wal_head,
tail: header.wal_tail,
page_size: header.page_size as usize,
pending_writes: HashMap::new(),
primary_region_size,
secondary_region_start,
secondary_region_size,
active_region,
primary_head,
secondary_head,
}
}
pub fn base_offset(&self) -> u64 {
self.base_offset
}
pub fn capacity(&self) -> u64 {
self.capacity
}
pub fn head(&self) -> u64 {
self.head
}
pub fn tail(&self) -> u64 {
self.tail
}
pub fn primary_head(&self) -> u64 {
self.primary_head
}
pub fn secondary_head(&self) -> u64 {
self.secondary_head
}
pub fn active_region(&self) -> u8 {
self.active_region
}
pub fn primary_region_size(&self) -> u64 {
self.primary_region_size
}
pub fn secondary_region_size(&self) -> u64 {
self.secondary_region_size
}
pub fn is_empty(&self) -> bool {
self.head == self.tail
}
pub fn used(&self) -> u64 {
if self.active_region == 0 {
self.primary_head - self.tail
} else {
self.secondary_head - self.secondary_region_start
}
}
pub fn free(&self) -> u64 {
if self.active_region == 0 {
self.primary_region_size.saturating_sub(self.primary_head)
} else {
self
.secondary_region_size
.saturating_sub(self.secondary_head - self.secondary_region_start)
}
}
pub fn usage_ratio(&self) -> f64 {
if self.active_region == 0 {
self.primary_head as f64 / self.primary_region_size as f64
} else {
(self.secondary_head - self.secondary_region_start) as f64 / self.secondary_region_size as f64
}
}
pub fn can_fit(&self, size: usize) -> bool {
let aligned_size = align_up(size, WAL_RECORD_ALIGNMENT) as u64;
aligned_size <= self.free()
}
pub fn switch_to_secondary(&mut self) {
if self.active_region == 1 {
return; }
self.active_region = 1;
self.head = self.secondary_head;
}
pub fn switch_to_primary(&mut self, reset_primary: bool) {
if self.active_region == 0 && !reset_primary {
return; }
self.active_region = 0;
if reset_primary {
self.primary_head = 0;
self.tail = 0;
}
self.head = self.primary_head;
}
pub fn merge_secondary_into_primary(&mut self, pager: &mut FilePager) -> Result<()> {
let has_secondary_records = self.secondary_head > self.secondary_region_start;
let secondary_records = if has_secondary_records {
self.scan_region(1, pager)?
} else {
Vec::new()
};
self.primary_head = 0;
self.secondary_head = self.secondary_region_start;
self.tail = 0;
self.active_region = 0;
self.head = 0;
for record in secondary_records {
let ParsedWalRecord {
record_type,
txid,
payload,
..
} = record;
let wal_record = WalRecord::new(record_type, txid, payload);
let record_bytes = wal_record.build();
self.write_record_bytes_to_primary(&record_bytes, pager)?;
}
Ok(())
}
pub fn recover_incomplete_checkpoint(&mut self, pager: &mut FilePager) -> Result<()> {
let primary_records = self.scan_region(0, pager)?;
let secondary_records = self.scan_region(1, pager)?;
self.primary_head = 0;
self.secondary_head = self.secondary_region_start;
self.tail = 0;
self.active_region = 0;
self.head = 0;
for record in primary_records
.into_iter()
.chain(secondary_records.into_iter())
{
let ParsedWalRecord {
record_type,
txid,
payload,
..
} = record;
let wal_record = WalRecord::new(record_type, txid, payload);
let record_bytes = wal_record.build();
self.write_record_bytes_to_primary(&record_bytes, pager)?;
}
Ok(())
}
pub fn scan_region(&mut self, region: u8, pager: &mut FilePager) -> Result<Vec<ParsedWalRecord>> {
let mut records = Vec::new();
let (mut pos, end_pos) = if region == 0 {
(self.tail, self.primary_head)
} else {
(self.secondary_region_start, self.secondary_head)
};
while pos < end_pos {
let file_offset = self.file_offset(pos);
let header_bytes = self.read_at_offset(file_offset, 8, pager)?;
let rec_len = read_u32(&header_bytes, 0) as usize;
if rec_len == 0 {
break;
}
let pad_len = padding_for(rec_len, WAL_RECORD_ALIGNMENT);
let total_len = rec_len + pad_len;
let record_bytes = self.read_at_offset(file_offset, total_len, pager)?;
match parse_wal_record(&record_bytes, 0) {
Some(record) => {
records.push(record);
pos += total_len as u64;
}
None => break, }
}
Ok(records)
}
fn write_record_bytes_to_primary(
&mut self,
record_bytes: &[u8],
pager: &mut FilePager,
) -> Result<u64> {
let record_size = record_bytes.len();
let aligned_size = align_up(record_size, WAL_RECORD_ALIGNMENT);
if self.primary_head + aligned_size as u64 > self.primary_region_size {
return Err(KiteError::WalBufferFull);
}
let file_offset = self.file_offset(self.primary_head);
self.buffer_write(file_offset, record_bytes, pager)?;
self.primary_head += aligned_size as u64;
self.head = self.primary_head;
Ok(self.primary_head)
}
pub fn file_offset(&self, buffer_pos: u64) -> u64 {
self.base_offset + buffer_pos
}
pub fn reserve(&mut self, size: usize) -> Option<u64> {
let aligned_size = align_up(size, WAL_RECORD_ALIGNMENT) as u64;
if !self.can_fit(aligned_size as usize) {
return None;
}
if self.active_region == 0 {
let write_pos = self.primary_head;
if self.primary_head + aligned_size > self.primary_region_size {
return None;
}
self.primary_head += aligned_size;
self.head = self.primary_head;
Some(write_pos)
} else {
let write_pos = self.secondary_head;
self.secondary_head += aligned_size;
self.head = self.secondary_head;
Some(write_pos)
}
}
pub fn write_record(&mut self, record: &WalRecord, pager: &mut FilePager) -> Result<u64> {
let record_bytes = record.build();
self.write_record_bytes(&record_bytes, pager)
}
pub fn write_record_bytes_batch(
&mut self,
record_bytes: &[u8],
pager: &mut FilePager,
) -> Result<u64> {
if record_bytes.is_empty() {
return Ok(self.head);
}
if record_bytes.len() % WAL_RECORD_ALIGNMENT != 0 {
return Err(KiteError::Internal(
"WAL batch bytes must be alignment-sized".to_string(),
));
}
if !self.can_fit(record_bytes.len()) {
return Err(KiteError::WalBufferFull);
}
if self.active_region == 0 {
if self.primary_head + record_bytes.len() as u64 > self.primary_region_size {
return Err(KiteError::WalBufferFull);
}
let file_offset = self.file_offset(self.primary_head);
self.buffer_write(file_offset, record_bytes, pager)?;
self.primary_head += record_bytes.len() as u64;
self.head = self.primary_head;
} else {
if self.secondary_head + record_bytes.len() as u64
> self.secondary_region_start + self.secondary_region_size
{
return Err(KiteError::WalBufferFull);
}
let file_offset = self.file_offset(self.secondary_head);
self.buffer_write(file_offset, record_bytes, pager)?;
self.secondary_head += record_bytes.len() as u64;
self.head = self.secondary_head;
}
Ok(self.head)
}
fn write_record_bytes(&mut self, record_bytes: &[u8], pager: &mut FilePager) -> Result<u64> {
let record_size = record_bytes.len();
let aligned_size = align_up(record_size, WAL_RECORD_ALIGNMENT);
if !self.can_fit(aligned_size) {
return Err(KiteError::WalBufferFull);
}
if self.active_region == 0 {
if self.primary_head + aligned_size as u64 > self.primary_region_size {
return Err(KiteError::WalBufferFull);
}
let file_offset = self.file_offset(self.primary_head);
self.buffer_write(file_offset, record_bytes, pager)?;
self.primary_head += aligned_size as u64;
self.head = self.primary_head;
} else {
let file_offset = self.file_offset(self.secondary_head);
self.buffer_write(file_offset, record_bytes, pager)?;
self.secondary_head += aligned_size as u64;
self.head = self.secondary_head;
}
Ok(self.head)
}
fn buffer_write(&mut self, offset: u64, data: &[u8], pager: &mut FilePager) -> Result<()> {
let page_size = self.page_size as u64;
let start_page = offset / page_size;
let end_page = (offset + data.len() as u64 - 1) / page_size;
let mut data_offset = 0usize;
for page_idx in start_page..=end_page {
let page_file_offset = page_idx * page_size;
let page_buffer = match self.pending_writes.entry(page_file_offset) {
std::collections::hash_map::Entry::Occupied(entry) => entry.into_mut(),
std::collections::hash_map::Entry::Vacant(entry) => {
let page_num = (page_file_offset / page_size) as u32;
let existing = pager.read_page(page_num)?;
entry.insert(existing)
}
};
let page_start = page_file_offset;
let page_end = page_start + page_size;
let write_start = offset.max(page_start);
let write_end = (offset + data.len() as u64).min(page_end);
let write_len = (write_end - write_start) as usize;
let page_write_offset = (write_start - page_start) as usize;
page_buffer[page_write_offset..page_write_offset + write_len]
.copy_from_slice(&data[data_offset..data_offset + write_len]);
data_offset += write_len;
}
Ok(())
}
fn read_at_offset(&self, offset: u64, length: usize, pager: &mut FilePager) -> Result<Vec<u8>> {
let page_size = self.page_size as u64;
let start_page = offset / page_size;
let end_page = (offset + length as u64 - 1) / page_size;
if start_page == end_page {
let page_file_offset = start_page * page_size;
let page_offset = (offset - page_file_offset) as usize;
if let Some(pending_page) = self.pending_writes.get(&page_file_offset) {
return Ok(pending_page[page_offset..page_offset + length].to_vec());
}
let page_num = start_page as u32;
let page = pager.read_page(page_num)?;
return Ok(page[page_offset..page_offset + length].to_vec());
}
let mut result = vec![0u8; length];
let mut result_offset = 0;
for page_idx in start_page..=end_page {
let page_file_offset = page_idx * page_size;
let page_start = page_file_offset;
let page_end = page_start + page_size;
let read_start = offset.max(page_start);
let read_end = (offset + length as u64).min(page_end);
let read_len = (read_end - read_start) as usize;
let page_read_offset = (read_start - page_start) as usize;
let page_data = if let Some(pending) = self.pending_writes.get(&page_file_offset) {
pending.clone()
} else {
let page_num = page_idx as u32;
pager.read_page(page_num)?
};
result[result_offset..result_offset + read_len]
.copy_from_slice(&page_data[page_read_offset..page_read_offset + read_len]);
result_offset += read_len;
}
Ok(result)
}
pub fn flush(&mut self, pager: &mut FilePager) -> Result<()> {
let page_size = self.page_size as u64;
for (&page_file_offset, data) in &self.pending_writes {
let page_num = (page_file_offset / page_size) as u32;
pager.write_page(page_num, data)?;
}
self.pending_writes.clear();
Ok(())
}
pub fn sync(&mut self, pager: &mut FilePager) -> Result<()> {
self.flush(pager)?;
pager.sync()?;
Ok(())
}
pub fn has_pending_writes(&self) -> bool {
!self.pending_writes.is_empty()
}
pub fn advance_tail(&mut self, new_tail: u64) {
self.tail = new_tail;
}
pub fn reset(&mut self) {
self.head = 0;
self.tail = 0;
self.pending_writes.clear();
self.primary_head = 0;
self.secondary_head = self.secondary_region_start;
self.active_region = 0;
}
pub fn discard_pending(&mut self) {
self.pending_writes.clear();
}
pub fn scan_records(&mut self, pager: &mut FilePager) -> Result<Vec<ParsedWalRecord>> {
let mut records = Vec::new();
if self.is_empty() {
return Ok(records);
}
if self.head < self.tail {
return Err(KiteError::InvalidWal(
"WAL head cannot be behind tail in linear mode".to_string(),
));
}
let mut pos = self.tail;
while pos < self.head {
let file_offset = self.file_offset(pos);
let header_bytes = self.read_at_offset(file_offset, 8, pager)?;
let rec_len = read_u32(&header_bytes, 0) as usize;
if rec_len == 0 {
break;
}
let pad_len = padding_for(rec_len, WAL_RECORD_ALIGNMENT);
let total_len = rec_len + pad_len;
let record_bytes = self.read_at_offset(file_offset, total_len, pager)?;
match parse_wal_record(&record_bytes, 0) {
Some(record) => {
records.push(record);
pos += total_len as u64;
}
None => break, }
}
Ok(records)
}
pub fn stats(&self) -> WalBufferStats {
WalBufferStats {
capacity: self.capacity,
used: self.used(),
free: self.free(),
head: self.head,
tail: self.tail,
pending_pages: self.pending_writes.len(),
primary_head: self.primary_head,
secondary_head: self.secondary_head,
active_region: self.active_region,
}
}
pub fn records_for_recovery(&mut self, pager: &mut FilePager) -> Result<Vec<ParsedWalRecord>> {
let mut records = self.scan_region(0, pager)?;
if self.secondary_head > self.secondary_region_start {
let secondary_records = self.scan_region(1, pager)?;
records.extend(secondary_records);
}
Ok(records)
}
}
#[derive(Debug, Clone)]
pub struct WalBufferStats {
pub capacity: u64,
pub used: u64,
pub free: u64,
pub head: u64,
pub tail: u64,
pub pending_pages: usize,
pub primary_head: u64,
pub secondary_head: u64,
pub active_region: u8,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::pager::create_pager;
use crate::core::wal::record::build_create_node_payload;
use tempfile::NamedTempFile;
fn create_test_pager() -> (FilePager, tempfile::NamedTempFile) {
let temp_file = NamedTempFile::new().expect("expected value");
let mut pager = create_pager(temp_file.path(), 4096).expect("expected value");
pager.allocate_pages(10).expect("expected value");
(pager, temp_file)
}
#[test]
fn test_wal_buffer_new() {
let buffer = WalBuffer::new(4096, 1024 * 1024, 4096);
assert!(buffer.is_empty());
assert_eq!(buffer.capacity(), 1024 * 1024);
assert_eq!(buffer.used(), 0);
}
#[test]
fn test_wal_buffer_reserve() {
let mut buffer = WalBuffer::new(4096, 1024, 4096);
let pos = buffer.reserve(100).expect("expected value");
assert_eq!(pos, 0);
assert!(!buffer.is_empty());
let pos2 = buffer.reserve(100).expect("expected value");
assert!(pos2 > pos);
}
#[test]
fn test_wal_buffer_full() {
let mut buffer = WalBuffer::new(4096, 512, 4096);
buffer.reserve(100).expect("expected value"); buffer.reserve(100).expect("expected value"); buffer.reserve(100).expect("expected value");
assert!(buffer.reserve(100).is_none());
}
#[test]
fn test_wal_buffer_reset() {
let mut buffer = WalBuffer::new(4096, 1024, 4096);
buffer.reserve(500).expect("expected value");
assert!(!buffer.is_empty());
buffer.reset();
assert!(buffer.is_empty());
assert_eq!(buffer.head(), 0);
assert_eq!(buffer.tail(), 0);
}
#[test]
fn test_wal_buffer_write_record() {
let (mut pager, _temp) = create_test_pager();
let mut buffer = WalBuffer::new(4096, 4 * 4096, 4096);
let record = WalRecord::new(
WalRecordType::CreateNode,
1,
build_create_node_payload(100, Some("test_key")),
);
let new_head = buffer
.write_record(&record, &mut pager)
.expect("expected value");
assert!(new_head > 0);
assert!(buffer.has_pending_writes());
buffer.flush(&mut pager).expect("expected value");
assert!(!buffer.has_pending_writes());
}
#[test]
fn test_wal_buffer_write_and_scan() {
let (mut pager, _temp) = create_test_pager();
let mut buffer = WalBuffer::new(4096, 4 * 4096, 4096);
for i in 0..5 {
let record = WalRecord::new(
WalRecordType::CreateNode,
i,
build_create_node_payload(100 + i, None),
);
buffer
.write_record(&record, &mut pager)
.expect("expected value");
}
buffer.flush(&mut pager).expect("expected value");
let records = buffer.scan_records(&mut pager).expect("expected value");
assert_eq!(records.len(), 5);
for (i, record) in records.iter().enumerate() {
assert_eq!(record.txid, i as u64);
assert_eq!(record.record_type, WalRecordType::CreateNode);
}
}
#[test]
fn test_wal_buffer_stats() {
let mut buffer = WalBuffer::new(4096, 1024, 4096);
buffer.reserve(100).expect("expected value");
let stats = buffer.stats();
assert_eq!(stats.capacity, 1024);
assert!(stats.used > 0);
}
#[test]
fn test_wal_buffer_discard_pending() {
let (mut pager, _temp) = create_test_pager();
let mut buffer = WalBuffer::new(4096, 4 * 4096, 4096);
let record = WalRecord::new(WalRecordType::Begin, 1, Vec::new());
buffer
.write_record(&record, &mut pager)
.expect("expected value");
assert!(buffer.has_pending_writes());
buffer.discard_pending();
assert!(!buffer.has_pending_writes());
}
#[test]
fn test_dual_region_switch() {
let (mut pager, _temp) = create_test_pager();
let mut buffer = WalBuffer::new(4096, 4 * 4096, 4096);
assert_eq!(buffer.active_region(), 0);
let record1 = WalRecord::new(WalRecordType::Begin, 1, Vec::new());
buffer
.write_record(&record1, &mut pager)
.expect("expected value");
buffer.flush(&mut pager).expect("expected value");
let primary_head_before = buffer.primary_head();
assert!(primary_head_before > 0);
buffer.switch_to_secondary();
assert_eq!(buffer.active_region(), 1);
let record2 = WalRecord::new(WalRecordType::Begin, 2, Vec::new());
buffer
.write_record(&record2, &mut pager)
.expect("expected value");
buffer.flush(&mut pager).expect("expected value");
assert_eq!(buffer.primary_head(), primary_head_before);
assert!(buffer.secondary_head() > buffer.secondary_region_start);
}
#[test]
fn test_dual_region_merge() {
let (mut pager, _temp) = create_test_pager();
let mut buffer = WalBuffer::new(4096, 4 * 4096, 4096);
let record1 = WalRecord::new(
WalRecordType::CreateNode,
1,
build_create_node_payload(100, Some("node1")),
);
buffer
.write_record(&record1, &mut pager)
.expect("expected value");
buffer.flush(&mut pager).expect("expected value");
buffer.switch_to_secondary();
let record2 = WalRecord::new(
WalRecordType::CreateNode,
2,
build_create_node_payload(101, Some("node2")),
);
buffer
.write_record(&record2, &mut pager)
.expect("expected value");
buffer.flush(&mut pager).expect("expected value");
assert!(buffer.primary_head() > 0);
assert!(buffer.secondary_head() > buffer.secondary_region_start);
buffer
.merge_secondary_into_primary(&mut pager)
.expect("expected value");
buffer.flush(&mut pager).expect("expected value");
assert_eq!(buffer.active_region(), 0);
assert_eq!(buffer.tail(), 0);
let records = buffer.scan_records(&mut pager).expect("expected value");
assert_eq!(records.len(), 1); assert_eq!(records[0].txid, 2);
}
}