Skip to main content

rnn/trainer/
trainer.rs

1use crate::activations::ActivationKind;
2use crate::layers::{build_dense_specs_from_layers, LayerSpec};
3use crate::losses::{loss_and_gradient, LossError, LossKind};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
6pub struct DenseSgdConfig {
7    pub learning_rate: f32,
8    pub hidden_activation: ActivationKind,
9    pub output_activation: ActivationKind,
10    pub loss: LossKind,
11    pub gradient_clip: Option<f32>,
12}
13
14impl DenseSgdConfig {
15    pub const fn new(
16        learning_rate: f32,
17        hidden_activation: ActivationKind,
18        output_activation: ActivationKind,
19        loss: LossKind,
20    ) -> Self {
21        Self {
22            learning_rate,
23            hidden_activation,
24            output_activation,
25            loss,
26            gradient_clip: None,
27        }
28    }
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum TrainError {
33    InvalidShape,
34    InvalidConfig,
35    CountMismatch,
36    BufferTooSmall,
37    ForwardNaN,
38    LossError,
39}
40
41pub fn required_train_buffer_len(layers: &[usize]) -> Option<usize> {
42    if layers.is_empty() {
43        return None;
44    }
45    let mut total = 0usize;
46    for &size in layers {
47        total = total.checked_add(size)?;
48    }
49    Some(total)
50}
51
52pub fn dense_sgd_step(
53    layers: &[usize],
54    weights: &mut [f32],
55    biases: &mut [f32],
56    input: &[f32],
57    target: &[f32],
58    layer_specs_scratch: &mut [LayerSpec],
59    activations_scratch: &mut [f32],
60    deltas_scratch: &mut [f32],
61    config: DenseSgdConfig,
62) -> Result<f32, TrainError> {
63    if layers.len() < 2 {
64        return Err(TrainError::InvalidShape);
65    }
66    if !config.learning_rate.is_finite() || config.learning_rate <= 0.0 {
67        return Err(TrainError::InvalidConfig);
68    }
69    if input.len() != layers[0] || target.len() != layers[layers.len() - 1] {
70        return Err(TrainError::InvalidShape);
71    }
72
73    let layer_count = build_dense_specs_from_layers(
74        layers,
75        config.hidden_activation,
76        config.output_activation,
77        weights.len(),
78        biases.len(),
79        layer_specs_scratch,
80    )
81    .map_err(map_layer_error)?;
82
83    let layer_specs = &layer_specs_scratch[..layer_count];
84
85    let mut layer_offsets = [0usize; 128];
86    if layers.len() > layer_offsets.len() {
87        return Err(TrainError::BufferTooSmall);
88    }
89    let mut running = 0usize;
90    for (i, &size) in layers.iter().enumerate() {
91        layer_offsets[i] = running;
92        running = running.checked_add(size).ok_or(TrainError::BufferTooSmall)?;
93    }
94
95    let required = running;
96    if activations_scratch.len() < required || deltas_scratch.len() < required {
97        return Err(TrainError::BufferTooSmall);
98    }
99
100    activations_scratch[..layers[0]].copy_from_slice(input);
101
102    for (layer_idx, spec) in layer_specs.iter().enumerate() {
103        let dense = match spec {
104            LayerSpec::Dense(d) => d,
105        };
106
107        let prev_off = layer_offsets[layer_idx];
108        let curr_off = layer_offsets[layer_idx + 1];
109        // split to obtain non-overlapping mutable slices, then borrow as needed
110        let (left, right) = activations_scratch.split_at_mut(curr_off);
111        let prev = &left[prev_off..prev_off + dense.input_size];
112        let curr = &mut right[..dense.output_size];
113
114        let w_len = dense
115            .input_size
116            .checked_mul(dense.output_size)
117            .ok_or(TrainError::InvalidShape)?;
118        let w = &weights[dense.weight_offset..dense.weight_offset + w_len];
119        let b = &biases[dense.bias_offset..dense.bias_offset + dense.output_size];
120
121        forward_dense_one(prev, curr, w, b, dense.input_size, dense.output_size, dense.activation)?;
122    }
123
124    let out_idx = layers.len() - 1;
125    let out_off = layer_offsets[out_idx];
126    let out_size = layers[out_idx];
127
128    let out_activations = &activations_scratch[out_off..out_off + out_size];
129    let out_deltas = &mut deltas_scratch[out_off..out_off + out_size];
130
131    let mut loss_grad = [0.0f32; 4096];
132    if out_size > loss_grad.len() {
133        return Err(TrainError::BufferTooSmall);
134    }
135    let loss = loss_and_gradient(config.loss, out_activations, target, &mut loss_grad[..out_size])
136        .map_err(map_loss_error)?;
137
138    let output_activation = match layer_specs[layer_count - 1] {
139        LayerSpec::Dense(d) => d.activation,
140    };
141
142    for i in 0..out_size {
143        let deriv = output_activation.derivative_from_output(out_activations[i]);
144        out_deltas[i] = loss_grad[i] * deriv;
145    }
146
147    for rev in 1..layer_count {
148        let curr_idx = layer_count - 1 - rev;
149        let curr_spec = match layer_specs[curr_idx] {
150            LayerSpec::Dense(d) => d,
151        };
152        let next_spec = match layer_specs[curr_idx + 1] {
153            LayerSpec::Dense(d) => d,
154        };
155
156        let curr_off = layer_offsets[curr_idx + 1];
157        let next_off = layer_offsets[curr_idx + 2];
158
159        let curr_out_size = curr_spec.output_size;
160        let next_out_size = next_spec.output_size;
161
162        // split deltas to get non-overlapping mutable slices
163        let (left_d, right_d) = deltas_scratch.split_at_mut(next_off);
164        let curr_acts = &activations_scratch[curr_off..curr_off + curr_out_size];
165        let next_deltas = &right_d[..next_out_size];
166        let curr_deltas = &mut left_d[curr_off..curr_off + curr_out_size];
167
168        let next_weights_len = next_spec
169            .input_size
170            .checked_mul(next_spec.output_size)
171            .ok_or(TrainError::InvalidShape)?;
172        let next_weights = &weights[next_spec.weight_offset..next_spec.weight_offset + next_weights_len];
173
174        for i in 0..curr_out_size {
175            let mut sum = 0.0f32;
176            for o in 0..next_out_size {
177                let w = next_weights[o * curr_out_size + i];
178                sum += w * next_deltas[o];
179            }
180            let deriv = curr_spec.activation.derivative_from_output(curr_acts[i]);
181            curr_deltas[i] = sum * deriv;
182        }
183    }
184
185    for (layer_idx, spec) in layer_specs.iter().enumerate() {
186        let dense = match spec {
187            LayerSpec::Dense(d) => d,
188        };
189
190        let prev_off = layer_offsets[layer_idx];
191        let curr_off = layer_offsets[layer_idx + 1];
192        let prev = &activations_scratch[prev_off..prev_off + dense.input_size];
193        let curr_delta = &deltas_scratch[curr_off..curr_off + dense.output_size];
194
195        let w_len = dense
196            .input_size
197            .checked_mul(dense.output_size)
198            .ok_or(TrainError::InvalidShape)?;
199        let w = &mut weights[dense.weight_offset..dense.weight_offset + w_len];
200        let b = &mut biases[dense.bias_offset..dense.bias_offset + dense.output_size];
201
202        apply_sgd_update(
203            w,
204            b,
205            prev,
206            curr_delta,
207            dense.input_size,
208            dense.output_size,
209            config.learning_rate,
210            config.gradient_clip,
211        );
212    }
213
214    Ok(loss)
215}
216
217fn forward_dense_one(
218    input: &[f32],
219    output: &mut [f32],
220    weights: &[f32],
221    biases: &[f32],
222    in_size: usize,
223    out_size: usize,
224    activation: ActivationKind,
225) -> Result<(), TrainError> {
226    for o in 0..out_size {
227        let row = o * in_size;
228        let mut acc = biases[o];
229        for i in 0..in_size {
230            acc += weights[row + i] * input[i];
231        }
232        let y = activation.apply(acc);
233        if !y.is_finite() {
234            return Err(TrainError::ForwardNaN);
235        }
236        output[o] = y;
237    }
238    Ok(())
239}
240
241fn apply_sgd_update(
242    weights: &mut [f32],
243    biases: &mut [f32],
244    prev_activation: &[f32],
245    delta: &[f32],
246    in_size: usize,
247    out_size: usize,
248    learning_rate: f32,
249    clip: Option<f32>,
250) {
251    for o in 0..out_size {
252        let mut grad_b = delta[o];
253        if let Some(limit) = clip {
254            grad_b = clamp(grad_b, -limit, limit);
255        }
256        biases[o] -= learning_rate * grad_b;
257
258        let row = o * in_size;
259        for i in 0..in_size {
260            let mut grad = delta[o] * prev_activation[i];
261            if let Some(limit) = clip {
262                grad = clamp(grad, -limit, limit);
263            }
264            weights[row + i] -= learning_rate * grad;
265        }
266    }
267}
268
269fn clamp(v: f32, min: f32, max: f32) -> f32 {
270    if v < min {
271        min
272    } else if v > max {
273        max
274    } else {
275        v
276    }
277}
278
279fn map_layer_error(_err: crate::layers::LayerError) -> TrainError {
280    TrainError::CountMismatch
281}
282
283fn map_loss_error(_err: LossError) -> TrainError {
284    TrainError::LossError
285}