Skip to main content

nklave_storage/
checkpoint.rs

1//! State checkpoints for fast recovery
2//!
3//! Checkpoints allow the enclave to skip replaying the entire log on startup
4
5use nklave_core::state::integrity::StateIntegrity;
6use nklave_core::state::validator::ValidatorState;
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8use std::collections::HashMap;
9use std::fs::File;
10use std::io::{BufReader, BufWriter};
11use std::path::Path;
12use thiserror::Error;
13
14/// Serialize [u8; 32] as hex
15fn serialize_hash<S>(bytes: &[u8; 32], serializer: S) -> Result<S::Ok, S::Error>
16where
17    S: Serializer,
18{
19    serializer.serialize_str(&hex::encode(bytes))
20}
21
22/// Deserialize [u8; 32] from hex
23fn deserialize_hash<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
24where
25    D: Deserializer<'de>,
26{
27    let s: String = Deserialize::deserialize(deserializer)?;
28    let s = s.strip_prefix("0x").unwrap_or(&s);
29    let bytes = hex::decode(s).map_err(serde::de::Error::custom)?;
30    let mut arr = [0u8; 32];
31    arr.copy_from_slice(&bytes);
32    Ok(arr)
33}
34
35/// Serialize Option<[u8; 32]> as hex
36fn serialize_option_hash<S>(bytes: &Option<[u8; 32]>, serializer: S) -> Result<S::Ok, S::Error>
37where
38    S: Serializer,
39{
40    match bytes {
41        Some(b) => serializer.serialize_some(&hex::encode(b)),
42        None => serializer.serialize_none(),
43    }
44}
45
46/// Deserialize Option<[u8; 32]> from hex
47fn deserialize_option_hash<'de, D>(deserializer: D) -> Result<Option<[u8; 32]>, D::Error>
48where
49    D: Deserializer<'de>,
50{
51    let opt: Option<String> = Deserialize::deserialize(deserializer)?;
52    match opt {
53        Some(s) => {
54            let s = s.strip_prefix("0x").unwrap_or(&s);
55            let bytes = hex::decode(s).map_err(serde::de::Error::custom)?;
56            let mut arr = [0u8; 32];
57            arr.copy_from_slice(&bytes);
58            Ok(Some(arr))
59        }
60        None => Ok(None),
61    }
62}
63
64/// Serialize HashMap<[u8; 48], ValidatorState> with hex keys
65fn serialize_validators<S>(
66    validators: &HashMap<[u8; 48], ValidatorState>,
67    serializer: S,
68) -> Result<S::Ok, S::Error>
69where
70    S: Serializer,
71{
72    use serde::ser::SerializeMap;
73    let mut map = serializer.serialize_map(Some(validators.len()))?;
74    for (k, v) in validators {
75        map.serialize_entry(&hex::encode(k), v)?;
76    }
77    map.end()
78}
79
80/// Deserialize HashMap<[u8; 48], ValidatorState> with hex keys
81fn deserialize_validators<'de, D>(
82    deserializer: D,
83) -> Result<HashMap<[u8; 48], ValidatorState>, D::Error>
84where
85    D: Deserializer<'de>,
86{
87    let string_map: HashMap<String, ValidatorState> = Deserialize::deserialize(deserializer)?;
88    let mut result = HashMap::new();
89    for (k, v) in string_map {
90        let k = k.strip_prefix("0x").unwrap_or(&k);
91        let bytes = hex::decode(k).map_err(serde::de::Error::custom)?;
92        let mut arr = [0u8; 48];
93        arr.copy_from_slice(&bytes);
94        result.insert(arr, v);
95    }
96    Ok(result)
97}
98
99/// A checkpoint of the enclave state
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct Checkpoint {
102    /// Sequence number at this checkpoint
103    pub sequence: u64,
104
105    /// State hash at this checkpoint
106    #[serde(serialize_with = "serialize_hash", deserialize_with = "deserialize_hash")]
107    pub state_hash: [u8; 32],
108
109    /// Genesis validators root (if set)
110    #[serde(serialize_with = "serialize_option_hash", deserialize_with = "deserialize_option_hash")]
111    pub genesis_validators_root: Option<[u8; 32]>,
112
113    /// All validator states
114    #[serde(serialize_with = "serialize_validators", deserialize_with = "deserialize_validators")]
115    pub validators: HashMap<[u8; 48], ValidatorState>,
116
117    /// Unix timestamp when checkpoint was created
118    pub timestamp: u64,
119}
120
121impl Checkpoint {
122    /// Create a new checkpoint from current state
123    pub fn new(
124        integrity: &StateIntegrity,
125        validators: HashMap<[u8; 48], ValidatorState>,
126    ) -> Self {
127        Self {
128            sequence: integrity.sequence_number,
129            state_hash: integrity.current_hash,
130            genesis_validators_root: integrity.genesis_validators_root,
131            validators,
132            timestamp: std::time::SystemTime::now()
133                .duration_since(std::time::UNIX_EPOCH)
134                .unwrap_or_default()
135                .as_secs(),
136        }
137    }
138
139    /// Save checkpoint to a file (simple, non-atomic)
140    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), CheckpointError> {
141        let file = File::create(path).map_err(|e| CheckpointError::Io(e.to_string()))?;
142        let writer = BufWriter::new(file);
143
144        serde_json::to_writer_pretty(writer, self)
145            .map_err(|e| CheckpointError::Serialize(e.to_string()))?;
146
147        Ok(())
148    }
149
150    /// Save checkpoint atomically with backup rotation
151    ///
152    /// This method:
153    /// 1. Writes to a temporary file
154    /// 2. Rotates existing checkpoint to backup
155    /// 3. Atomically renames temp to target
156    /// 4. Syncs the parent directory for durability
157    pub fn save_atomic(&self, path: impl AsRef<Path>, backup_count: u32) -> Result<(), CheckpointError> {
158        let path = path.as_ref();
159        let parent = path.parent().ok_or_else(|| {
160            CheckpointError::Io("Checkpoint path has no parent directory".to_string())
161        })?;
162
163        // Ensure parent directory exists
164        std::fs::create_dir_all(parent)
165            .map_err(|e| CheckpointError::Io(format!("Failed to create directory: {}", e)))?;
166
167        // Write to temporary file
168        let temp_path = path.with_extension("json.tmp");
169        {
170            let file = File::create(&temp_path)
171                .map_err(|e| CheckpointError::Io(format!("Failed to create temp file: {}", e)))?;
172            let mut writer = BufWriter::new(file);
173            serde_json::to_writer_pretty(&mut writer, self)
174                .map_err(|e| CheckpointError::Serialize(e.to_string()))?;
175            writer.into_inner()
176                .map_err(|e| CheckpointError::Io(format!("Failed to flush temp file: {}", e)))?
177                .sync_all()
178                .map_err(|e| CheckpointError::Io(format!("Failed to sync temp file: {}", e)))?;
179        }
180
181        // Rotate backups if the main checkpoint exists
182        if path.exists() {
183            Self::rotate_backups(path, backup_count)?;
184        }
185
186        // Atomic rename
187        std::fs::rename(&temp_path, path)
188            .map_err(|e| CheckpointError::Io(format!("Failed to rename temp to checkpoint: {}", e)))?;
189
190        // Sync parent directory for durability
191        if let Ok(dir) = File::open(parent) {
192            let _ = dir.sync_all();
193        }
194
195        Ok(())
196    }
197
198    /// Rotate backup files
199    ///
200    /// Shifts checkpoint.json.1 -> checkpoint.json.2, etc.
201    /// Then renames current checkpoint to checkpoint.json.1
202    fn rotate_backups(path: &Path, backup_count: u32) -> Result<(), CheckpointError> {
203        if backup_count == 0 {
204            return Ok(());
205        }
206
207        // Remove oldest backup if it exists
208        let oldest = path.with_extension(format!("json.{}", backup_count));
209        if oldest.exists() {
210            std::fs::remove_file(&oldest)
211                .map_err(|e| CheckpointError::Io(format!("Failed to remove oldest backup: {}", e)))?;
212        }
213
214        // Shift existing backups
215        for i in (1..backup_count).rev() {
216            let from = path.with_extension(format!("json.{}", i));
217            let to = path.with_extension(format!("json.{}", i + 1));
218            if from.exists() {
219                std::fs::rename(&from, &to)
220                    .map_err(|e| CheckpointError::Io(format!("Failed to rotate backup: {}", e)))?;
221            }
222        }
223
224        // Move current checkpoint to .1
225        let backup1 = path.with_extension("json.1");
226        std::fs::rename(path, &backup1)
227            .map_err(|e| CheckpointError::Io(format!("Failed to create backup: {}", e)))?;
228
229        Ok(())
230    }
231
232    /// Load checkpoint from a file
233    pub fn load(path: impl AsRef<Path>) -> Result<Self, CheckpointError> {
234        let file = File::open(path).map_err(|e| CheckpointError::Io(e.to_string()))?;
235        let reader = BufReader::new(file);
236
237        serde_json::from_reader(reader).map_err(|e| CheckpointError::Parse(e.to_string()))
238    }
239
240    /// Load checkpoint with fallback to backups
241    ///
242    /// Tries to load from the primary checkpoint, then falls back to backups
243    pub fn load_with_recovery(path: impl AsRef<Path>, backup_count: u32) -> Result<Self, CheckpointError> {
244        let path = path.as_ref();
245
246        // Try primary checkpoint
247        if let Ok(checkpoint) = Self::load(path) {
248            return Ok(checkpoint);
249        }
250
251        // Try backups in order
252        for i in 1..=backup_count {
253            let backup_path = path.with_extension(format!("json.{}", i));
254            if let Ok(checkpoint) = Self::load(&backup_path) {
255                tracing::warn!(
256                    backup = i,
257                    "Recovered checkpoint from backup"
258                );
259                return Ok(checkpoint);
260            }
261        }
262
263        Err(CheckpointError::Io("No valid checkpoint found (tried primary and all backups)".to_string()))
264    }
265
266    /// Restore state integrity from this checkpoint
267    pub fn restore_integrity(&self) -> StateIntegrity {
268        StateIntegrity::from_checkpoint(
269            self.state_hash,
270            self.sequence,
271            self.genesis_validators_root,
272        )
273    }
274}
275
276/// Errors related to checkpoints
277#[derive(Debug, Error)]
278pub enum CheckpointError {
279    #[error("I/O error: {0}")]
280    Io(String),
281
282    #[error("Serialization error: {0}")]
283    Serialize(String),
284
285    #[error("Parse error: {0}")]
286    Parse(String),
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use tempfile::TempDir;
293
294    #[test]
295    fn test_checkpoint_save_load() {
296        let dir = TempDir::new().unwrap();
297        let path = dir.path().join("checkpoint.json");
298
299        let integrity = StateIntegrity::new();
300        let validators = HashMap::new();
301
302        let checkpoint = Checkpoint::new(&integrity, validators);
303        checkpoint.save(&path).unwrap();
304
305        let loaded = Checkpoint::load(&path).unwrap();
306        assert_eq!(loaded.sequence, checkpoint.sequence);
307        assert_eq!(loaded.state_hash, checkpoint.state_hash);
308    }
309}