use crate::activations::ActivationKind;
use crate::layers::{build_dense_specs_from_layers, LayerSpec};
use crate::losses::{loss_and_gradient, LossError, LossKind};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DenseSgdConfig {
pub learning_rate: f32,
pub hidden_activation: ActivationKind,
pub output_activation: ActivationKind,
pub loss: LossKind,
pub gradient_clip: Option<f32>,
}
impl DenseSgdConfig {
pub const fn new(
learning_rate: f32,
hidden_activation: ActivationKind,
output_activation: ActivationKind,
loss: LossKind,
) -> Self {
Self {
learning_rate,
hidden_activation,
output_activation,
loss,
gradient_clip: None,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TrainError {
InvalidShape,
InvalidConfig,
CountMismatch,
BufferTooSmall,
ForwardNaN,
LossError,
}
pub fn required_train_buffer_len(layers: &[usize]) -> Option<usize> {
if layers.is_empty() {
return None;
}
let mut total = 0usize;
for &size in layers {
total = total.checked_add(size)?;
}
Some(total)
}
pub struct DenseSgdScratch<'a> {
pub layer_specs_scratch: &'a mut [LayerSpec],
pub activations_scratch: &'a mut [f32],
pub deltas_scratch: &'a mut [f32],
}
pub fn dense_sgd_step(
layers: &[usize],
weights: &mut [f32],
biases: &mut [f32],
input: &[f32],
target: &[f32],
scratch: &mut DenseSgdScratch,
config: DenseSgdConfig,
) -> Result<f32, TrainError> {
if layers.len() < 2 {
return Err(TrainError::InvalidShape);
}
if !config.learning_rate.is_finite() || config.learning_rate <= 0.0 {
return Err(TrainError::InvalidConfig);
}
if input.len() != layers[0] || target.len() != layers[layers.len() - 1] {
return Err(TrainError::InvalidShape);
}
let layer_count = build_dense_specs_from_layers(
layers,
config.hidden_activation,
config.output_activation,
weights.len(),
biases.len(),
scratch.layer_specs_scratch,
)
.map_err(map_layer_error)?;
let layer_specs = &scratch.layer_specs_scratch[..layer_count];
let mut layer_offsets = [0usize; 128];
if layers.len() > layer_offsets.len() {
return Err(TrainError::BufferTooSmall);
}
let mut running = 0usize;
for (i, &size) in layers.iter().enumerate() {
layer_offsets[i] = running;
running = running.checked_add(size).ok_or(TrainError::BufferTooSmall)?;
}
let required = running;
if scratch.activations_scratch.len() < required || scratch.deltas_scratch.len() < required {
return Err(TrainError::BufferTooSmall);
}
scratch.activations_scratch[..layers[0]].copy_from_slice(input);
for (layer_idx, spec) in layer_specs.iter().enumerate() {
let LayerSpec::Dense(dense) = *spec;
let prev_off = layer_offsets[layer_idx];
let curr_off = layer_offsets[layer_idx + 1];
let (left, right) = scratch.activations_scratch.split_at_mut(curr_off);
let prev = &left[prev_off..prev_off + dense.input_size];
let curr = &mut right[..dense.output_size];
let w_len = dense
.input_size
.checked_mul(dense.output_size)
.ok_or(TrainError::InvalidShape)?;
let w = &weights[dense.weight_offset..dense.weight_offset + w_len];
let b = &biases[dense.bias_offset..dense.bias_offset + dense.output_size];
forward_dense_one(prev, curr, w, b, dense.input_size, dense.output_size, dense.activation)?;
}
let out_idx = layers.len() - 1;
let out_off = layer_offsets[out_idx];
let out_size = layers[out_idx];
let out_activations = &scratch.activations_scratch[out_off..out_off + out_size];
let out_deltas = &mut scratch.deltas_scratch[out_off..out_off + out_size];
let mut loss_grad = [0.0f32; 4096];
if out_size > loss_grad.len() {
return Err(TrainError::BufferTooSmall);
}
let loss = loss_and_gradient(config.loss, out_activations, target, &mut loss_grad[..out_size])
.map_err(map_loss_error)?;
let LayerSpec::Dense(last_dense) = layer_specs[layer_count - 1];
let output_activation = last_dense.activation;
for i in 0..out_size {
let deriv = output_activation.derivative_from_output(out_activations[i]);
out_deltas[i] = loss_grad[i] * deriv;
}
for rev in 1..layer_count {
let curr_idx = layer_count - 1 - rev;
let LayerSpec::Dense(curr_spec) = layer_specs[curr_idx];
let LayerSpec::Dense(next_spec) = layer_specs[curr_idx + 1];
let curr_off = layer_offsets[curr_idx + 1];
let next_off = layer_offsets[curr_idx + 2];
let curr_out_size = curr_spec.output_size;
let next_out_size = next_spec.output_size;
let (left_d, right_d) = scratch.deltas_scratch.split_at_mut(next_off);
let curr_acts = &scratch.activations_scratch[curr_off..curr_off + curr_out_size];
let next_deltas = &right_d[..next_out_size];
let curr_deltas = &mut left_d[curr_off..curr_off + curr_out_size];
let next_weights_len = next_spec
.input_size
.checked_mul(next_spec.output_size)
.ok_or(TrainError::InvalidShape)?;
let next_weights = &weights[next_spec.weight_offset..next_spec.weight_offset + next_weights_len];
for i in 0..curr_out_size {
let mut sum = 0.0f32;
for o in 0..next_out_size {
let w = next_weights[o * curr_out_size + i];
sum += w * next_deltas[o];
}
let deriv = curr_spec.activation.derivative_from_output(curr_acts[i]);
curr_deltas[i] = sum * deriv;
}
}
for (layer_idx, spec) in layer_specs.iter().enumerate() {
let LayerSpec::Dense(dense) = *spec;
let prev_off = layer_offsets[layer_idx];
let curr_off = layer_offsets[layer_idx + 1];
let prev = &scratch.activations_scratch[prev_off..prev_off + dense.input_size];
let curr_delta = &scratch.deltas_scratch[curr_off..curr_off + dense.output_size];
let w_len = dense
.input_size
.checked_mul(dense.output_size)
.ok_or(TrainError::InvalidShape)?;
let w = &mut weights[dense.weight_offset..dense.weight_offset + w_len];
let b = &mut biases[dense.bias_offset..dense.bias_offset + dense.output_size];
let args = ApplySgdArgs {
weights: w,
biases: b,
prev_activation: prev,
delta: curr_delta,
in_size: dense.input_size,
out_size: dense.output_size,
learning_rate: config.learning_rate,
clip: config.gradient_clip,
};
apply_sgd_update(args);
}
Ok(loss)
}
fn forward_dense_one(
input: &[f32],
output: &mut [f32],
weights: &[f32],
biases: &[f32],
in_size: usize,
out_size: usize,
activation: ActivationKind,
) -> Result<(), TrainError> {
for o in 0..out_size {
let row = o * in_size;
let mut acc = biases[o];
for i in 0..in_size {
acc += weights[row + i] * input[i];
}
let y = activation.apply(acc);
if !y.is_finite() {
return Err(TrainError::ForwardNaN);
}
output[o] = y;
}
Ok(())
}
pub struct ApplySgdArgs<'a> {
pub weights: &'a mut [f32],
pub biases: &'a mut [f32],
pub prev_activation: &'a [f32],
pub delta: &'a [f32],
pub in_size: usize,
pub out_size: usize,
pub learning_rate: f32,
pub clip: Option<f32>,
}
fn apply_sgd_update(args: ApplySgdArgs) {
let ApplySgdArgs {
weights,
biases,
prev_activation,
delta,
in_size,
out_size,
learning_rate,
clip,
} = args;
for o in 0..out_size {
let mut grad_b = delta[o];
if let Some(limit) = clip {
grad_b = clamp(grad_b, -limit, limit);
}
biases[o] -= learning_rate * grad_b;
let row = o * in_size;
for i in 0..in_size {
let mut grad = delta[o] * prev_activation[i];
if let Some(limit) = clip {
grad = clamp(grad, -limit, limit);
}
weights[row + i] -= learning_rate * grad;
}
}
}
fn clamp(v: f32, min: f32, max: f32) -> f32 {
if v < min {
min
} else if v > max {
max
} else {
v
}
}
fn map_layer_error(err: crate::layers::LayerError) -> TrainError {
match err {
crate::layers::LayerError::EmptyPlan => TrainError::CountMismatch,
crate::layers::LayerError::InvalidShape => TrainError::CountMismatch,
crate::layers::LayerError::InvalidRange => TrainError::CountMismatch,
crate::layers::LayerError::IncompatibleChain => TrainError::CountMismatch,
crate::layers::LayerError::BufferTooSmall => TrainError::CountMismatch,
crate::layers::LayerError::CountMismatch => TrainError::CountMismatch,
}
}
fn map_loss_error(err: LossError) -> TrainError {
match err {
LossError::Empty => TrainError::LossError,
LossError::ShapeMismatch => TrainError::LossError,
LossError::NonFinite => TrainError::LossError,
}
}