use std::path::Path;
use super::TrainingArguments;
use crate::{
error::Result,
session::SessionInputs,
training::{Checkpoint, Optimizer, Trainer}
};
#[derive(Clone)]
#[non_exhaustive]
pub struct TrainerState {
pub epoch: Option<f32>,
pub global_step: usize,
pub iter_step: usize,
pub gradient_accumulation_steps: usize,
pub max_steps: usize,
pub current_lr: f32
}
impl TrainerState {
pub(crate) fn new<I: Into<SessionInputs<'static, 'static, NI>>, L: Into<SessionInputs<'static, 'static, NL>>, const NI: usize, const NL: usize>(
args: &TrainingArguments<I, L, NI, NL>
) -> Self {
Self {
epoch: None,
global_step: 0,
iter_step: 0,
gradient_accumulation_steps: args.gradient_accumulation_steps,
max_steps: args.max_steps,
current_lr: args.lr
}
}
}
pub struct TrainerControl<'t> {
pub(crate) halt: bool,
pub(crate) lr: Option<f32>,
trainer: &'t Trainer
}
impl<'t> TrainerControl<'t> {
pub(crate) fn new(trainer: &'t Trainer) -> Self {
Self { halt: false, trainer, lr: None }
}
pub fn halt(&mut self) {
self.halt = true;
}
pub fn set_lr(&mut self, lr: f32) {
self.lr = Some(lr);
}
pub fn export<O: AsRef<str>>(&self, out_path: impl AsRef<Path>, output_names: impl AsRef<[O]>) -> Result<()> {
self.trainer.export(out_path, output_names)
}
pub fn optimizer(&self) -> Optimizer<'_> {
self.trainer.optimizer()
}
pub fn checkpoint(&self) -> &Checkpoint {
self.trainer.checkpoint()
}
}
#[allow(unused_variables)]
pub trait TrainerCallbacks: Send {
fn epoch(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
fn eval_begin(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
fn eval_end(&mut self, eval_loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
fn train_step(&mut self, train_loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
fn optimizer_step(&mut self, loss: f32, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
fn end(&mut self, state: &TrainerState, control: &mut TrainerControl<'_>) -> Result<()> {
Ok(())
}
}