use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use crate::hlc::HybridLogicalClock;
use sochdb_core::{Result, SochDBError};
pub type Lsn = u64;
pub type PageId = u64;
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub max_wal_size: u64,
pub max_interval: Duration,
pub min_records: u64,
pub truncate_wal: bool,
pub enabled: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
max_wal_size: 64 * 1024 * 1024, max_interval: Duration::from_secs(60),
min_records: 100_000,
truncate_wal: true,
enabled: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActiveTransactionEntry {
pub txn_id: u64,
pub first_lsn: Lsn,
pub last_lsn: Lsn,
pub start_ts: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirtyPageEntry {
pub page_id: PageId,
pub recovery_lsn: Lsn,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointData {
pub checkpoint_id: u64,
pub begin_checkpoint_lsn: Lsn,
pub end_checkpoint_lsn: Lsn,
pub active_transactions: Vec<ActiveTransactionEntry>,
pub dirty_pages: Vec<DirtyPageEntry>,
pub timestamp: u64,
pub oldest_required_lsn: Lsn,
}
impl CheckpointData {
pub fn new(
checkpoint_id: u64,
begin_lsn: Lsn,
active_txns: Vec<ActiveTransactionEntry>,
dirty_pages: Vec<DirtyPageEntry>,
) -> Self {
let oldest_txn_lsn = active_txns.iter().map(|t| t.first_lsn).min().unwrap_or(Lsn::MAX);
let oldest_page_lsn = dirty_pages.iter().map(|p| p.recovery_lsn).min().unwrap_or(Lsn::MAX);
let oldest_required_lsn = oldest_txn_lsn.min(oldest_page_lsn).min(begin_lsn);
Self {
checkpoint_id,
begin_checkpoint_lsn: begin_lsn,
end_checkpoint_lsn: 0, active_transactions: active_txns,
dirty_pages,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_micros() as u64,
oldest_required_lsn,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMeta {
pub last_checkpoint: Option<CheckpointData>,
pub total_checkpoints: u64,
pub total_bytes_truncated: u64,
}
impl Default for CheckpointMeta {
fn default() -> Self {
Self {
last_checkpoint: None,
total_checkpoints: 0,
total_bytes_truncated: 0,
}
}
}
pub struct DirtyPageTracker {
dirty_pages: RwLock<HashMap<PageId, Lsn>>,
}
impl DirtyPageTracker {
pub fn new() -> Self {
Self {
dirty_pages: RwLock::new(HashMap::new()),
}
}
pub fn mark_dirty(&self, page_id: PageId, lsn: Lsn) {
let mut dirty = self.dirty_pages.write();
dirty.entry(page_id).or_insert(lsn);
}
pub fn mark_clean(&self, page_id: PageId) {
self.dirty_pages.write().remove(&page_id);
}
pub fn get_dirty_pages(&self) -> Vec<DirtyPageEntry> {
self.dirty_pages
.read()
.iter()
.map(|(&page_id, &recovery_lsn)| DirtyPageEntry { page_id, recovery_lsn })
.collect()
}
pub fn dirty_count(&self) -> usize {
self.dirty_pages.read().len()
}
}
impl Default for DirtyPageTracker {
fn default() -> Self {
Self::new()
}
}
pub struct ActiveTransactionTracker {
active_txns: RwLock<HashMap<u64, (Lsn, Lsn, u64)>>,
}
impl ActiveTransactionTracker {
pub fn new() -> Self {
Self {
active_txns: RwLock::new(HashMap::new()),
}
}
pub fn register(&self, txn_id: u64, start_ts: u64) {
self.active_txns
.write()
.insert(txn_id, (Lsn::MAX, 0, start_ts));
}
pub fn update_lsn(&self, txn_id: u64, lsn: Lsn) {
if let Some(entry) = self.active_txns.write().get_mut(&txn_id) {
if entry.0 == Lsn::MAX {
entry.0 = lsn; }
entry.1 = lsn; }
}
pub fn remove(&self, txn_id: u64) {
self.active_txns.write().remove(&txn_id);
}
pub fn get_active_transactions(&self) -> Vec<ActiveTransactionEntry> {
self.active_txns
.read()
.iter()
.filter(|(_, (first_lsn, _, _))| *first_lsn != Lsn::MAX)
.map(|(&txn_id, &(first_lsn, last_lsn, start_ts))| ActiveTransactionEntry {
txn_id,
first_lsn,
last_lsn,
start_ts,
})
.collect()
}
pub fn active_count(&self) -> usize {
self.active_txns.read().len()
}
}
impl Default for ActiveTransactionTracker {
fn default() -> Self {
Self::new()
}
}
pub struct CheckpointManager {
config: CheckpointConfig,
meta_path: PathBuf,
#[allow(dead_code)]
wal_dir: PathBuf,
meta: RwLock<CheckpointMeta>,
dirty_pages: Arc<DirtyPageTracker>,
active_txns: Arc<ActiveTransactionTracker>,
current_lsn: AtomicU64,
records_since_checkpoint: AtomicU64,
wal_bytes_since_checkpoint: AtomicU64,
last_checkpoint_time: Mutex<Instant>,
checkpoint_in_progress: AtomicBool,
next_checkpoint_id: AtomicU64,
#[allow(dead_code)]
hlc: Arc<HybridLogicalClock>,
}
impl CheckpointManager {
pub fn new(
data_dir: &Path,
config: CheckpointConfig,
dirty_pages: Arc<DirtyPageTracker>,
active_txns: Arc<ActiveTransactionTracker>,
hlc: Arc<HybridLogicalClock>,
) -> Result<Self> {
let meta_path = data_dir.join("checkpoint.meta");
let wal_dir = data_dir.join("wal");
fs::create_dir_all(&wal_dir)?;
let meta = if meta_path.exists() {
let data = fs::read(&meta_path)?;
bincode::deserialize(&data).unwrap_or_default()
} else {
CheckpointMeta::default()
};
let next_id = meta.last_checkpoint.as_ref().map(|c| c.checkpoint_id + 1).unwrap_or(1);
let last_lsn = meta.last_checkpoint.as_ref().map(|c| c.end_checkpoint_lsn).unwrap_or(0);
Ok(Self {
config,
meta_path,
wal_dir,
meta: RwLock::new(meta),
dirty_pages,
active_txns,
current_lsn: AtomicU64::new(last_lsn),
records_since_checkpoint: AtomicU64::new(0),
wal_bytes_since_checkpoint: AtomicU64::new(0),
last_checkpoint_time: Mutex::new(Instant::now()),
checkpoint_in_progress: AtomicBool::new(false),
next_checkpoint_id: AtomicU64::new(next_id),
hlc,
})
}
#[inline]
pub fn next_lsn(&self) -> Lsn {
self.current_lsn.fetch_add(1, Ordering::SeqCst)
}
pub fn record_wal_write(&self, bytes: u64) {
self.records_since_checkpoint.fetch_add(1, Ordering::Relaxed);
self.wal_bytes_since_checkpoint.fetch_add(bytes, Ordering::Relaxed);
}
pub fn should_checkpoint(&self) -> bool {
if !self.config.enabled {
return false;
}
if self.checkpoint_in_progress.load(Ordering::Relaxed) {
return false;
}
let records = self.records_since_checkpoint.load(Ordering::Relaxed);
let bytes = self.wal_bytes_since_checkpoint.load(Ordering::Relaxed);
let elapsed = self.last_checkpoint_time.lock().elapsed();
records >= self.config.min_records
|| bytes >= self.config.max_wal_size
|| elapsed >= self.config.max_interval
}
pub fn checkpoint<F>(&self, flush_dirty_pages: F) -> Result<CheckpointData>
where
F: FnOnce(&[DirtyPageEntry]) -> Result<()>,
{
if self
.checkpoint_in_progress
.compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
.is_err()
{
return Err(SochDBError::Internal("Checkpoint already in progress".into()));
}
struct CheckpointGuard<'a>(&'a AtomicBool);
impl<'a> Drop for CheckpointGuard<'a> {
fn drop(&mut self) {
self.0.store(false, Ordering::SeqCst);
}
}
let _guard = CheckpointGuard(&self.checkpoint_in_progress);
let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
let begin_lsn = self.next_lsn();
let active_txns = self.active_txns.get_active_transactions();
let dirty_pages = self.dirty_pages.get_dirty_pages();
let mut checkpoint = CheckpointData::new(checkpoint_id, begin_lsn, active_txns, dirty_pages.clone());
flush_dirty_pages(&dirty_pages)?;
for page in &dirty_pages {
self.dirty_pages.mark_clean(page.page_id);
}
let end_lsn = self.next_lsn();
checkpoint.end_checkpoint_lsn = end_lsn;
{
let mut meta = self.meta.write();
meta.last_checkpoint = Some(checkpoint.clone());
meta.total_checkpoints += 1;
let data = bincode::serialize(&*meta).map_err(|e| SochDBError::Serialization(e.to_string()))?;
fs::write(&self.meta_path, data)?;
}
self.records_since_checkpoint.store(0, Ordering::Relaxed);
self.wal_bytes_since_checkpoint.store(0, Ordering::Relaxed);
*self.last_checkpoint_time.lock() = Instant::now();
if self.config.truncate_wal {
self.truncate_wal(checkpoint.oldest_required_lsn)?;
}
Ok(checkpoint)
}
fn truncate_wal(&self, safe_lsn: Lsn) -> Result<()> {
let mut meta = self.meta.write();
if let Some(ref checkpoint) = meta.last_checkpoint {
let truncated = checkpoint.begin_checkpoint_lsn.saturating_sub(safe_lsn);
meta.total_bytes_truncated += truncated;
}
Ok(())
}
pub fn recovery_lsn(&self) -> Option<Lsn> {
self.meta
.read()
.last_checkpoint
.as_ref()
.map(|c| c.oldest_required_lsn)
}
pub fn last_checkpoint(&self) -> Option<CheckpointData> {
self.meta.read().last_checkpoint.clone()
}
pub fn stats(&self) -> CheckpointStats {
let meta = self.meta.read();
CheckpointStats {
total_checkpoints: meta.total_checkpoints,
total_bytes_truncated: meta.total_bytes_truncated,
records_since_checkpoint: self.records_since_checkpoint.load(Ordering::Relaxed),
wal_bytes_since_checkpoint: self.wal_bytes_since_checkpoint.load(Ordering::Relaxed),
dirty_pages: self.dirty_pages.dirty_count(),
active_transactions: self.active_txns.active_count(),
}
}
}
#[derive(Debug, Clone)]
pub struct CheckpointStats {
pub total_checkpoints: u64,
pub total_bytes_truncated: u64,
pub records_since_checkpoint: u64,
pub wal_bytes_since_checkpoint: u64,
pub dirty_pages: usize,
pub active_transactions: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_checkpoint_data_creation() {
let active_txns = vec![
ActiveTransactionEntry {
txn_id: 1,
first_lsn: 100,
last_lsn: 150,
start_ts: 1000,
},
ActiveTransactionEntry {
txn_id: 2,
first_lsn: 120,
last_lsn: 180,
start_ts: 1100,
},
];
let dirty_pages = vec![
DirtyPageEntry { page_id: 10, recovery_lsn: 90 },
DirtyPageEntry { page_id: 20, recovery_lsn: 110 },
];
let checkpoint = CheckpointData::new(1, 200, active_txns, dirty_pages);
assert_eq!(checkpoint.oldest_required_lsn, 90);
}
#[test]
fn test_dirty_page_tracker() {
let tracker = DirtyPageTracker::new();
tracker.mark_dirty(1, 100);
tracker.mark_dirty(2, 110);
tracker.mark_dirty(1, 120);
assert_eq!(tracker.dirty_count(), 2);
let pages = tracker.get_dirty_pages();
assert_eq!(pages.len(), 2);
let page1 = pages.iter().find(|p| p.page_id == 1).unwrap();
assert_eq!(page1.recovery_lsn, 100);
tracker.mark_clean(1);
assert_eq!(tracker.dirty_count(), 1);
}
#[test]
fn test_active_transaction_tracker() {
let tracker = ActiveTransactionTracker::new();
tracker.register(1, 1000);
tracker.update_lsn(1, 100);
tracker.update_lsn(1, 150);
tracker.register(2, 1100);
tracker.update_lsn(2, 120);
assert_eq!(tracker.active_count(), 2);
let txns = tracker.get_active_transactions();
assert_eq!(txns.len(), 2);
let txn1 = txns.iter().find(|t| t.txn_id == 1).unwrap();
assert_eq!(txn1.first_lsn, 100);
assert_eq!(txn1.last_lsn, 150);
tracker.remove(1);
assert_eq!(tracker.active_count(), 1);
}
#[test]
fn test_checkpoint_manager() -> Result<()> {
let temp_dir = TempDir::new().unwrap();
let dirty_pages = Arc::new(DirtyPageTracker::new());
let active_txns = Arc::new(ActiveTransactionTracker::new());
let hlc = Arc::new(HybridLogicalClock::new());
let manager = CheckpointManager::new(
temp_dir.path(),
CheckpointConfig::default(),
dirty_pages.clone(),
active_txns.clone(),
hlc,
)?;
dirty_pages.mark_dirty(1, manager.next_lsn());
dirty_pages.mark_dirty(2, manager.next_lsn());
active_txns.register(100, 1000);
active_txns.update_lsn(100, manager.next_lsn());
let checkpoint = manager.checkpoint(|_pages| Ok(()))?;
assert_eq!(checkpoint.checkpoint_id, 1);
assert_eq!(checkpoint.dirty_pages.len(), 2);
assert_eq!(checkpoint.active_transactions.len(), 1);
Ok(())
}
}