Skip to main content

raft_hpc_core/
state_machine.rs

1//! Generic Raft state machine with persistent snapshot support.
2//!
3//! The state machine is parameterized by:
4//! - `C: RaftTypeConfig` — the openraft type configuration
5//! - `S: StateMachineState<C>` — the application state type
6//!
7//! Snapshot management (build, install, persist, prune, load) is handled
8//! generically. The application provides the `apply()` logic via the
9//! `StateMachineState` trait.
10
11use std::io;
12use std::io::Cursor;
13use std::marker::PhantomData;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16
17use openraft::storage::{RaftStateMachine, Snapshot};
18use openraft::{
19    EntryPayload, LogId, OptionalSend, RaftSnapshotBuilder, RaftTypeConfig, SnapshotMeta,
20    StoredMembership,
21};
22use serde::de::DeserializeOwned;
23use serde::{Deserialize, Serialize};
24use tokio::sync::RwLock;
25use tracing::{debug, warn};
26
27use crate::StateMachineState;
28
29/// Maximum number of old snapshot files to keep.
30const MAX_SNAPSHOTS: usize = 3;
31
32/// The generic Raft state machine wrapping an application state.
33pub struct HpcStateMachine<C, S>
34where
35    C: RaftTypeConfig<SnapshotData = Cursor<Vec<u8>>>,
36    S: StateMachineState<C>,
37{
38    state: Arc<RwLock<S>>,
39    last_applied: Option<LogId<C>>,
40    last_membership: StoredMembership<C>,
41    snapshot_idx: u64,
42    snapshot_dir: Option<PathBuf>,
43    _phantom: PhantomData<C>,
44}
45
46impl<C, S> HpcStateMachine<C, S>
47where
48    C: RaftTypeConfig<Entry = openraft::Entry<C>, SnapshotData = Cursor<Vec<u8>>>,
49    S: StateMachineState<C>,
50    LogId<C>: Serialize + DeserializeOwned,
51    StoredMembership<C>: Serialize + DeserializeOwned,
52{
53    pub fn new(state: Arc<RwLock<S>>) -> Self {
54        Self {
55            state,
56            last_applied: None,
57            last_membership: StoredMembership::default(),
58            snapshot_idx: 0,
59            snapshot_dir: None,
60            _phantom: PhantomData,
61        }
62    }
63
64    /// Create a state machine with persistent snapshot directory.
65    ///
66    /// On startup, loads the latest snapshot from disk if available.
67    pub fn with_snapshot_dir(state: Arc<RwLock<S>>, snapshot_dir: PathBuf) -> io::Result<Self> {
68        std::fs::create_dir_all(&snapshot_dir)?;
69
70        let mut sm = Self {
71            state,
72            last_applied: None,
73            last_membership: StoredMembership::default(),
74            snapshot_idx: 0,
75            snapshot_dir: Some(snapshot_dir),
76            _phantom: PhantomData,
77        };
78
79        // Try to load the latest snapshot from disk
80        if let Some(ref dir) = sm.snapshot_dir {
81            if let Some((meta, app_state)) = load_latest_snapshot::<C, S>(dir)? {
82                debug!("Loaded snapshot from disk at {:?}", meta.last_log_id);
83                let mut state = sm.state.blocking_write();
84                *state = app_state;
85                sm.last_applied = meta.last_log_id;
86                sm.last_membership = meta.last_membership;
87                sm.snapshot_idx += 1;
88            }
89        }
90
91        Ok(sm)
92    }
93
94    /// Get a read handle to the application state (for queries).
95    pub fn state(&self) -> Arc<RwLock<S>> {
96        Arc::clone(&self.state)
97    }
98}
99
100/// Persisted snapshot format (stored alongside the state).
101#[derive(Serialize, Deserialize)]
102#[serde(bound(
103    serialize = "S: Serialize",
104    deserialize = "S: serde::de::DeserializeOwned"
105))]
106struct PersistedSnapshot<C: RaftTypeConfig, S> {
107    meta: PersistedSnapshotMeta<C>,
108    state: S,
109}
110
111#[derive(Serialize, Deserialize)]
112#[serde(bound = "")]
113struct PersistedSnapshotMeta<C: RaftTypeConfig> {
114    last_log_id: Option<LogId<C>>,
115    last_membership: StoredMembership<C>,
116    snapshot_id: String,
117}
118
119fn snapshot_filename<C: RaftTypeConfig>(meta: &SnapshotMeta<C>) -> String
120where
121    LogId<C>: Serialize,
122{
123    let (term, index) = meta.last_log_id.as_ref().map_or((0, 0), |log_id| {
124        // Extract term from CommittedLeaderId via serde, since it's an opaque
125        // associated type for generic C. Supports both adv mode (object with
126        // "term" field) and std mode (bare integer).
127        let term = serde_json::to_value(&log_id.leader_id)
128            .ok()
129            .and_then(|v| match v {
130                serde_json::Value::Number(n) => n.as_u64(),
131                serde_json::Value::Object(m) => m.get("term").and_then(serde_json::Value::as_u64),
132                _ => None,
133            })
134            .unwrap_or(0);
135        (term, log_id.index)
136    });
137    format!("snap-{term}-{index}.json")
138}
139
140fn persist_snapshot<C, S>(dir: &Path, meta: &SnapshotMeta<C>, data: &[u8]) -> io::Result<()>
141where
142    C: RaftTypeConfig,
143    S: DeserializeOwned + Serialize,
144    LogId<C>: Serialize,
145    StoredMembership<C>: Serialize + Clone,
146{
147    let filename = snapshot_filename::<C>(meta);
148    let path = dir.join(&filename);
149
150    // Parse the state to persist it in our format
151    let app_state: S =
152        serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
153
154    let persisted = PersistedSnapshot::<C, S> {
155        meta: PersistedSnapshotMeta {
156            last_log_id: meta.last_log_id.clone(),
157            last_membership: meta.last_membership.clone(),
158            snapshot_id: meta.snapshot_id.clone(),
159        },
160        state: app_state,
161    };
162
163    let json = serde_json::to_vec_pretty(&persisted).map_err(io::Error::other)?;
164
165    // Atomic write
166    let tmp = path.with_extension("tmp");
167    std::fs::write(&tmp, &json)?;
168    if let Ok(f) = std::fs::File::open(&tmp) {
169        let _ = f.sync_all();
170    }
171    std::fs::rename(&tmp, &path)?;
172
173    // Update "current" pointer
174    let current = dir.join("current");
175    let _ = std::fs::remove_file(&current);
176    std::fs::write(&current, filename.as_bytes())?;
177
178    // Prune old snapshots
179    prune_old_snapshots(dir)?;
180
181    debug!("Persisted snapshot to {}", path.display());
182    Ok(())
183}
184
185fn prune_old_snapshots(dir: &Path) -> io::Result<()> {
186    let mut snaps: Vec<(PathBuf, u64)> = Vec::new();
187
188    for entry in std::fs::read_dir(dir)? {
189        let entry = entry?;
190        let name = entry.file_name();
191        let name_str = name.to_string_lossy();
192        if name_str.starts_with("snap-") && name_str.ends_with(".json") {
193            let parts: Vec<&str> = name_str
194                .trim_start_matches("snap-")
195                .trim_end_matches(".json")
196                .split('-')
197                .collect();
198            if parts.len() == 2 {
199                if let Ok(index) = parts[1].parse::<u64>() {
200                    snaps.push((entry.path(), index));
201                }
202            }
203        }
204    }
205
206    snaps.sort_by(|a, b| b.1.cmp(&a.1));
207
208    for (path, _) in snaps.iter().skip(MAX_SNAPSHOTS) {
209        debug!("Pruning old snapshot: {}", path.display());
210        let _ = std::fs::remove_file(path);
211    }
212
213    Ok(())
214}
215
216fn load_latest_snapshot<C, S>(dir: &Path) -> io::Result<Option<(SnapshotMeta<C>, S)>>
217where
218    C: RaftTypeConfig,
219    S: DeserializeOwned,
220    LogId<C>: DeserializeOwned,
221    StoredMembership<C>: DeserializeOwned,
222{
223    let current_path = dir.join("current");
224    if !current_path.exists() {
225        return Ok(None);
226    }
227
228    let filename = std::fs::read_to_string(&current_path)?.trim().to_string();
229    let snap_path = dir.join(&filename);
230
231    if !snap_path.exists() {
232        warn!("Current snapshot file {} not found", snap_path.display());
233        return Ok(None);
234    }
235
236    let data = std::fs::read_to_string(&snap_path)?;
237    let persisted: PersistedSnapshot<C, S> =
238        serde_json::from_str(&data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
239
240    let meta = SnapshotMeta {
241        last_log_id: persisted.meta.last_log_id,
242        last_membership: persisted.meta.last_membership,
243        snapshot_id: persisted.meta.snapshot_id,
244    };
245
246    Ok(Some((meta, persisted.state)))
247}
248
249impl<C, S> RaftStateMachine<C> for HpcStateMachine<C, S>
250where
251    C: RaftTypeConfig<Entry = openraft::Entry<C>, SnapshotData = Cursor<Vec<u8>>>,
252    S: StateMachineState<C>,
253    LogId<C>: Serialize + DeserializeOwned,
254    StoredMembership<C>: Serialize + DeserializeOwned + Clone,
255{
256    type SnapshotBuilder = HpcSnapshotBuilder<C, S>;
257
258    async fn applied_state(
259        &mut self,
260    ) -> Result<(Option<LogId<C>>, StoredMembership<C>), io::Error> {
261        Ok((self.last_applied.clone(), self.last_membership.clone()))
262    }
263
264    async fn apply<Strm>(&mut self, entries: Strm) -> Result<(), io::Error>
265    where
266        Strm: futures::Stream<Item = Result<openraft::storage::EntryResponder<C>, io::Error>>
267            + Unpin
268            + OptionalSend,
269    {
270        use futures::StreamExt;
271
272        let mut stream = entries;
273        while let Some(item) = stream.next().await {
274            let (entry, responder) = item?;
275
276            self.last_applied = Some(entry.log_id.clone());
277
278            let response = match entry.payload {
279                EntryPayload::Blank => S::blank_response(),
280                EntryPayload::Normal(cmd) => {
281                    let mut state = self.state.write().await;
282                    state.apply(cmd)
283                }
284                EntryPayload::Membership(mem) => {
285                    self.last_membership = StoredMembership::new(self.last_applied.clone(), mem);
286                    S::blank_response()
287                }
288            };
289
290            if let Some(r) = responder {
291                r.send(response);
292            }
293        }
294
295        Ok(())
296    }
297
298    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
299        HpcSnapshotBuilder {
300            state: Arc::clone(&self.state),
301            last_applied: self.last_applied.clone(),
302            last_membership: self.last_membership.clone(),
303            snapshot_idx: self.snapshot_idx,
304            snapshot_dir: self.snapshot_dir.clone(),
305            _phantom: PhantomData,
306        }
307    }
308
309    async fn begin_receiving_snapshot(&mut self) -> Result<Cursor<Vec<u8>>, io::Error> {
310        Ok(Cursor::new(Vec::new()))
311    }
312
313    async fn install_snapshot(
314        &mut self,
315        meta: &SnapshotMeta<C>,
316        snapshot: Cursor<Vec<u8>>,
317    ) -> Result<(), io::Error> {
318        let data = snapshot.into_inner();
319        let new_state: S = serde_json::from_slice(&data)
320            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
321
322        // Persist snapshot to disk if configured
323        if let Some(ref dir) = self.snapshot_dir {
324            persist_snapshot::<C, S>(dir, meta, &data)?;
325        }
326
327        let mut state = self.state.write().await;
328        *state = new_state;
329
330        self.last_applied.clone_from(&meta.last_log_id);
331        self.last_membership.clone_from(&meta.last_membership);
332        self.snapshot_idx += 1;
333
334        debug!("Installed snapshot at {:?}", meta.last_log_id);
335        Ok(())
336    }
337
338    async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<C>>, io::Error> {
339        // If we have a snapshot dir, try loading from disk first (cold start case)
340        if self.last_applied.is_none() {
341            if let Some(ref dir) = self.snapshot_dir {
342                if let Some((meta, app_state)) = load_latest_snapshot::<C, S>(dir)? {
343                    let data = serde_json::to_vec(&app_state).map_err(io::Error::other)?;
344                    let mut state = self.state.write().await;
345                    *state = app_state;
346                    self.last_applied.clone_from(&meta.last_log_id);
347                    self.last_membership.clone_from(&meta.last_membership);
348                    self.snapshot_idx += 1;
349
350                    return Ok(Some(Snapshot {
351                        meta,
352                        snapshot: Cursor::new(data),
353                    }));
354                }
355            }
356        }
357
358        let state = self.state.read().await;
359        let data = serde_json::to_vec(&*state).map_err(io::Error::other)?;
360
361        if self.last_applied.is_none() {
362            return Ok(None);
363        }
364
365        let snapshot = Snapshot {
366            meta: SnapshotMeta {
367                last_log_id: self.last_applied.clone(),
368                last_membership: self.last_membership.clone(),
369                snapshot_id: format!("snap-{}", self.snapshot_idx),
370            },
371            snapshot: Cursor::new(data),
372        };
373
374        Ok(Some(snapshot))
375    }
376}
377
378/// Builds snapshots from the current application state.
379pub struct HpcSnapshotBuilder<C, S>
380where
381    C: RaftTypeConfig,
382    S: StateMachineState<C>,
383{
384    state: Arc<RwLock<S>>,
385    last_applied: Option<LogId<C>>,
386    last_membership: StoredMembership<C>,
387    snapshot_idx: u64,
388    snapshot_dir: Option<PathBuf>,
389    _phantom: PhantomData<C>,
390}
391
392impl<C, S> RaftSnapshotBuilder<C> for HpcSnapshotBuilder<C, S>
393where
394    C: RaftTypeConfig<Entry = openraft::Entry<C>, SnapshotData = Cursor<Vec<u8>>>,
395    S: StateMachineState<C>,
396    LogId<C>: Serialize + DeserializeOwned,
397    StoredMembership<C>: Serialize + DeserializeOwned + Clone,
398{
399    async fn build_snapshot(&mut self) -> Result<Snapshot<C>, io::Error> {
400        let state = self.state.read().await;
401        let data = serde_json::to_vec(&*state).map_err(io::Error::other)?;
402
403        self.snapshot_idx += 1;
404
405        let meta = SnapshotMeta {
406            last_log_id: self.last_applied.clone(),
407            last_membership: self.last_membership.clone(),
408            snapshot_id: format!("snap-{}", self.snapshot_idx),
409        };
410
411        // Persist snapshot to disk if configured
412        if let Some(ref dir) = self.snapshot_dir {
413            persist_snapshot::<C, S>(dir, &meta, &data)?;
414        }
415
416        let snapshot = Snapshot {
417            meta,
418            snapshot: Cursor::new(data),
419        };
420
421        debug!("Built snapshot at {:?}", self.last_applied);
422        Ok(snapshot)
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use crate::test_types::*;
430    use openraft::storage::RaftStateMachine;
431    use openraft::vote::RaftLeaderId;
432    use openraft::vote::leader_id_adv::CommittedLeaderId;
433
434    fn make_log_id(term: u64, node: u64, index: u64) -> LogId<TestTypeConfig> {
435        LogId::new(CommittedLeaderId::new(term, node), index)
436    }
437
438    #[tokio::test]
439    async fn new_state_machine_initial_state() {
440        let state = Arc::new(RwLock::new(TestState::default()));
441        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
442        let (last_applied, membership) = sm.applied_state().await.unwrap();
443        assert!(last_applied.is_none());
444        assert!(membership.log_id().is_none());
445    }
446
447    #[tokio::test]
448    async fn begin_receiving_snapshot_returns_empty_cursor() {
449        let state = Arc::new(RwLock::new(TestState::default()));
450        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
451        let cursor = sm.begin_receiving_snapshot().await.unwrap();
452        assert!(cursor.into_inner().is_empty());
453    }
454
455    #[tokio::test]
456    async fn install_snapshot_updates_state() {
457        let state = Arc::new(RwLock::new(TestState::default()));
458        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state.clone());
459
460        let new_state = TestState {
461            data: [("k".into(), "v".into())].into(),
462        };
463        let snapshot_data = serde_json::to_vec(&new_state).unwrap();
464        let meta = SnapshotMeta {
465            last_log_id: Some(make_log_id(1, 1, 5)),
466            last_membership: StoredMembership::default(),
467            snapshot_id: "snap-1".to_string(),
468        };
469
470        sm.install_snapshot(&meta, Cursor::new(snapshot_data))
471            .await
472            .unwrap();
473
474        let s = state.read().await;
475        assert_eq!(s.data.get("k").unwrap(), "v");
476
477        let (last_applied, _) = sm.applied_state().await.unwrap();
478        assert_eq!(last_applied.unwrap().index, 5);
479    }
480
481    #[tokio::test]
482    async fn get_current_snapshot_none_when_no_applied() {
483        let state = Arc::new(RwLock::new(TestState::default()));
484        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
485        let snap = sm.get_current_snapshot().await.unwrap();
486        assert!(snap.is_none());
487    }
488
489    #[tokio::test]
490    async fn get_current_snapshot_returns_data_after_install() {
491        let state = Arc::new(RwLock::new(TestState::default()));
492        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
493
494        let new_state = TestState {
495            data: [("x".into(), "y".into())].into(),
496        };
497        let snapshot_data = serde_json::to_vec(&new_state).unwrap();
498        let meta = SnapshotMeta {
499            last_log_id: Some(make_log_id(1, 1, 10)),
500            last_membership: StoredMembership::default(),
501            snapshot_id: "snap-1".to_string(),
502        };
503        sm.install_snapshot(&meta, Cursor::new(snapshot_data))
504            .await
505            .unwrap();
506
507        let snap = sm.get_current_snapshot().await.unwrap();
508        assert!(snap.is_some());
509        let snap = snap.unwrap();
510        assert_eq!(snap.meta.last_log_id.as_ref().unwrap().index, 10);
511
512        let loaded: TestState = serde_json::from_slice(&snap.snapshot.into_inner()).unwrap();
513        assert_eq!(loaded.data.get("x").unwrap(), "y");
514    }
515
516    #[tokio::test]
517    async fn get_snapshot_builder_and_build() {
518        let state = Arc::new(RwLock::new(TestState::default()));
519        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
520
521        // Install some state first
522        let new_state = TestState {
523            data: [("a".into(), "b".into())].into(),
524        };
525        let snapshot_data = serde_json::to_vec(&new_state).unwrap();
526        let meta = SnapshotMeta {
527            last_log_id: Some(make_log_id(1, 1, 3)),
528            last_membership: StoredMembership::default(),
529            snapshot_id: "snap-0".to_string(),
530        };
531        sm.install_snapshot(&meta, Cursor::new(snapshot_data))
532            .await
533            .unwrap();
534
535        let mut builder = sm.get_snapshot_builder().await;
536        let snap = builder.build_snapshot().await.unwrap();
537        assert_eq!(snap.meta.last_log_id.as_ref().unwrap().index, 3);
538
539        let loaded: TestState = serde_json::from_slice(&snap.snapshot.into_inner()).unwrap();
540        assert_eq!(loaded.data.get("a").unwrap(), "b");
541    }
542
543    #[test]
544    fn with_snapshot_dir_creates_directory() {
545        let dir = tempfile::tempdir().unwrap();
546        let snap_dir = dir.path().join("snapshots");
547        let state = Arc::new(RwLock::new(TestState::default()));
548        let _sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
549            state,
550            snap_dir.clone(),
551        )
552        .unwrap();
553        assert!(snap_dir.exists());
554    }
555
556    #[tokio::test]
557    async fn install_snapshot_persists_to_disk() {
558        let dir = tempfile::tempdir().unwrap();
559        let snap_dir = dir.path().join("snapshots");
560        let state = Arc::new(RwLock::new(TestState::default()));
561        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
562            state,
563            snap_dir.clone(),
564        )
565        .unwrap();
566
567        let new_state = TestState {
568            data: [("disk".into(), "test".into())].into(),
569        };
570        let snapshot_data = serde_json::to_vec(&new_state).unwrap();
571        let meta = SnapshotMeta {
572            last_log_id: Some(make_log_id(1, 1, 7)),
573            last_membership: StoredMembership::default(),
574            snapshot_id: "snap-1".to_string(),
575        };
576        sm.install_snapshot(&meta, Cursor::new(snapshot_data))
577            .await
578            .unwrap();
579
580        // Verify files written
581        assert!(snap_dir.join("current").exists());
582        let current = std::fs::read_to_string(snap_dir.join("current")).unwrap();
583        assert!(snap_dir.join(current.trim()).exists());
584    }
585
586    #[tokio::test]
587    async fn build_snapshot_persists_to_disk() {
588        let dir = tempfile::tempdir().unwrap();
589        let snap_dir = dir.path().join("snapshots");
590        let state = Arc::new(RwLock::new(TestState::default()));
591        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
592            state,
593            snap_dir.clone(),
594        )
595        .unwrap();
596
597        // Install state so we have something to snapshot
598        let new_state = TestState {
599            data: [("build".into(), "snap".into())].into(),
600        };
601        let snapshot_data = serde_json::to_vec(&new_state).unwrap();
602        let meta = SnapshotMeta {
603            last_log_id: Some(make_log_id(1, 1, 2)),
604            last_membership: StoredMembership::default(),
605            snapshot_id: "snap-0".to_string(),
606        };
607        sm.install_snapshot(&meta, Cursor::new(snapshot_data))
608            .await
609            .unwrap();
610
611        let mut builder = sm.get_snapshot_builder().await;
612        let _snap = builder.build_snapshot().await.unwrap();
613
614        assert!(snap_dir.join("current").exists());
615    }
616
617    #[tokio::test]
618    async fn load_latest_snapshot_roundtrip() {
619        let dir = tempfile::tempdir().unwrap();
620        let snap_dir = dir.path().join("snapshots");
621        let state = Arc::new(RwLock::new(TestState::default()));
622        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
623            state.clone(),
624            snap_dir.clone(),
625        )
626        .unwrap();
627
628        // Install a snapshot
629        let new_state = TestState {
630            data: [("load".into(), "test".into())].into(),
631        };
632        let snapshot_data = serde_json::to_vec(&new_state).unwrap();
633        let meta = SnapshotMeta {
634            last_log_id: Some(make_log_id(1, 1, 15)),
635            last_membership: StoredMembership::default(),
636            snapshot_id: "snap-1".to_string(),
637        };
638        sm.install_snapshot(&meta, Cursor::new(snapshot_data))
639            .await
640            .unwrap();
641
642        // Create a new state machine from the same dir in a blocking context
643        // (with_snapshot_dir uses blocking_write internally)
644        let snap_dir_clone = snap_dir.clone();
645        let fresh_state = tokio::task::spawn_blocking(move || {
646            let fresh_state = Arc::new(RwLock::new(TestState::default()));
647            let _sm2 = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
648                fresh_state.clone(),
649                snap_dir_clone,
650            )
651            .unwrap();
652            fresh_state
653        })
654        .await
655        .unwrap();
656
657        let s = fresh_state.read().await;
658        assert_eq!(s.data.get("load").unwrap(), "test");
659    }
660
661    #[tokio::test]
662    async fn prune_old_snapshots_keeps_max() {
663        let dir = tempfile::tempdir().unwrap();
664        let snap_dir = dir.path().join("snapshots");
665        let state = Arc::new(RwLock::new(TestState::default()));
666        let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
667            state,
668            snap_dir.clone(),
669        )
670        .unwrap();
671
672        // Install many snapshots to trigger pruning
673        for i in 1..=6u64 {
674            let new_state = TestState {
675                data: [(format!("k{i}"), format!("v{i}"))].into(),
676            };
677            let snapshot_data = serde_json::to_vec(&new_state).unwrap();
678            let meta = SnapshotMeta {
679                last_log_id: Some(make_log_id(1, 1, i)),
680                last_membership: StoredMembership::default(),
681                snapshot_id: format!("snap-{i}"),
682            };
683            sm.install_snapshot(&meta, Cursor::new(snapshot_data))
684                .await
685                .unwrap();
686        }
687
688        // Count snap-*.json files
689        let snap_count = std::fs::read_dir(&snap_dir)
690            .unwrap()
691            .filter_map(Result::ok)
692            .filter(|e| {
693                let name = e.file_name().to_string_lossy().to_string();
694                name.starts_with("snap-")
695                    && std::path::Path::new(&name)
696                        .extension()
697                        .is_some_and(|ext| ext.eq_ignore_ascii_case("json"))
698            })
699            .count();
700
701        assert!(
702            snap_count <= MAX_SNAPSHOTS,
703            "Expected at most {MAX_SNAPSHOTS} snapshots, found {snap_count}"
704        );
705    }
706
707    #[test]
708    fn snapshot_filename_format() {
709        let meta: SnapshotMeta<TestTypeConfig> = SnapshotMeta {
710            last_log_id: Some(make_log_id(2, 1, 42)),
711            last_membership: StoredMembership::default(),
712            snapshot_id: "test".to_string(),
713        };
714        let name = snapshot_filename::<TestTypeConfig>(&meta);
715        assert_eq!(name, "snap-2-42.json");
716    }
717
718    #[test]
719    fn snapshot_filename_none_log_id() {
720        let meta: SnapshotMeta<TestTypeConfig> = SnapshotMeta {
721            last_log_id: None,
722            last_membership: StoredMembership::default(),
723            snapshot_id: "test".to_string(),
724        };
725        let name = snapshot_filename::<TestTypeConfig>(&meta);
726        assert_eq!(name, "snap-0-0.json");
727    }
728
729    #[tokio::test]
730    async fn get_current_snapshot_loads_from_disk_on_cold_start() {
731        let dir = tempfile::tempdir().unwrap();
732        let snap_dir = dir.path().join("snapshots");
733        let state = Arc::new(RwLock::new(TestState::default()));
734
735        // First: create SM, install snapshot, drop it
736        {
737            let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
738                state.clone(),
739                snap_dir.clone(),
740            )
741            .unwrap();
742            let new_state = TestState {
743                data: [("cold".into(), "start".into())].into(),
744            };
745            let snapshot_data = serde_json::to_vec(&new_state).unwrap();
746            let meta = SnapshotMeta {
747                last_log_id: Some(make_log_id(1, 1, 20)),
748                last_membership: StoredMembership::default(),
749                snapshot_id: "snap-1".to_string(),
750            };
751            sm.install_snapshot(&meta, Cursor::new(snapshot_data))
752                .await
753                .unwrap();
754        }
755
756        // Second: create a fresh SM with no loaded state, call get_current_snapshot
757        let fresh_state = Arc::new(RwLock::new(TestState::default()));
758        let mut sm2 = HpcStateMachine::<TestTypeConfig, TestState>::new(fresh_state.clone());
759        sm2.snapshot_dir = Some(snap_dir);
760
761        let snap = sm2.get_current_snapshot().await.unwrap();
762        assert!(snap.is_some());
763        let snap = snap.unwrap();
764        let loaded: TestState = serde_json::from_slice(&snap.snapshot.into_inner()).unwrap();
765        assert_eq!(loaded.data.get("cold").unwrap(), "start");
766    }
767
768    #[tokio::test]
769    async fn load_latest_snapshot_missing_file_returns_none() {
770        let dir = tempfile::tempdir().unwrap();
771        let snap_dir = dir.path().join("snapshots");
772        std::fs::create_dir_all(&snap_dir).unwrap();
773
774        // Create "current" pointing to a nonexistent file
775        std::fs::write(snap_dir.join("current"), b"snap-0-999.json").unwrap();
776
777        let result = load_latest_snapshot::<TestTypeConfig, TestState>(&snap_dir).unwrap();
778        assert!(result.is_none());
779    }
780
781    #[test]
782    fn state_accessor() {
783        let state = Arc::new(RwLock::new(TestState::default()));
784        let sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state.clone());
785        assert!(Arc::ptr_eq(&sm.state(), &state));
786    }
787}