Skip to main content

diskann_disk/build/chunking/checkpoint/
checkpoint_record.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use diskann::ANNResult;
7use serde::{Deserialize, Serialize};
8use tracing::info;
9
10use super::WorkStage;
11
12/// Represents a checkpoint record in the index build process.
13/// The checkpoint record can be marked as in-valid to indicate that the exising intermediate data should be discarded.
14/// This can happen because of a crash or an unexpected shutdown during the in-memory index build.
15#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct CheckpointRecord {
17    /// The work type represents the current stage of the index build process.
18    stage: WorkStage,
19
20    /// Indicates if the checkpoint record is dirty.
21    is_valid: bool,
22
23    progress: usize,
24}
25
26impl Default for CheckpointRecord {
27    fn default() -> Self {
28        CheckpointRecord::new()
29    }
30}
31
32impl CheckpointRecord {
33    /// Create a new CheckpointRecord with the work type set to Start.
34    pub fn new() -> CheckpointRecord {
35        CheckpointRecord {
36            stage: WorkStage::Start,
37            is_valid: true,
38            progress: 0,
39        }
40    }
41
42    pub fn is_valid(&self) -> bool {
43        self.is_valid
44    }
45
46    pub fn get_resumption_point(&self, stage: WorkStage) -> Option<usize> {
47        if self.stage == stage {
48            info!(
49                "The resumption point is at {} for stage {:?}",
50                self.progress, stage
51            );
52            Some(if self.is_valid { self.progress } else { 0 })
53        } else {
54            info!(
55                "Failed to get resumption point for {:?} since the current stage is {:?}.",
56                stage, self.stage
57            );
58            None
59        }
60    }
61
62    // Advance the work type to the next stage in the index build process.
63    // This method is used in each individual step of the index build process
64    // ..t o update the checkpoint record.
65    pub fn advance_work_type(&self, next_stage: WorkStage) -> ANNResult<CheckpointRecord> {
66        info!(
67            "Advancing work type from {:?} to {:?}.",
68            self.stage, next_stage
69        );
70        Ok(CheckpointRecord {
71            stage: next_stage,
72            is_valid: true,
73            progress: 0,
74        })
75    }
76
77    // Mark the checkpoint record as invalid.
78    pub fn mark_as_invalid(&self) -> CheckpointRecord {
79        CheckpointRecord {
80            stage: self.stage,
81            is_valid: false,
82            progress: self.progress,
83        }
84    }
85
86    // Update the progress of the current work type.
87    pub fn update_progress(&self, progress: usize) -> CheckpointRecord {
88        info!("Updating progress to {:?}={}", self.stage, progress);
89        CheckpointRecord {
90            stage: self.stage,
91            is_valid: true,
92            progress,
93        }
94    }
95
96    #[allow(unused)]
97    // This function is used for testing purposes only.
98    pub(crate) fn get_work_stage(&self) -> WorkStage {
99        self.stage
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use rstest::rstest;
106
107    use super::*;
108
109    #[derive(Serialize, Deserialize, PartialEq, Debug, Clone, Copy)]
110    enum LegacyWorkStage {
111        QuantizeFPV,
112        InMemIndexBuild,
113        WriteDiskLayout,
114        End,
115    }
116
117    #[derive(Serialize, Deserialize, Debug, Clone)]
118    struct LegacyCheckpointRecord {
119        stage: LegacyWorkStage,
120        is_valid: bool,
121        progress: usize,
122    }
123
124    #[rstest]
125    #[case(LegacyWorkStage::QuantizeFPV, WorkStage::QuantizeFPV, true, 0)]
126    #[case(
127        LegacyWorkStage::InMemIndexBuild,
128        WorkStage::InMemIndexBuild,
129        false,
130        42
131    )]
132    #[case(
133        LegacyWorkStage::WriteDiskLayout,
134        WorkStage::WriteDiskLayout,
135        true,
136        100
137    )]
138    #[case(LegacyWorkStage::End, WorkStage::End, false, 0)]
139    fn test_backward_compatibility(
140        #[case] legacy_stage: LegacyWorkStage,
141        #[case] stage: WorkStage,
142        #[case] is_valid: bool,
143        #[case] progress: usize,
144    ) {
145        // Test backward compatibility: Newer code (current) reading older data format (legacy)
146        let legacy_record = LegacyCheckpointRecord {
147            stage: legacy_stage,
148            is_valid,
149            progress,
150        };
151        let serialized = bincode::serialize(&legacy_record).unwrap();
152        let deserialized: CheckpointRecord = bincode::deserialize(&serialized).unwrap();
153        assert_eq!(deserialized.stage, stage);
154        assert_eq!(deserialized.is_valid, is_valid);
155        assert_eq!(deserialized.progress, progress);
156    }
157
158    #[rstest]
159    #[case(WorkStage::QuantizeFPV, LegacyWorkStage::QuantizeFPV, true, 10)]
160    #[case(
161        WorkStage::InMemIndexBuild,
162        LegacyWorkStage::InMemIndexBuild,
163        false,
164        30
165    )]
166    #[case(WorkStage::WriteDiskLayout, LegacyWorkStage::WriteDiskLayout, true, 80)]
167    #[case(WorkStage::End, LegacyWorkStage::End, false, 0)]
168    fn test_forward_compatibility(
169        #[case] current_stage: WorkStage,
170        #[case] expected_legacy_stage: LegacyWorkStage,
171        #[case] is_valid: bool,
172        #[case] progress: usize,
173    ) {
174        // Test forward compatibility: Older code (legacy) reading newer data format (current)
175        // This simulates rolling back to an older version after using a newer version
176        let current_record = CheckpointRecord {
177            stage: current_stage,
178            is_valid,
179            progress,
180        };
181
182        let serialized = bincode::serialize(&current_record).unwrap();
183
184        // Legacy code should still be able to deserialize common enum variants
185        let deserialized: LegacyCheckpointRecord = bincode::deserialize(&serialized).unwrap();
186        assert_eq!(deserialized.stage, expected_legacy_stage);
187        assert_eq!(deserialized.is_valid, is_valid);
188        assert_eq!(deserialized.progress, progress);
189    }
190
191    #[rstest]
192    #[case(WorkStage::PartitionData, true, 25)]
193    #[case(WorkStage::BuildIndicesOnShards(0), true, 75)]
194    #[case(WorkStage::BuildIndicesOnShards(10), true, 75)]
195    #[case(WorkStage::MergeIndices, false, 75)]
196
197    fn test_rolling_back_with_new_variants(
198        #[case] stage: WorkStage,
199        #[case] is_valid: bool,
200        #[case] progress: usize,
201    ) {
202        // When rolling back to older versions, newer variants should fail to deserialize
203        // This is expected behavior and we should test for it
204        let current_record = CheckpointRecord {
205            stage,
206            is_valid,
207            progress,
208        };
209
210        let serialized = bincode::serialize(&current_record).unwrap();
211
212        // Legacy code should fail to deserialize newer enum variants
213        // This is expected behavior - we're testing that it fails
214        let result: Result<LegacyCheckpointRecord, bincode::Error> =
215            bincode::deserialize(&serialized);
216        assert!(
217            result.is_err(),
218            "Legacy code should not be able to deserialize newer enum variants"
219        );
220    }
221
222    #[test]
223    fn test_checkpoint_record_default() {
224        let record = CheckpointRecord::default();
225        assert!(record.is_valid());
226        assert_eq!(record.get_work_stage(), WorkStage::Start);
227    }
228
229    #[test]
230    fn test_checkpoint_record_is_valid() {
231        let record = CheckpointRecord::new();
232        assert!(record.is_valid());
233
234        let invalid_record = record.mark_as_invalid();
235        assert!(!invalid_record.is_valid());
236    }
237
238    #[test]
239    fn test_get_resumption_point_with_matching_stage() {
240        let record = CheckpointRecord::new().update_progress(42);
241        let resumption = record.get_resumption_point(WorkStage::Start);
242        assert_eq!(resumption, Some(42));
243    }
244
245    #[test]
246    fn test_get_resumption_point_with_different_stage() {
247        let record = CheckpointRecord::new();
248        let resumption = record.get_resumption_point(WorkStage::QuantizeFPV);
249        assert_eq!(resumption, None);
250    }
251
252    #[test]
253    fn test_get_resumption_point_when_invalid() {
254        let record = CheckpointRecord::new()
255            .update_progress(100)
256            .mark_as_invalid();
257        let resumption = record.get_resumption_point(WorkStage::Start);
258        assert_eq!(resumption, Some(0)); // Should return 0 when invalid
259    }
260
261    #[test]
262    fn test_advance_work_type() {
263        let record = CheckpointRecord::new();
264        let advanced = record.advance_work_type(WorkStage::QuantizeFPV).unwrap();
265        assert_eq!(advanced.get_work_stage(), WorkStage::QuantizeFPV);
266        assert!(advanced.is_valid());
267    }
268}