use crate::checkpoint::EarlyStopping;
use crate::distributed::{GradientSync, LocalGradientSync};
use crate::error::ModelResult;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
pub trait DataProvider: Send {
fn num_samples(&self) -> usize;
fn num_features(&self) -> usize;
fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>);
fn shuffle_indices(&self, rng_seed: u64) -> Vec<usize> {
let n = self.num_samples();
let mut indices: Vec<usize> = (0..n).collect();
let mut state = rng_seed.wrapping_add(1);
for i in (1..n).rev() {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let j = (state >> 33) as usize % (i + 1);
indices.swap(i, j);
}
indices
}
}
pub struct ArrayDataProvider {
features: Array2<f32>,
targets: Array1<f32>,
}
impl ArrayDataProvider {
pub fn new(features: Array2<f32>, targets: Array1<f32>) -> Self {
debug_assert_eq!(
features.nrows(),
targets.len(),
"features and targets must have the same number of samples"
);
Self { features, targets }
}
}
impl DataProvider for ArrayDataProvider {
fn num_samples(&self) -> usize {
self.features.nrows()
}
fn num_features(&self) -> usize {
self.features.ncols()
}
fn get_batch(&self, indices: &[usize]) -> (Array2<f32>, Array1<f32>) {
let nf = self.num_features();
let nb = indices.len();
let mut feat = Array2::<f32>::zeros((nb, nf));
let mut tgt = Array1::<f32>::zeros(nb);
for (batch_idx, &sample_idx) in indices.iter().enumerate() {
let sample_idx = sample_idx.min(self.features.nrows().saturating_sub(1));
feat.row_mut(batch_idx)
.assign(&self.features.row(sample_idx));
tgt[batch_idx] = self.targets[sample_idx];
}
(feat, tgt)
}
}
pub trait Optimizer: Send {
fn step(
&mut self,
weights: &mut Array1<f32>,
bias: &mut f32,
weight_grad: &Array1<f32>,
bias_grad: f32,
);
fn learning_rate(&self) -> f32;
fn set_learning_rate(&mut self, lr: f32);
}
pub struct SgdOptimizer {
lr: f32,
}
impl SgdOptimizer {
pub fn new(lr: f32) -> Self {
Self { lr }
}
}
impl Optimizer for SgdOptimizer {
fn step(
&mut self,
weights: &mut Array1<f32>,
bias: &mut f32,
weight_grad: &Array1<f32>,
bias_grad: f32,
) {
*weights = weights.clone() - self.lr * weight_grad;
*bias -= self.lr * bias_grad;
}
fn learning_rate(&self) -> f32 {
self.lr
}
fn set_learning_rate(&mut self, lr: f32) {
self.lr = lr;
}
}
pub struct AdamOptimizer {
lr: f32,
beta1: f32,
beta2: f32,
epsilon: f32,
m_w: Option<Array1<f32>>,
v_w: Option<Array1<f32>>,
m_b: f32,
v_b: f32,
t: u64,
}
impl AdamOptimizer {
pub fn new(lr: f32) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-8,
m_w: None,
v_w: None,
m_b: 0.0,
v_b: 0.0,
t: 0,
}
}
}
impl Optimizer for AdamOptimizer {
fn step(
&mut self,
weights: &mut Array1<f32>,
bias: &mut f32,
weight_grad: &Array1<f32>,
bias_grad: f32,
) {
self.t += 1;
let t = self.t as f32;
let n = weights.len();
let m_w = self.m_w.get_or_insert_with(|| Array1::<f32>::zeros(n));
let v_w = self.v_w.get_or_insert_with(|| Array1::<f32>::zeros(n));
*m_w = self.beta1 * m_w.clone() + (1.0 - self.beta1) * weight_grad;
let grad_sq = weight_grad.mapv(|x| x * x);
*v_w = self.beta2 * v_w.clone() + (1.0 - self.beta2) * &grad_sq;
let bc1 = 1.0 - self.beta1.powf(t);
let bc2 = 1.0 - self.beta2.powf(t);
let m_hat = m_w.clone() / bc1;
let v_hat = v_w.clone() / bc2;
*weights = weights.clone() - self.lr * &m_hat / (v_hat.mapv(|x| x.sqrt()) + self.epsilon);
self.m_b = self.beta1 * self.m_b + (1.0 - self.beta1) * bias_grad;
self.v_b = self.beta2 * self.v_b + (1.0 - self.beta2) * bias_grad * bias_grad;
let mb_hat = self.m_b / bc1;
let vb_hat = self.v_b / bc2;
*bias -= self.lr * mb_hat / (vb_hat.sqrt() + self.epsilon);
}
fn learning_rate(&self) -> f32 {
self.lr
}
fn set_learning_rate(&mut self, lr: f32) {
self.lr = lr;
}
}
pub trait LrScheduler: Send {
fn step(&mut self, epoch: usize, val_loss: Option<f32>) -> f32;
fn current_lr(&self) -> f32;
}
pub struct ConstantScheduler {
lr: f32,
}
impl ConstantScheduler {
pub fn new(lr: f32) -> Self {
Self { lr }
}
}
impl LrScheduler for ConstantScheduler {
fn step(&mut self, _epoch: usize, _val_loss: Option<f32>) -> f32 {
self.lr
}
fn current_lr(&self) -> f32 {
self.lr
}
}
pub struct ExponentialScheduler {
decay_rate: f32,
min_lr: f32,
current: f32,
}
impl ExponentialScheduler {
pub fn new(initial_lr: f32, decay_rate: f32, min_lr: f32) -> Self {
Self {
decay_rate,
min_lr,
current: initial_lr,
}
}
}
impl LrScheduler for ExponentialScheduler {
fn step(&mut self, _epoch: usize, _val_loss: Option<f32>) -> f32 {
self.current = (self.current * self.decay_rate).max(self.min_lr);
self.current
}
fn current_lr(&self) -> f32 {
self.current
}
}
pub struct StepDecayScheduler {
step_size: usize,
gamma: f32,
current: f32,
}
impl StepDecayScheduler {
pub fn new(initial_lr: f32, step_size: usize, gamma: f32) -> Self {
Self {
step_size,
gamma,
current: initial_lr,
}
}
}
impl LrScheduler for StepDecayScheduler {
fn step(&mut self, epoch: usize, _val_loss: Option<f32>) -> f32 {
if epoch > 0 && epoch.is_multiple_of(self.step_size) {
self.current *= self.gamma;
}
self.current
}
fn current_lr(&self) -> f32 {
self.current
}
}
pub trait TrainingCallback: Send {
fn on_epoch_start(&mut self, _epoch: usize) {}
fn on_epoch_end(&mut self, _epoch: usize, _train_loss: f32, _val_loss: Option<f32>) {}
fn on_batch_end(&mut self, _epoch: usize, _batch: usize, _loss: f32) {}
fn on_training_end(&mut self, _result: &TrainingResult) {}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingResult {
pub train_losses: Vec<f32>,
pub val_losses: Vec<Option<f32>>,
pub best_epoch: usize,
pub best_val_loss: Option<f32>,
pub epochs_trained: usize,
pub final_train_loss: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingConfig {
pub max_epochs: usize,
pub batch_size: usize,
pub learning_rate: f32,
pub val_fraction: f32,
pub rng_seed: u64,
pub log_every_n_epochs: usize,
}
impl Default for TrainingConfig {
fn default() -> Self {
Self {
max_epochs: 100,
batch_size: 32,
learning_rate: 0.01,
val_fraction: 0.1,
rng_seed: 42,
log_every_n_epochs: 10,
}
}
}
fn mse_linear_backward(
features: &Array2<f32>,
targets: &Array1<f32>,
weights: &Array1<f32>,
bias: f32,
) -> (f32, Array1<f32>, f32) {
let n = features.nrows() as f32;
let nf = features.ncols();
let mut predictions = Array1::<f32>::zeros(features.nrows());
for (i, row) in features.rows().into_iter().enumerate() {
let dot: f32 = row.iter().zip(weights.iter()).map(|(&x, &w)| x * w).sum();
predictions[i] = dot + bias;
}
let residuals = &predictions - targets;
let loss = residuals.iter().map(|&r| r * r).sum::<f32>() / n;
let mut weight_grad = Array1::<f32>::zeros(nf);
for (i, row) in features.rows().into_iter().enumerate() {
let r = residuals[i];
for (j, &x) in row.iter().enumerate() {
weight_grad[j] += 2.0 * x * r / n;
}
}
let bias_grad = 2.0 * residuals.sum() / n;
(loss, weight_grad, bias_grad)
}
fn train_val_split(n: usize, val_fraction: f32, rng_seed: u64) -> (Vec<usize>, Vec<usize>) {
let val_fraction = val_fraction.clamp(0.0, 0.99);
let all_indices: Vec<usize> = lcg_shuffle((0..n).collect(), rng_seed);
let val_count = ((n as f32 * val_fraction).round() as usize).min(n.saturating_sub(1));
let train_count = n - val_count;
let train = all_indices[..train_count].to_vec();
let val = all_indices[train_count..].to_vec();
(train, val)
}
fn lcg_shuffle(mut v: Vec<usize>, seed: u64) -> Vec<usize> {
let n = v.len();
let mut state = seed.wrapping_add(1);
for i in (1..n).rev() {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let j = (state >> 33) as usize % (i + 1);
v.swap(i, j);
}
v
}
pub struct TrainingLoop {
config: TrainingConfig,
callbacks: Vec<Box<dyn TrainingCallback>>,
gradient_sync: Box<dyn GradientSync>,
}
impl TrainingLoop {
pub fn new(config: TrainingConfig) -> Self {
Self {
config,
callbacks: Vec::new(),
gradient_sync: Box::new(LocalGradientSync::new()),
}
}
pub fn add_callback(&mut self, cb: Box<dyn TrainingCallback>) {
self.callbacks.push(cb);
}
pub fn with_gradient_sync(mut self, sync: Box<dyn GradientSync>) -> Self {
self.gradient_sync = sync;
self
}
pub fn run(
&mut self,
data: &dyn DataProvider,
optimizer: &mut dyn Optimizer,
lr_scheduler: &mut dyn LrScheduler,
mut early_stopping: Option<&mut EarlyStopping>,
model_weights: &mut Array1<f32>,
model_bias: &mut f32,
) -> ModelResult<TrainingResult> {
let n = data.num_samples();
let (train_indices, val_indices) =
train_val_split(n, self.config.val_fraction, self.config.rng_seed);
optimizer.set_learning_rate(self.config.learning_rate);
let mut train_losses: Vec<f32> = Vec::with_capacity(self.config.max_epochs);
let mut val_losses: Vec<Option<f32>> = Vec::with_capacity(self.config.max_epochs);
let mut best_val_loss: Option<f32> = None;
let mut best_epoch = 0_usize;
'epoch_loop: for epoch in 0..self.config.max_epochs {
for cb in self.callbacks.iter_mut() {
cb.on_epoch_start(epoch);
}
let shuffled = lcg_shuffle(
train_indices.clone(),
self.config.rng_seed.wrapping_add(epoch as u64),
);
let batch_size = self.config.batch_size.max(1);
let mut epoch_loss_sum = 0.0_f32;
let mut epoch_batches = 0_usize;
let mut batch_idx = 0_usize;
let mut offset = 0_usize;
while offset < shuffled.len() {
let end = (offset + batch_size).min(shuffled.len());
let batch_sample_ids = &shuffled[offset..end];
let (batch_feat, batch_tgt) = data.get_batch(batch_sample_ids);
let (loss, mut weight_grad, bias_grad) =
mse_linear_backward(&batch_feat, &batch_tgt, model_weights, *model_bias);
self.gradient_sync.sync_gradients(&mut weight_grad)?;
optimizer.step(model_weights, model_bias, &weight_grad, bias_grad);
epoch_loss_sum += loss;
epoch_batches += 1;
for cb in self.callbacks.iter_mut() {
cb.on_batch_end(epoch, batch_idx, loss);
}
offset += batch_size;
batch_idx += 1;
}
let epoch_train_loss = if epoch_batches > 0 {
epoch_loss_sum / epoch_batches as f32
} else {
0.0
};
let epoch_val_loss = if !val_indices.is_empty() {
let (val_feat, val_tgt) = data.get_batch(&val_indices);
let (vloss, _, _) =
mse_linear_backward(&val_feat, &val_tgt, model_weights, *model_bias);
Some(vloss)
} else {
None
};
train_losses.push(epoch_train_loss);
val_losses.push(epoch_val_loss);
if let Some(vl) = epoch_val_loss {
if best_val_loss.is_none_or(|best| vl < best) {
best_val_loss = Some(vl);
best_epoch = epoch;
}
}
let new_lr = lr_scheduler.step(epoch, epoch_val_loss);
optimizer.set_learning_rate(new_lr);
if let Some(ref mut es) = early_stopping {
let check_loss = epoch_val_loss.unwrap_or(epoch_train_loss);
if es.should_stop(check_loss) {
for cb in self.callbacks.iter_mut() {
cb.on_epoch_end(epoch, epoch_train_loss, epoch_val_loss);
}
break 'epoch_loop;
}
}
if self.config.log_every_n_epochs > 0 && epoch % self.config.log_every_n_epochs == 0 {
if let Some(vl) = epoch_val_loss {
tracing::info!(
"Epoch {:>4} | train_loss={:.6} | val_loss={:.6} | lr={:.6}",
epoch,
epoch_train_loss,
vl,
lr_scheduler.current_lr()
);
} else {
tracing::info!(
"Epoch {:>4} | train_loss={:.6} | lr={:.6}",
epoch,
epoch_train_loss,
lr_scheduler.current_lr()
);
}
}
for cb in self.callbacks.iter_mut() {
cb.on_epoch_end(epoch, epoch_train_loss, epoch_val_loss);
}
}
let epochs_trained = train_losses.len();
let final_train_loss = train_losses.last().copied().unwrap_or(f32::NAN);
let result = TrainingResult {
train_losses,
val_losses,
best_epoch,
best_val_loss,
epochs_trained,
final_train_loss,
};
for cb in self.callbacks.iter_mut() {
cb.on_training_end(&result);
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array1, Array2};
fn make_linear_dataset(n: usize, noise: f32) -> ArrayDataProvider {
let mut feat_data = vec![0.0_f32; n];
let mut tgt_data = vec![0.0_f32; n];
let mut state: u64 = 12345;
for i in 0..n {
state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let x = i as f32 / n as f32;
let eps = ((state >> 33) as f32 / u32::MAX as f32 - 0.5) * 2.0 * noise;
feat_data[i] = x;
tgt_data[i] = 2.0 * x + 1.0 + eps;
}
let features = Array2::from_shape_vec((n, 1), feat_data).expect("shape ok");
let targets = Array1::from_vec(tgt_data);
ArrayDataProvider::new(features, targets)
}
#[test]
fn test_array_data_provider_batch() {
let provider = make_linear_dataset(50, 0.0);
assert_eq!(provider.num_samples(), 50);
assert_eq!(provider.num_features(), 1);
let indices: Vec<usize> = (0..10).collect();
let (feat, tgt) = provider.get_batch(&indices);
assert_eq!(feat.shape(), &[10, 1]);
assert_eq!(tgt.len(), 10);
}
#[test]
fn test_training_loop_linear_regression_convergence() {
let data = make_linear_dataset(100, 0.05);
let config = TrainingConfig {
max_epochs: 200,
batch_size: 32,
learning_rate: 0.1,
val_fraction: 0.2,
rng_seed: 7,
log_every_n_epochs: 0,
};
let mut optimizer = SgdOptimizer::new(config.learning_rate);
let mut scheduler = ConstantScheduler::new(config.learning_rate);
let mut weights = Array1::<f32>::zeros(1);
let mut bias = 0.0_f32;
let mut training_loop = TrainingLoop::new(config);
let result = training_loop
.run(
&data,
&mut optimizer,
&mut scheduler,
None,
&mut weights,
&mut bias,
)
.expect("training should succeed");
assert!(
result.final_train_loss < 0.1,
"expected final loss < 0.1, got {}",
result.final_train_loss
);
}
#[test]
fn test_training_loop_early_stopping() {
let data = make_linear_dataset(60, 0.0);
let config = TrainingConfig {
max_epochs: 500,
batch_size: 60,
learning_rate: 0.05,
val_fraction: 0.3,
rng_seed: 99,
log_every_n_epochs: 0,
};
let mut optimizer = SgdOptimizer::new(config.learning_rate);
let mut scheduler = ConstantScheduler::new(config.learning_rate);
let mut es = EarlyStopping::new(3, 0.001);
let mut weights = Array1::<f32>::zeros(1);
let mut bias = 0.0_f32;
let mut training_loop = TrainingLoop::new(config.clone());
let result = training_loop
.run(
&data,
&mut optimizer,
&mut scheduler,
Some(&mut es),
&mut weights,
&mut bias,
)
.expect("training should succeed");
assert!(
result.epochs_trained < config.max_epochs,
"expected early stop before {} epochs, trained {} epochs",
config.max_epochs,
result.epochs_trained
);
}
#[test]
fn test_training_loop_lr_scheduling() {
let data = make_linear_dataset(40, 0.0);
let initial_lr = 0.1_f32;
let config = TrainingConfig {
max_epochs: 20,
batch_size: 40,
learning_rate: initial_lr,
val_fraction: 0.0,
rng_seed: 1,
log_every_n_epochs: 0,
};
let mut optimizer = SgdOptimizer::new(initial_lr);
let mut scheduler = StepDecayScheduler::new(initial_lr, 2, 0.9);
let mut weights = Array1::<f32>::zeros(1);
let mut bias = 0.0_f32;
let mut training_loop = TrainingLoop::new(config.clone());
training_loop
.run(
&data,
&mut optimizer,
&mut scheduler,
None,
&mut weights,
&mut bias,
)
.expect("training should succeed");
assert!(
scheduler.current_lr() < initial_lr,
"scheduler should have reduced LR from {initial_lr} but got {}",
scheduler.current_lr()
);
}
#[test]
fn test_training_result_history() {
let data = make_linear_dataset(30, 0.0);
let config = TrainingConfig {
max_epochs: 10,
batch_size: 10,
learning_rate: 0.01,
val_fraction: 0.0,
rng_seed: 5,
log_every_n_epochs: 0,
};
let mut optimizer = SgdOptimizer::new(0.01);
let mut scheduler = ConstantScheduler::new(0.01);
let mut weights = Array1::<f32>::zeros(1);
let mut bias = 0.0_f32;
let mut training_loop = TrainingLoop::new(config.clone());
let result = training_loop
.run(
&data,
&mut optimizer,
&mut scheduler,
None,
&mut weights,
&mut bias,
)
.expect("training should succeed");
assert_eq!(
result.train_losses.len(),
result.epochs_trained,
"train_losses length must match epochs_trained"
);
assert_eq!(
result.val_losses.len(),
result.epochs_trained,
"val_losses length must match epochs_trained"
);
assert_eq!(result.epochs_trained, config.max_epochs);
}
struct EpochCounter {
count: usize,
}
impl TrainingCallback for EpochCounter {
fn on_epoch_end(&mut self, _epoch: usize, _train_loss: f32, _val_loss: Option<f32>) {
self.count += 1;
}
}
#[test]
fn test_training_callback_fired() {
let data = make_linear_dataset(20, 0.0);
let max_epochs = 7;
let config = TrainingConfig {
max_epochs,
batch_size: 20,
learning_rate: 0.01,
val_fraction: 0.0,
rng_seed: 3,
log_every_n_epochs: 0,
};
let mut optimizer = SgdOptimizer::new(0.01);
let mut scheduler = ConstantScheduler::new(0.01);
let mut weights = Array1::<f32>::zeros(1);
let mut bias = 0.0_f32;
let counter = EpochCounter { count: 0 };
let mut training_loop = TrainingLoop::new(config.clone());
training_loop.add_callback(Box::new(counter));
training_loop
.run(
&data,
&mut optimizer,
&mut scheduler,
None,
&mut weights,
&mut bias,
)
.expect("training should succeed");
}
}