diskann_disk/build/chunking/checkpoint/
checkpoint_context.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5use diskann::ANNResult;
6
7use super::{CheckpointManager, Progress, WorkStage};
8
9// Context for managing checkpoint operations during various processing stages.
10pub struct CheckpointContext<'a> {
11    checkpoint_manager: &'a dyn CheckpointManager,
12    current_stage: WorkStage,
13    next_stage: WorkStage,
14}
15
16impl<'a> CheckpointContext<'a> {
17    pub fn new(
18        checkpoint_manager: &'a dyn CheckpointManager,
19        current_stage: WorkStage,
20        next_stage: WorkStage,
21    ) -> Self {
22        Self {
23            checkpoint_manager,
24            current_stage,
25            next_stage,
26        }
27    }
28
29    pub fn current_stage(&self) -> WorkStage {
30        self.current_stage
31    }
32
33    pub fn to_owned(&self) -> OwnedCheckpointContext {
34        OwnedCheckpointContext::new(
35            self.checkpoint_manager.clone_box(),
36            self.current_stage,
37            self.next_stage,
38        )
39    }
40
41    pub fn get_resumption_point(&self) -> ANNResult<Option<usize>> {
42        self.checkpoint_manager
43            .get_resumption_point(self.current_stage)
44    }
45}
46
47/// Context for managing checkpoint operations with an owned checkpoint manager
48pub struct OwnedCheckpointContext {
49    checkpoint_manager: Box<dyn CheckpointManager>,
50    current_stage: WorkStage,
51    next_stage: WorkStage,
52}
53
54impl OwnedCheckpointContext {
55    pub fn new(
56        checkpoint_manager: Box<dyn CheckpointManager>,
57        current_stage: WorkStage,
58        next_stage: WorkStage,
59    ) -> Self {
60        Self {
61            checkpoint_manager,
62            current_stage,
63            next_stage,
64        }
65    }
66
67    pub fn current_stage(&self) -> WorkStage {
68        self.current_stage
69    }
70
71    pub fn get_resumption_point(&mut self) -> ANNResult<Option<usize>> {
72        self.checkpoint_manager
73            .get_resumption_point(self.current_stage)
74    }
75
76    pub fn update(&mut self, progress: Progress) -> ANNResult<()> {
77        self.checkpoint_manager.update(progress, self.next_stage)
78    }
79
80    pub fn mark_as_invalid(&mut self) -> ANNResult<()> {
81        self.checkpoint_manager.mark_as_invalid()
82    }
83}