1use std::io;
12use std::io::Cursor;
13use std::marker::PhantomData;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16
17use openraft::storage::{RaftStateMachine, Snapshot};
18use openraft::{
19 EntryPayload, LogId, OptionalSend, RaftSnapshotBuilder, RaftTypeConfig, SnapshotMeta,
20 StoredMembership,
21};
22use serde::de::DeserializeOwned;
23use serde::{Deserialize, Serialize};
24use tokio::sync::RwLock;
25use tracing::{debug, warn};
26
27use crate::StateMachineState;
28
29const MAX_SNAPSHOTS: usize = 3;
31
32pub struct HpcStateMachine<C, S>
34where
35 C: RaftTypeConfig<SnapshotData = Cursor<Vec<u8>>>,
36 S: StateMachineState<C>,
37{
38 state: Arc<RwLock<S>>,
39 last_applied: Option<LogId<C>>,
40 last_membership: StoredMembership<C>,
41 snapshot_idx: u64,
42 snapshot_dir: Option<PathBuf>,
43 _phantom: PhantomData<C>,
44}
45
46impl<C, S> HpcStateMachine<C, S>
47where
48 C: RaftTypeConfig<Entry = openraft::Entry<C>, SnapshotData = Cursor<Vec<u8>>>,
49 S: StateMachineState<C>,
50 LogId<C>: Serialize + DeserializeOwned,
51 StoredMembership<C>: Serialize + DeserializeOwned,
52{
53 pub fn new(state: Arc<RwLock<S>>) -> Self {
54 Self {
55 state,
56 last_applied: None,
57 last_membership: StoredMembership::default(),
58 snapshot_idx: 0,
59 snapshot_dir: None,
60 _phantom: PhantomData,
61 }
62 }
63
64 pub fn with_snapshot_dir(state: Arc<RwLock<S>>, snapshot_dir: PathBuf) -> io::Result<Self> {
68 std::fs::create_dir_all(&snapshot_dir)?;
69
70 let mut sm = Self {
71 state,
72 last_applied: None,
73 last_membership: StoredMembership::default(),
74 snapshot_idx: 0,
75 snapshot_dir: Some(snapshot_dir),
76 _phantom: PhantomData,
77 };
78
79 if let Some(ref dir) = sm.snapshot_dir {
81 if let Some((meta, app_state)) = load_latest_snapshot::<C, S>(dir)? {
82 debug!("Loaded snapshot from disk at {:?}", meta.last_log_id);
83 let mut state = sm.state.blocking_write();
84 *state = app_state;
85 sm.last_applied = meta.last_log_id;
86 sm.last_membership = meta.last_membership;
87 sm.snapshot_idx += 1;
88 }
89 }
90
91 Ok(sm)
92 }
93
94 pub fn state(&self) -> Arc<RwLock<S>> {
96 Arc::clone(&self.state)
97 }
98}
99
100#[derive(Serialize, Deserialize)]
102#[serde(bound(
103 serialize = "S: Serialize",
104 deserialize = "S: serde::de::DeserializeOwned"
105))]
106struct PersistedSnapshot<C: RaftTypeConfig, S> {
107 meta: PersistedSnapshotMeta<C>,
108 state: S,
109}
110
111#[derive(Serialize, Deserialize)]
112#[serde(bound = "")]
113struct PersistedSnapshotMeta<C: RaftTypeConfig> {
114 last_log_id: Option<LogId<C>>,
115 last_membership: StoredMembership<C>,
116 snapshot_id: String,
117}
118
119fn snapshot_filename<C: RaftTypeConfig>(meta: &SnapshotMeta<C>) -> String
120where
121 LogId<C>: Serialize,
122{
123 let (term, index) = meta.last_log_id.as_ref().map_or((0, 0), |log_id| {
124 let term = serde_json::to_value(&log_id.leader_id)
128 .ok()
129 .and_then(|v| match v {
130 serde_json::Value::Number(n) => n.as_u64(),
131 serde_json::Value::Object(m) => m.get("term").and_then(serde_json::Value::as_u64),
132 _ => None,
133 })
134 .unwrap_or(0);
135 (term, log_id.index)
136 });
137 format!("snap-{term}-{index}.json")
138}
139
140fn persist_snapshot<C, S>(dir: &Path, meta: &SnapshotMeta<C>, data: &[u8]) -> io::Result<()>
141where
142 C: RaftTypeConfig,
143 S: DeserializeOwned + Serialize,
144 LogId<C>: Serialize,
145 StoredMembership<C>: Serialize + Clone,
146{
147 let filename = snapshot_filename::<C>(meta);
148 let path = dir.join(&filename);
149
150 let app_state: S =
152 serde_json::from_slice(data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
153
154 let persisted = PersistedSnapshot::<C, S> {
155 meta: PersistedSnapshotMeta {
156 last_log_id: meta.last_log_id.clone(),
157 last_membership: meta.last_membership.clone(),
158 snapshot_id: meta.snapshot_id.clone(),
159 },
160 state: app_state,
161 };
162
163 let json = serde_json::to_vec_pretty(&persisted).map_err(io::Error::other)?;
164
165 let tmp = path.with_extension("tmp");
167 std::fs::write(&tmp, &json)?;
168 if let Ok(f) = std::fs::File::open(&tmp) {
169 let _ = f.sync_all();
170 }
171 std::fs::rename(&tmp, &path)?;
172
173 let current = dir.join("current");
175 let _ = std::fs::remove_file(¤t);
176 std::fs::write(¤t, filename.as_bytes())?;
177
178 prune_old_snapshots(dir)?;
180
181 debug!("Persisted snapshot to {}", path.display());
182 Ok(())
183}
184
185fn prune_old_snapshots(dir: &Path) -> io::Result<()> {
186 let mut snaps: Vec<(PathBuf, u64)> = Vec::new();
187
188 for entry in std::fs::read_dir(dir)? {
189 let entry = entry?;
190 let name = entry.file_name();
191 let name_str = name.to_string_lossy();
192 if name_str.starts_with("snap-") && name_str.ends_with(".json") {
193 let parts: Vec<&str> = name_str
194 .trim_start_matches("snap-")
195 .trim_end_matches(".json")
196 .split('-')
197 .collect();
198 if parts.len() == 2 {
199 if let Ok(index) = parts[1].parse::<u64>() {
200 snaps.push((entry.path(), index));
201 }
202 }
203 }
204 }
205
206 snaps.sort_by(|a, b| b.1.cmp(&a.1));
207
208 for (path, _) in snaps.iter().skip(MAX_SNAPSHOTS) {
209 debug!("Pruning old snapshot: {}", path.display());
210 let _ = std::fs::remove_file(path);
211 }
212
213 Ok(())
214}
215
216fn load_latest_snapshot<C, S>(dir: &Path) -> io::Result<Option<(SnapshotMeta<C>, S)>>
217where
218 C: RaftTypeConfig,
219 S: DeserializeOwned,
220 LogId<C>: DeserializeOwned,
221 StoredMembership<C>: DeserializeOwned,
222{
223 let current_path = dir.join("current");
224 if !current_path.exists() {
225 return Ok(None);
226 }
227
228 let filename = std::fs::read_to_string(¤t_path)?.trim().to_string();
229 let snap_path = dir.join(&filename);
230
231 if !snap_path.exists() {
232 warn!("Current snapshot file {} not found", snap_path.display());
233 return Ok(None);
234 }
235
236 let data = std::fs::read_to_string(&snap_path)?;
237 let persisted: PersistedSnapshot<C, S> =
238 serde_json::from_str(&data).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
239
240 let meta = SnapshotMeta {
241 last_log_id: persisted.meta.last_log_id,
242 last_membership: persisted.meta.last_membership,
243 snapshot_id: persisted.meta.snapshot_id,
244 };
245
246 Ok(Some((meta, persisted.state)))
247}
248
249impl<C, S> RaftStateMachine<C> for HpcStateMachine<C, S>
250where
251 C: RaftTypeConfig<Entry = openraft::Entry<C>, SnapshotData = Cursor<Vec<u8>>>,
252 S: StateMachineState<C>,
253 LogId<C>: Serialize + DeserializeOwned,
254 StoredMembership<C>: Serialize + DeserializeOwned + Clone,
255{
256 type SnapshotBuilder = HpcSnapshotBuilder<C, S>;
257
258 async fn applied_state(
259 &mut self,
260 ) -> Result<(Option<LogId<C>>, StoredMembership<C>), io::Error> {
261 Ok((self.last_applied.clone(), self.last_membership.clone()))
262 }
263
264 async fn apply<Strm>(&mut self, entries: Strm) -> Result<(), io::Error>
265 where
266 Strm: futures::Stream<Item = Result<openraft::storage::EntryResponder<C>, io::Error>>
267 + Unpin
268 + OptionalSend,
269 {
270 use futures::StreamExt;
271
272 let mut stream = entries;
273 while let Some(item) = stream.next().await {
274 let (entry, responder) = item?;
275
276 self.last_applied = Some(entry.log_id.clone());
277
278 let response = match entry.payload {
279 EntryPayload::Blank => S::blank_response(),
280 EntryPayload::Normal(cmd) => {
281 let mut state = self.state.write().await;
282 state.apply(cmd)
283 }
284 EntryPayload::Membership(mem) => {
285 self.last_membership = StoredMembership::new(self.last_applied.clone(), mem);
286 S::blank_response()
287 }
288 };
289
290 if let Some(r) = responder {
291 r.send(response);
292 }
293 }
294
295 Ok(())
296 }
297
298 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
299 HpcSnapshotBuilder {
300 state: Arc::clone(&self.state),
301 last_applied: self.last_applied.clone(),
302 last_membership: self.last_membership.clone(),
303 snapshot_idx: self.snapshot_idx,
304 snapshot_dir: self.snapshot_dir.clone(),
305 _phantom: PhantomData,
306 }
307 }
308
309 async fn begin_receiving_snapshot(&mut self) -> Result<Cursor<Vec<u8>>, io::Error> {
310 Ok(Cursor::new(Vec::new()))
311 }
312
313 async fn install_snapshot(
314 &mut self,
315 meta: &SnapshotMeta<C>,
316 snapshot: Cursor<Vec<u8>>,
317 ) -> Result<(), io::Error> {
318 let data = snapshot.into_inner();
319 let new_state: S = serde_json::from_slice(&data)
320 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
321
322 if let Some(ref dir) = self.snapshot_dir {
324 persist_snapshot::<C, S>(dir, meta, &data)?;
325 }
326
327 let mut state = self.state.write().await;
328 *state = new_state;
329
330 self.last_applied.clone_from(&meta.last_log_id);
331 self.last_membership.clone_from(&meta.last_membership);
332 self.snapshot_idx += 1;
333
334 debug!("Installed snapshot at {:?}", meta.last_log_id);
335 Ok(())
336 }
337
338 async fn get_current_snapshot(&mut self) -> Result<Option<Snapshot<C>>, io::Error> {
339 if self.last_applied.is_none() {
341 if let Some(ref dir) = self.snapshot_dir {
342 if let Some((meta, app_state)) = load_latest_snapshot::<C, S>(dir)? {
343 let data = serde_json::to_vec(&app_state).map_err(io::Error::other)?;
344 let mut state = self.state.write().await;
345 *state = app_state;
346 self.last_applied.clone_from(&meta.last_log_id);
347 self.last_membership.clone_from(&meta.last_membership);
348 self.snapshot_idx += 1;
349
350 return Ok(Some(Snapshot {
351 meta,
352 snapshot: Cursor::new(data),
353 }));
354 }
355 }
356 }
357
358 let state = self.state.read().await;
359 let data = serde_json::to_vec(&*state).map_err(io::Error::other)?;
360
361 if self.last_applied.is_none() {
362 return Ok(None);
363 }
364
365 let snapshot = Snapshot {
366 meta: SnapshotMeta {
367 last_log_id: self.last_applied.clone(),
368 last_membership: self.last_membership.clone(),
369 snapshot_id: format!("snap-{}", self.snapshot_idx),
370 },
371 snapshot: Cursor::new(data),
372 };
373
374 Ok(Some(snapshot))
375 }
376}
377
378pub struct HpcSnapshotBuilder<C, S>
380where
381 C: RaftTypeConfig,
382 S: StateMachineState<C>,
383{
384 state: Arc<RwLock<S>>,
385 last_applied: Option<LogId<C>>,
386 last_membership: StoredMembership<C>,
387 snapshot_idx: u64,
388 snapshot_dir: Option<PathBuf>,
389 _phantom: PhantomData<C>,
390}
391
392impl<C, S> RaftSnapshotBuilder<C> for HpcSnapshotBuilder<C, S>
393where
394 C: RaftTypeConfig<Entry = openraft::Entry<C>, SnapshotData = Cursor<Vec<u8>>>,
395 S: StateMachineState<C>,
396 LogId<C>: Serialize + DeserializeOwned,
397 StoredMembership<C>: Serialize + DeserializeOwned + Clone,
398{
399 async fn build_snapshot(&mut self) -> Result<Snapshot<C>, io::Error> {
400 let state = self.state.read().await;
401 let data = serde_json::to_vec(&*state).map_err(io::Error::other)?;
402
403 self.snapshot_idx += 1;
404
405 let meta = SnapshotMeta {
406 last_log_id: self.last_applied.clone(),
407 last_membership: self.last_membership.clone(),
408 snapshot_id: format!("snap-{}", self.snapshot_idx),
409 };
410
411 if let Some(ref dir) = self.snapshot_dir {
413 persist_snapshot::<C, S>(dir, &meta, &data)?;
414 }
415
416 let snapshot = Snapshot {
417 meta,
418 snapshot: Cursor::new(data),
419 };
420
421 debug!("Built snapshot at {:?}", self.last_applied);
422 Ok(snapshot)
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::test_types::*;
430 use openraft::storage::RaftStateMachine;
431 use openraft::vote::RaftLeaderId;
432 use openraft::vote::leader_id_adv::CommittedLeaderId;
433
434 fn make_log_id(term: u64, node: u64, index: u64) -> LogId<TestTypeConfig> {
435 LogId::new(CommittedLeaderId::new(term, node), index)
436 }
437
438 #[tokio::test]
439 async fn new_state_machine_initial_state() {
440 let state = Arc::new(RwLock::new(TestState::default()));
441 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
442 let (last_applied, membership) = sm.applied_state().await.unwrap();
443 assert!(last_applied.is_none());
444 assert!(membership.log_id().is_none());
445 }
446
447 #[tokio::test]
448 async fn begin_receiving_snapshot_returns_empty_cursor() {
449 let state = Arc::new(RwLock::new(TestState::default()));
450 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
451 let cursor = sm.begin_receiving_snapshot().await.unwrap();
452 assert!(cursor.into_inner().is_empty());
453 }
454
455 #[tokio::test]
456 async fn install_snapshot_updates_state() {
457 let state = Arc::new(RwLock::new(TestState::default()));
458 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state.clone());
459
460 let new_state = TestState {
461 data: [("k".into(), "v".into())].into(),
462 };
463 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
464 let meta = SnapshotMeta {
465 last_log_id: Some(make_log_id(1, 1, 5)),
466 last_membership: StoredMembership::default(),
467 snapshot_id: "snap-1".to_string(),
468 };
469
470 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
471 .await
472 .unwrap();
473
474 let s = state.read().await;
475 assert_eq!(s.data.get("k").unwrap(), "v");
476
477 let (last_applied, _) = sm.applied_state().await.unwrap();
478 assert_eq!(last_applied.unwrap().index, 5);
479 }
480
481 #[tokio::test]
482 async fn get_current_snapshot_none_when_no_applied() {
483 let state = Arc::new(RwLock::new(TestState::default()));
484 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
485 let snap = sm.get_current_snapshot().await.unwrap();
486 assert!(snap.is_none());
487 }
488
489 #[tokio::test]
490 async fn get_current_snapshot_returns_data_after_install() {
491 let state = Arc::new(RwLock::new(TestState::default()));
492 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
493
494 let new_state = TestState {
495 data: [("x".into(), "y".into())].into(),
496 };
497 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
498 let meta = SnapshotMeta {
499 last_log_id: Some(make_log_id(1, 1, 10)),
500 last_membership: StoredMembership::default(),
501 snapshot_id: "snap-1".to_string(),
502 };
503 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
504 .await
505 .unwrap();
506
507 let snap = sm.get_current_snapshot().await.unwrap();
508 assert!(snap.is_some());
509 let snap = snap.unwrap();
510 assert_eq!(snap.meta.last_log_id.as_ref().unwrap().index, 10);
511
512 let loaded: TestState = serde_json::from_slice(&snap.snapshot.into_inner()).unwrap();
513 assert_eq!(loaded.data.get("x").unwrap(), "y");
514 }
515
516 #[tokio::test]
517 async fn get_snapshot_builder_and_build() {
518 let state = Arc::new(RwLock::new(TestState::default()));
519 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state);
520
521 let new_state = TestState {
523 data: [("a".into(), "b".into())].into(),
524 };
525 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
526 let meta = SnapshotMeta {
527 last_log_id: Some(make_log_id(1, 1, 3)),
528 last_membership: StoredMembership::default(),
529 snapshot_id: "snap-0".to_string(),
530 };
531 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
532 .await
533 .unwrap();
534
535 let mut builder = sm.get_snapshot_builder().await;
536 let snap = builder.build_snapshot().await.unwrap();
537 assert_eq!(snap.meta.last_log_id.as_ref().unwrap().index, 3);
538
539 let loaded: TestState = serde_json::from_slice(&snap.snapshot.into_inner()).unwrap();
540 assert_eq!(loaded.data.get("a").unwrap(), "b");
541 }
542
543 #[test]
544 fn with_snapshot_dir_creates_directory() {
545 let dir = tempfile::tempdir().unwrap();
546 let snap_dir = dir.path().join("snapshots");
547 let state = Arc::new(RwLock::new(TestState::default()));
548 let _sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
549 state,
550 snap_dir.clone(),
551 )
552 .unwrap();
553 assert!(snap_dir.exists());
554 }
555
556 #[tokio::test]
557 async fn install_snapshot_persists_to_disk() {
558 let dir = tempfile::tempdir().unwrap();
559 let snap_dir = dir.path().join("snapshots");
560 let state = Arc::new(RwLock::new(TestState::default()));
561 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
562 state,
563 snap_dir.clone(),
564 )
565 .unwrap();
566
567 let new_state = TestState {
568 data: [("disk".into(), "test".into())].into(),
569 };
570 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
571 let meta = SnapshotMeta {
572 last_log_id: Some(make_log_id(1, 1, 7)),
573 last_membership: StoredMembership::default(),
574 snapshot_id: "snap-1".to_string(),
575 };
576 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
577 .await
578 .unwrap();
579
580 assert!(snap_dir.join("current").exists());
582 let current = std::fs::read_to_string(snap_dir.join("current")).unwrap();
583 assert!(snap_dir.join(current.trim()).exists());
584 }
585
586 #[tokio::test]
587 async fn build_snapshot_persists_to_disk() {
588 let dir = tempfile::tempdir().unwrap();
589 let snap_dir = dir.path().join("snapshots");
590 let state = Arc::new(RwLock::new(TestState::default()));
591 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
592 state,
593 snap_dir.clone(),
594 )
595 .unwrap();
596
597 let new_state = TestState {
599 data: [("build".into(), "snap".into())].into(),
600 };
601 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
602 let meta = SnapshotMeta {
603 last_log_id: Some(make_log_id(1, 1, 2)),
604 last_membership: StoredMembership::default(),
605 snapshot_id: "snap-0".to_string(),
606 };
607 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
608 .await
609 .unwrap();
610
611 let mut builder = sm.get_snapshot_builder().await;
612 let _snap = builder.build_snapshot().await.unwrap();
613
614 assert!(snap_dir.join("current").exists());
615 }
616
617 #[tokio::test]
618 async fn load_latest_snapshot_roundtrip() {
619 let dir = tempfile::tempdir().unwrap();
620 let snap_dir = dir.path().join("snapshots");
621 let state = Arc::new(RwLock::new(TestState::default()));
622 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
623 state.clone(),
624 snap_dir.clone(),
625 )
626 .unwrap();
627
628 let new_state = TestState {
630 data: [("load".into(), "test".into())].into(),
631 };
632 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
633 let meta = SnapshotMeta {
634 last_log_id: Some(make_log_id(1, 1, 15)),
635 last_membership: StoredMembership::default(),
636 snapshot_id: "snap-1".to_string(),
637 };
638 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
639 .await
640 .unwrap();
641
642 let snap_dir_clone = snap_dir.clone();
645 let fresh_state = tokio::task::spawn_blocking(move || {
646 let fresh_state = Arc::new(RwLock::new(TestState::default()));
647 let _sm2 = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
648 fresh_state.clone(),
649 snap_dir_clone,
650 )
651 .unwrap();
652 fresh_state
653 })
654 .await
655 .unwrap();
656
657 let s = fresh_state.read().await;
658 assert_eq!(s.data.get("load").unwrap(), "test");
659 }
660
661 #[tokio::test]
662 async fn prune_old_snapshots_keeps_max() {
663 let dir = tempfile::tempdir().unwrap();
664 let snap_dir = dir.path().join("snapshots");
665 let state = Arc::new(RwLock::new(TestState::default()));
666 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
667 state,
668 snap_dir.clone(),
669 )
670 .unwrap();
671
672 for i in 1..=6u64 {
674 let new_state = TestState {
675 data: [(format!("k{i}"), format!("v{i}"))].into(),
676 };
677 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
678 let meta = SnapshotMeta {
679 last_log_id: Some(make_log_id(1, 1, i)),
680 last_membership: StoredMembership::default(),
681 snapshot_id: format!("snap-{i}"),
682 };
683 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
684 .await
685 .unwrap();
686 }
687
688 let snap_count = std::fs::read_dir(&snap_dir)
690 .unwrap()
691 .filter_map(Result::ok)
692 .filter(|e| {
693 let name = e.file_name().to_string_lossy().to_string();
694 name.starts_with("snap-")
695 && std::path::Path::new(&name)
696 .extension()
697 .is_some_and(|ext| ext.eq_ignore_ascii_case("json"))
698 })
699 .count();
700
701 assert!(
702 snap_count <= MAX_SNAPSHOTS,
703 "Expected at most {MAX_SNAPSHOTS} snapshots, found {snap_count}"
704 );
705 }
706
707 #[test]
708 fn snapshot_filename_format() {
709 let meta: SnapshotMeta<TestTypeConfig> = SnapshotMeta {
710 last_log_id: Some(make_log_id(2, 1, 42)),
711 last_membership: StoredMembership::default(),
712 snapshot_id: "test".to_string(),
713 };
714 let name = snapshot_filename::<TestTypeConfig>(&meta);
715 assert_eq!(name, "snap-2-42.json");
716 }
717
718 #[test]
719 fn snapshot_filename_none_log_id() {
720 let meta: SnapshotMeta<TestTypeConfig> = SnapshotMeta {
721 last_log_id: None,
722 last_membership: StoredMembership::default(),
723 snapshot_id: "test".to_string(),
724 };
725 let name = snapshot_filename::<TestTypeConfig>(&meta);
726 assert_eq!(name, "snap-0-0.json");
727 }
728
729 #[tokio::test]
730 async fn get_current_snapshot_loads_from_disk_on_cold_start() {
731 let dir = tempfile::tempdir().unwrap();
732 let snap_dir = dir.path().join("snapshots");
733 let state = Arc::new(RwLock::new(TestState::default()));
734
735 {
737 let mut sm = HpcStateMachine::<TestTypeConfig, TestState>::with_snapshot_dir(
738 state.clone(),
739 snap_dir.clone(),
740 )
741 .unwrap();
742 let new_state = TestState {
743 data: [("cold".into(), "start".into())].into(),
744 };
745 let snapshot_data = serde_json::to_vec(&new_state).unwrap();
746 let meta = SnapshotMeta {
747 last_log_id: Some(make_log_id(1, 1, 20)),
748 last_membership: StoredMembership::default(),
749 snapshot_id: "snap-1".to_string(),
750 };
751 sm.install_snapshot(&meta, Cursor::new(snapshot_data))
752 .await
753 .unwrap();
754 }
755
756 let fresh_state = Arc::new(RwLock::new(TestState::default()));
758 let mut sm2 = HpcStateMachine::<TestTypeConfig, TestState>::new(fresh_state.clone());
759 sm2.snapshot_dir = Some(snap_dir);
760
761 let snap = sm2.get_current_snapshot().await.unwrap();
762 assert!(snap.is_some());
763 let snap = snap.unwrap();
764 let loaded: TestState = serde_json::from_slice(&snap.snapshot.into_inner()).unwrap();
765 assert_eq!(loaded.data.get("cold").unwrap(), "start");
766 }
767
768 #[tokio::test]
769 async fn load_latest_snapshot_missing_file_returns_none() {
770 let dir = tempfile::tempdir().unwrap();
771 let snap_dir = dir.path().join("snapshots");
772 std::fs::create_dir_all(&snap_dir).unwrap();
773
774 std::fs::write(snap_dir.join("current"), b"snap-0-999.json").unwrap();
776
777 let result = load_latest_snapshot::<TestTypeConfig, TestState>(&snap_dir).unwrap();
778 assert!(result.is_none());
779 }
780
781 #[test]
782 fn state_accessor() {
783 let state = Arc::new(RwLock::new(TestState::default()));
784 let sm = HpcStateMachine::<TestTypeConfig, TestState>::new(state.clone());
785 assert!(Arc::ptr_eq(&sm.state(), &state));
786 }
787}