use scirs2_core::ndarray::ScalarOperand;
use scirs2_core::numeric::{Float, NumAssign};
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct History<F: Float + NumAssign> {
pub train_loss: Vec<F>,
pub val_loss: Vec<F>,
pub metrics: HashMap<String, Vec<F>>,
}
impl<F: Float + NumAssign> Default for History<F> {
fn default() -> Self {
Self {
train_loss: Vec::new(),
val_loss: Vec::new(),
metrics: HashMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct TrainingConfig {
pub epochs: usize,
pub batch_size: usize,
pub learning_rate: f64,
pub shuffle: bool,
pub validation_split: f64,
pub verbose: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
epochs: 10,
batch_size: 32,
learning_rate: 0.01,
shuffle: true,
validation_split: 0.2,
verbose: 1,
}
}
}
pub struct Trainer<F: Float + Debug + ScalarOperand + NumAssign> {
pub config: TrainingConfig,
pub history: History<F>,
}
impl<F: Float + Debug + ScalarOperand + NumAssign> Trainer<F> {
pub fn new(config: TrainingConfig) -> Self {
Self {
config,
history: History::default(),
}
}
}