use std::marker::PhantomData;
use burn::module::AutodiffModule;
use burn::prelude::*;
use burn::tensor::backend::AutodiffBackend;
use serde::{Deserialize, Serialize};
use crate::callback::{CallbackContext, CallbackList, ProgressCallback};
use crate::error::Result;
use crate::metrics::{Accuracy, Metric};
use crate::training::TrainingOutput;
use tsai_data::TSDataLoaders;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LearnerConfig {
pub lr: f64,
pub weight_decay: f64,
pub grad_clip: f64,
pub mixed_precision: bool,
}
impl Default for LearnerConfig {
fn default() -> Self {
Self {
lr: 1e-3,
weight_decay: 0.01,
grad_clip: 0.0,
mixed_precision: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingState {
pub epoch: usize,
pub step: usize,
pub best_valid_loss: f32,
pub history: TrainingHistory,
}
impl Default for TrainingState {
fn default() -> Self {
Self {
epoch: 0,
step: 0,
best_valid_loss: f32::INFINITY,
history: TrainingHistory::default(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TrainingHistory {
pub train_losses: Vec<f32>,
pub valid_losses: Vec<f32>,
pub metrics: Vec<std::collections::HashMap<String, f32>>,
pub lrs: Vec<f64>,
}
pub struct Learner<B, M>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
{
model: M,
dls: TSDataLoaders,
config: LearnerConfig,
state: TrainingState,
callbacks: CallbackList,
metrics: Vec<Box<dyn Metric<B>>>,
device: B::Device,
_backend: PhantomData<B>,
}
impl<B, M> Learner<B, M>
where
B: AutodiffBackend,
M: AutodiffModule<B>,
{
pub fn new(model: M, dls: TSDataLoaders, config: LearnerConfig, device: &B::Device) -> Self {
let mut callbacks = CallbackList::new();
callbacks.add(ProgressCallback::new(false));
let metrics: Vec<Box<dyn Metric<B>>> = vec![Box::new(Accuracy)];
Self {
model,
dls,
config,
state: TrainingState::default(),
callbacks,
metrics,
device: device.clone(),
_backend: PhantomData,
}
}
pub fn add_callback<C: crate::callback::Callback + 'static>(mut self, callback: C) -> Self {
self.callbacks.add(callback);
self
}
pub fn add_metric<M2: Metric<B> + 'static>(mut self, metric: M2) -> Self {
self.metrics.push(Box::new(metric));
self
}
pub fn model(&self) -> &M {
&self.model
}
pub fn model_mut(&mut self) -> &mut M {
&mut self.model
}
pub fn state(&self) -> &TrainingState {
&self.state
}
pub fn history(&self) -> &TrainingHistory {
&self.state.history
}
pub fn device(&self) -> &B::Device {
&self.device
}
pub fn fit_one_cycle<F, G>(
mut self,
n_epochs: usize,
forward_fn: F,
valid_forward_fn: G,
) -> Result<TrainingOutput<M>>
where
M: Clone,
B::FloatElem: From<f32>,
F: Fn(&M, Tensor<B, 3>) -> Tensor<B, 2>,
G: Fn(&M::InnerModule, Tensor<B::InnerBackend, 3>) -> Tensor<B::InnerBackend, 2>,
{
use crate::training::{ClassificationTrainer, ClassificationTrainerConfig};
let mut ctx = CallbackContext::new(n_epochs, self.dls.train().n_batches());
let _ = self.callbacks.before_fit(&mut ctx);
let trainer_config = ClassificationTrainerConfig {
n_epochs,
lr: self.config.lr,
weight_decay: self.config.weight_decay as f32,
grad_clip: self.config.grad_clip as f32,
verbose: true,
early_stopping_patience: 0,
early_stopping_min_delta: 0.001,
};
let trainer = ClassificationTrainer::<B>::new(trainer_config, self.device.clone());
let output = trainer.fit_with_forward(
self.model,
&self.dls,
forward_fn,
valid_forward_fn,
)?;
self.state.history.train_losses = output.train_losses.clone();
self.state.history.valid_losses = output.valid_losses.clone();
self.state.best_valid_loss = output.valid_losses.last().copied().unwrap_or(f32::INFINITY);
let _ = self.callbacks.after_fit(&mut ctx);
Ok(output)
}
pub fn fit_with_early_stopping<F, G>(
self,
n_epochs: usize,
patience: usize,
forward_fn: F,
valid_forward_fn: G,
) -> Result<TrainingOutput<M>>
where
M: Clone,
B::FloatElem: From<f32>,
F: Fn(&M, Tensor<B, 3>) -> Tensor<B, 2>,
G: Fn(&M::InnerModule, Tensor<B::InnerBackend, 3>) -> Tensor<B::InnerBackend, 2>,
{
use crate::training::{ClassificationTrainer, ClassificationTrainerConfig};
let trainer_config = ClassificationTrainerConfig {
n_epochs,
lr: self.config.lr,
weight_decay: self.config.weight_decay as f32,
grad_clip: self.config.grad_clip as f32,
verbose: true,
early_stopping_patience: patience,
early_stopping_min_delta: 0.001,
};
let trainer = ClassificationTrainer::<B>::new(trainer_config, self.device.clone());
trainer.fit_with_forward(
self.model,
&self.dls,
forward_fn,
valid_forward_fn,
)
}
pub fn get_preds<G>(&self, forward_fn: G) -> Result<Predictions<B::InnerBackend>>
where
G: Fn(&M::InnerModule, Tensor<B::InnerBackend, 3>) -> Tensor<B::InnerBackend, 2>,
{
let inner_model = self.model.clone().valid();
let inner_device: <B::InnerBackend as Backend>::Device = self.device.clone().into();
let mut all_preds: Vec<Tensor<B::InnerBackend, 2>> = Vec::new();
let mut all_targets: Vec<Tensor<B::InnerBackend, 2>> = Vec::new();
for batch_result in self.dls.valid().iter::<B::InnerBackend>(&inner_device) {
let batch = batch_result?;
let x = batch.x.inner().clone();
let logits = forward_fn(&inner_model, x);
all_preds.push(logits);
if let Some(y) = batch.y {
all_targets.push(y);
}
}
let preds = Tensor::cat(all_preds, 0);
let mut predictions = Predictions::new(preds);
if !all_targets.is_empty() {
let targets = Tensor::cat(all_targets, 0);
predictions = predictions.with_targets(targets);
}
Ok(predictions)
}
}
#[derive(Debug)]
pub struct Predictions<B: Backend> {
pub x: Option<Tensor<B, 3>>,
pub preds: Tensor<B, 2>,
pub targets: Option<Tensor<B, 2>>,
pub decoded: Tensor<B, 1, Int>,
pub losses: Option<Tensor<B, 1>>,
}
impl<B: Backend> Predictions<B> {
pub fn new(preds: Tensor<B, 2>) -> Self {
let decoded = preds.clone().argmax(1).squeeze(1);
Self {
x: None,
preds,
targets: None,
decoded,
losses: None,
}
}
pub fn with_x(mut self, x: Tensor<B, 3>) -> Self {
self.x = Some(x);
self
}
pub fn with_targets(mut self, targets: Tensor<B, 2>) -> Self {
self.targets = Some(targets);
self
}
pub fn with_losses(mut self, losses: Tensor<B, 1>) -> Self {
self.losses = Some(losses);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_learner_config_default() {
let config = LearnerConfig::default();
assert_eq!(config.lr, 1e-3);
assert_eq!(config.weight_decay, 0.01);
}
#[test]
fn test_training_state_default() {
let state = TrainingState::default();
assert_eq!(state.epoch, 0);
assert_eq!(state.step, 0);
}
}