1use crate::activations::ActivationKind;
2use crate::layers::{build_dense_specs_from_layers, LayerSpec};
3use crate::losses::{loss_and_gradient, LossError, LossKind};
4
5#[derive(Clone, Copy, Debug, PartialEq)]
6pub struct DenseSgdConfig {
7 pub learning_rate: f32,
8 pub hidden_activation: ActivationKind,
9 pub output_activation: ActivationKind,
10 pub loss: LossKind,
11 pub gradient_clip: Option<f32>,
12}
13
14impl DenseSgdConfig {
15 pub const fn new(
16 learning_rate: f32,
17 hidden_activation: ActivationKind,
18 output_activation: ActivationKind,
19 loss: LossKind,
20 ) -> Self {
21 Self {
22 learning_rate,
23 hidden_activation,
24 output_activation,
25 loss,
26 gradient_clip: None,
27 }
28 }
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum TrainError {
33 InvalidShape,
34 InvalidConfig,
35 CountMismatch,
36 BufferTooSmall,
37 ForwardNaN,
38 LossError,
39}
40
41pub fn required_train_buffer_len(layers: &[usize]) -> Option<usize> {
42 if layers.is_empty() {
43 return None;
44 }
45 let mut total = 0usize;
46 for &size in layers {
47 total = total.checked_add(size)?;
48 }
49 Some(total)
50}
51
52pub fn dense_sgd_step(
53 layers: &[usize],
54 weights: &mut [f32],
55 biases: &mut [f32],
56 input: &[f32],
57 target: &[f32],
58 layer_specs_scratch: &mut [LayerSpec],
59 activations_scratch: &mut [f32],
60 deltas_scratch: &mut [f32],
61 config: DenseSgdConfig,
62) -> Result<f32, TrainError> {
63 if layers.len() < 2 {
64 return Err(TrainError::InvalidShape);
65 }
66 if !config.learning_rate.is_finite() || config.learning_rate <= 0.0 {
67 return Err(TrainError::InvalidConfig);
68 }
69 if input.len() != layers[0] || target.len() != layers[layers.len() - 1] {
70 return Err(TrainError::InvalidShape);
71 }
72
73 let layer_count = build_dense_specs_from_layers(
74 layers,
75 config.hidden_activation,
76 config.output_activation,
77 weights.len(),
78 biases.len(),
79 layer_specs_scratch,
80 )
81 .map_err(map_layer_error)?;
82
83 let layer_specs = &layer_specs_scratch[..layer_count];
84
85 let mut layer_offsets = [0usize; 128];
86 if layers.len() > layer_offsets.len() {
87 return Err(TrainError::BufferTooSmall);
88 }
89 let mut running = 0usize;
90 for (i, &size) in layers.iter().enumerate() {
91 layer_offsets[i] = running;
92 running = running.checked_add(size).ok_or(TrainError::BufferTooSmall)?;
93 }
94
95 let required = running;
96 if activations_scratch.len() < required || deltas_scratch.len() < required {
97 return Err(TrainError::BufferTooSmall);
98 }
99
100 activations_scratch[..layers[0]].copy_from_slice(input);
101
102 for (layer_idx, spec) in layer_specs.iter().enumerate() {
103 let dense = match spec {
104 LayerSpec::Dense(d) => d,
105 };
106
107 let prev_off = layer_offsets[layer_idx];
108 let curr_off = layer_offsets[layer_idx + 1];
109 let (left, right) = activations_scratch.split_at_mut(curr_off);
111 let prev = &left[prev_off..prev_off + dense.input_size];
112 let curr = &mut right[..dense.output_size];
113
114 let w_len = dense
115 .input_size
116 .checked_mul(dense.output_size)
117 .ok_or(TrainError::InvalidShape)?;
118 let w = &weights[dense.weight_offset..dense.weight_offset + w_len];
119 let b = &biases[dense.bias_offset..dense.bias_offset + dense.output_size];
120
121 forward_dense_one(prev, curr, w, b, dense.input_size, dense.output_size, dense.activation)?;
122 }
123
124 let out_idx = layers.len() - 1;
125 let out_off = layer_offsets[out_idx];
126 let out_size = layers[out_idx];
127
128 let out_activations = &activations_scratch[out_off..out_off + out_size];
129 let out_deltas = &mut deltas_scratch[out_off..out_off + out_size];
130
131 let mut loss_grad = [0.0f32; 4096];
132 if out_size > loss_grad.len() {
133 return Err(TrainError::BufferTooSmall);
134 }
135 let loss = loss_and_gradient(config.loss, out_activations, target, &mut loss_grad[..out_size])
136 .map_err(map_loss_error)?;
137
138 let output_activation = match layer_specs[layer_count - 1] {
139 LayerSpec::Dense(d) => d.activation,
140 };
141
142 for i in 0..out_size {
143 let deriv = output_activation.derivative_from_output(out_activations[i]);
144 out_deltas[i] = loss_grad[i] * deriv;
145 }
146
147 for rev in 1..layer_count {
148 let curr_idx = layer_count - 1 - rev;
149 let curr_spec = match layer_specs[curr_idx] {
150 LayerSpec::Dense(d) => d,
151 };
152 let next_spec = match layer_specs[curr_idx + 1] {
153 LayerSpec::Dense(d) => d,
154 };
155
156 let curr_off = layer_offsets[curr_idx + 1];
157 let next_off = layer_offsets[curr_idx + 2];
158
159 let curr_out_size = curr_spec.output_size;
160 let next_out_size = next_spec.output_size;
161
162 let (left_d, right_d) = deltas_scratch.split_at_mut(next_off);
164 let curr_acts = &activations_scratch[curr_off..curr_off + curr_out_size];
165 let next_deltas = &right_d[..next_out_size];
166 let curr_deltas = &mut left_d[curr_off..curr_off + curr_out_size];
167
168 let next_weights_len = next_spec
169 .input_size
170 .checked_mul(next_spec.output_size)
171 .ok_or(TrainError::InvalidShape)?;
172 let next_weights = &weights[next_spec.weight_offset..next_spec.weight_offset + next_weights_len];
173
174 for i in 0..curr_out_size {
175 let mut sum = 0.0f32;
176 for o in 0..next_out_size {
177 let w = next_weights[o * curr_out_size + i];
178 sum += w * next_deltas[o];
179 }
180 let deriv = curr_spec.activation.derivative_from_output(curr_acts[i]);
181 curr_deltas[i] = sum * deriv;
182 }
183 }
184
185 for (layer_idx, spec) in layer_specs.iter().enumerate() {
186 let dense = match spec {
187 LayerSpec::Dense(d) => d,
188 };
189
190 let prev_off = layer_offsets[layer_idx];
191 let curr_off = layer_offsets[layer_idx + 1];
192 let prev = &activations_scratch[prev_off..prev_off + dense.input_size];
193 let curr_delta = &deltas_scratch[curr_off..curr_off + dense.output_size];
194
195 let w_len = dense
196 .input_size
197 .checked_mul(dense.output_size)
198 .ok_or(TrainError::InvalidShape)?;
199 let w = &mut weights[dense.weight_offset..dense.weight_offset + w_len];
200 let b = &mut biases[dense.bias_offset..dense.bias_offset + dense.output_size];
201
202 apply_sgd_update(
203 w,
204 b,
205 prev,
206 curr_delta,
207 dense.input_size,
208 dense.output_size,
209 config.learning_rate,
210 config.gradient_clip,
211 );
212 }
213
214 Ok(loss)
215}
216
217fn forward_dense_one(
218 input: &[f32],
219 output: &mut [f32],
220 weights: &[f32],
221 biases: &[f32],
222 in_size: usize,
223 out_size: usize,
224 activation: ActivationKind,
225) -> Result<(), TrainError> {
226 for o in 0..out_size {
227 let row = o * in_size;
228 let mut acc = biases[o];
229 for i in 0..in_size {
230 acc += weights[row + i] * input[i];
231 }
232 let y = activation.apply(acc);
233 if !y.is_finite() {
234 return Err(TrainError::ForwardNaN);
235 }
236 output[o] = y;
237 }
238 Ok(())
239}
240
241fn apply_sgd_update(
242 weights: &mut [f32],
243 biases: &mut [f32],
244 prev_activation: &[f32],
245 delta: &[f32],
246 in_size: usize,
247 out_size: usize,
248 learning_rate: f32,
249 clip: Option<f32>,
250) {
251 for o in 0..out_size {
252 let mut grad_b = delta[o];
253 if let Some(limit) = clip {
254 grad_b = clamp(grad_b, -limit, limit);
255 }
256 biases[o] -= learning_rate * grad_b;
257
258 let row = o * in_size;
259 for i in 0..in_size {
260 let mut grad = delta[o] * prev_activation[i];
261 if let Some(limit) = clip {
262 grad = clamp(grad, -limit, limit);
263 }
264 weights[row + i] -= learning_rate * grad;
265 }
266 }
267}
268
269fn clamp(v: f32, min: f32, max: f32) -> f32 {
270 if v < min {
271 min
272 } else if v > max {
273 max
274 } else {
275 v
276 }
277}
278
279fn map_layer_error(_err: crate::layers::LayerError) -> TrainError {
280 TrainError::CountMismatch
281}
282
283fn map_loss_error(_err: LossError) -> TrainError {
284 TrainError::LossError
285}