use std::collections::HashSet;
use std::sync::Arc;
use super::log::install_file_logger;
use super::Learner;
use crate::checkpoint::{
AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
KeepLastNCheckpoints, MetricCheckpointingStrategy,
};
use crate::components::LearnerComponentsMarker;
use crate::learner::base::TrainingInterrupter;
use crate::learner::EarlyStoppingStrategy;
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::processor::{FullEventProcessor, Metrics};
use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
use crate::metric::{Adaptor, LossMetric, Metric};
use crate::renderer::{default_renderer, MetricsRenderer};
use crate::{LearnerCheckpointer, LearnerSummaryConfig};
use burn_core::lr_scheduler::LrScheduler;
use burn_core::module::AutodiffModule;
use burn_core::optim::Optimizer;
use burn_core::record::FileRecorder;
use burn_core::tensor::backend::AutodiffBackend;
pub struct LearnerBuilder<B, T, V, M, O, S>
where
T: Send + 'static,
V: Send + 'static,
B: AutodiffBackend,
M: AutodiffModule<B>,
O: Optimizer<M, B>,
S: LrScheduler<B>,
{
#[allow(clippy::type_complexity)]
checkpointers: Option<(
AsyncCheckpointer<M::Record, B>,
AsyncCheckpointer<O::Record, B>,
AsyncCheckpointer<S::Record, B>,
)>,
num_epochs: usize,
checkpoint: Option<usize>,
directory: String,
grad_accumulation: Option<usize>,
devices: Vec<B::Device>,
renderer: Option<Box<dyn MetricsRenderer + 'static>>,
metrics: Metrics<T, V>,
event_store: LogEventStore,
interrupter: TrainingInterrupter,
log_to_file: bool,
num_loggers: usize,
checkpointer_strategy: Box<dyn CheckpointingStrategy>,
early_stopping: Option<Box<dyn EarlyStoppingStrategy>>,
summary_metrics: HashSet<String>,
summary: bool,
}
impl<B, T, V, M, O, S> LearnerBuilder<B, T, V, M, O, S>
where
B: AutodiffBackend,
T: Send + 'static,
V: Send + 'static,
M: AutodiffModule<B> + core::fmt::Display + 'static,
O: Optimizer<M, B>,
S: LrScheduler<B>,
{
pub fn new(directory: &str) -> Self {
Self {
num_epochs: 1,
checkpoint: None,
checkpointers: None,
directory: directory.to_string(),
grad_accumulation: None,
devices: vec![B::Device::default()],
metrics: Metrics::default(),
event_store: LogEventStore::default(),
renderer: None,
interrupter: TrainingInterrupter::new(),
log_to_file: true,
num_loggers: 0,
checkpointer_strategy: Box::new(
ComposedCheckpointingStrategy::builder()
.add(KeepLastNCheckpoints::new(2))
.add(MetricCheckpointingStrategy::new::<LossMetric<B>>(
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
))
.build(),
),
early_stopping: None,
summary_metrics: HashSet::new(),
summary: false,
}
}
pub fn metric_loggers<MT, MV>(mut self, logger_train: MT, logger_valid: MV) -> Self
where
MT: MetricLogger + 'static,
MV: MetricLogger + 'static,
{
self.event_store.register_logger_train(logger_train);
self.event_store.register_logger_valid(logger_valid);
self.num_loggers += 1;
self
}
pub fn with_checkpointing_strategy<CS>(&mut self, strategy: CS)
where
CS: CheckpointingStrategy + 'static,
{
self.checkpointer_strategy = Box::new(strategy);
}
pub fn renderer<MR>(mut self, renderer: MR) -> Self
where
MR: MetricsRenderer + 'static,
{
self.renderer = Some(Box::new(renderer));
self
}
pub fn metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
T: Adaptor<Me::Input>,
{
self.metrics.register_train_metric(metric);
self
}
pub fn metric_valid<Me: Metric + 'static>(mut self, metric: Me) -> Self
where
V: Adaptor<Me::Input>,
{
self.metrics.register_valid_metric(metric);
self
}
pub fn grads_accumulation(mut self, accumulation: usize) -> Self {
self.grad_accumulation = Some(accumulation);
self
}
pub fn metric_train_numeric<Me>(mut self, metric: Me) -> Self
where
Me: Metric + crate::metric::Numeric + 'static,
T: Adaptor<Me::Input>,
{
self.summary_metrics.insert(Me::NAME.to_string());
self.metrics.register_train_metric_numeric(metric);
self
}
pub fn metric_valid_numeric<Me: Metric + crate::metric::Numeric + 'static>(
mut self,
metric: Me,
) -> Self
where
V: Adaptor<Me::Input>,
{
self.summary_metrics.insert(Me::NAME.to_string());
self.metrics.register_valid_metric_numeric(metric);
self
}
pub fn num_epochs(mut self, num_epochs: usize) -> Self {
self.num_epochs = num_epochs;
self
}
pub fn devices(mut self, devices: Vec<B::Device>) -> Self {
self.devices = devices;
self
}
pub fn checkpoint(mut self, checkpoint: usize) -> Self {
self.checkpoint = Some(checkpoint);
self
}
pub fn interrupter(&self) -> TrainingInterrupter {
self.interrupter.clone()
}
pub fn early_stopping<Strategy>(mut self, strategy: Strategy) -> Self
where
Strategy: EarlyStoppingStrategy + 'static,
{
self.early_stopping = Some(Box::new(strategy));
self
}
pub fn log_to_file(mut self, enabled: bool) -> Self {
self.log_to_file = enabled;
self
}
pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
where
FR: FileRecorder<B> + 'static,
FR: FileRecorder<B::InnerBackend> + 'static,
O::Record: 'static,
M::Record: 'static,
S::Record: 'static,
{
let checkpointer_model = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"model",
);
let checkpointer_optimizer = FileCheckpointer::new(
recorder.clone(),
format!("{}/checkpoint", self.directory).as_str(),
"optim",
);
let checkpointer_scheduler = FileCheckpointer::new(
recorder,
format!("{}/checkpoint", self.directory).as_str(),
"scheduler",
);
self.checkpointers = Some((
AsyncCheckpointer::new(checkpointer_model),
AsyncCheckpointer::new(checkpointer_optimizer),
AsyncCheckpointer::new(checkpointer_scheduler),
));
self
}
pub fn summary(mut self) -> Self {
self.summary = true;
self
}
#[allow(clippy::type_complexity)] pub fn build(
mut self,
model: M,
optim: O,
lr_scheduler: S,
) -> Learner<
LearnerComponentsMarker<
B,
S,
M,
O,
AsyncCheckpointer<M::Record, B>,
AsyncCheckpointer<O::Record, B>,
AsyncCheckpointer<S::Record, B>,
FullEventProcessor<T, V>,
Box<dyn CheckpointingStrategy>,
>,
>
where
M::Record: 'static,
O::Record: 'static,
S::Record: 'static,
{
if self.log_to_file {
self.init_logger();
}
let renderer = self.renderer.unwrap_or_else(|| {
Box::new(default_renderer(self.interrupter.clone(), self.checkpoint))
});
let directory = &self.directory;
if self.num_loggers == 0 {
self.event_store
.register_logger_train(FileMetricLogger::new(
format!("{directory}/train").as_str(),
));
self.event_store
.register_logger_valid(FileMetricLogger::new(
format!("{directory}/valid").as_str(),
));
}
let event_store = Arc::new(EventStoreClient::new(self.event_store));
let event_processor = FullEventProcessor::new(self.metrics, renderer, event_store.clone());
let checkpointer = self.checkpointers.map(|(model, optim, scheduler)| {
LearnerCheckpointer::new(model, optim, scheduler, self.checkpointer_strategy)
});
let summary = if self.summary {
Some(LearnerSummaryConfig {
directory: self.directory,
metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
})
} else {
None
};
Learner {
model,
optim,
lr_scheduler,
checkpointer,
num_epochs: self.num_epochs,
event_processor,
event_store,
checkpoint: self.checkpoint,
grad_accumulation: self.grad_accumulation,
devices: self.devices,
interrupter: self.interrupter,
early_stopping: self.early_stopping,
summary,
}
}
fn init_logger(&self) {
let file_path = format!("{}/experiment.log", self.directory);
install_file_logger(file_path.as_str());
}
}