Skip to main content

diskann_disk/build/chunking/checkpoint/
checkpoint_context.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use diskann::ANNResult;
6
7use super::{CheckpointManager, Progress, WorkStage};
8
9// Context for managing checkpoint operations during various processing stages.
10pub struct CheckpointContext<'a> {
11    checkpoint_manager: &'a dyn CheckpointManager,
12    current_stage: WorkStage,
13    next_stage: WorkStage,
14}
15
16impl<'a> CheckpointContext<'a> {
17    pub fn new(
18        checkpoint_manager: &'a dyn CheckpointManager,
19        current_stage: WorkStage,
20        next_stage: WorkStage,
21    ) -> Self {
22        Self {
23            checkpoint_manager,
24            current_stage,
25            next_stage,
26        }
27    }
28
29    pub fn current_stage(&self) -> WorkStage {
30        self.current_stage
31    }
32
33    pub fn to_owned(&self) -> OwnedCheckpointContext {
34        OwnedCheckpointContext::new(
35            self.checkpoint_manager.clone_box(),
36            self.current_stage,
37            self.next_stage,
38        )
39    }
40
41    pub fn get_resumption_point(&self) -> ANNResult<Option<usize>> {
42        self.checkpoint_manager
43            .get_resumption_point(self.current_stage)
44    }
45}
46
47/// Context for managing checkpoint operations with an owned checkpoint manager
48pub struct OwnedCheckpointContext {
49    checkpoint_manager: Box<dyn CheckpointManager>,
50    current_stage: WorkStage,
51    next_stage: WorkStage,
52}
53
54impl OwnedCheckpointContext {
55    pub fn new(
56        checkpoint_manager: Box<dyn CheckpointManager>,
57        current_stage: WorkStage,
58        next_stage: WorkStage,
59    ) -> Self {
60        Self {
61            checkpoint_manager,
62            current_stage,
63            next_stage,
64        }
65    }
66
67    pub fn current_stage(&self) -> WorkStage {
68        self.current_stage
69    }
70
71    pub fn get_resumption_point(&mut self) -> ANNResult<Option<usize>> {
72        self.checkpoint_manager
73            .get_resumption_point(self.current_stage)
74    }
75
76    pub fn update(&mut self, progress: Progress) -> ANNResult<()> {
77        self.checkpoint_manager.update(progress, self.next_stage)
78    }
79
80    pub fn mark_as_invalid(&mut self) -> ANNResult<()> {
81        self.checkpoint_manager.mark_as_invalid()
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::super::NaiveCheckpointRecordManager;
88    use super::*;
89
90    #[test]
91    fn test_checkpoint_context_new() {
92        let manager = NaiveCheckpointRecordManager;
93        let context = CheckpointContext::new(&manager, WorkStage::Start, WorkStage::End);
94
95        assert_eq!(context.current_stage(), WorkStage::Start);
96    }
97
98    #[test]
99    fn test_checkpoint_context_current_stage() {
100        let manager = NaiveCheckpointRecordManager;
101        let context =
102            CheckpointContext::new(&manager, WorkStage::QuantizeFPV, WorkStage::InMemIndexBuild);
103
104        assert_eq!(context.current_stage(), WorkStage::QuantizeFPV);
105    }
106
107    #[test]
108    fn test_checkpoint_context_get_resumption_point() {
109        let manager = NaiveCheckpointRecordManager;
110        let context = CheckpointContext::new(&manager, WorkStage::Start, WorkStage::End);
111
112        let result = context.get_resumption_point();
113        assert!(result.is_ok());
114        assert_eq!(result.unwrap(), Some(0));
115    }
116
117    #[test]
118    fn test_checkpoint_context_to_owned() {
119        let manager = NaiveCheckpointRecordManager;
120        let context = CheckpointContext::new(&manager, WorkStage::Start, WorkStage::End);
121
122        let owned = context.to_owned();
123        assert_eq!(owned.current_stage(), WorkStage::Start);
124    }
125
126    #[test]
127    fn test_owned_checkpoint_context_new() {
128        let manager = Box::new(NaiveCheckpointRecordManager);
129        let context = OwnedCheckpointContext::new(
130            manager,
131            WorkStage::TrainBuildQuantizer,
132            WorkStage::PartitionData,
133        );
134
135        assert_eq!(context.current_stage(), WorkStage::TrainBuildQuantizer);
136    }
137
138    #[test]
139    fn test_owned_checkpoint_context_update() {
140        let manager = Box::new(NaiveCheckpointRecordManager);
141        let mut context = OwnedCheckpointContext::new(manager, WorkStage::Start, WorkStage::End);
142
143        let result = context.update(Progress::Completed);
144        assert!(result.is_ok());
145    }
146
147    #[test]
148    fn test_owned_checkpoint_context_mark_as_invalid() {
149        let manager = Box::new(NaiveCheckpointRecordManager);
150        let mut context = OwnedCheckpointContext::new(manager, WorkStage::Start, WorkStage::End);
151
152        let result = context.mark_as_invalid();
153        assert!(result.is_ok());
154    }
155}