use nklave_core::state::integrity::StateIntegrity;
use nklave_core::state::validator::ValidatorState;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use thiserror::Error;
fn serialize_hash<S>(bytes: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&hex::encode(bytes))
}
fn deserialize_hash<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
where
D: Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let s = s.strip_prefix("0x").unwrap_or(&s);
let bytes = hex::decode(s).map_err(serde::de::Error::custom)?;
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(arr)
}
fn serialize_option_hash<S>(bytes: &Option<[u8; 32]>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match bytes {
Some(b) => serializer.serialize_some(&hex::encode(b)),
None => serializer.serialize_none(),
}
}
fn deserialize_option_hash<'de, D>(deserializer: D) -> Result<Option<[u8; 32]>, D::Error>
where
D: Deserializer<'de>,
{
let opt: Option<String> = Deserialize::deserialize(deserializer)?;
match opt {
Some(s) => {
let s = s.strip_prefix("0x").unwrap_or(&s);
let bytes = hex::decode(s).map_err(serde::de::Error::custom)?;
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
Ok(Some(arr))
}
None => Ok(None),
}
}
fn serialize_validators<S>(
validators: &HashMap<[u8; 48], ValidatorState>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(validators.len()))?;
for (k, v) in validators {
map.serialize_entry(&hex::encode(k), v)?;
}
map.end()
}
fn deserialize_validators<'de, D>(
deserializer: D,
) -> Result<HashMap<[u8; 48], ValidatorState>, D::Error>
where
D: Deserializer<'de>,
{
let string_map: HashMap<String, ValidatorState> = Deserialize::deserialize(deserializer)?;
let mut result = HashMap::new();
for (k, v) in string_map {
let k = k.strip_prefix("0x").unwrap_or(&k);
let bytes = hex::decode(k).map_err(serde::de::Error::custom)?;
let mut arr = [0u8; 48];
arr.copy_from_slice(&bytes);
result.insert(arr, v);
}
Ok(result)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub sequence: u64,
#[serde(serialize_with = "serialize_hash", deserialize_with = "deserialize_hash")]
pub state_hash: [u8; 32],
#[serde(serialize_with = "serialize_option_hash", deserialize_with = "deserialize_option_hash")]
pub genesis_validators_root: Option<[u8; 32]>,
#[serde(serialize_with = "serialize_validators", deserialize_with = "deserialize_validators")]
pub validators: HashMap<[u8; 48], ValidatorState>,
pub timestamp: u64,
}
impl Checkpoint {
pub fn new(
integrity: &StateIntegrity,
validators: HashMap<[u8; 48], ValidatorState>,
) -> Self {
Self {
sequence: integrity.sequence_number,
state_hash: integrity.current_hash,
genesis_validators_root: integrity.genesis_validators_root,
validators,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
pub fn save(&self, path: impl AsRef<Path>) -> Result<(), CheckpointError> {
let file = File::create(path).map_err(|e| CheckpointError::Io(e.to_string()))?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, self)
.map_err(|e| CheckpointError::Serialize(e.to_string()))?;
Ok(())
}
pub fn save_atomic(&self, path: impl AsRef<Path>, backup_count: u32) -> Result<(), CheckpointError> {
let path = path.as_ref();
let parent = path.parent().ok_or_else(|| {
CheckpointError::Io("Checkpoint path has no parent directory".to_string())
})?;
std::fs::create_dir_all(parent)
.map_err(|e| CheckpointError::Io(format!("Failed to create directory: {}", e)))?;
let temp_path = path.with_extension("json.tmp");
{
let file = File::create(&temp_path)
.map_err(|e| CheckpointError::Io(format!("Failed to create temp file: {}", e)))?;
let mut writer = BufWriter::new(file);
serde_json::to_writer_pretty(&mut writer, self)
.map_err(|e| CheckpointError::Serialize(e.to_string()))?;
writer.into_inner()
.map_err(|e| CheckpointError::Io(format!("Failed to flush temp file: {}", e)))?
.sync_all()
.map_err(|e| CheckpointError::Io(format!("Failed to sync temp file: {}", e)))?;
}
if path.exists() {
Self::rotate_backups(path, backup_count)?;
}
std::fs::rename(&temp_path, path)
.map_err(|e| CheckpointError::Io(format!("Failed to rename temp to checkpoint: {}", e)))?;
if let Ok(dir) = File::open(parent) {
let _ = dir.sync_all();
}
Ok(())
}
fn rotate_backups(path: &Path, backup_count: u32) -> Result<(), CheckpointError> {
if backup_count == 0 {
return Ok(());
}
let oldest = path.with_extension(format!("json.{}", backup_count));
if oldest.exists() {
std::fs::remove_file(&oldest)
.map_err(|e| CheckpointError::Io(format!("Failed to remove oldest backup: {}", e)))?;
}
for i in (1..backup_count).rev() {
let from = path.with_extension(format!("json.{}", i));
let to = path.with_extension(format!("json.{}", i + 1));
if from.exists() {
std::fs::rename(&from, &to)
.map_err(|e| CheckpointError::Io(format!("Failed to rotate backup: {}", e)))?;
}
}
let backup1 = path.with_extension("json.1");
std::fs::rename(path, &backup1)
.map_err(|e| CheckpointError::Io(format!("Failed to create backup: {}", e)))?;
Ok(())
}
pub fn load(path: impl AsRef<Path>) -> Result<Self, CheckpointError> {
let file = File::open(path).map_err(|e| CheckpointError::Io(e.to_string()))?;
let reader = BufReader::new(file);
serde_json::from_reader(reader).map_err(|e| CheckpointError::Parse(e.to_string()))
}
pub fn load_with_recovery(path: impl AsRef<Path>, backup_count: u32) -> Result<Self, CheckpointError> {
let path = path.as_ref();
if let Ok(checkpoint) = Self::load(path) {
return Ok(checkpoint);
}
for i in 1..=backup_count {
let backup_path = path.with_extension(format!("json.{}", i));
if let Ok(checkpoint) = Self::load(&backup_path) {
tracing::warn!(
backup = i,
"Recovered checkpoint from backup"
);
return Ok(checkpoint);
}
}
Err(CheckpointError::Io("No valid checkpoint found (tried primary and all backups)".to_string()))
}
pub fn restore_integrity(&self) -> StateIntegrity {
StateIntegrity::from_checkpoint(
self.state_hash,
self.sequence,
self.genesis_validators_root,
)
}
}
#[derive(Debug, Error)]
pub enum CheckpointError {
#[error("I/O error: {0}")]
Io(String),
#[error("Serialization error: {0}")]
Serialize(String),
#[error("Parse error: {0}")]
Parse(String),
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_checkpoint_save_load() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("checkpoint.json");
let integrity = StateIntegrity::new();
let validators = HashMap::new();
let checkpoint = Checkpoint::new(&integrity, validators);
checkpoint.save(&path).unwrap();
let loaded = Checkpoint::load(&path).unwrap();
assert_eq!(loaded.sequence, checkpoint.sequence);
assert_eq!(loaded.state_hash, checkpoint.state_hash);
}
}