diskann_disk/build/chunking/checkpoint/
checkpoint_context.rs1use diskann::ANNResult;
6
7use super::{CheckpointManager, Progress, WorkStage};
8
9pub struct CheckpointContext<'a> {
11 checkpoint_manager: &'a dyn CheckpointManager,
12 current_stage: WorkStage,
13 next_stage: WorkStage,
14}
15
16impl<'a> CheckpointContext<'a> {
17 pub fn new(
18 checkpoint_manager: &'a dyn CheckpointManager,
19 current_stage: WorkStage,
20 next_stage: WorkStage,
21 ) -> Self {
22 Self {
23 checkpoint_manager,
24 current_stage,
25 next_stage,
26 }
27 }
28
29 pub fn current_stage(&self) -> WorkStage {
30 self.current_stage
31 }
32
33 pub fn to_owned(&self) -> OwnedCheckpointContext {
34 OwnedCheckpointContext::new(
35 self.checkpoint_manager.clone_box(),
36 self.current_stage,
37 self.next_stage,
38 )
39 }
40
41 pub fn get_resumption_point(&self) -> ANNResult<Option<usize>> {
42 self.checkpoint_manager
43 .get_resumption_point(self.current_stage)
44 }
45}
46
47pub struct OwnedCheckpointContext {
49 checkpoint_manager: Box<dyn CheckpointManager>,
50 current_stage: WorkStage,
51 next_stage: WorkStage,
52}
53
54impl OwnedCheckpointContext {
55 pub fn new(
56 checkpoint_manager: Box<dyn CheckpointManager>,
57 current_stage: WorkStage,
58 next_stage: WorkStage,
59 ) -> Self {
60 Self {
61 checkpoint_manager,
62 current_stage,
63 next_stage,
64 }
65 }
66
67 pub fn current_stage(&self) -> WorkStage {
68 self.current_stage
69 }
70
71 pub fn get_resumption_point(&mut self) -> ANNResult<Option<usize>> {
72 self.checkpoint_manager
73 .get_resumption_point(self.current_stage)
74 }
75
76 pub fn update(&mut self, progress: Progress) -> ANNResult<()> {
77 self.checkpoint_manager.update(progress, self.next_stage)
78 }
79
80 pub fn mark_as_invalid(&mut self) -> ANNResult<()> {
81 self.checkpoint_manager.mark_as_invalid()
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::super::NaiveCheckpointRecordManager;
88 use super::*;
89
90 #[test]
91 fn test_checkpoint_context_new() {
92 let manager = NaiveCheckpointRecordManager;
93 let context = CheckpointContext::new(&manager, WorkStage::Start, WorkStage::End);
94
95 assert_eq!(context.current_stage(), WorkStage::Start);
96 }
97
98 #[test]
99 fn test_checkpoint_context_current_stage() {
100 let manager = NaiveCheckpointRecordManager;
101 let context =
102 CheckpointContext::new(&manager, WorkStage::QuantizeFPV, WorkStage::InMemIndexBuild);
103
104 assert_eq!(context.current_stage(), WorkStage::QuantizeFPV);
105 }
106
107 #[test]
108 fn test_checkpoint_context_get_resumption_point() {
109 let manager = NaiveCheckpointRecordManager;
110 let context = CheckpointContext::new(&manager, WorkStage::Start, WorkStage::End);
111
112 let result = context.get_resumption_point();
113 assert!(result.is_ok());
114 assert_eq!(result.unwrap(), Some(0));
115 }
116
117 #[test]
118 fn test_checkpoint_context_to_owned() {
119 let manager = NaiveCheckpointRecordManager;
120 let context = CheckpointContext::new(&manager, WorkStage::Start, WorkStage::End);
121
122 let owned = context.to_owned();
123 assert_eq!(owned.current_stage(), WorkStage::Start);
124 }
125
126 #[test]
127 fn test_owned_checkpoint_context_new() {
128 let manager = Box::new(NaiveCheckpointRecordManager);
129 let context = OwnedCheckpointContext::new(
130 manager,
131 WorkStage::TrainBuildQuantizer,
132 WorkStage::PartitionData,
133 );
134
135 assert_eq!(context.current_stage(), WorkStage::TrainBuildQuantizer);
136 }
137
138 #[test]
139 fn test_owned_checkpoint_context_update() {
140 let manager = Box::new(NaiveCheckpointRecordManager);
141 let mut context = OwnedCheckpointContext::new(manager, WorkStage::Start, WorkStage::End);
142
143 let result = context.update(Progress::Completed);
144 assert!(result.is_ok());
145 }
146
147 #[test]
148 fn test_owned_checkpoint_context_mark_as_invalid() {
149 let manager = Box::new(NaiveCheckpointRecordManager);
150 let mut context = OwnedCheckpointContext::new(manager, WorkStage::Start, WorkStage::End);
151
152 let result = context.mark_as_invalid();
153 assert!(result.is_ok());
154 }
155}