burn_train/checkpoint/strategy/
base.rs

1use std::ops::DerefMut;
2
3use crate::metric::store::EventStoreClient;
4
5/// Action to be taken by a [checkpointer](crate::checkpoint::Checkpointer).
6#[derive(Clone, PartialEq, Debug)]
7pub enum CheckpointingAction {
8    /// Delete the given epoch.
9    Delete(usize),
10    /// Save the current record.
11    Save,
12}
13
14/// Define when checkpoint should be saved and deleted.
15pub trait CheckpointingStrategy: Send {
16    /// Based on the epoch, determine if the checkpoint should be saved.
17    fn checkpointing(
18        &mut self,
19        epoch: usize,
20        collector: &EventStoreClient,
21    ) -> Vec<CheckpointingAction>;
22}
23
24// We make dyn box implement the checkpointing strategy so that it can be used with generic, but
25// still be dynamic.
26impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
27    fn checkpointing(
28        &mut self,
29        epoch: usize,
30        collector: &EventStoreClient,
31    ) -> Vec<CheckpointingAction> {
32        self.deref_mut().checkpointing(epoch, collector)
33    }
34}