1use nklave_core::state::integrity::StateIntegrity;
6use nklave_core::state::validator::ValidatorState;
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{BufReader, BufWriter};
11use std::path::Path;
12use thiserror::Error;
13
14fn serialize_hash<S>(bytes: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
16where
17 S: Serializer,
18{
19 serializer.serialize_str(&hex::encode(bytes))
20}
21
22fn deserialize_hash<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
24where
25 D: Deserializer<'de>,
26{
27 let s: String = Deserialize::deserialize(deserializer)?;
28 let s = s.strip_prefix("0x").unwrap_or(&s);
29 let bytes = hex::decode(s).map_err(serde::de::Error::custom)?;
30 let mut arr = [0u8; 32];
31 arr.copy_from_slice(&bytes);
32 Ok(arr)
33}
34
35fn serialize_option_hash<S>(bytes: &Option<[u8; 32]>, serializer: S) -> Result<S::Ok, S::Error>
37where
38 S: Serializer,
39{
40 match bytes {
41 Some(b) => serializer.serialize_some(&hex::encode(b)),
42 None => serializer.serialize_none(),
43 }
44}
45
46fn deserialize_option_hash<'de, D>(deserializer: D) -> Result<Option<[u8; 32]>, D::Error>
48where
49 D: Deserializer<'de>,
50{
51 let opt: Option<String> = Deserialize::deserialize(deserializer)?;
52 match opt {
53 Some(s) => {
54 let s = s.strip_prefix("0x").unwrap_or(&s);
55 let bytes = hex::decode(s).map_err(serde::de::Error::custom)?;
56 let mut arr = [0u8; 32];
57 arr.copy_from_slice(&bytes);
58 Ok(Some(arr))
59 }
60 None => Ok(None),
61 }
62}
63
64fn serialize_validators<S>(
66 validators: &HashMap<[u8; 48], ValidatorState>,
67 serializer: S,
68) -> Result<S::Ok, S::Error>
69where
70 S: Serializer,
71{
72 use serde::ser::SerializeMap;
73 let mut map = serializer.serialize_map(Some(validators.len()))?;
74 for (k, v) in validators {
75 map.serialize_entry(&hex::encode(k), v)?;
76 }
77 map.end()
78}
79
80fn deserialize_validators<'de, D>(
82 deserializer: D,
83) -> Result<HashMap<[u8; 48], ValidatorState>, D::Error>
84where
85 D: Deserializer<'de>,
86{
87 let string_map: HashMap<String, ValidatorState> = Deserialize::deserialize(deserializer)?;
88 let mut result = HashMap::new();
89 for (k, v) in string_map {
90 let k = k.strip_prefix("0x").unwrap_or(&k);
91 let bytes = hex::decode(k).map_err(serde::de::Error::custom)?;
92 let mut arr = [0u8; 48];
93 arr.copy_from_slice(&bytes);
94 result.insert(arr, v);
95 }
96 Ok(result)
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Checkpoint {
102 pub sequence: u64,
104
105 #[serde(serialize_with = "serialize_hash", deserialize_with = "deserialize_hash")]
107 pub state_hash: [u8; 32],
108
109 #[serde(serialize_with = "serialize_option_hash", deserialize_with = "deserialize_option_hash")]
111 pub genesis_validators_root: Option<[u8; 32]>,
112
113 #[serde(serialize_with = "serialize_validators", deserialize_with = "deserialize_validators")]
115 pub validators: HashMap<[u8; 48], ValidatorState>,
116
117 pub timestamp: u64,
119}
120
121impl Checkpoint {
122 pub fn new(
124 integrity: &StateIntegrity,
125 validators: HashMap<[u8; 48], ValidatorState>,
126 ) -> Self {
127 Self {
128 sequence: integrity.sequence_number,
129 state_hash: integrity.current_hash,
130 genesis_validators_root: integrity.genesis_validators_root,
131 validators,
132 timestamp: std::time::SystemTime::now()
133 .duration_since(std::time::UNIX_EPOCH)
134 .unwrap_or_default()
135 .as_secs(),
136 }
137 }
138
139 pub fn save(&self, path: impl AsRef<Path>) -> Result<(), CheckpointError> {
141 let file = File::create(path).map_err(|e| CheckpointError::Io(e.to_string()))?;
142 let writer = BufWriter::new(file);
143
144 serde_json::to_writer_pretty(writer, self)
145 .map_err(|e| CheckpointError::Serialize(e.to_string()))?;
146
147 Ok(())
148 }
149
150 pub fn save_atomic(&self, path: impl AsRef<Path>, backup_count: u32) -> Result<(), CheckpointError> {
158 let path = path.as_ref();
159 let parent = path.parent().ok_or_else(|| {
160 CheckpointError::Io("Checkpoint path has no parent directory".to_string())
161 })?;
162
163 std::fs::create_dir_all(parent)
165 .map_err(|e| CheckpointError::Io(format!("Failed to create directory: {}", e)))?;
166
167 let temp_path = path.with_extension("json.tmp");
169 {
170 let file = File::create(&temp_path)
171 .map_err(|e| CheckpointError::Io(format!("Failed to create temp file: {}", e)))?;
172 let mut writer = BufWriter::new(file);
173 serde_json::to_writer_pretty(&mut writer, self)
174 .map_err(|e| CheckpointError::Serialize(e.to_string()))?;
175 writer.into_inner()
176 .map_err(|e| CheckpointError::Io(format!("Failed to flush temp file: {}", e)))?
177 .sync_all()
178 .map_err(|e| CheckpointError::Io(format!("Failed to sync temp file: {}", e)))?;
179 }
180
181 if path.exists() {
183 Self::rotate_backups(path, backup_count)?;
184 }
185
186 std::fs::rename(&temp_path, path)
188 .map_err(|e| CheckpointError::Io(format!("Failed to rename temp to checkpoint: {}", e)))?;
189
190 if let Ok(dir) = File::open(parent) {
192 let _ = dir.sync_all();
193 }
194
195 Ok(())
196 }
197
198 fn rotate_backups(path: &Path, backup_count: u32) -> Result<(), CheckpointError> {
203 if backup_count == 0 {
204 return Ok(());
205 }
206
207 let oldest = path.with_extension(format!("json.{}", backup_count));
209 if oldest.exists() {
210 std::fs::remove_file(&oldest)
211 .map_err(|e| CheckpointError::Io(format!("Failed to remove oldest backup: {}", e)))?;
212 }
213
214 for i in (1..backup_count).rev() {
216 let from = path.with_extension(format!("json.{}", i));
217 let to = path.with_extension(format!("json.{}", i + 1));
218 if from.exists() {
219 std::fs::rename(&from, &to)
220 .map_err(|e| CheckpointError::Io(format!("Failed to rotate backup: {}", e)))?;
221 }
222 }
223
224 let backup1 = path.with_extension("json.1");
226 std::fs::rename(path, &backup1)
227 .map_err(|e| CheckpointError::Io(format!("Failed to create backup: {}", e)))?;
228
229 Ok(())
230 }
231
232 pub fn load(path: impl AsRef<Path>) -> Result<Self, CheckpointError> {
234 let file = File::open(path).map_err(|e| CheckpointError::Io(e.to_string()))?;
235 let reader = BufReader::new(file);
236
237 serde_json::from_reader(reader).map_err(|e| CheckpointError::Parse(e.to_string()))
238 }
239
240 pub fn load_with_recovery(path: impl AsRef<Path>, backup_count: u32) -> Result<Self, CheckpointError> {
244 let path = path.as_ref();
245
246 if let Ok(checkpoint) = Self::load(path) {
248 return Ok(checkpoint);
249 }
250
251 for i in 1..=backup_count {
253 let backup_path = path.with_extension(format!("json.{}", i));
254 if let Ok(checkpoint) = Self::load(&backup_path) {
255 tracing::warn!(
256 backup = i,
257 "Recovered checkpoint from backup"
258 );
259 return Ok(checkpoint);
260 }
261 }
262
263 Err(CheckpointError::Io("No valid checkpoint found (tried primary and all backups)".to_string()))
264 }
265
266 pub fn restore_integrity(&self) -> StateIntegrity {
268 StateIntegrity::from_checkpoint(
269 self.state_hash,
270 self.sequence,
271 self.genesis_validators_root,
272 )
273 }
274}
275
276#[derive(Debug, Error)]
278pub enum CheckpointError {
279 #[error("I/O error: {0}")]
280 Io(String),
281
282 #[error("Serialization error: {0}")]
283 Serialize(String),
284
285 #[error("Parse error: {0}")]
286 Parse(String),
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292 use tempfile::TempDir;
293
294 #[test]
295 fn test_checkpoint_save_load() {
296 let dir = TempDir::new().unwrap();
297 let path = dir.path().join("checkpoint.json");
298
299 let integrity = StateIntegrity::new();
300 let validators = HashMap::new();
301
302 let checkpoint = Checkpoint::new(&integrity, validators);
303 checkpoint.save(&path).unwrap();
304
305 let loaded = Checkpoint::load(&path).unwrap();
306 assert_eq!(loaded.sequence, checkpoint.sequence);
307 assert_eq!(loaded.state_hash, checkpoint.state_hash);
308 }
309}