diskann_disk/build/chunking/checkpoint/
checkpoint_record.rs1use diskann::ANNResult;
7use serde::{Deserialize, Serialize};
8use tracing::info;
9
10use super::WorkStage;
11
12#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct CheckpointRecord {
17 stage: WorkStage,
19
20 is_valid: bool,
22
23 progress: usize,
24}
25
26impl Default for CheckpointRecord {
27 fn default() -> Self {
28 CheckpointRecord::new()
29 }
30}
31
32impl CheckpointRecord {
33 pub fn new() -> CheckpointRecord {
35 CheckpointRecord {
36 stage: WorkStage::Start,
37 is_valid: true,
38 progress: 0,
39 }
40 }
41
42 pub fn is_valid(&self) -> bool {
43 self.is_valid
44 }
45
46 pub fn get_resumption_point(&self, stage: WorkStage) -> Option<usize> {
47 if self.stage == stage {
48 info!(
49 "The resumption point is at {} for stage {:?}",
50 self.progress, stage
51 );
52 Some(if self.is_valid { self.progress } else { 0 })
53 } else {
54 info!(
55 "Failed to get resumption point for {:?} since the current stage is {:?}.",
56 stage, self.stage
57 );
58 None
59 }
60 }
61
62 pub fn advance_work_type(&self, next_stage: WorkStage) -> ANNResult<CheckpointRecord> {
66 info!(
67 "Advancing work type from {:?} to {:?}.",
68 self.stage, next_stage
69 );
70 Ok(CheckpointRecord {
71 stage: next_stage,
72 is_valid: true,
73 progress: 0,
74 })
75 }
76
77 pub fn mark_as_invalid(&self) -> CheckpointRecord {
79 CheckpointRecord {
80 stage: self.stage,
81 is_valid: false,
82 progress: self.progress,
83 }
84 }
85
86 pub fn update_progress(&self, progress: usize) -> CheckpointRecord {
88 info!("Updating progress to {:?}={}", self.stage, progress);
89 CheckpointRecord {
90 stage: self.stage,
91 is_valid: true,
92 progress,
93 }
94 }
95
96 #[allow(unused)]
97 pub(crate) fn get_work_stage(&self) -> WorkStage {
99 self.stage
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use rstest::rstest;
106
107 use super::*;
108
109 #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
110 enum LegacyWorkStage {
111 QuantizeFPV,
112 InMemIndexBuild,
113 WriteDiskLayout,
114 End,
115 }
116
117 #[derive(Serialize, Deserialize, Debug, Clone)]
118 struct LegacyCheckpointRecord {
119 stage: LegacyWorkStage,
120 is_valid: bool,
121 progress: usize,
122 }
123
124 #[rstest]
125 #[case(LegacyWorkStage::QuantizeFPV, WorkStage::QuantizeFPV, true, 0)]
126 #[case(
127 LegacyWorkStage::InMemIndexBuild,
128 WorkStage::InMemIndexBuild,
129 false,
130 42
131 )]
132 #[case(
133 LegacyWorkStage::WriteDiskLayout,
134 WorkStage::WriteDiskLayout,
135 true,
136 100
137 )]
138 #[case(LegacyWorkStage::End, WorkStage::End, false, 0)]
139 fn test_backward_compatibility(
140 #[case] legacy_stage: LegacyWorkStage,
141 #[case] stage: WorkStage,
142 #[case] is_valid: bool,
143 #[case] progress: usize,
144 ) {
145 let legacy_record = LegacyCheckpointRecord {
147 stage: legacy_stage,
148 is_valid,
149 progress,
150 };
151 let serialized = bincode::serialize(&legacy_record).unwrap();
152 let deserialized: CheckpointRecord = bincode::deserialize(&serialized).unwrap();
153 assert_eq!(deserialized.stage, stage);
154 assert_eq!(deserialized.is_valid, is_valid);
155 assert_eq!(deserialized.progress, progress);
156 }
157
158 #[rstest]
159 #[case(WorkStage::QuantizeFPV, LegacyWorkStage::QuantizeFPV, true, 10)]
160 #[case(
161 WorkStage::InMemIndexBuild,
162 LegacyWorkStage::InMemIndexBuild,
163 false,
164 30
165 )]
166 #[case(WorkStage::WriteDiskLayout, LegacyWorkStage::WriteDiskLayout, true, 80)]
167 #[case(WorkStage::End, LegacyWorkStage::End, false, 0)]
168 fn test_forward_compatibility(
169 #[case] current_stage: WorkStage,
170 #[case] expected_legacy_stage: LegacyWorkStage,
171 #[case] is_valid: bool,
172 #[case] progress: usize,
173 ) {
174 let current_record = CheckpointRecord {
177 stage: current_stage,
178 is_valid,
179 progress,
180 };
181
182 let serialized = bincode::serialize(¤t_record).unwrap();
183
184 let deserialized: LegacyCheckpointRecord = bincode::deserialize(&serialized).unwrap();
186 assert_eq!(deserialized.stage, expected_legacy_stage);
187 assert_eq!(deserialized.is_valid, is_valid);
188 assert_eq!(deserialized.progress, progress);
189 }
190
191 #[rstest]
192 #[case(WorkStage::PartitionData, true, 25)]
193 #[case(WorkStage::BuildIndicesOnShards(0), true, 75)]
194 #[case(WorkStage::BuildIndicesOnShards(10), true, 75)]
195 #[case(WorkStage::MergeIndices, false, 75)]
196
197 fn test_rolling_back_with_new_variants(
198 #[case] stage: WorkStage,
199 #[case] is_valid: bool,
200 #[case] progress: usize,
201 ) {
202 let current_record = CheckpointRecord {
205 stage,
206 is_valid,
207 progress,
208 };
209
210 let serialized = bincode::serialize(¤t_record).unwrap();
211
212 let result: Result<LegacyCheckpointRecord, bincode::Error> =
215 bincode::deserialize(&serialized);
216 assert!(
217 result.is_err(),
218 "Legacy code should not be able to deserialize newer enum variants"
219 );
220 }
221
222 #[test]
223 fn test_checkpoint_record_default() {
224 let record = CheckpointRecord::default();
225 assert!(record.is_valid());
226 assert_eq!(record.get_work_stage(), WorkStage::Start);
227 }
228
229 #[test]
230 fn test_checkpoint_record_is_valid() {
231 let record = CheckpointRecord::new();
232 assert!(record.is_valid());
233
234 let invalid_record = record.mark_as_invalid();
235 assert!(!invalid_record.is_valid());
236 }
237
238 #[test]
239 fn test_get_resumption_point_with_matching_stage() {
240 let record = CheckpointRecord::new().update_progress(42);
241 let resumption = record.get_resumption_point(WorkStage::Start);
242 assert_eq!(resumption, Some(42));
243 }
244
245 #[test]
246 fn test_get_resumption_point_with_different_stage() {
247 let record = CheckpointRecord::new();
248 let resumption = record.get_resumption_point(WorkStage::QuantizeFPV);
249 assert_eq!(resumption, None);
250 }
251
252 #[test]
253 fn test_get_resumption_point_when_invalid() {
254 let record = CheckpointRecord::new()
255 .update_progress(100)
256 .mark_as_invalid();
257 let resumption = record.get_resumption_point(WorkStage::Start);
258 assert_eq!(resumption, Some(0)); }
260
261 #[test]
262 fn test_advance_work_type() {
263 let record = CheckpointRecord::new();
264 let advanced = record.advance_work_type(WorkStage::QuantizeFPV).unwrap();
265 assert_eq!(advanced.get_work_stage(), WorkStage::QuantizeFPV);
266 assert!(advanced.is_valid());
267 }
268}