use super::TrainingHistory;
use crate::error::Result;
use scirs2_core::ndarray::ArrayD;
use std::collections::HashMap;
pub trait Callback: Send + Sync {
fn on_epoch_end(&self, epoch: usize, history: &TrainingHistory) -> Result<()>;
}
pub struct EarlyStopping {
monitor: String,
min_delta: f64,
patience: usize,
best: f64,
wait: usize,
stopped: bool,
}
impl EarlyStopping {
pub fn new(monitor: String, min_delta: f64, patience: usize) -> Self {
Self {
monitor,
min_delta,
patience,
best: f64::INFINITY,
wait: 0,
stopped: false,
}
}
}
impl Callback for EarlyStopping {
fn on_epoch_end(&self, _epoch: usize, _history: &TrainingHistory) -> Result<()> {
Ok(())
}
}
pub struct ModelCheckpoint {
filepath: String,
monitor: String,
save_best_only: bool,
mode: String,
best: f64,
}
impl ModelCheckpoint {
pub fn new(filepath: impl Into<String>) -> Self {
Self {
filepath: filepath.into(),
monitor: "val_loss".to_string(),
save_best_only: false,
mode: "min".to_string(),
best: f64::INFINITY,
}
}
pub fn monitor(mut self, monitor: impl Into<String>) -> Self {
self.monitor = monitor.into();
self
}
pub fn save_best_only(mut self, save_best_only: bool) -> Self {
self.save_best_only = save_best_only;
self
}
pub fn mode(mut self, mode: impl Into<String>) -> Self {
self.mode = mode.into();
if self.mode == "max" {
self.best = f64::NEG_INFINITY;
}
self
}
}
impl Callback for ModelCheckpoint {
fn on_epoch_end(&self, _epoch: usize, history: &TrainingHistory) -> Result<()> {
if let Some(¤t) = history.loss.last() {
let _ = current;
}
Ok(())
}
}
pub struct CSVLogger {
filename: String,
separator: String,
append: bool,
logged_data: Vec<HashMap<String, f64>>,
}
impl CSVLogger {
pub fn new(filename: impl Into<String>) -> Self {
Self {
filename: filename.into(),
separator: ",".to_string(),
append: false,
logged_data: Vec::new(),
}
}
pub fn separator(mut self, separator: impl Into<String>) -> Self {
self.separator = separator.into();
self
}
pub fn append(mut self, append: bool) -> Self {
self.append = append;
self
}
pub fn get_logged_data(&self) -> &[HashMap<String, f64>] {
&self.logged_data
}
}
impl Callback for CSVLogger {
fn on_epoch_end(&self, epoch: usize, history: &TrainingHistory) -> Result<()> {
let _ = (epoch, history);
Ok(())
}
}
pub struct ReduceLROnPlateau {
monitor: String,
factor: f64,
patience: usize,
min_lr: f64,
mode: String,
best: f64,
wait: usize,
current_lr: f64,
}
impl ReduceLROnPlateau {
pub fn new() -> Self {
Self {
monitor: "val_loss".to_string(),
factor: 0.1,
patience: 10,
min_lr: 1e-7,
mode: "min".to_string(),
best: f64::INFINITY,
wait: 0,
current_lr: 0.001,
}
}
pub fn monitor(mut self, monitor: impl Into<String>) -> Self {
self.monitor = monitor.into();
self
}
pub fn factor(mut self, factor: f64) -> Self {
self.factor = factor;
self
}
pub fn patience(mut self, patience: usize) -> Self {
self.patience = patience;
self
}
pub fn get_lr(&self) -> f64 {
self.current_lr
}
}
impl Default for ReduceLROnPlateau {
fn default() -> Self {
Self::new()
}
}
impl Callback for ReduceLROnPlateau {
fn on_epoch_end(&self, _epoch: usize, history: &TrainingHistory) -> Result<()> {
if let Some(¤t) = history.loss.last() {
let _ = current;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub enum RegularizerType {
L1(f64),
L2(f64),
L1L2 { l1: f64, l2: f64 },
}
impl RegularizerType {
pub fn compute(&self, weights: &ArrayD<f64>) -> f64 {
match self {
RegularizerType::L1(l1) => l1 * weights.iter().map(|w| w.abs()).sum::<f64>(),
RegularizerType::L2(l2) => l2 * weights.iter().map(|w| w * w).sum::<f64>(),
RegularizerType::L1L2 { l1, l2 } => {
l1 * weights.iter().map(|w| w.abs()).sum::<f64>()
+ l2 * weights.iter().map(|w| w * w).sum::<f64>()
}
}
}
}