nevermind_neu/dataloader/
databatch.rs

1use ndarray::{Array, Axis};
2
3use crate::util::{DataVec, Array2D};
4
5
6#[derive(Clone, Default)]
7pub struct LabeledEntry {
8    pub input: DataVec,
9    pub expected: DataVec,
10}
11
12impl LabeledEntry {
13    pub fn new(input: Vec<f32>, expected: Vec<f32>) -> Self {
14        Self {
15            input: Array::from_vec(input),
16            expected: Array::from_vec(expected),
17        }
18    }
19}
20
21#[derive(Default, Clone)]
22pub struct MiniBatch {
23    pub input: Array2D,
24    pub output: Array2D,
25}
26
27impl MiniBatch {
28    pub fn new(b: Vec<&LabeledEntry>) -> Self {
29        assert!( !b.is_empty() );
30
31        let mut inp_arr = Array2D::zeros( (b.len(), b.first().unwrap().input.shape()[0]) );
32        let mut out_arr = Array2D::zeros( (b.len(), b.first().unwrap().expected.shape()[0]) );
33
34        // Copies memory into batch
35
36        for (idx, it) in b.iter().enumerate() {
37            let mut inp_entry = inp_arr.index_axis_mut(Axis(0), idx);
38            inp_entry.assign(&it.input);
39
40            let mut out_entry = out_arr.index_axis_mut(Axis(0), idx);
41            out_entry.assign(&it.expected);
42        }
43
44        Self {
45            input: inp_arr,
46            output: out_arr,
47        }
48    }
49
50    pub fn new_no_ref(b: Vec<LabeledEntry>) -> Self {
51        // TODO : refactor cause dublicating constructors
52        assert!( !b.is_empty() );
53
54        let mut inp_arr = Array2D::zeros( (b.len(), b.first().unwrap().input.shape()[0]) );
55        let mut out_arr = Array2D::zeros( (b.len(), b.first().unwrap().expected.shape()[0]) );
56
57        // Copies memory into batch
58
59        for (idx, it) in b.iter().enumerate() {
60            let mut inp_entry = inp_arr.index_axis_mut(Axis(0), idx);
61            inp_entry.assign(&it.input);
62
63            let mut out_entry = out_arr.index_axis_mut(Axis(0), idx);
64            out_entry.assign(&it.expected);
65        }
66
67        Self {
68            input: inp_arr,
69            output: out_arr,
70        }
71    }
72}