use crate::SkipList;
use crate::db::IsolationLevel;
use crate::mem::MemSize;
use crate::transaction::Workspace;
use log::error;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::{self, File, OpenOptions};
use std::io::{self, BufReader, Seek, SeekFrom, Write};
use std::marker::PhantomData;
use std::path::PathBuf;
use std::sync::{
Arc, Condvar, Mutex,
atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
mpsc,
};
use std::thread::JoinHandle;
use std::time::Duration;
use tokio::sync::Notify;
#[cfg(unix)]
use rustix::fs::{FallocateFlags, fallocate};
pub const WAL_SEGMENT_SIZE_BYTES: u64 = 1 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct PersistenceOptions {
pub wal_path: PathBuf,
pub wal_pool_size: usize,
pub wal_segment_size_bytes: u64,
}
impl PersistenceOptions {
pub fn new(wal_path: PathBuf) -> Self {
Self {
wal_path,
wal_pool_size: 4,
wal_segment_size_bytes: 1 * 1024 * 1024, }
}
}
#[derive(Debug, Clone)]
pub enum DurabilityLevel {
InMemory,
Relaxed {
options: PersistenceOptions,
flush_interval_ms: Option<u64>,
flush_after_n_commits: Option<usize>,
flush_after_m_bytes: Option<u64>,
},
Full {
options: PersistenceOptions,
},
}
#[derive(Debug, PartialEq, Eq)]
pub enum WalSegmentState {
Writing,
PendingSnapshot,
Available,
}
#[derive(Debug)]
struct WalSegment {
file: Arc<Mutex<File>>,
state: Arc<(Mutex<WalSegmentState>, Condvar)>,
}
impl Clone for WalSegment {
fn clone(&self) -> Self {
Self {
file: self.file.clone(),
state: self.state.clone(),
}
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(bound(
serialize = "K: Eq + std::hash::Hash + Serialize + MemSize, V: Serialize + MemSize",
deserialize = "K: Eq + std::hash::Hash + Deserialize<'de> + MemSize, V: Deserialize<'de> + MemSize"
))]
struct SnapshotData<K, V> {
last_processed_segment_idx: Option<usize>,
data: HashMap<K, Arc<V>>,
}
impl<K: MemSize, V: MemSize> Default for SnapshotData<K, V> {
fn default() -> Self {
Self {
last_processed_segment_idx: None,
data: HashMap::new(),
}
}
}
#[derive(Debug)]
pub struct PersistenceEngine<K, V> {
config: DurabilityLevel,
wal_path: PathBuf,
snapshot_path: PathBuf,
wal_segments: Vec<WalSegment>,
current_segment_idx: Arc<AtomicUsize>,
writer_position: Arc<AtomicU64>,
shutdown: Arc<AtomicBool>,
flush_signal: Arc<Notify>,
pub commits_since_flush: Arc<AtomicUsize>,
pub bytes_since_flush: Arc<AtomicU64>,
snapshot_queue_tx: mpsc::Sender<usize>,
_snapshot_queue_rx: Arc<Mutex<mpsc::Receiver<usize>>>,
_flusher_handle: Option<JoinHandle<()>>,
_snapshotter_handle: Option<JoinHandle<()>>,
fatal_error: Arc<Mutex<Option<String>>>,
_phantom: PhantomData<(K, V)>,
}
fn preallocate_file(file: &File, size: u64) -> io::Result<()> {
#[cfg(unix)]
{
fallocate(file, FallocateFlags::empty(), 0, size)?;
}
#[cfg(windows)]
{
file.set_len(size)?;
}
#[cfg(not(any(unix, windows)))]
{
let _ = file;
let _ = size;
}
Ok(())
}
fn setup_wal_files(options: &PersistenceOptions) -> io::Result<Vec<WalSegment>> {
fs::create_dir_all(&options.wal_path)?;
let mut wal_segments = Vec::with_capacity(options.wal_pool_size);
for i in 0..options.wal_pool_size {
let segment_path = options.wal_path.join(format!("wal.{}", i));
let file = OpenOptions::new()
.write(true)
.create(true)
.open(&segment_path)?;
preallocate_file(&file, options.wal_segment_size_bytes)?;
let initial_state = if i == 0 {
WalSegmentState::Writing
} else {
WalSegmentState::Available
};
wal_segments.push(WalSegment {
file: Arc::new(Mutex::new(file)),
state: Arc::new((Mutex::new(initial_state), Condvar::new())),
});
}
Ok(wal_segments)
}
impl<K, V> PersistenceEngine<K, V>
where
K: Ord
+ Clone
+ Send
+ Sync
+ 'static
+ std::hash::Hash
+ Eq
+ Serialize
+ for<'de> Deserialize<'de>
+ MemSize
+ std::borrow::Borrow<str>,
V: Clone + Send + Sync + 'static + Serialize + for<'de> Deserialize<'de> + MemSize,
{
pub fn new(
config: DurabilityLevel,
fatal_error: Arc<Mutex<Option<String>>>,
) -> io::Result<Option<Self>> {
if let DurabilityLevel::InMemory = &config {
return Ok(None);
}
let options = match &config {
DurabilityLevel::Relaxed { options, .. } => options,
DurabilityLevel::Full { options } => options,
DurabilityLevel::InMemory => unreachable!(),
};
let snapshot_path = options.wal_path.join("snapshot.db");
let wal_segments = setup_wal_files(options)?;
let (tx, rx) = mpsc::channel();
let snapshot_queue_rx = Arc::new(Mutex::new(rx));
let current_segment_idx = Arc::new(AtomicUsize::new(0));
let shutdown = Arc::new(AtomicBool::new(false));
let flush_signal = Arc::new(Notify::new());
let commits_since_flush = Arc::new(AtomicUsize::new(0));
let bytes_since_flush = Arc::new(AtomicU64::new(0));
let segments_clone_snap = wal_segments.clone();
let rx_clone_snap = Arc::clone(&snapshot_queue_rx);
let snapshot_path_clone = snapshot_path.clone();
let wal_path_clone = options.wal_path.clone();
let shutdown_clone_snap = Arc::clone(&shutdown);
let fatal_error_clone_snap = fatal_error.clone();
let snapshotter_handle = std::thread::spawn(move || {
while !shutdown_clone_snap.load(Ordering::Relaxed) {
if fatal_error_clone_snap.lock().unwrap().is_some() {
break;
}
let segment_idx = match rx_clone_snap
.lock()
.expect("Snapshot queue mutex poisoned")
.recv_timeout(Duration::from_millis(100))
{
Ok(idx) => idx,
Err(mpsc::RecvTimeoutError::Timeout) => continue, Err(mpsc::RecvTimeoutError::Disconnected) => {
return;
}
};
let segment: &WalSegment = &segments_clone_snap[segment_idx];
{
let (lock, _) = &*segment.state;
let state = lock.lock().expect("WalSegment state mutex poisoned");
if *state != WalSegmentState::PendingSnapshot {
error!(
"Snapshotter: Segment {} in unexpected state: {:?}",
segment_idx, *state
);
continue;
}
}
let mut snapshot_data: SnapshotData<K, V> = if snapshot_path_clone.exists() {
match File::open(&snapshot_path_clone) {
Ok(file) => match ciborium::from_reader(file) {
Ok(data) => data,
Err(e) => {
let err_msg = format!(
"Snapshot file at {:?} is corrupted and could not be deserialized: {}. Manual intervention required.",
snapshot_path_clone, e
);
error!("FATAL: {}", err_msg);
*fatal_error_clone_snap.lock().unwrap() = Some(err_msg);
return; }
},
Err(e) => {
error!(
"Snapshotter: Failed to open existing snapshot file {:?}: {}",
snapshot_path_clone, e
);
SnapshotData::default()
}
}
} else {
SnapshotData::default()
};
let wal_file_path = wal_path_clone.join(format!("wal.{}", segment_idx));
if let Ok(wal_file) = File::open(&wal_file_path) {
let mut reader = BufReader::new(wal_file);
loop {
match ciborium::from_reader::<Workspace<K, V>, _>(&mut reader) {
Ok(workspace) => {
for (key, value) in workspace {
match value {
Some(val) => {
snapshot_data.data.insert(key, val);
}
None => {
snapshot_data.data.remove((&key).borrow());
}
}
}
}
Err(ciborium::de::Error::Io(ref e))
if e.kind() == io::ErrorKind::UnexpectedEof =>
{
break; }
Err(e) => {
error!(
"Snapshotter: Failed to deserialize workspace from WAL segment {:?}: {}",
wal_file_path, e
);
break; }
}
}
} else {
error!(
"Snapshotter: Failed to open WAL segment file {:?}.",
wal_file_path
);
}
snapshot_data.last_processed_segment_idx = Some(segment_idx);
let tmp_snapshot_path = snapshot_path_clone.with_extension("db.tmp");
let write_result = (|| {
let tmp_file = File::create(&tmp_snapshot_path)?;
ciborium::into_writer(&snapshot_data, tmp_file).map_err(|e| match e {
ciborium::ser::Error::Io(io_err) => io_err,
ciborium::ser::Error::Value(msg) => {
io::Error::new(io::ErrorKind::InvalidData, msg)
}
})?;
let file_to_sync = File::open(&tmp_snapshot_path)?;
file_to_sync.sync_all()?;
fs::rename(&tmp_snapshot_path, &snapshot_path_clone)?;
Ok::<(), io::Error>(())
})();
if let Err(e) = write_result {
let _ = fs::remove_file(&tmp_snapshot_path); let err_msg = format!(
"Snapshotter failed to write or rename snapshot file {:?}: {}. Shutting down to prevent data loss or deadlock.",
snapshot_path_clone, e
);
error!("FATAL: {}", err_msg);
*fatal_error_clone_snap.lock().unwrap() = Some(err_msg);
return; } else {
if let Ok(mut file) = segment.file.lock() {
if let Err(e) = file.set_len(0) {
error!(
"Snapshotter: Failed to truncate WAL segment file {}: {}",
segment_idx, e
);
}
if let Err(e) = file.seek(SeekFrom::Start(0)) {
error!(
"Snapshotter: Failed to seek WAL segment file {}: {}",
segment_idx, e
);
}
} else {
error!(
"Snapshotter: Failed to lock WAL segment file {} for truncation.",
segment_idx
);
}
let (lock, cvar) = &*segment.state;
let mut state = lock.lock().unwrap();
*state = WalSegmentState::Available;
cvar.notify_one();
}
}
});
let flusher_handle = if let DurabilityLevel::Relaxed {
flush_interval_ms, ..
} = config
{
let segments_clone_flush = wal_segments.clone();
let current_idx_clone_flush = Arc::clone(¤t_segment_idx);
let shutdown_clone_flush = Arc::clone(&shutdown);
let flush_signal_clone = Arc::clone(&flush_signal);
let commits_clone = Arc::clone(&commits_since_flush);
let bytes_clone = Arc::clone(&bytes_since_flush);
let fatal_error_clone_flush = fatal_error.clone();
Some(std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
runtime.block_on(async move {
loop {
if shutdown_clone_flush.load(Ordering::Relaxed)
|| fatal_error_clone_flush.lock().unwrap().is_some()
{
break;
}
let wait_future = flush_signal_clone.notified();
let timeout_future = if let Some(interval) = flush_interval_ms {
tokio::time::sleep(Duration::from_millis(interval))
} else {
tokio::time::sleep(Duration::from_secs(u64::MAX))
};
tokio::select! {
_ = wait_future => {
},
_ = timeout_future => {
}
};
if shutdown_clone_flush.load(Ordering::Relaxed) {
break;
}
let idx = current_idx_clone_flush.load(Ordering::Relaxed);
let segment = &segments_clone_flush[idx];
if let Ok(file) = segment.file.lock() {
if let Err(e) = file.sync_all() {
error!("Flusher thread failed to sync WAL segment {}: {}", idx, e);
} else {
commits_clone.store(0, Ordering::Relaxed);
bytes_clone.store(0, Ordering::Relaxed);
}
}
}
});
}))
} else {
None
};
Ok(Some(Self {
config: config.clone(),
wal_path: options.wal_path.clone(),
snapshot_path,
wal_segments,
current_segment_idx,
writer_position: Arc::new(AtomicU64::new(0)),
shutdown,
flush_signal,
commits_since_flush,
bytes_since_flush,
snapshot_queue_tx: tx,
_snapshot_queue_rx: snapshot_queue_rx,
_flusher_handle: flusher_handle,
_snapshotter_handle: Some(snapshotter_handle),
fatal_error,
_phantom: PhantomData,
}))
}
pub async fn recover(
&self,
current_memory_bytes: Arc<AtomicU64>,
access_clock: Arc<AtomicU64>,
p_factor: Option<f64>,
) -> Result<SkipList<K, V>, crate::error::FluxError> {
let skiplist = match p_factor {
Some(p) => SkipList::with_p(p, current_memory_bytes, access_clock),
None => SkipList::new(current_memory_bytes, access_clock),
};
let tx_manager = skiplist.transaction_manager();
let snapshot_data: SnapshotData<K, V> = if self.snapshot_path.exists() {
let file = File::open(&self.snapshot_path)?;
ciborium::from_reader(file).unwrap_or_default()
} else {
SnapshotData::default()
};
let last_snap_idx = snapshot_data.last_processed_segment_idx;
let tx = tx_manager.begin();
for (key, value) in snapshot_data.data {
let _ = skiplist.insert(key, value, &tx).await;
}
let wal_pool_size = match &self.config {
DurabilityLevel::Relaxed { options, .. } => options.wal_pool_size,
DurabilityLevel::Full { options } => options.wal_pool_size,
DurabilityLevel::InMemory => unreachable!(),
};
let mut wal_files = Vec::new();
for i in 0..wal_pool_size {
let path = self.wal_path.join(format!("wal.{}", i));
if path.exists() {
let meta = fs::metadata(&path)?;
if meta.len() > 0 {
wal_files.push((i, path, meta.modified()?));
}
}
}
wal_files.sort_by_key(|k| k.2);
let replay_order_indices: Vec<usize> = if let Some(last_idx) = last_snap_idx {
if let Some(pos) = wal_files.iter().position(|(idx, _, _)| *idx == last_idx) {
wal_files
.iter()
.skip(pos + 1)
.map(|(idx, _, _)| *idx)
.collect()
} else {
wal_files.iter().map(|(idx, _, _)| *idx).collect()
}
} else {
wal_files.iter().map(|(idx, _, _)| *idx).collect()
};
for idx in replay_order_indices {
let path = self.wal_path.join(format!("wal.{}", idx));
if let Ok(wal_file) = File::open(&path) {
let mut reader = BufReader::new(wal_file);
while let Ok(workspace) = ciborium::from_reader::<Workspace<K, V>, _>(&mut reader) {
for (key, value) in workspace {
match value {
Some(val) => {
let _ = skiplist.insert(key, val, &tx).await;
}
None => {
skiplist.remove(&key, &tx).await;
}
}
}
}
}
}
tx_manager.commit(&tx, || Ok(()), IsolationLevel::Serializable)?;
Ok(skiplist)
}
pub fn log(&self, data: &[u8]) -> io::Result<()> {
loop {
let current_pos = self.writer_position.load(Ordering::Relaxed);
let segment_size = match &self.config {
DurabilityLevel::Relaxed { options, .. } => options.wal_segment_size_bytes,
DurabilityLevel::Full { options } => options.wal_segment_size_bytes,
DurabilityLevel::InMemory => unreachable!(),
};
if current_pos + data.len() as u64 > segment_size {
self.rotate_wal()?;
continue;
}
let segment_idx = self.current_segment_idx.load(Ordering::Relaxed);
let segment = &self.wal_segments[segment_idx];
{
let mut file_lock = segment.file.lock().unwrap();
file_lock.seek(SeekFrom::Start(current_pos))?;
file_lock.write_all(data)?;
if let DurabilityLevel::Full { .. } = self.config {
file_lock.sync_all()?;
}
}
let bytes_written = data.len() as u64;
self.writer_position
.fetch_add(bytes_written, Ordering::Relaxed);
if let DurabilityLevel::Relaxed {
flush_after_n_commits,
flush_after_m_bytes,
..
} = &self.config
{
let old_commits = self.commits_since_flush.fetch_add(1, Ordering::Relaxed);
let old_bytes = self
.bytes_since_flush
.fetch_add(bytes_written, Ordering::Relaxed);
let mut should_flush = false;
if let Some(n) = flush_after_n_commits {
if old_commits + 1 >= *n {
should_flush = true;
}
}
if let Some(m) = flush_after_m_bytes {
if old_bytes + bytes_written >= *m {
should_flush = true;
}
}
if should_flush {
self.flush_signal.notify_one();
}
}
return Ok(());
}
}
fn rotate_wal(&self) -> io::Result<()> {
let wal_pool_size = match &self.config {
DurabilityLevel::Relaxed { options, .. } => options.wal_pool_size,
DurabilityLevel::Full { options } => options.wal_pool_size,
DurabilityLevel::InMemory => unreachable!(),
};
let current_idx = self.current_segment_idx.load(Ordering::Relaxed);
let next_idx = (current_idx + 1) % wal_pool_size;
let (next_lock, next_cvar) = &*self.wal_segments[next_idx].state;
let mut state = next_lock.lock().unwrap();
while *state != WalSegmentState::Available {
state = next_cvar.wait(state).unwrap();
}
*state = WalSegmentState::Writing;
drop(state);
let (old_lock, _) = &*self.wal_segments[current_idx].state;
*old_lock.lock().unwrap() = WalSegmentState::PendingSnapshot;
if let Err(e) = self.snapshot_queue_tx.send(current_idx) {
let err_msg = format!(
"Snapshotter thread has died. Cannot send segment {} for snapshotting. Error: {}",
current_idx, e
);
error!("FATAL: {}", err_msg);
*self.fatal_error.lock().unwrap() = Some(err_msg);
}
self.writer_position.store(0, Ordering::Relaxed);
self.current_segment_idx.store(next_idx, Ordering::Relaxed);
self.commits_since_flush.store(0, Ordering::Relaxed);
self.bytes_since_flush.store(0, Ordering::Relaxed);
let segment = &self.wal_segments[next_idx];
let mut file_lock = segment.file.lock().unwrap();
file_lock.set_len(0)?;
file_lock.seek(SeekFrom::Start(0))?;
Ok(())
}
}
impl<K, V> Drop for PersistenceEngine<K, V> {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Relaxed);
if let Some(handle) = self._flusher_handle.take() {
if let Err(e) = handle.join() {
error!("Flusher thread panicked: {:?}", e);
}
}
if let Some(handle) = self._snapshotter_handle.take() {
if let Err(e) = handle.join() {
error!("Snapshotter thread panicked: {:?}", e);
}
}
}
}