1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
use super::CheckpointingStrategy;
use crate::{
    checkpoint::CheckpointingAction,
    metric::{
        store::{Aggregate, Direction, EventStoreClient, Split},
        Metric,
    },
};

/// Keep the best checkpoint based on a metric.
pub struct MetricCheckpointingStrategy {
    current: Option<usize>,
    aggregate: Aggregate,
    direction: Direction,
    split: Split,
    name: String,
}

impl MetricCheckpointingStrategy {
    /// Create a new metric strategy.
    pub fn new<M>(aggregate: Aggregate, direction: Direction, split: Split) -> Self
    where
        M: Metric,
    {
        Self {
            current: None,
            name: M::NAME.to_string(),
            aggregate,
            direction,
            split,
        }
    }
}

impl CheckpointingStrategy for MetricCheckpointingStrategy {
    fn checkpointing(
        &mut self,
        epoch: usize,
        store: &EventStoreClient,
    ) -> Vec<CheckpointingAction> {
        let best_epoch =
            match store.find_epoch(&self.name, self.aggregate, self.direction, self.split) {
                Some(epoch_best) => epoch_best,
                None => epoch,
            };

        let mut actions = Vec::new();

        if let Some(current) = self.current {
            if current != best_epoch {
                actions.push(CheckpointingAction::Delete(current));
            }
        }

        if best_epoch == epoch {
            actions.push(CheckpointingAction::Save);
        }

        self.current = Some(best_epoch);

        actions
    }
}

#[cfg(test)]
mod tests {
    use crate::{
        logger::InMemoryMetricLogger,
        metric::{
            processor::{
                test_utils::{end_epoch, process_train},
                Metrics, MinimalEventProcessor,
            },
            store::LogEventStore,
            LossMetric,
        },
        TestBackend,
    };
    use std::sync::Arc;

    use super::*;

    #[test]
    fn always_keep_the_best_epoch() {
        let mut store = LogEventStore::default();
        let mut strategy = MetricCheckpointingStrategy::new::<LossMetric<TestBackend>>(
            Aggregate::Mean,
            Direction::Lowest,
            Split::Train,
        );
        let mut metrics = Metrics::<f64, f64>::default();
        // Register an in memory logger.
        store.register_logger_train(InMemoryMetricLogger::default());
        // Register the loss metric.
        metrics.register_train_metric_numeric(LossMetric::<TestBackend>::new());
        let store = Arc::new(EventStoreClient::new(store));
        let mut processor = MinimalEventProcessor::new(metrics, store.clone());

        // Two points for the first epoch. Mean 0.75
        let mut epoch = 1;
        process_train(&mut processor, 1.0, epoch);
        process_train(&mut processor, 0.5, epoch);
        end_epoch(&mut processor, epoch);

        // Should save the current record.
        assert_eq!(
            vec![CheckpointingAction::Save],
            strategy.checkpointing(epoch, &store)
        );

        // Two points for the second epoch. Mean 0.4
        epoch += 1;
        process_train(&mut processor, 0.5, epoch);
        process_train(&mut processor, 0.3, epoch);
        end_epoch(&mut processor, epoch);

        // Should save the current record and delete the pervious one.
        assert_eq!(
            vec![CheckpointingAction::Delete(1), CheckpointingAction::Save],
            strategy.checkpointing(epoch, &store)
        );

        // Two points for the last epoch. Mean 2.0
        epoch += 1;
        process_train(&mut processor, 1.0, epoch);
        process_train(&mut processor, 3.0, epoch);
        end_epoch(&mut processor, epoch);

        // Should not delete the previous record, since it's the best one, and should not save a
        // new one.
        assert!(strategy.checkpointing(epoch, &store).is_empty());
    }
}