use crate::LearningComponentsMarker;
use crate::checkpoint::{
AsyncCheckpointer, Checkpointer, CheckpointingAction, CheckpointingStrategy,
};
use crate::components::{LearningComponentsTypes, TrainingBackend};
use crate::metric::store::EventStoreClient;
use crate::{
CloneEarlyStoppingStrategy, InferenceStep, TrainOutput, TrainStep, TrainingModelInput,
TrainingModelOutput,
};
use burn_core::module::{AutodiffModule, Module};
use burn_core::tensor::Device;
use burn_core::tensor::backend::AutodiffBackend;
use burn_optim::lr_scheduler::LrScheduler;
use burn_optim::{GradientsParams, MultiGradientsParams, Optimizer};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
pub type LearnerModelRecord<LC> =
<<LC as LearningComponentsTypes>::TrainingModel as Module<TrainingBackend<LC>>>::Record;
pub type LearnerOptimizerRecord<LC> = <<LC as LearningComponentsTypes>::Optimizer as Optimizer<
<LC as LearningComponentsTypes>::TrainingModel,
TrainingBackend<LC>,
>>::Record;
pub type LearnerSchedulerRecord<LC> =
<<LC as LearningComponentsTypes>::LrScheduler as LrScheduler>::Record<TrainingBackend<LC>>;
pub struct Learner<LC: LearningComponentsTypes> {
pub(crate) model: LC::TrainingModel,
optim: LC::Optimizer,
lr_scheduler: LC::LrScheduler,
lr: f64,
}
impl<LC: LearningComponentsTypes> Clone for Learner<LC> {
fn clone(&self) -> Self {
Self {
model: self.model.clone(),
optim: self.optim.clone(),
lr_scheduler: self.lr_scheduler.clone(),
lr: self.lr,
}
}
}
impl<B, LR, M, O> Learner<LearningComponentsMarker<B, LR, M, O>>
where
B: AutodiffBackend,
LR: LrScheduler + 'static,
M: TrainStep + AutodiffModule<B> + core::fmt::Display + 'static,
M::InnerModule: InferenceStep,
O: Optimizer<M, B> + 'static,
{
pub fn new(model: M, optim: O, lr_scheduler: LR) -> Self {
Self {
model,
optim,
lr_scheduler,
lr: 0.0,
}
}
}
impl<LC: LearningComponentsTypes> Learner<LC> {
pub fn fork(&mut self, device: &Device<TrainingBackend<LC>>) {
self.model = self.model().fork(device);
}
pub fn model(&self) -> LC::TrainingModel {
self.model.clone()
}
pub fn lr_current(&self) -> f64 {
self.lr
}
pub fn lr_step(&mut self) {
self.lr = self.lr_scheduler.step();
}
pub fn train_step(&self, item: TrainingModelInput<LC>) -> TrainOutput<TrainingModelOutput<LC>> {
self.model.step(item)
}
pub fn optimizer_step(&mut self, grads: GradientsParams) {
self.model = self.model().optimize(&mut self.optim, self.lr, grads);
}
pub fn optimizer_step_multi(&mut self, grads: MultiGradientsParams) {
self.model = self.model().optimize_multi(&mut self.optim, self.lr, grads);
}
pub fn load_model(&mut self, record: LearnerModelRecord<LC>) {
self.model = self.model.clone().load_record(record);
}
pub fn load_optim(&mut self, record: LearnerOptimizerRecord<LC>) {
self.optim = self.optim.clone().load_record(record);
}
pub fn load_scheduler(&mut self, record: LearnerSchedulerRecord<LC>) {
self.lr_scheduler = self.lr_scheduler.clone().load_record(record);
}
}
#[derive(new)]
pub struct LearningCheckpointer<LC: LearningComponentsTypes> {
model: AsyncCheckpointer<LearnerModelRecord<LC>, LC::Backend>,
optim: AsyncCheckpointer<LearnerOptimizerRecord<LC>, LC::Backend>,
lr_scheduler: AsyncCheckpointer<LearnerSchedulerRecord<LC>, LC::Backend>,
strategy: Box<dyn CheckpointingStrategy>,
}
impl<LC: LearningComponentsTypes> LearningCheckpointer<LC> {
pub fn checkpoint(&mut self, learner: &Learner<LC>, epoch: usize, store: &EventStoreClient) {
let actions = self.strategy.checkpointing(epoch, store);
for action in actions {
match action {
CheckpointingAction::Delete(epoch) => {
self.model
.delete(epoch)
.expect("Can delete model checkpoint.");
self.optim
.delete(epoch)
.expect("Can delete optimizer checkpoint.");
self.lr_scheduler
.delete(epoch)
.expect("Can delete learning rate scheduler checkpoint.");
}
CheckpointingAction::Save => {
self.model
.save(epoch, learner.model.clone().into_record())
.expect("Can save model checkpoint.");
self.optim
.save(epoch, learner.optim.to_record())
.expect("Can save optimizer checkpoint.");
self.lr_scheduler
.save(epoch, learner.lr_scheduler.to_record())
.expect("Can save learning rate scheduler checkpoint.");
}
}
}
}
pub fn load_checkpoint(
&self,
mut learner: Learner<LC>,
device: &Device<LC::Backend>,
epoch: usize,
) -> Learner<LC> {
let record = self
.model
.restore(epoch, device)
.expect("Can load model checkpoint.");
learner.load_model(record);
let record = self
.optim
.restore(epoch, device)
.expect("Can load optimizer checkpoint.");
learner.load_optim(record);
let record = self
.lr_scheduler
.restore(epoch, device)
.expect("Can load learning rate scheduler checkpoint.");
learner.load_scheduler(record);
learner
}
}
pub(crate) type EarlyStoppingStrategyRef = Box<dyn CloneEarlyStoppingStrategy>;
#[derive(Clone, Default)]
pub struct Interrupter {
state: Arc<AtomicBool>,
message: Arc<Mutex<Option<String>>>,
}
impl Interrupter {
pub fn new() -> Self {
Self::default()
}
pub fn stop(&self, reason: Option<&str>) {
self.state.store(true, Ordering::Relaxed);
reason.inspect(|r| {
let mut message = self.message.lock().unwrap();
*message = Some(String::from(*r));
});
}
pub fn reset(&self) {
self.state.store(false, Ordering::Relaxed);
}
pub fn should_stop(&self) -> bool {
self.state.load(Ordering::Relaxed)
}
pub fn get_message(&self) -> Option<String> {
let message = self.message.lock().unwrap();
message.clone()
}
}