Skip to main content

tml_utils/
network.rs

1use crate::conv::{Conv, conv_out_dim};
2use crate::{ConvGeometryIsValid, Float, Sample};
3use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
4use std::fmt;
5
6pub trait Initializer {
7    fn fill<R: Rng + ?Sized>(
8        &self,
9        values: &mut [Float],
10        fan_in: usize,
11        fan_out: usize,
12        rng: &mut R,
13    );
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct Uniform {
18    pub low: Float,
19    pub high: Float,
20}
21
22impl Uniform {
23    pub const fn new(low: Float, high: Float) -> Self {
24        Self { low, high }
25    }
26}
27
28impl Initializer for Uniform {
29    fn fill<R: Rng + ?Sized>(
30        &self,
31        values: &mut [Float],
32        _fan_in: usize,
33        _fan_out: usize,
34        rng: &mut R,
35    ) {
36        for value in values {
37            *value = rng.random_range(self.low..self.high);
38        }
39    }
40}
41
42#[derive(Debug, Clone, Copy, Default)]
43pub struct XavierUniform;
44
45impl Initializer for XavierUniform {
46    fn fill<R: Rng + ?Sized>(
47        &self,
48        values: &mut [Float],
49        fan_in: usize,
50        fan_out: usize,
51        rng: &mut R,
52    ) {
53        let denom = (fan_in + fan_out).max(1) as Float;
54        let bound = (6.0 / denom).sqrt();
55        Uniform::new(-bound, bound).fill(values, fan_in, fan_out, rng);
56    }
57}
58
59#[derive(Debug, Clone, Copy, Default)]
60pub struct KaimingUniform;
61
62impl Initializer for KaimingUniform {
63    fn fill<R: Rng + ?Sized>(
64        &self,
65        values: &mut [Float],
66        fan_in: usize,
67        fan_out: usize,
68        rng: &mut R,
69    ) {
70        let denom = fan_in.max(1) as Float;
71        let bound = (6.0 / denom).sqrt();
72        Uniform::new(-bound, bound).fill(values, fan_in, fan_out, rng);
73    }
74}
75
76pub trait Optimizer: fmt::Debug {
77    fn begin_step(&mut self) {}
78    fn update_parameter(
79        &mut self,
80        slot: usize,
81        params: &mut [Float],
82        grads: &[Float],
83        scale: Float,
84    );
85}
86
87#[derive(Debug, Clone, Copy)]
88pub struct Sgd {
89    pub lr: Float,
90    pub weight_decay: Float,
91}
92
93impl Sgd {
94    pub const fn new(lr: Float) -> Self {
95        Self {
96            lr,
97            weight_decay: 0.0,
98        }
99    }
100
101    pub const fn with_weight_decay(mut self, weight_decay: Float) -> Self {
102        self.weight_decay = weight_decay;
103        self
104    }
105}
106
107impl Optimizer for Sgd {
108    fn update_parameter(
109        &mut self,
110        _slot: usize,
111        params: &mut [Float],
112        grads: &[Float],
113        scale: Float,
114    ) {
115        for (param, grad) in params.iter_mut().zip(grads.iter()) {
116            let update = *grad * scale + self.weight_decay * *param;
117            *param -= self.lr * update;
118        }
119    }
120}
121
122#[derive(Debug, Clone)]
123pub struct Adam {
124    pub lr: Float,
125    pub beta1: Float,
126    pub beta2: Float,
127    pub epsilon: Float,
128    pub weight_decay: Float,
129    step: usize,
130    first_moment: Vec<Box<[Float]>>,
131    second_moment: Vec<Box<[Float]>>,
132}
133
134impl Adam {
135    pub fn new(lr: Float) -> Self {
136        Self {
137            lr,
138            beta1: 0.9,
139            beta2: 0.999,
140            epsilon: 1e-8,
141            weight_decay: 0.0,
142            step: 0,
143            first_moment: Vec::new(),
144            second_moment: Vec::new(),
145        }
146    }
147
148    pub const fn with_weight_decay(mut self, weight_decay: Float) -> Self {
149        self.weight_decay = weight_decay;
150        self
151    }
152
153    fn ensure_slot(&mut self, slot: usize, len: usize) {
154        while self.first_moment.len() <= slot {
155            self.first_moment.push(Vec::new().into_boxed_slice());
156            self.second_moment.push(Vec::new().into_boxed_slice());
157        }
158        if self.first_moment[slot].len() != len {
159            self.first_moment[slot] = vec![0.0; len].into_boxed_slice();
160            self.second_moment[slot] = vec![0.0; len].into_boxed_slice();
161        }
162    }
163}
164
165impl Optimizer for Adam {
166    fn begin_step(&mut self) {
167        self.step += 1;
168    }
169
170    fn update_parameter(
171        &mut self,
172        slot: usize,
173        params: &mut [Float],
174        grads: &[Float],
175        scale: Float,
176    ) {
177        self.ensure_slot(slot, params.len());
178        let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
179        let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
180
181        let first = &mut self.first_moment[slot];
182        let second = &mut self.second_moment[slot];
183        for i in 0..params.len() {
184            let grad = grads[i] * scale + self.weight_decay * params[i];
185            first[i] = self.beta1 * first[i] + (1.0 - self.beta1) * grad;
186            second[i] = self.beta2 * second[i] + (1.0 - self.beta2) * grad * grad;
187            let m_hat = first[i] / bias_correction1.max(f64::EPSILON);
188            let v_hat = second[i] / bias_correction2.max(f64::EPSILON);
189            params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
190        }
191    }
192}
193
194pub struct TrainConfig {
195    optimizer: Box<dyn Optimizer>,
196    pub epochs: usize,
197    pub batch_size: usize,
198    pub shuffle_seed: Option<u64>,
199}
200
201impl fmt::Debug for TrainConfig {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        f.debug_struct("TrainConfig")
204            .field("optimizer", &self.optimizer)
205            .field("epochs", &self.epochs)
206            .field("batch_size", &self.batch_size)
207            .field("shuffle_seed", &self.shuffle_seed)
208            .finish()
209    }
210}
211
212impl TrainConfig {
213    pub fn new<O: Optimizer + 'static>(optimizer: O) -> Self {
214        Self {
215            optimizer: Box::new(optimizer),
216            epochs: 1,
217            batch_size: 1,
218            shuffle_seed: None,
219        }
220    }
221
222    pub fn sgd(lr: Float) -> Self {
223        Self::new(Sgd::new(lr))
224    }
225
226    pub fn adam(lr: Float) -> Self {
227        Self::new(Adam::new(lr))
228    }
229
230    pub fn epochs(mut self, epochs: usize) -> Self {
231        self.epochs = epochs;
232        self
233    }
234
235    pub fn batch_size(mut self, batch_size: usize) -> Self {
236        self.batch_size = batch_size.max(1);
237        self
238    }
239
240    pub fn shuffle_seed(mut self, shuffle_seed: u64) -> Self {
241        self.shuffle_seed = Some(shuffle_seed);
242        self
243    }
244
245    fn optimizer_mut(&mut self) -> &mut dyn Optimizer {
246        self.optimizer.as_mut()
247    }
248}
249
250impl Default for TrainConfig {
251    fn default() -> Self {
252        Self::adam(1e-3)
253    }
254}
255
256pub trait LossFunction<const N: usize>: fmt::Debug {
257    fn loss_and_grad(
258        &self,
259        output: &[Float; N],
260        target: &[Float; N],
261        grad: &mut [Float; N],
262    ) -> Float;
263}
264
265#[derive(Debug, Clone, Copy, Default)]
266pub struct MeanSquaredError;
267
268pub fn mse_loss<const N: usize>(
269    output: &[Float; N],
270    target: &[Float; N],
271    grad: &mut [Float; N],
272) -> Float {
273    let scale = 2.0 / N as Float;
274    let loss = output
275        .iter()
276        .zip(target.iter())
277        .zip(grad.iter_mut())
278        .map(|((&o, &t), g)| {
279            let diff = o - t;
280            *g = diff * scale;
281            diff * diff
282        })
283        .sum::<Float>();
284    loss / N as Float
285}
286
287impl<const N: usize> LossFunction<N> for MeanSquaredError {
288    fn loss_and_grad(
289        &self,
290        output: &[Float; N],
291        target: &[Float; N],
292        grad: &mut [Float; N],
293    ) -> Float {
294        mse_loss(output, target, grad)
295    }
296}
297
298pub trait Layer<const IN: usize, const OUT: usize> {
299    fn forward(&self, input: &[Float; IN], output: &mut [Float; OUT]);
300    fn backward(
301        &mut self,
302        input: &[Float; IN],
303        output: &[Float; OUT],
304        output_grad: &[Float; OUT],
305        input_grad: &mut [Float; IN],
306    );
307
308    fn zero_grad(&mut self) {}
309
310    fn apply_gradients(
311        &mut self,
312        _optimizer: &mut dyn Optimizer,
313        _slot: &mut usize,
314        _scale: Float,
315    ) {
316    }
317}
318
319pub trait LayerDims {
320    const INPUT: usize;
321    const OUTPUT: usize;
322}
323
324#[derive(Debug)]
325pub struct DenseLayer<const IN: usize, const OUT: usize> {
326    weights: Box<[Float]>,
327    biases: Box<[Float; OUT]>,
328    weight_grads: Box<[Float]>,
329    bias_grads: Box<[Float; OUT]>,
330}
331
332#[derive(Debug)]
333pub struct ReLU<const N: usize>;
334
335#[derive(Debug)]
336pub struct Sigmoid<const N: usize>;
337
338#[derive(Debug)]
339pub struct Flatten<const N: usize>;
340
341impl<const IN: usize, const OUT: usize> DenseLayer<IN, OUT> {
342    pub fn init() -> Self {
343        Self::with_initializer(XavierUniform)
344    }
345
346    pub fn seeded(seed: u64) -> Self {
347        Self::with_initializer_and_seed(XavierUniform, seed)
348    }
349
350    pub fn with_initializer<I: Initializer>(initializer: I) -> Self {
351        let mut rng = rand::rng();
352        Self::with_initializer_and_rng(initializer, &mut rng)
353    }
354
355    pub fn with_initializer_and_seed<I: Initializer>(initializer: I, seed: u64) -> Self {
356        let mut rng = StdRng::seed_from_u64(seed);
357        Self::with_initializer_and_rng(initializer, &mut rng)
358    }
359
360    pub fn with_initializer_and_rng<I: Initializer, R: Rng + ?Sized>(
361        initializer: I,
362        rng: &mut R,
363    ) -> Self {
364        let mut weights = vec![0.0; IN * OUT].into_boxed_slice();
365        initializer.fill(&mut weights, IN, OUT, rng);
366        Self {
367            weights,
368            biases: Box::new([0.0; OUT]),
369            weight_grads: vec![0.0; IN * OUT].into_boxed_slice(),
370            bias_grads: Box::new([0.0; OUT]),
371        }
372    }
373
374    pub fn forward(&self, input: &[Float; IN], output: &mut [Float; OUT]) {
375        for (o, out) in output.iter_mut().enumerate() {
376            let row = &self.weights[o * IN..(o + 1) * IN];
377            let mut sum = self.biases[o];
378            for (weight, inp) in row.iter().zip(input.iter()) {
379                sum += *weight * *inp;
380            }
381            *out = sum;
382        }
383    }
384
385    pub fn backward(
386        &mut self,
387        input: &[Float; IN],
388        _output: &[Float; OUT],
389        output_grad: &[Float; OUT],
390        input_grad: &mut [Float; IN],
391    ) {
392        input_grad.fill(0.0);
393
394        for (o, &grad) in output_grad.iter().enumerate() {
395            let row = &self.weights[o * IN..(o + 1) * IN];
396            for (in_grad, weight) in input_grad.iter_mut().zip(row.iter()) {
397                *in_grad += *weight * grad;
398            }
399        }
400
401        for (o, &grad) in output_grad.iter().enumerate() {
402            self.bias_grads[o] += grad;
403            let row_grads = &mut self.weight_grads[o * IN..(o + 1) * IN];
404            for (weight_grad, inp) in row_grads.iter_mut().zip(input.iter()) {
405                *weight_grad += grad * *inp;
406            }
407        }
408    }
409}
410
411impl<const IN: usize, const OUT: usize> LayerDims for DenseLayer<IN, OUT> {
412    const INPUT: usize = IN;
413    const OUTPUT: usize = OUT;
414}
415
416impl<const IN: usize, const OUT: usize> Layer<IN, OUT> for DenseLayer<IN, OUT> {
417    fn forward(&self, input: &[Float; IN], output: &mut [Float; OUT]) {
418        DenseLayer::forward(self, input, output);
419    }
420
421    fn backward(
422        &mut self,
423        input: &[Float; IN],
424        output: &[Float; OUT],
425        output_grad: &[Float; OUT],
426        input_grad: &mut [Float; IN],
427    ) {
428        DenseLayer::backward(self, input, output, output_grad, input_grad);
429    }
430
431    fn zero_grad(&mut self) {
432        self.weight_grads.fill(0.0);
433        self.bias_grads.fill(0.0);
434    }
435
436    fn apply_gradients(&mut self, optimizer: &mut dyn Optimizer, slot: &mut usize, scale: Float) {
437        optimizer.update_parameter(*slot, &mut self.weights, &self.weight_grads, scale);
438        *slot += 1;
439        optimizer.update_parameter(
440            *slot,
441            self.biases.as_mut_slice(),
442            self.bias_grads.as_slice(),
443            scale,
444        );
445        *slot += 1;
446        self.zero_grad();
447    }
448}
449
450impl<const N: usize> ReLU<N> {
451    pub fn init() -> Self {
452        ReLU
453    }
454
455    pub fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
456        for i in 0..N {
457            output[i] = input[i].max(0.0);
458        }
459    }
460
461    pub fn backward(
462        &self,
463        input: &[Float; N],
464        _output: &[Float; N],
465        output_grad: &[Float; N],
466        input_grad: &mut [Float; N],
467    ) {
468        for i in 0..N {
469            input_grad[i] = if input[i] > 0.0 { output_grad[i] } else { 0.0 };
470        }
471    }
472}
473
474impl<const N: usize> LayerDims for ReLU<N> {
475    const INPUT: usize = N;
476    const OUTPUT: usize = N;
477}
478
479impl<const N: usize> Layer<N, N> for ReLU<N> {
480    fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
481        ReLU::forward(self, input, output);
482    }
483
484    fn backward(
485        &mut self,
486        input: &[Float; N],
487        output: &[Float; N],
488        output_grad: &[Float; N],
489        input_grad: &mut [Float; N],
490    ) {
491        ReLU::backward(self, input, output, output_grad, input_grad);
492    }
493}
494
495impl<const N: usize> Sigmoid<N> {
496    pub fn init() -> Self {
497        Sigmoid
498    }
499
500    pub fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
501        for i in 0..N {
502            output[i] = 1.0 / (1.0 + (-input[i]).exp());
503        }
504    }
505
506    pub fn backward(
507        &self,
508        _input: &[Float; N],
509        output: &[Float; N],
510        output_grad: &[Float; N],
511        input_grad: &mut [Float; N],
512    ) {
513        for i in 0..N {
514            let y = output[i];
515            input_grad[i] = output_grad[i] * y * (1.0 - y);
516        }
517    }
518}
519
520impl<const N: usize> LayerDims for Sigmoid<N> {
521    const INPUT: usize = N;
522    const OUTPUT: usize = N;
523}
524
525impl<const N: usize> Layer<N, N> for Sigmoid<N> {
526    fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
527        Sigmoid::forward(self, input, output);
528    }
529
530    fn backward(
531        &mut self,
532        input: &[Float; N],
533        output: &[Float; N],
534        output_grad: &[Float; N],
535        input_grad: &mut [Float; N],
536    ) {
537        Sigmoid::backward(self, input, output, output_grad, input_grad);
538    }
539}
540
541impl<const N: usize> Flatten<N> {
542    pub fn init() -> Self {
543        Flatten
544    }
545
546    pub fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
547        output.copy_from_slice(input);
548    }
549
550    pub fn backward(
551        &self,
552        _input: &[Float; N],
553        _output: &[Float; N],
554        output_grad: &[Float; N],
555        input_grad: &mut [Float; N],
556    ) {
557        input_grad.copy_from_slice(output_grad);
558    }
559}
560
561impl<const N: usize> LayerDims for Flatten<N> {
562    const INPUT: usize = N;
563    const OUTPUT: usize = N;
564}
565
566impl<const N: usize> Layer<N, N> for Flatten<N> {
567    fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
568        Flatten::forward(self, input, output);
569    }
570
571    fn backward(
572        &mut self,
573        input: &[Float; N],
574        output: &[Float; N],
575        output_grad: &[Float; N],
576        input_grad: &mut [Float; N],
577    ) {
578        Flatten::backward(self, input, output, output_grad, input_grad);
579    }
580}
581
582mod private {
583    use super::*;
584
585    #[derive(Debug, Clone, Copy, Default)]
586    pub struct End;
587
588    #[derive(Debug)]
589    pub struct Chain<Head, Tail, const MID: usize> {
590        pub(super) head: Head,
591        pub(super) tail: Tail,
592    }
593
594    impl<Head, Tail, const MID: usize> Chain<Head, Tail, MID> {
595        pub const fn new(head: Head, tail: Tail) -> Self {
596            Self { head, tail }
597        }
598    }
599
600    pub trait AppendLayer<Next, const NEXT_OUTPUT: usize>: Sized {
601        type Output;
602        fn then(self, next: Next) -> Self::Output;
603    }
604
605    impl<Next, const NEXT_OUTPUT: usize> AppendLayer<Next, NEXT_OUTPUT> for End {
606        type Output = Chain<Next, End, NEXT_OUTPUT>;
607
608        fn then(self, next: Next) -> Self::Output {
609            Chain::new(next, End)
610        }
611    }
612
613    impl<Head, Tail, const MID: usize, Next, const NEXT_OUTPUT: usize>
614        AppendLayer<Next, NEXT_OUTPUT> for Chain<Head, Tail, MID>
615    where
616        Tail: AppendLayer<Next, NEXT_OUTPUT>,
617    {
618        type Output = Chain<Head, <Tail as AppendLayer<Next, NEXT_OUTPUT>>::Output, MID>;
619
620        fn then(self, next: Next) -> Self::Output {
621            Chain::new(self.head, self.tail.then(next))
622        }
623    }
624
625    #[derive(Debug)]
626    pub struct TerminalWorkspace<const OUT: usize> {
627        activation: Box<[Float; OUT]>,
628        gradient: Box<[Float; OUT]>,
629    }
630
631    #[derive(Debug)]
632    pub struct ChainWorkspace<const MID: usize, TailWorkspace> {
633        activation: Box<[Float; MID]>,
634        gradient: Box<[Float; MID]>,
635        tail: TailWorkspace,
636    }
637
638    #[derive(Debug)]
639    pub struct StackWorkspace<BodyWorkspace, const INPUT: usize> {
640        body: BodyWorkspace,
641        input_grad: Box<[Float; INPUT]>,
642    }
643
644    pub trait ModuleChain<const INPUT: usize, const OUTPUT: usize> {
645        type Workspace;
646
647        fn workspace(&self) -> Self::Workspace;
648        fn forward_with_workspace(&self, input: &[Float; INPUT], workspace: &mut Self::Workspace);
649        fn output(workspace: &Self::Workspace) -> &[Float; OUTPUT];
650        fn set_output_grad(workspace: &mut Self::Workspace, grad: &[Float; OUTPUT]);
651        fn backward_with_workspace(
652            &mut self,
653            input: &[Float; INPUT],
654            input_grad: &mut [Float; INPUT],
655            workspace: &mut Self::Workspace,
656        );
657        fn zero_grad(&mut self);
658        fn apply_gradients(
659            &mut self,
660            optimizer: &mut dyn Optimizer,
661            slot: &mut usize,
662            scale: Float,
663        );
664    }
665
666    impl<Head, const INPUT: usize, const OUTPUT: usize> ModuleChain<INPUT, OUTPUT>
667        for Chain<Head, End, OUTPUT>
668    where
669        Head: Layer<INPUT, OUTPUT>,
670    {
671        type Workspace = TerminalWorkspace<OUTPUT>;
672
673        fn workspace(&self) -> Self::Workspace {
674            TerminalWorkspace {
675                activation: Box::new([0.0; OUTPUT]),
676                gradient: Box::new([0.0; OUTPUT]),
677            }
678        }
679
680        fn forward_with_workspace(&self, input: &[Float; INPUT], workspace: &mut Self::Workspace) {
681            self.head.forward(input, workspace.activation.as_mut());
682        }
683
684        fn output(workspace: &Self::Workspace) -> &[Float; OUTPUT] {
685            workspace.activation.as_ref()
686        }
687
688        fn set_output_grad(workspace: &mut Self::Workspace, grad: &[Float; OUTPUT]) {
689            workspace.gradient.copy_from_slice(grad);
690        }
691
692        fn backward_with_workspace(
693            &mut self,
694            input: &[Float; INPUT],
695            input_grad: &mut [Float; INPUT],
696            workspace: &mut Self::Workspace,
697        ) {
698            self.head.backward(
699                input,
700                workspace.activation.as_ref(),
701                workspace.gradient.as_ref(),
702                input_grad,
703            );
704        }
705
706        fn zero_grad(&mut self) {
707            self.head.zero_grad();
708        }
709
710        fn apply_gradients(
711            &mut self,
712            optimizer: &mut dyn Optimizer,
713            slot: &mut usize,
714            scale: Float,
715        ) {
716            self.head.apply_gradients(optimizer, slot, scale);
717        }
718    }
719
720    impl<Head, Tail, const INPUT: usize, const MID: usize, const OUTPUT: usize>
721        ModuleChain<INPUT, OUTPUT> for Chain<Head, Tail, MID>
722    where
723        Head: Layer<INPUT, MID>,
724        Tail: ModuleChain<MID, OUTPUT>,
725    {
726        type Workspace = ChainWorkspace<MID, Tail::Workspace>;
727
728        fn workspace(&self) -> Self::Workspace {
729            ChainWorkspace {
730                activation: Box::new([0.0; MID]),
731                gradient: Box::new([0.0; MID]),
732                tail: self.tail.workspace(),
733            }
734        }
735
736        fn forward_with_workspace(&self, input: &[Float; INPUT], workspace: &mut Self::Workspace) {
737            self.head.forward(input, workspace.activation.as_mut());
738            self.tail
739                .forward_with_workspace(workspace.activation.as_ref(), &mut workspace.tail);
740        }
741
742        fn output(workspace: &Self::Workspace) -> &[Float; OUTPUT] {
743            Tail::output(&workspace.tail)
744        }
745
746        fn set_output_grad(workspace: &mut Self::Workspace, grad: &[Float; OUTPUT]) {
747            Tail::set_output_grad(&mut workspace.tail, grad);
748        }
749
750        fn backward_with_workspace(
751            &mut self,
752            input: &[Float; INPUT],
753            input_grad: &mut [Float; INPUT],
754            workspace: &mut Self::Workspace,
755        ) {
756            self.tail.backward_with_workspace(
757                workspace.activation.as_ref(),
758                workspace.gradient.as_mut(),
759                &mut workspace.tail,
760            );
761            self.head.backward(
762                input,
763                workspace.activation.as_ref(),
764                workspace.gradient.as_ref(),
765                input_grad,
766            );
767        }
768
769        fn zero_grad(&mut self) {
770            self.head.zero_grad();
771            self.tail.zero_grad();
772        }
773
774        fn apply_gradients(
775            &mut self,
776            optimizer: &mut dyn Optimizer,
777            slot: &mut usize,
778            scale: Float,
779        ) {
780            self.head.apply_gradients(optimizer, slot, scale);
781            self.tail.apply_gradients(optimizer, slot, scale);
782        }
783    }
784
785    #[derive(Debug)]
786    pub struct Stack<Layers, const INPUT: usize, const OUTPUT: usize>
787    where
788        Layers: ModuleChain<INPUT, OUTPUT>,
789    {
790        layers: Layers,
791    }
792
793    impl<Layers, const INPUT: usize, const OUTPUT: usize> Stack<Layers, INPUT, OUTPUT>
794    where
795        Layers: ModuleChain<INPUT, OUTPUT>,
796    {
797        pub const fn new(layers: Layers) -> Self {
798            Self { layers }
799        }
800
801        pub fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
802            let mut workspace = StackWorkspace {
803                body: self.layers.workspace(),
804                input_grad: Box::new([0.0; INPUT]),
805            };
806            self.layers
807                .forward_with_workspace(input, &mut workspace.body);
808            let mut result = [0.0; OUTPUT];
809            result.copy_from_slice(Layers::output(&workspace.body));
810            result
811        }
812
813        pub fn fit_with_loss(
814            &mut self,
815            samples: &[Sample<INPUT, OUTPUT>],
816            loss_fn: &dyn LossFunction<OUTPUT>,
817            mut config: TrainConfig,
818        ) -> Float {
819            if samples.is_empty() || config.epochs == 0 {
820                return 0.0;
821            }
822
823            let batch_size = config.batch_size.max(1);
824            let mut workspace = StackWorkspace {
825                body: self.layers.workspace(),
826                input_grad: Box::new([0.0; INPUT]),
827            };
828            let mut order = (0..samples.len()).collect::<Vec<_>>();
829            let mut shuffler = config.shuffle_seed.map(StdRng::seed_from_u64);
830            let mut total_loss = 0.0;
831            let mut steps = 0usize;
832
833            for _ in 0..config.epochs {
834                if let Some(rng) = shuffler.as_mut() {
835                    order.shuffle(rng);
836                }
837
838                for batch in order.chunks(batch_size) {
839                    self.layers.zero_grad();
840                    let mut batch_loss = 0.0;
841
842                    for &sample_idx in batch {
843                        let sample = &samples[sample_idx];
844                        self.layers
845                            .forward_with_workspace(&sample.input, &mut workspace.body);
846                        let mut grad = [0.0; OUTPUT];
847                        let loss = loss_fn.loss_and_grad(
848                            Layers::output(&workspace.body),
849                            &sample.target,
850                            &mut grad,
851                        );
852                        Layers::set_output_grad(&mut workspace.body, &grad);
853                        self.layers.backward_with_workspace(
854                            &sample.input,
855                            workspace.input_grad.as_mut(),
856                            &mut workspace.body,
857                        );
858                        batch_loss += loss;
859                    }
860
861                    config.optimizer_mut().begin_step();
862                    let mut slot = 0usize;
863                    self.layers.apply_gradients(
864                        config.optimizer_mut(),
865                        &mut slot,
866                        1.0 / batch.len() as Float,
867                    );
868                    total_loss += batch_loss / batch.len() as Float;
869                    steps += 1;
870                }
871            }
872
873            total_loss / steps as Float
874        }
875    }
876
877    pub trait ModelRuntime<const INPUT: usize, const OUTPUT: usize>: fmt::Debug {
878        fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT];
879        fn fit_with_loss(
880            &mut self,
881            samples: &[Sample<INPUT, OUTPUT>],
882            loss_fn: &dyn LossFunction<OUTPUT>,
883            config: TrainConfig,
884        ) -> Float;
885    }
886
887    impl<Layers, const INPUT: usize, const OUTPUT: usize> ModelRuntime<INPUT, OUTPUT>
888        for Stack<Layers, INPUT, OUTPUT>
889    where
890        Layers: ModuleChain<INPUT, OUTPUT> + fmt::Debug + 'static,
891    {
892        fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
893            Stack::predict(self, input)
894        }
895
896        fn fit_with_loss(
897            &mut self,
898            samples: &[Sample<INPUT, OUTPUT>],
899            loss_fn: &dyn LossFunction<OUTPUT>,
900            config: TrainConfig,
901        ) -> Float {
902            Stack::fit_with_loss(self, samples, loss_fn, config)
903        }
904    }
905}
906
907pub struct Sequential<const INPUT: usize, const OUTPUT: usize> {
908    inner: Box<dyn private::ModelRuntime<INPUT, OUTPUT>>,
909}
910
911impl<const INPUT: usize, const OUTPUT: usize> fmt::Debug for Sequential<INPUT, OUTPUT> {
912    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
913        f.debug_struct("Sequential")
914            .field("input", &INPUT)
915            .field("output", &OUTPUT)
916            .finish()
917    }
918}
919
920impl<const INPUT: usize, const OUTPUT: usize> Sequential<INPUT, OUTPUT> {
921    fn from_runtime<R>(runtime: R) -> Self
922    where
923        R: private::ModelRuntime<INPUT, OUTPUT> + 'static,
924    {
925        Self {
926            inner: Box::new(runtime),
927        }
928    }
929
930    pub fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
931        self.inner.predict(input)
932    }
933
934    pub fn predict_in_place(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
935        self.predict(input)
936    }
937
938    pub fn fit(&mut self, samples: &[Sample<INPUT, OUTPUT>], config: TrainConfig) -> Float {
939        self.fit_with_loss(samples, &MeanSquaredError, config)
940    }
941
942    pub fn fit_with_loss(
943        &mut self,
944        samples: &[Sample<INPUT, OUTPUT>],
945        loss_fn: &dyn LossFunction<OUTPUT>,
946        config: TrainConfig,
947    ) -> Float {
948        self.inner.fit_with_loss(samples, loss_fn, config)
949    }
950}
951
952#[derive(Debug, Clone, Copy, Default)]
953pub struct ModelBuilder;
954
955impl ModelBuilder {
956    pub const fn new() -> Self {
957        Self
958    }
959
960    pub fn input<const N: usize>(self) -> VectorBuilder<private::End, N, N> {
961        VectorBuilder {
962            layers: private::End,
963        }
964    }
965
966    pub fn image_input<const C: usize, const H: usize, const W: usize>(
967        self,
968    ) -> ImageBuilder<private::End, { C * H * W }, C, H, W>
969    where
970        [(); C * H * W]:,
971    {
972        ImageBuilder {
973            layers: private::End,
974        }
975    }
976}
977
978pub struct VectorBuilder<Layers, const INPUT: usize, const CURRENT: usize> {
979    layers: Layers,
980}
981
982impl<Layers, const INPUT: usize, const CURRENT: usize> fmt::Debug
983    for VectorBuilder<Layers, INPUT, CURRENT>
984where
985    Layers: fmt::Debug,
986{
987    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
988        f.debug_struct("VectorBuilder")
989            .field("input", &INPUT)
990            .field("current", &CURRENT)
991            .finish()
992    }
993}
994
995impl<Layers, const INPUT: usize, const CURRENT: usize> VectorBuilder<Layers, INPUT, CURRENT> {
996    pub const fn flatten(self) -> Self {
997        self
998    }
999
1000    pub fn dense<const NEXT: usize>(
1001        self,
1002    ) -> VectorBuilder<
1003        <Layers as private::AppendLayer<DenseLayer<CURRENT, NEXT>, NEXT>>::Output,
1004        INPUT,
1005        NEXT,
1006    >
1007    where
1008        Layers: private::AppendLayer<DenseLayer<CURRENT, NEXT>, NEXT>,
1009    {
1010        VectorBuilder {
1011            layers: self.layers.then(DenseLayer::<CURRENT, NEXT>::init()),
1012        }
1013    }
1014
1015    pub fn relu(
1016        self,
1017    ) -> VectorBuilder<
1018        <Layers as private::AppendLayer<ReLU<CURRENT>, CURRENT>>::Output,
1019        INPUT,
1020        CURRENT,
1021    >
1022    where
1023        Layers: private::AppendLayer<ReLU<CURRENT>, CURRENT>,
1024    {
1025        VectorBuilder {
1026            layers: self.layers.then(ReLU::<CURRENT>::init()),
1027        }
1028    }
1029
1030    pub fn sigmoid(
1031        self,
1032    ) -> VectorBuilder<
1033        <Layers as private::AppendLayer<Sigmoid<CURRENT>, CURRENT>>::Output,
1034        INPUT,
1035        CURRENT,
1036    >
1037    where
1038        Layers: private::AppendLayer<Sigmoid<CURRENT>, CURRENT>,
1039    {
1040        VectorBuilder {
1041            layers: self.layers.then(Sigmoid::<CURRENT>::init()),
1042        }
1043    }
1044
1045    pub fn build(self) -> Sequential<INPUT, CURRENT>
1046    where
1047        Layers: private::ModuleChain<INPUT, CURRENT> + fmt::Debug + 'static,
1048    {
1049        Sequential::from_runtime(private::Stack::new(self.layers))
1050    }
1051}
1052
1053pub struct ImageBuilder<Layers, const INPUT: usize, const C: usize, const H: usize, const W: usize>
1054{
1055    layers: Layers,
1056}
1057
1058impl<Layers, const INPUT: usize, const C: usize, const H: usize, const W: usize> fmt::Debug
1059    for ImageBuilder<Layers, INPUT, C, H, W>
1060where
1061    Layers: fmt::Debug,
1062{
1063    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1064        f.debug_struct("ImageBuilder")
1065            .field("input", &INPUT)
1066            .field("channels", &C)
1067            .field("height", &H)
1068            .field("width", &W)
1069            .finish()
1070    }
1071}
1072
1073impl<Layers, const INPUT: usize, const C: usize, const H: usize, const W: usize>
1074    ImageBuilder<Layers, INPUT, C, H, W>
1075where
1076    [(); INPUT]:,
1077{
1078    pub fn relu(
1079        self,
1080    ) -> ImageBuilder<
1081        <Layers as private::AppendLayer<ReLU<{ C * H * W }>, { C * H * W }>>::Output,
1082        INPUT,
1083        C,
1084        H,
1085        W,
1086    >
1087    where
1088        [(); C * H * W]:,
1089        Layers: private::AppendLayer<ReLU<{ C * H * W }>, { C * H * W }>,
1090    {
1091        ImageBuilder {
1092            layers: self.layers.then(ReLU::<{ C * H * W }>::init()),
1093        }
1094    }
1095
1096    pub fn sigmoid(
1097        self,
1098    ) -> ImageBuilder<
1099        <Layers as private::AppendLayer<Sigmoid<{ C * H * W }>, { C * H * W }>>::Output,
1100        INPUT,
1101        C,
1102        H,
1103        W,
1104    >
1105    where
1106        [(); C * H * W]:,
1107        Layers: private::AppendLayer<Sigmoid<{ C * H * W }>, { C * H * W }>,
1108    {
1109        ImageBuilder {
1110            layers: self.layers.then(Sigmoid::<{ C * H * W }>::init()),
1111        }
1112    }
1113
1114    pub fn conv<const OC: usize, const FH: usize, const FW: usize, const S: usize, const P: usize>(
1115        self,
1116    ) -> ImageBuilder<
1117        <Layers as private::AppendLayer<
1118            Conv<W, H, C, FH, FW, OC, S, P>,
1119            { OC * conv_out_dim(H, P, FH, S) * conv_out_dim(W, P, FW, S) },
1120        >>::Output,
1121        INPUT,
1122        OC,
1123        { conv_out_dim(H, P, FH, S) },
1124        { conv_out_dim(W, P, FW, S) },
1125    >
1126    where
1127        [(); C * H * W]:,
1128        [(); OC * conv_out_dim(H, P, FH, S) * conv_out_dim(W, P, FW, S)]:,
1129        (): ConvGeometryIsValid<H, W, FH, FW, S, P>,
1130        Layers: private::AppendLayer<
1131                Conv<W, H, C, FH, FW, OC, S, P>,
1132                { OC * conv_out_dim(H, P, FH, S) * conv_out_dim(W, P, FW, S) },
1133            >,
1134    {
1135        ImageBuilder {
1136            layers: self.layers.then(Conv::<W, H, C, FH, FW, OC, S, P>::init()),
1137        }
1138    }
1139
1140    pub fn flatten(self) -> VectorBuilder<Layers, INPUT, { C * H * W }>
1141    where
1142        [(); C * H * W]:,
1143    {
1144        VectorBuilder {
1145            layers: self.layers,
1146        }
1147    }
1148
1149    pub fn build(self) -> Sequential<INPUT, { C * H * W }>
1150    where
1151        [(); C * H * W]:,
1152        Layers: private::ModuleChain<INPUT, { C * H * W }> + fmt::Debug + 'static,
1153    {
1154        Sequential::from_runtime(private::Stack::new(self.layers))
1155    }
1156}
1157
1158#[cfg(test)]
1159mod tests {
1160    use super::*;
1161
1162    fn approx_eq(a: Float, b: Float, eps: Float) {
1163        let diff = (a - b).abs();
1164        assert!(diff <= eps, "expected {a} ~= {b} (diff={diff}, eps={eps})");
1165    }
1166
1167    #[test]
1168    fn mse_loss_matches_manual_computation() {
1169        let output = [2.0, -1.0];
1170        let target = [1.0, 1.0];
1171        let mut grad = [0.0; 2];
1172        let loss = mse_loss(&output, &target, &mut grad);
1173        approx_eq(loss, 2.5, 1e-12);
1174        assert_eq!(grad, [1.0, -2.0]);
1175    }
1176
1177    #[test]
1178    fn dense_input_gradient_matches_finite_difference() {
1179        let mut layer = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 7);
1180        layer.weights.copy_from_slice(&[0.4, -0.2, 0.1, 0.3]);
1181        *layer.biases = [0.05, -0.1];
1182
1183        let input = [0.7, -1.2];
1184        let output_grad = [0.8, -0.4];
1185        let mut output = [0.0; 2];
1186        let mut input_grad = [0.0; 2];
1187
1188        layer.zero_grad();
1189        layer.forward(&input, &mut output);
1190        layer.backward(&input, &output, &output_grad, &mut input_grad);
1191
1192        let eps = 1e-7;
1193        for i in 0..2 {
1194            let mut plus = input;
1195            let mut minus = input;
1196            plus[i] += eps;
1197            minus[i] -= eps;
1198
1199            let mut plus_out = [0.0; 2];
1200            let mut minus_out = [0.0; 2];
1201            layer.forward(&plus, &mut plus_out);
1202            layer.forward(&minus, &mut minus_out);
1203            let objective_plus = plus_out
1204                .iter()
1205                .zip(output_grad.iter())
1206                .map(|(o, g)| o * g)
1207                .sum::<Float>();
1208            let objective_minus = minus_out
1209                .iter()
1210                .zip(output_grad.iter())
1211                .map(|(o, g)| o * g)
1212                .sum::<Float>();
1213            let numeric = (objective_plus - objective_minus) / (2.0 * eps);
1214            approx_eq(input_grad[i], numeric, 1e-6);
1215        }
1216    }
1217
1218    #[test]
1219    fn dense_weight_gradient_matches_finite_difference() {
1220        let mut layer = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 11);
1221        layer.weights.copy_from_slice(&[0.4, -0.2, 0.1, 0.3]);
1222        *layer.biases = [0.05, -0.1];
1223
1224        let input = [0.7, -1.2];
1225        let output_grad = [0.8, -0.4];
1226        let mut output = [0.0; 2];
1227        let mut input_grad = [0.0; 2];
1228
1229        layer.zero_grad();
1230        layer.forward(&input, &mut output);
1231        layer.backward(&input, &output, &output_grad, &mut input_grad);
1232
1233        let weight_idx = 1;
1234        let eps = 1e-7;
1235        let mut plus = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 0);
1236        plus.weights.copy_from_slice(&layer.weights);
1237        plus.biases.copy_from_slice(layer.biases.as_ref());
1238        plus.weights[weight_idx] += eps;
1239        let mut minus = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 0);
1240        minus.weights.copy_from_slice(&layer.weights);
1241        minus.biases.copy_from_slice(layer.biases.as_ref());
1242        minus.weights[weight_idx] -= eps;
1243
1244        let mut plus_out = [0.0; 2];
1245        let mut minus_out = [0.0; 2];
1246        plus.forward(&input, &mut plus_out);
1247        minus.forward(&input, &mut minus_out);
1248        let objective_plus = plus_out
1249            .iter()
1250            .zip(output_grad.iter())
1251            .map(|(o, g)| o * g)
1252            .sum::<Float>();
1253        let objective_minus = minus_out
1254            .iter()
1255            .zip(output_grad.iter())
1256            .map(|(o, g)| o * g)
1257            .sum::<Float>();
1258        let numeric = (objective_plus - objective_minus) / (2.0 * eps);
1259
1260        approx_eq(layer.weight_grads[weight_idx], numeric, 1e-6);
1261    }
1262
1263    #[test]
1264    fn seeded_initialization_is_reproducible() {
1265        let a = DenseLayer::<3, 2>::seeded(42);
1266        let b = DenseLayer::<3, 2>::seeded(42);
1267        assert_eq!(&*a.weights, &*b.weights);
1268        assert_eq!(&*a.biases, &*b.biases);
1269    }
1270
1271    #[test]
1272    fn builder_training_decreases_loss_with_seeded_shuffle() {
1273        let mut model = ModelBuilder::new()
1274            .input::<1>()
1275            .dense::<8>()
1276            .relu()
1277            .dense::<1>()
1278            .build();
1279        let samples = (-20..=20)
1280            .map(|i| {
1281                let x = i as Float / 10.0;
1282                Sample::new([x], [2.0 * x - 0.5])
1283            })
1284            .collect::<Vec<_>>();
1285        let config = TrainConfig::adam(0.03)
1286            .epochs(250)
1287            .batch_size(8)
1288            .shuffle_seed(9);
1289
1290        let before = samples
1291            .iter()
1292            .map(|sample| {
1293                let output = model.predict(&sample.input);
1294                let mut grad = [0.0; 1];
1295                MeanSquaredError.loss_and_grad(&output, &sample.target, &mut grad)
1296            })
1297            .sum::<Float>()
1298            / samples.len() as Float;
1299        let during = model.fit(&samples, config);
1300        let after = samples
1301            .iter()
1302            .map(|sample| {
1303                let output = model.predict(&sample.input);
1304                let mut grad = [0.0; 1];
1305                MeanSquaredError.loss_and_grad(&output, &sample.target, &mut grad)
1306            })
1307            .sum::<Float>()
1308            / samples.len() as Float;
1309
1310        assert!(during < before, "training step average should improve");
1311        assert!(after < before * 0.2, "expected loss to fall sharply");
1312    }
1313}