use crate::ops::{CoreState, StateOp};
use crate::sync::{SyncMessage, SyncResponse};
use crate::types::{compare_entries, Entry, Metadata, NodeId, PeerState};
use crate::wal::WriteAheadLog;
use anyhow::{anyhow, bail, Context, Result};
use chrono::Utc;
use fs_err::{self as fs, File, OpenOptions};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::time::Duration;
use tokio::sync::watch;
use tracing::{debug, info, trace, warn};
const DEFAULT_MAX_LOG_ENTRIES: usize = 1000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeStatus {
pub id: NodeId,
pub n_kvs: usize,
pub next_seq: u64,
pub dirty: bool,
pub wal: bool,
pub peers: Vec<PeerStatus>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PeerStatus {
pub id: NodeId,
pub ack: u64,
pub pack: u64,
pub logs: usize,
}
pub struct NodeState {
pub id: NodeId,
core: CoreState,
wal: Option<WriteAheadLog>,
max_log_entries: usize,
snapshot_path: Option<PathBuf>,
watchers: Vec<Watcher>,
dirty: bool,
}
enum WatchPattern {
Exact(String),
Prefix(String),
}
struct Watcher {
pattern: WatchPattern,
sender: watch::Sender<()>,
}
#[derive(Debug, Serialize, Deserialize)]
struct SnapshotFile {
magic: [u8; 4],
version: u32,
node_id: NodeId,
core: CoreState,
}
impl SnapshotFile {
const VERSION: u32 = 1;
const MAGIC: [u8; 4] = *b"WVKV";
fn from_state(state: &NodeState) -> Self {
Self {
magic: Self::MAGIC,
version: Self::VERSION,
node_id: state.id,
core: state.core.clone(),
}
}
fn validate(&self, expected_node: NodeId) -> Result<()> {
if self.magic != Self::MAGIC {
bail!("Invalid snapshot magic header");
}
if self.version != Self::VERSION {
bail!(
"Unsupported snapshot version: expected {}, found {}",
Self::VERSION,
self.version
);
}
if self.node_id != expected_node {
bail!(
"Snapshot node_id mismatch: expected {}, found {}",
expected_node,
self.node_id
);
}
Ok(())
}
}
#[derive(Clone)]
pub struct Node {
state: Arc<RwLock<NodeState>>,
}
impl NodeState {
fn snapshot_path(&self) -> Result<&Path> {
self.snapshot_path
.as_deref()
.ok_or_else(|| anyhow!("Snapshot path not configured"))
}
fn load_snapshot_if_exists(&mut self) -> Result<bool> {
let Some(path) = self.snapshot_path.clone() else {
return Ok(false);
};
if !path.exists() {
return Ok(false);
}
let mut reader = BufReader::new(File::open(&path)?);
let mut buf = Vec::new();
reader.read_to_end(&mut buf)?;
let snapshot: SnapshotFile =
rmp_serde::from_slice(&buf).context("Failed to deserialize snapshot")?;
snapshot.validate(self.id)?;
self.core = snapshot.core;
self.dirty = false;
Ok(true)
}
pub fn persist_to_disk(&mut self) -> Result<()> {
let snapshot_path = self.snapshot_path()?.to_path_buf();
if let Some(parent) = snapshot_path.parent() {
fs::create_dir_all(parent)?;
}
let tmp_path = snapshot_path.with_extension("snapshot.tmp");
let snapshot = SnapshotFile::from_state(self);
let encoded = rmp_serde::to_vec(&snapshot)?;
{
let mut writer = BufWriter::new(
OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&tmp_path)?,
);
writer.write_all(&encoded)?;
writer.flush()?;
writer.get_ref().sync_all()?;
}
fs::rename(&tmp_path, &snapshot_path)?;
if let Some(parent) = snapshot_path.parent() {
if let Ok(dir_file) = File::open(parent) {
let _ = dir_file.sync_all();
}
}
if let Some(wal) = self.wal.as_mut() {
wal.reset()?;
}
info!("Persisted snapshot to {:?}", snapshot_path);
self.dirty = false;
Ok(())
}
fn notify_watchers(&mut self, key: &str) {
self.watchers.retain(|watcher| {
let matches = match &watcher.pattern {
WatchPattern::Exact(watch_key) => watch_key == key,
WatchPattern::Prefix(prefix) => key.starts_with(prefix),
};
if matches {
watcher.sender.send(()).is_ok()
} else {
true
}
});
}
pub fn watch_key(&mut self, key: &str) -> watch::Receiver<()> {
let (sender, receiver) = watch::channel(());
self.watchers.push(Watcher {
pattern: WatchPattern::Exact(key.to_string()),
sender,
});
receiver
}
pub fn watch_prefix(&mut self, prefix: &str) -> watch::Receiver<()> {
let (sender, receiver) = watch::channel(());
self.watchers.push(Watcher {
pattern: WatchPattern::Prefix(prefix.to_string()),
sender,
});
receiver
}
fn execute_ops(&mut self, ops: Vec<StateOp>) -> Result<()> {
self.execute_ops_impl(ops, true)
}
fn execute_ops_impl(&mut self, ops: Vec<StateOp>, write_to_wal: bool) -> Result<()> {
let ops: Vec<_> = ops
.into_iter()
.filter(|op| {
if self.core.is_noop(op) {
trace!("Skipping noop op: {op:?}");
false
} else {
true
}
})
.collect();
if ops.is_empty() {
return Ok(());
}
if write_to_wal {
if let Some(wal) = self.wal.as_mut() {
wal.write_ops(&ops)?;
}
}
for op in ops {
self.execute_op(op);
}
self.mark_dirty();
Ok(())
}
fn execute_op(&mut self, op: StateOp) {
let changed_key = if let StateOp::Set(ref entry) = op {
Some(entry.key.clone())
} else {
None
};
self.core.execute(op);
if let Some(key) = changed_key {
self.notify_watchers(&key);
}
}
fn replay_ops(&mut self, ops: Vec<StateOp>) -> Result<()> {
self.execute_ops_impl(ops, false)
}
pub fn sync(&mut self, entry: Entry) -> Result<bool> {
debug!(entry.key, "Syncing entry, meta: {:?}", entry.meta);
let should_update = if let Some(existing) = self.core.data().get(&entry.key) {
compare_entries(existing, &entry) == std::cmp::Ordering::Less
} else {
true
};
let mut ops = vec![StateOp::PushPeerLog {
peer_id: entry.meta.node,
entry: entry.clone(),
max_entries: self.max_log_entries,
}];
if should_update {
ops.push(StateOp::Set(entry));
}
self.execute_ops(ops)
.context("Failed to execute ops in sync")?;
Ok(should_update)
}
pub fn update_peer_ack(&mut self, peer_id: NodeId, ack_seq: u64) -> Result<()> {
self.execute_ops(vec![StateOp::UpdatePeerAck {
peer_id,
ack_seq,
monotonic: false,
}])
}
fn update_local_ack(&mut self, progress: &HashMap<NodeId, u64>) -> Result<()> {
let ops: Vec<StateOp> = progress
.iter()
.map(|(&peer_id, &ack_seq)| StateOp::UpdateLocalAck {
peer_id,
ack_seq,
monotonic: true,
})
.collect();
self.execute_ops(ops)
}
pub fn apply_pulled_entries(&mut self, sync_message: SyncResponse) -> Result<()> {
if sync_message.is_snapshot {
debug!("Applying pulled snapshot");
self.update_local_ack(&sync_message.progress)?;
}
for entry in sync_message.entries {
self.sync(entry).context("Failed to sync entry")?;
}
let peer_ack = sync_message.progress.get(&self.id).copied().unwrap_or(0);
self.update_peer_ack(sync_message.peer_id, peer_ack)?;
Ok(())
}
pub fn apply_pushed_entries(&mut self, sync_message: SyncMessage) -> Result<()> {
let Some(first) = sync_message.entries.first() else {
return Ok(());
};
let local_ack = self
.core
.peers()
.get(&sync_message.sender_id)
.map(|p| p.local_ack)
.unwrap_or(0);
let sender_id = sync_message.sender_id;
let expected_since = local_ack + 1;
let actual_since = first.meta.seq;
if actual_since > expected_since {
warn!(
sender_id,
expected_since, actual_since, "Received entries with gap"
);
return Ok(());
}
for entry in sync_message.entries {
self.sync(entry).context("Failed to sync entry")?;
}
Ok(())
}
fn alloc_entry_meta(&mut self) -> Metadata {
let seq = self.core.next_seq();
self.core.execute(StateOp::IncrementSeq);
let timestamp = Utc::now().timestamp_millis();
Metadata::new(self.id, seq, timestamp)
}
pub fn put(&mut self, key: String, value: impl Into<Vec<u8>>) -> Result<Entry> {
let value = value.into();
let meta = self.alloc_entry_meta();
let entry = Entry::new(key.clone(), Some(value), meta);
let ops = vec![
StateOp::PushPeerLog {
peer_id: self.id,
entry: entry.clone(),
max_entries: self.max_log_entries,
},
StateOp::Set(entry.clone()),
];
self.execute_ops(ops)?;
Ok(entry)
}
pub fn get(&self, key: &str) -> Option<Entry> {
self.core
.data()
.get(key)
.cloned()
.filter(|entry| entry.value.is_some())
}
pub fn get_including_tombstones(&self, key: &str) -> Option<Entry> {
self.core.data().get(key).cloned()
}
pub fn get_by_prefix(&self, prefix: &str) -> HashMap<String, Entry> {
self.iter_by_prefix(prefix)
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub fn iter_by_prefix<'a, 'b>(
&'a self,
prefix: &'b str,
) -> impl Iterator<Item = (&'a String, &'a Entry)> + use<'a, 'b> {
self.core
.data()
.range(prefix.to_string()..)
.take_while(move |(k, _)| k.starts_with(prefix))
.filter(|(_, v)| v.value.is_some())
}
pub fn get_all_including_tombstones(&self) -> HashMap<String, Entry> {
self.iter_all_including_tombstones()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
}
pub fn iter_all_including_tombstones(&self) -> impl Iterator<Item = (&String, &Entry)> {
self.core.data().iter()
}
pub fn delete(&mut self, key: String) -> Result<Option<Entry>> {
let meta = self.alloc_entry_meta();
let tombstone = Entry::new(key.clone(), None, meta);
let previous = self.core.data().get(&key).cloned();
let ops = vec![
StateOp::PushPeerLog {
peer_id: self.id,
entry: tombstone.clone(),
max_entries: self.max_log_entries,
},
StateOp::Set(tombstone),
];
self.execute_ops(ops)?;
Ok(previous)
}
pub fn get_peer_missing_logs(
&self,
peer_progress: &HashMap<NodeId, u64>,
) -> Option<Vec<Entry>> {
if peer_progress.is_empty() {
debug!("Peer has no progress, returning full dump");
return None;
}
let mut missing_entries = Vec::new();
for (node_id, peer_state) in self.core.peers().iter() {
let node_log = &peer_state.log;
let peer_ack = peer_progress.get(node_id).cloned().unwrap_or(0);
if let Some(oldest_entry) = node_log.front() {
if peer_ack < oldest_entry.meta.seq {
if node_id == &self.id {
debug!(
node_id,
peer_ack,
oldest_log = oldest_entry.meta.seq,
"Requested my seq has been truncated, need full dump"
);
return None;
} else {
debug!(
node_id,
peer_ack,
oldest_log = oldest_entry.meta.seq,
"Requested peer seq has been truncated, skipping"
);
continue;
}
}
}
for entry in node_log {
if entry.meta.seq > peer_ack {
missing_entries.push(entry.clone());
}
}
}
Some(missing_entries)
}
pub fn kv_to_log_entries(&self) -> Vec<Entry> {
self.core.data().values().cloned().collect()
}
pub fn get_local_ack(&self) -> HashMap<NodeId, u64> {
self.core
.peers()
.iter()
.map(|(id, peer_state)| (*id, peer_state.local_ack))
.collect()
}
pub fn get_peers(&self) -> Vec<NodeId> {
self.core
.peers()
.keys()
.filter(|&&id| id != self.id)
.copied()
.collect()
}
pub fn add_peer(&mut self, peer_id: NodeId) -> Result<bool> {
if self.core.peers().contains_key(&peer_id) {
return Ok(false);
}
let ops = vec![StateOp::AddPeer { peer_id }];
self.execute_ops(ops)?;
info!("Added peer node: {}", peer_id);
Ok(true)
}
pub fn remove_peer(&mut self, peer_id: NodeId) -> Result<bool> {
if peer_id == self.id || !self.core.peers().contains_key(&peer_id) {
return Ok(false);
}
let ops = vec![StateOp::RemovePeer { peer_id }];
self.execute_ops(ops)?;
info!("Removed peer node: {}", peer_id);
Ok(true)
}
pub fn get_all_nodes(&self) -> Vec<NodeId> {
self.core.peers().keys().copied().collect()
}
pub fn get_peer_state(&self, node_id: NodeId) -> Option<PeerState> {
self.core.peers().get(&node_id).cloned()
}
pub fn get_peer_logs_since(&self, node_id: NodeId, since: u64) -> Option<Vec<Entry>> {
let peer_state = self.core.peers().get(&node_id)?;
let log = &peer_state.log;
let since_index = log.iter().position(|entry| entry.meta.seq > since)?;
Some(log.iter().skip(since_index).cloned().collect())
}
pub fn cleanup_expired_tombstones(&mut self, ttl: Duration) -> Result<usize> {
let now = Utc::now().timestamp_millis();
let ttl_ms = ttl.as_millis() as i64;
let expired_keys = self
.core
.data()
.iter()
.flat_map(|(k, item)| {
if item.is_expired_tombstone(ttl_ms, now) {
Some(StateOp::Clear(k.clone()))
} else {
None
}
})
.collect::<Vec<_>>();
let removed = expired_keys.len();
self.execute_ops(expired_keys)?;
Ok(removed)
}
pub fn get_log_snapshot(&self) -> Vec<Entry> {
let mut all_entries = Vec::new();
for (_node_id, peer_state) in self.core.peers().iter() {
let entries = &peer_state.log;
all_entries.extend(entries.iter().cloned());
}
all_entries.sort_by_key(|e| (e.meta.timestamp, e.meta.node, e.meta.seq));
all_entries
}
pub fn ensure_next_seq(&mut self, min_next_seq: u64) {
let ops = vec![StateOp::SetNextSeq(min_next_seq)];
let _ = self.execute_ops(ops);
}
pub fn get_next_seq(&self) -> u64 {
self.core.next_seq()
}
pub fn get_all_peer_states(&self) -> &HashMap<NodeId, PeerState> {
self.core.peers()
}
fn persist_if_dirty(&mut self) -> Result<bool> {
if !self.dirty {
return Ok(false);
}
self.persist_to_disk()?;
Ok(true)
}
#[inline]
fn mark_dirty(&mut self) {
self.dirty = true;
}
pub fn status(&self) -> NodeStatus {
let peers: Vec<PeerStatus> = self
.core
.peers()
.iter()
.map(|(peer_id, peer_state)| PeerStatus {
id: *peer_id,
ack: peer_state.local_ack,
pack: peer_state.peer_ack,
logs: peer_state.log.len(),
})
.collect();
NodeStatus {
id: self.id,
n_kvs: self.core.data().len(),
next_seq: self.core.next_seq(),
dirty: self.dirty,
wal: self.wal.is_some(),
peers,
}
}
}
impl Node {
pub fn new(id: NodeId, peer_ids: Vec<NodeId>) -> Self {
let core = CoreState::new(id, peer_ids);
let state = NodeState {
id,
core,
wal: None, max_log_entries: DEFAULT_MAX_LOG_ENTRIES,
snapshot_path: None,
watchers: Vec::new(),
dirty: false,
};
Self {
state: Arc::new(RwLock::new(state)),
}
}
pub fn new_with_persistence<P: Into<PathBuf>>(
id: NodeId,
peers: Vec<NodeId>,
data_dir: P,
) -> Result<Self> {
let data_dir = data_dir.into();
let wal_path = data_dir.join(format!("node_{id}.wal"));
let snapshot_path = data_dir.join(format!("node_{id}.snapshot"));
if let Some(parent) = wal_path.parent() {
fs_err::create_dir_all(parent)?;
}
let wal = WriteAheadLog::new(&wal_path, id)?;
let existing_ops = wal.read_all_ops()?;
let core = CoreState::new(id, peers);
let mut state = NodeState {
id,
core,
wal: Some(wal),
max_log_entries: DEFAULT_MAX_LOG_ENTRIES,
snapshot_path: Some(snapshot_path.clone()),
watchers: Vec::new(),
dirty: false,
};
let loaded = state.load_snapshot_if_exists()?;
if loaded {
info!("Loaded snapshot from {}", snapshot_path.display());
let status = state.status();
info!("Node status: {:#?}", status);
}
if !existing_ops.is_empty() {
info!(
"Recovering {} state operations from WAL",
existing_ops.len()
);
state.replay_ops(existing_ops)?;
let status = state.status();
info!("Node status after recovery: {:#?}", status);
}
let store = Self {
state: Arc::new(RwLock::new(state)),
};
info!(
"Persistence enabled: WAL={:?}, snapshot={:?}",
wal_path, snapshot_path
);
Ok(store)
}
pub fn write(&self) -> RwLockWriteGuard<NodeState> {
self.state.write().expect("Failed to lock store state")
}
pub fn read(&self) -> RwLockReadGuard<NodeState> {
self.state.read().expect("Failed to lock store state")
}
pub fn persist(&self) -> Result<()> {
self.write().persist_to_disk()
}
pub fn persist_if_dirty(&self) -> Result<bool> {
self.write().persist_if_dirty()
}
pub fn watch(&self, key: &str) -> watch::Receiver<()> {
self.write().watch_key(key)
}
pub fn watch_prefix(&self, prefix: &str) -> watch::Receiver<()> {
self.write().watch_prefix(prefix)
}
}