use scivex_core::Float;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallbackAction {
Continue,
Stop,
}
pub trait Callback<T: Float> {
fn on_epoch_end(&mut self, epoch: usize, loss: T) -> CallbackAction;
fn on_train_begin(&mut self) {
}
fn on_train_end(&mut self) {
}
}
pub struct EarlyStopping<T: Float> {
patience: usize,
min_delta: T,
best_loss: Option<T>,
wait: usize,
}
impl<T: Float> EarlyStopping<T> {
pub fn new(patience: usize, min_delta: T) -> Self {
Self {
patience,
min_delta,
best_loss: None,
wait: 0,
}
}
}
impl<T: Float> Callback<T> for EarlyStopping<T> {
fn on_epoch_end(&mut self, _epoch: usize, loss: T) -> CallbackAction {
match self.best_loss {
None => {
self.best_loss = Some(loss);
self.wait = 0;
CallbackAction::Continue
}
Some(best) => {
if best - loss > self.min_delta {
self.best_loss = Some(loss);
self.wait = 0;
CallbackAction::Continue
} else {
self.wait += 1;
if self.wait >= self.patience {
CallbackAction::Stop
} else {
CallbackAction::Continue
}
}
}
}
}
fn on_train_begin(&mut self) {
self.best_loss = None;
self.wait = 0;
}
}
pub struct ModelCheckpoint<T: Float> {
best_loss: Option<T>,
best_epoch: usize,
}
impl<T: Float> ModelCheckpoint<T> {
pub fn new() -> Self {
Self {
best_loss: None,
best_epoch: 0,
}
}
pub fn best_epoch(&self) -> usize {
self.best_epoch
}
pub fn best_loss(&self) -> Option<T> {
self.best_loss
}
}
impl<T: Float> Default for ModelCheckpoint<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> Callback<T> for ModelCheckpoint<T> {
fn on_epoch_end(&mut self, epoch: usize, loss: T) -> CallbackAction {
let is_best = match self.best_loss {
None => true,
Some(prev) => loss < prev,
};
if is_best {
self.best_loss = Some(loss);
self.best_epoch = epoch;
}
CallbackAction::Continue
}
fn on_train_begin(&mut self) {
self.best_loss = None;
self.best_epoch = 0;
}
}
pub struct LossLogger<T: Float> {
losses: Vec<T>,
}
impl<T: Float> LossLogger<T> {
pub fn new() -> Self {
Self { losses: Vec::new() }
}
pub fn losses(&self) -> &[T] {
&self.losses
}
}
impl<T: Float> Default for LossLogger<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> Callback<T> for LossLogger<T> {
fn on_epoch_end(&mut self, _epoch: usize, loss: T) -> CallbackAction {
self.losses.push(loss);
CallbackAction::Continue
}
fn on_train_begin(&mut self) {
self.losses.clear();
}
}