burn_train/checkpoint/strategy/
base.rs1use std::ops::DerefMut;
2
3use crate::metric::store::EventStoreClient;
4
5#[derive(Clone, PartialEq, Debug)]
7pub enum CheckpointingAction {
8 Delete(usize),
10 Save,
12}
13
14pub trait CheckpointingStrategy: Send {
16 fn checkpointing(
18 &mut self,
19 epoch: usize,
20 collector: &EventStoreClient,
21 ) -> Vec<CheckpointingAction>;
22}
23
24impl CheckpointingStrategy for Box<dyn CheckpointingStrategy> {
27 fn checkpointing(
28 &mut self,
29 epoch: usize,
30 collector: &EventStoreClient,
31 ) -> Vec<CheckpointingAction> {
32 self.deref_mut().checkpointing(epoch, collector)
33 }
34}