use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::time::SystemTime;
use arc_swap::ArcSwap;
use bytes::Bytes;
use d_engine_core::ApplyResult;
use d_engine_core::Error;
use d_engine_core::Lease;
use d_engine_core::StateMachine;
use d_engine_core::StorageError;
use d_engine_proto::client::WriteCommand;
use d_engine_proto::client::write_command::CompareAndSwap;
use d_engine_proto::client::write_command::Delete;
use d_engine_proto::client::write_command::Insert;
use d_engine_proto::client::write_command::Operation;
use d_engine_proto::common::Entry;
use d_engine_proto::common::LogId;
use d_engine_proto::common::entry_payload::Payload;
use d_engine_proto::server::storage::SnapshotMetadata;
use parking_lot::RwLock;
use prost::Message;
use rocksdb::Cache;
use rocksdb::DB;
use rocksdb::IteratorMode;
use rocksdb::Options;
use rocksdb::WriteBatch;
use tonic::async_trait;
use tracing::debug;
use tracing::error;
use tracing::info;
use tracing::instrument;
use tracing::warn;
use crate::storage::DefaultLease;
const STATE_MACHINE_CF: &str = "state_machine";
const STATE_MACHINE_META_CF: &str = "state_machine_meta";
const LAST_APPLIED_INDEX_KEY: &[u8] = b"last_applied_index";
const LAST_APPLIED_TERM_KEY: &[u8] = b"last_applied_term";
const SNAPSHOT_METADATA_KEY: &[u8] = b"snapshot_metadata";
const TTL_STATE_KEY: &[u8] = b"ttl_state";
#[derive(Debug)]
pub struct RocksDBStateMachine {
db: Arc<ArcSwap<DB>>,
db_path: PathBuf,
is_serving: AtomicBool,
last_applied_index: AtomicU64,
last_applied_term: AtomicU64,
last_snapshot_metadata: RwLock<Option<SnapshotMetadata>>,
lease: Option<Arc<DefaultLease>>,
lease_enabled: bool,
}
impl RocksDBStateMachine {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, Error> {
let db_path = path.as_ref().to_path_buf();
let opts = Self::configure_db_options();
let cfs = vec![STATE_MACHINE_CF, STATE_MACHINE_META_CF];
let db =
DB::open_cf(&opts, &db_path, cfs).map_err(|e| StorageError::DbError(e.to_string()))?;
let db_arc = Arc::new(db);
let (last_applied_index, last_applied_term) = Self::load_state_machine_metadata(&db_arc)?;
let last_snapshot_metadata = Self::load_snapshot_metadata(&db_arc)?;
Ok(Self {
db: Arc::new(ArcSwap::new(db_arc)),
db_path,
is_serving: AtomicBool::new(true),
last_applied_index: AtomicU64::new(last_applied_index),
last_applied_term: AtomicU64::new(last_applied_term),
last_snapshot_metadata: RwLock::new(last_snapshot_metadata),
lease: None, lease_enabled: false, })
}
pub fn set_lease(
&mut self,
lease: Arc<DefaultLease>,
) {
self.lease_enabled = true;
self.lease = Some(lease);
}
fn configure_db_options() -> Options {
let mut opts = Options::default();
opts.create_if_missing(true);
opts.create_missing_column_families(true);
opts.set_max_write_buffer_number(4);
opts.set_min_write_buffer_number_to_merge(2);
opts.set_write_buffer_size(128 * 1024 * 1024);
opts.set_compression_type(rocksdb::DBCompressionType::Lz4);
opts.set_bottommost_compression_type(rocksdb::DBCompressionType::Zstd);
opts.set_compression_options(-14, 0, 0, 0);
opts.set_wal_bytes_per_sync(1024 * 1024); opts.set_manual_wal_flush(true);
opts.set_use_fsync(false);
opts.set_max_background_jobs(4);
opts.set_max_open_files(5000);
opts.set_use_direct_io_for_flush_and_compaction(true);
opts.set_use_direct_reads(true);
opts.set_level_compaction_dynamic_level_bytes(true);
opts.set_target_file_size_base(64 * 1024 * 1024); opts.set_max_bytes_for_level_base(256 * 1024 * 1024);
let cache = Cache::new_lru_cache(128 * 1024 * 1024); opts.set_row_cache(&cache);
opts
}
fn open_db<P: AsRef<Path>>(path: P) -> Result<DB, Error> {
let opts = Self::configure_db_options();
let cfs = vec![STATE_MACHINE_CF, STATE_MACHINE_META_CF];
DB::open_cf(&opts, path, cfs).map_err(|e| StorageError::DbError(e.to_string()).into())
}
fn load_state_machine_metadata(db: &Arc<DB>) -> Result<(u64, u64), Error> {
let cf = db
.cf_handle(STATE_MACHINE_META_CF)
.ok_or_else(|| StorageError::DbError("State machine meta CF not found".to_string()))?;
let index = match db
.get_cf(&cf, LAST_APPLIED_INDEX_KEY)
.map_err(|e| StorageError::DbError(e.to_string()))?
{
Some(bytes) if bytes.len() == 8 => u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]),
_ => 0,
};
let term = match db
.get_cf(&cf, LAST_APPLIED_TERM_KEY)
.map_err(|e| StorageError::DbError(e.to_string()))?
{
Some(bytes) if bytes.len() == 8 => u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]),
_ => 0,
};
Ok((index, term))
}
fn load_snapshot_metadata(db: &Arc<DB>) -> Result<Option<SnapshotMetadata>, Error> {
let cf = db
.cf_handle(STATE_MACHINE_META_CF)
.ok_or_else(|| StorageError::DbError("State machine meta CF not found".to_string()))?;
match db
.get_cf(&cf, SNAPSHOT_METADATA_KEY)
.map_err(|e| StorageError::DbError(e.to_string()))?
{
Some(bytes) => {
let metadata = bincode::deserialize(&bytes).map_err(StorageError::BincodeError)?;
Ok(Some(metadata))
}
None => Ok(None),
}
}
fn persist_state_machine_metadata(&self) -> Result<(), Error> {
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_META_CF)
.ok_or_else(|| StorageError::DbError("State machine meta CF not found".to_string()))?;
let index = self.last_applied_index.load(Ordering::SeqCst);
let term = self.last_applied_term.load(Ordering::SeqCst);
db.put_cf(&cf, LAST_APPLIED_INDEX_KEY, index.to_be_bytes())
.map_err(|e| StorageError::DbError(e.to_string()))?;
db.put_cf(&cf, LAST_APPLIED_TERM_KEY, term.to_be_bytes())
.map_err(|e| StorageError::DbError(e.to_string()))?;
Ok(())
}
fn persist_snapshot_metadata(&self) -> Result<(), Error> {
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_META_CF)
.ok_or_else(|| StorageError::DbError("State machine meta CF not found".to_string()))?;
if let Some(metadata) = self.last_snapshot_metadata.read().clone() {
let bytes = bincode::serialize(&metadata).map_err(StorageError::BincodeError)?;
db.put_cf(&cf, SNAPSHOT_METADATA_KEY, bytes)
.map_err(|e| StorageError::DbError(e.to_string()))?;
}
Ok(())
}
fn persist_ttl_metadata(&self) -> Result<(), Error> {
if let Some(ref lease) = self.lease {
let db = self.db.load();
let cf = db.cf_handle(STATE_MACHINE_META_CF).ok_or_else(|| {
StorageError::DbError("State machine meta CF not found".to_string())
})?;
let ttl_snapshot = lease.to_snapshot();
db.put_cf(&cf, TTL_STATE_KEY, ttl_snapshot)
.map_err(|e| StorageError::DbError(e.to_string()))?;
debug!("Persisted TTL state to RocksDB");
}
Ok(())
}
pub async fn load_lease_data(&self) -> Result<(), Error> {
let Some(ref lease) = self.lease else {
return Ok(()); };
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_META_CF)
.ok_or_else(|| StorageError::DbError("State machine meta CF not found".to_string()))?;
match db
.get_cf(&cf, TTL_STATE_KEY)
.map_err(|e| StorageError::DbError(e.to_string()))?
{
Some(ttl_data) => {
lease.reload(&ttl_data)?;
debug!("Loaded TTL state from RocksDB: {} active TTLs", lease.len());
}
None => {
debug!("No TTL state found in RocksDB");
}
}
Ok(())
}
#[allow(dead_code)]
fn maybe_cleanup_expired(
&self,
max_duration_ms: u64,
) -> usize {
let start = std::time::Instant::now();
let now = SystemTime::now();
let mut deleted_count = 0;
if let Some(ref lease) = self.lease {
if !lease.has_lease_keys() {
return 0; }
if !lease.may_have_expired_keys(now) {
return 0; }
} else {
return 0; }
let db = self.db.load();
let cf = match db.cf_handle(STATE_MACHINE_CF) {
Some(cf) => cf,
None => {
error!("State machine CF not found during TTL cleanup");
return 0;
}
};
let max_duration = std::time::Duration::from_millis(max_duration_ms);
loop {
if start.elapsed() >= max_duration {
debug!(
"Piggyback cleanup time budget exceeded: deleted {} keys in {:?}",
deleted_count,
start.elapsed()
);
break;
}
let expired_keys = if let Some(ref lease) = self.lease {
lease.get_expired_keys(now)
} else {
vec![]
};
if expired_keys.is_empty() {
break; }
let mut batch = WriteBatch::default();
for key in expired_keys {
batch.delete_cf(&cf, &key);
deleted_count += 1;
}
if let Err(e) = db.write(batch) {
error!("Failed to delete expired keys: {}", e);
break;
}
}
if deleted_count > 0 {
debug!(
"Piggyback cleanup: deleted {} expired keys in {:?}",
deleted_count,
start.elapsed()
);
}
deleted_count
}
fn apply_batch(
&self,
batch: WriteBatch,
) -> Result<(), Error> {
self.db.load().write(batch).map_err(|e| StorageError::DbError(e.to_string()))?;
Ok(())
}
}
#[async_trait]
impl StateMachine for RocksDBStateMachine {
async fn start(&self) -> Result<(), Error> {
self.is_serving.store(true, Ordering::SeqCst);
if let Some(ref _lease) = self.lease {
self.load_lease_data().await?;
debug!("Lease data loaded during state machine initialization");
}
info!("RocksDB state machine started");
Ok(())
}
fn stop(&self) -> Result<(), Error> {
self.is_serving.store(false, Ordering::SeqCst);
if let Err(e) = self.persist_ttl_metadata() {
error!("Failed to persist TTL metadata on shutdown: {:?}", e);
return Err(e);
}
info!("RocksDB state machine stopped");
Ok(())
}
fn is_running(&self) -> bool {
self.is_serving.load(Ordering::SeqCst)
}
fn get(
&self,
key_buffer: &[u8],
) -> Result<Option<Bytes>, Error> {
if !self.is_serving.load(Ordering::SeqCst) {
return Err(StorageError::NotServing(
"State machine is restoring from snapshot".to_string(),
)
.into());
}
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_CF)
.ok_or_else(|| StorageError::DbError("State machine CF not found".to_string()))?;
match db.get_cf(&cf, key_buffer).map_err(|e| StorageError::DbError(e.to_string()))? {
Some(value) => Ok(Some(Bytes::copy_from_slice(&value))),
None => Ok(None),
}
}
fn entry_term(
&self,
_entry_id: u64,
) -> Option<u64> {
None
}
#[instrument(skip(self, chunk))]
async fn apply_chunk(
&self,
chunk: Vec<Entry>,
) -> Result<Vec<ApplyResult>, Error> {
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_CF)
.ok_or_else(|| StorageError::DbError("State machine CF not found".to_string()))?;
let mut batch = WriteBatch::default();
let mut highest_index_entry: Option<LogId> = None;
let mut results = Vec::with_capacity(chunk.len());
for entry in chunk {
assert!(entry.payload.is_some(), "Entry payload should not be None!");
if let Some(prev) = highest_index_entry {
assert!(
entry.index > prev.index,
"apply_chunk: received unordered entry at index {} (prev={})",
entry.index,
prev.index
);
}
highest_index_entry = Some(LogId {
index: entry.index,
term: entry.term,
});
match entry.payload.unwrap().payload {
Some(Payload::Noop(_)) => {
debug!("Handling NOOP command at index {}", entry.index);
results.push(ApplyResult::success(entry.index));
}
Some(Payload::Command(data)) => match WriteCommand::decode(&data[..]) {
Ok(write_cmd) => match write_cmd.operation {
Some(Operation::Insert(Insert {
key,
value,
ttl_secs,
})) => {
batch.put_cf(&cf, &key, &value);
if ttl_secs > 0 {
if !self.lease_enabled {
return Err(StorageError::FeatureNotEnabled(
"TTL feature is not enabled on this server. \
Enable it in config: [raft.state_machine.lease] enabled = true".into()
).into());
}
let lease = unsafe { self.lease.as_ref().unwrap_unchecked() };
lease.register(key.clone(), ttl_secs);
}
results.push(ApplyResult::success(entry.index));
}
Some(Operation::Delete(Delete { key })) => {
batch.delete_cf(&cf, &key);
if let Some(ref lease) = self.lease {
lease.unregister(&key);
}
results.push(ApplyResult::success(entry.index));
}
Some(Operation::CompareAndSwap(CompareAndSwap {
key,
expected_value,
new_value,
})) => {
let current_value = db.get_cf(&cf, &key).map_err(|e| {
StorageError::DbError(format!("CAS read failed: {e}"))
})?;
let cas_success = match (current_value, &expected_value) {
(Some(current), Some(expected)) => current == expected.as_ref(),
(None, None) => true,
_ => false,
};
if cas_success {
batch.put_cf(&cf, &key, &new_value);
}
results.push(if cas_success {
ApplyResult::success(entry.index)
} else {
ApplyResult::failure(entry.index)
});
debug!(
"CAS at index {}: key={:?}, success={}",
entry.index,
String::from_utf8_lossy(&key),
cas_success
);
}
None => {
warn!("WriteCommand without operation at index {}", entry.index);
}
},
Err(e) => {
error!(
"Failed to decode WriteCommand at index {}: {:?}",
entry.index, e
);
return Err(StorageError::SerializationError(e.to_string()).into());
}
},
Some(Payload::Config(_config_change)) => {
debug!("Ignoring config change at index {}", entry.index);
}
None => panic!("Entry payload variant should not be None!"),
}
}
self.apply_batch(batch)?;
if let Some(highest) = highest_index_entry {
self.update_last_applied(highest);
}
Ok(results)
}
fn len(&self) -> usize {
let db = self.db.load();
let cf = match db.cf_handle(STATE_MACHINE_CF) {
Some(cf) => cf,
None => return 0,
};
let iter = db.iterator_cf(&cf, IteratorMode::Start);
iter.count()
}
fn update_last_applied(
&self,
last_applied: LogId,
) {
self.last_applied_index.store(last_applied.index, Ordering::SeqCst);
self.last_applied_term.store(last_applied.term, Ordering::SeqCst);
}
fn last_applied(&self) -> LogId {
LogId {
index: self.last_applied_index.load(Ordering::SeqCst),
term: self.last_applied_term.load(Ordering::SeqCst),
}
}
fn persist_last_applied(
&self,
last_applied: LogId,
) -> Result<(), Error> {
self.update_last_applied(last_applied);
self.persist_state_machine_metadata()
}
fn update_last_snapshot_metadata(
&self,
snapshot_metadata: &SnapshotMetadata,
) -> Result<(), Error> {
*self.last_snapshot_metadata.write() = Some(snapshot_metadata.clone());
Ok(())
}
fn snapshot_metadata(&self) -> Option<SnapshotMetadata> {
self.last_snapshot_metadata.read().clone()
}
fn persist_last_snapshot_metadata(
&self,
snapshot_metadata: &SnapshotMetadata,
) -> Result<(), Error> {
self.update_last_snapshot_metadata(snapshot_metadata)?;
self.persist_snapshot_metadata()
}
#[instrument(skip(self))]
async fn apply_snapshot_from_file(
&self,
metadata: &SnapshotMetadata,
snapshot_dir: std::path::PathBuf,
) -> Result<(), Error> {
info!("Applying snapshot from checkpoint: {:?}", snapshot_dir);
self.is_serving.store(false, Ordering::SeqCst);
info!("Stopped serving requests for snapshot restoration");
{
let old_db = self.db.load();
old_db.flush().map_err(|e| StorageError::DbError(e.to_string()))?;
old_db.cancel_all_background_work(true);
info!("Flushed and stopped background work on old DB");
}
let temp_dir = tempfile::TempDir::new()?;
let temp_db_path = temp_dir.path().join("temp_db");
let temp_db = Self::open_db(&temp_db_path).map_err(|e| {
error!("Failed to create temporary DB: {:?}", e);
e
})?;
self.db.store(Arc::new(temp_db));
info!("Swapped to temporary DB, old DB lock released");
let backup_dir = self.db_path.with_extension("backup");
if backup_dir.exists() {
tokio::fs::remove_dir_all(&backup_dir).await?;
}
tokio::fs::rename(&self.db_path, &backup_dir).await?;
info!("Backed up current DB to: {:?}", backup_dir);
tokio::fs::rename(&snapshot_dir, &self.db_path).await.inspect_err(|_e| {
let _ = std::fs::rename(&backup_dir, &self.db_path);
})?;
info!("Moved checkpoint to DB path: {:?}", self.db_path);
let new_db = Self::open_db(&self.db_path).map_err(|e| {
let _ = std::fs::rename(&backup_dir, &self.db_path);
error!("Failed to open new DB, rolled back to backup: {:?}", e);
e
})?;
self.db.store(Arc::new(new_db));
info!("Atomically swapped to new DB instance");
if let Some(ref lease) = self.lease {
let ttl_path = self.db_path.join("ttl_state.bin");
if ttl_path.exists() {
let ttl_data = tokio::fs::read(&ttl_path).await?;
lease.reload(&ttl_data)?;
self.persist_ttl_metadata()?;
info!("Lease state restored from snapshot and persisted to metadata CF");
} else {
warn!("No lease state found in snapshot");
}
}
*self.last_snapshot_metadata.write() = Some(metadata.clone());
if let Some(last_included) = &metadata.last_included {
self.update_last_applied(*last_included);
}
self.is_serving.store(true, Ordering::SeqCst);
info!("Resumed serving requests");
if let Err(e) = tokio::fs::remove_dir_all(&backup_dir).await {
warn!("Failed to remove backup directory: {}", e);
} else {
info!("Cleaned up backup directory");
}
info!("Snapshot applied successfully - full DB restoration complete");
Ok(())
}
#[instrument(skip(self))]
async fn generate_snapshot_data(
&self,
new_snapshot_dir: std::path::PathBuf,
last_included: LogId,
) -> Result<Bytes, Error> {
{
let db = self.db.load();
let checkpoint = rocksdb::checkpoint::Checkpoint::new(db.as_ref())
.map_err(|e| StorageError::DbError(e.to_string()))?;
checkpoint
.create_checkpoint(&new_snapshot_dir)
.map_err(|e| StorageError::DbError(e.to_string()))?;
}
if let Some(ref lease) = self.lease {
let ttl_snapshot = lease.to_snapshot();
let ttl_path = new_snapshot_dir.join("ttl_state.bin");
tokio::fs::write(&ttl_path, ttl_snapshot).await?;
}
let checksum = [0; 32]; let snapshot_metadata = SnapshotMetadata {
last_included: Some(last_included),
checksum: Bytes::copy_from_slice(&checksum),
};
self.persist_last_snapshot_metadata(&snapshot_metadata)?;
info!("Snapshot generated at {:?} with TTL data", new_snapshot_dir);
Ok(Bytes::copy_from_slice(&checksum))
}
fn save_hard_state(&self) -> Result<(), Error> {
self.persist_state_machine_metadata()?;
self.persist_snapshot_metadata()?;
Ok(())
}
fn flush(&self) -> Result<(), Error> {
let db = self.db.load();
db.flush_wal(true).map_err(|e| StorageError::DbError(e.to_string()))?;
db.flush().map_err(|e| StorageError::DbError(e.to_string()))?;
self.persist_state_machine_metadata()?;
Ok(())
}
async fn flush_async(&self) -> Result<(), Error> {
self.flush()
}
#[instrument(skip(self))]
async fn reset(&self) -> Result<(), Error> {
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_CF)
.ok_or_else(|| StorageError::DbError("State machine CF not found".to_string()))?;
let mut batch = WriteBatch::default();
let iter = db.iterator_cf(&cf, IteratorMode::Start);
for item in iter {
let (key, _) = item.map_err(|e| StorageError::DbError(e.to_string()))?;
batch.delete_cf(&cf, &key);
}
db.write(batch).map_err(|e| StorageError::DbError(e.to_string()))?;
self.last_applied_index.store(0, Ordering::SeqCst);
self.last_applied_term.store(0, Ordering::SeqCst);
*self.last_snapshot_metadata.write() = None;
self.persist_state_machine_metadata()?;
self.persist_snapshot_metadata()?;
info!("RocksDB state machine reset completed");
Ok(())
}
async fn lease_background_cleanup(&self) -> Result<Vec<Bytes>, Error> {
let Some(ref lease) = self.lease else {
return Ok(vec![]);
};
let now = SystemTime::now();
let expired_keys = lease.get_expired_keys(now);
if expired_keys.is_empty() {
return Ok(vec![]);
}
debug!(
"Lease background cleanup: found {} expired keys",
expired_keys.len()
);
let db = self.db.load();
let cf = db
.cf_handle(STATE_MACHINE_CF)
.ok_or_else(|| StorageError::DbError("State machine CF not found".to_string()))?;
let mut batch = WriteBatch::default();
for key in &expired_keys {
batch.delete_cf(&cf, key);
}
self.apply_batch(batch)?;
info!(
"Lease background cleanup: deleted {} expired keys",
expired_keys.len()
);
Ok(expired_keys)
}
}
impl Drop for RocksDBStateMachine {
fn drop(&mut self) {
if let Err(e) = self.save_hard_state() {
error!("Failed to save hard state on drop: {}", e);
}
if let Err(e) = self.flush() {
error!("Failed to flush on drop: {}", e);
} else {
debug!("RocksDBStateMachine flushed successfully on drop");
}
self.db.load().cancel_all_background_work(true); debug!("RocksDB background work cancelled on drop");
}
}