use crate::layers::LayerSpec;
use crate::trainer::{dense_sgd_step, DenseSgdConfig, DenseSgdScratch, TrainError};
pub struct DenseSgdEpochArgs<'a> {
pub layers: &'a [usize],
pub weights: &'a mut [f32],
pub biases: &'a mut [f32],
pub inputs: &'a [f32],
pub targets: &'a [f32],
pub batch_size: usize,
pub layer_specs_scratch: &'a mut [LayerSpec],
pub activations_scratch: &'a mut [f32],
pub deltas_scratch: &'a mut [f32],
pub config: DenseSgdConfig,
}
pub fn dense_sgd_epoch(args: DenseSgdEpochArgs) -> Result<f32, TrainError> {
let DenseSgdEpochArgs {
layers,
weights,
biases,
inputs,
targets,
batch_size,
layer_specs_scratch,
activations_scratch,
deltas_scratch,
config,
} = args;
if batch_size == 0 || layers.len() < 2 {
return Err(TrainError::InvalidShape);
}
let in_size = layers[0];
let out_size = layers[layers.len() - 1];
if inputs.len() != batch_size.saturating_mul(in_size) || targets.len() != batch_size.saturating_mul(out_size) {
return Err(TrainError::InvalidShape);
}
let mut loss_sum = 0.0f32;
for b in 0..batch_size {
let in_off = b * in_size;
let out_off = b * out_size;
let mut scratch = DenseSgdScratch {
layer_specs_scratch,
activations_scratch,
deltas_scratch,
};
let loss = dense_sgd_step(
layers,
weights,
biases,
&inputs[in_off..in_off + in_size],
&targets[out_off..out_off + out_size],
&mut scratch,
config,
)?;
loss_sum += loss;
}
Ok(loss_sum / batch_size as f32)
}