diskann_disk/build/chunking/checkpoint/
checkpoint_record_manager.rs1use diskann::ANNResult;
7use tracing::info;
8
9use super::{Progress, WorkStage};
10
11pub trait CheckpointManager: Send + Sync + CheckpointManagerClone {
16 fn get_resumption_point(&self, stage: WorkStage) -> ANNResult<Option<usize>>;
23
24 fn update(&mut self, progress: Progress, next_stage: WorkStage) -> ANNResult<()>;
31
32 fn mark_as_invalid(&mut self) -> ANNResult<()>;
39}
40
41pub trait CheckpointManagerExt {
42 fn execute_stage<F, S, U>(
43 &mut self,
44 stage: WorkStage,
45 next_stage: WorkStage,
46 operation: F,
47 skip_handler: S,
48 ) -> ANNResult<U>
49 where
50 F: FnOnce() -> ANNResult<U>,
51 S: FnOnce() -> ANNResult<U>;
52}
53
54impl<T: ?Sized> CheckpointManagerExt for T
55where
56 T: CheckpointManager,
57{
58 fn execute_stage<F, S, U>(
59 &mut self,
60 stage: WorkStage,
61 next_stage: WorkStage,
62 operation: F,
63 skip_handler: S,
64 ) -> ANNResult<U>
65 where
66 F: FnOnce() -> ANNResult<U>,
67 S: FnOnce() -> ANNResult<U>,
68 {
69 match self.get_resumption_point(stage)? {
70 Some(_) => {
71 let result = operation()?;
72 self.update(Progress::Completed, next_stage)?;
73 Ok(result)
74 }
75 None => {
76 info!("[Stage:{:?}] Skip stage - invalid checkpoint", stage);
77 skip_handler()
78 }
79 }
80 }
81}
82
83pub trait CheckpointManagerClone {
85 fn clone_box(&self) -> Box<dyn CheckpointManager>;
86}
87
88impl<T> CheckpointManagerClone for T
89where
90 T: 'static + CheckpointManager + Clone,
91{
92 fn clone_box(&self) -> Box<dyn CheckpointManager> {
93 Box::new(self.clone())
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::super::NaiveCheckpointRecordManager;
100 use super::*;
101
102 #[test]
103 fn test_checkpoint_manager_ext_execute_stage_with_resumption() {
104 let mut manager = NaiveCheckpointRecordManager;
105 let mut executed = false;
106
107 let result = manager.execute_stage(
108 WorkStage::Start,
109 WorkStage::End,
110 || {
111 executed = true;
112 Ok(42)
113 },
114 || Ok(0),
115 );
116
117 assert!(result.is_ok());
118 assert_eq!(result.unwrap(), 42);
119 assert!(executed);
120 }
121
122 #[test]
123 fn test_checkpoint_manager_clone_box() {
124 let manager = NaiveCheckpointRecordManager;
125 let boxed = manager.clone_box();
126
127 let result = boxed.get_resumption_point(WorkStage::Start);
129 assert!(result.is_ok());
130 assert_eq!(result.unwrap(), Some(0));
131 }
132}