nevermind_neu/dataloader/
databatch.rs1use ndarray::{Array, Axis};
2
3use crate::util::{DataVec, Array2D};
4
5
6#[derive(Clone, Default)]
7pub struct LabeledEntry {
8 pub input: DataVec,
9 pub expected: DataVec,
10}
11
12impl LabeledEntry {
13 pub fn new(input: Vec<f32>, expected: Vec<f32>) -> Self {
14 Self {
15 input: Array::from_vec(input),
16 expected: Array::from_vec(expected),
17 }
18 }
19}
20
21#[derive(Default, Clone)]
22pub struct MiniBatch {
23 pub input: Array2D,
24 pub output: Array2D,
25}
26
27impl MiniBatch {
28 pub fn new(b: Vec<&LabeledEntry>) -> Self {
29 assert!( !b.is_empty() );
30
31 let mut inp_arr = Array2D::zeros( (b.len(), b.first().unwrap().input.shape()[0]) );
32 let mut out_arr = Array2D::zeros( (b.len(), b.first().unwrap().expected.shape()[0]) );
33
34 for (idx, it) in b.iter().enumerate() {
37 let mut inp_entry = inp_arr.index_axis_mut(Axis(0), idx);
38 inp_entry.assign(&it.input);
39
40 let mut out_entry = out_arr.index_axis_mut(Axis(0), idx);
41 out_entry.assign(&it.expected);
42 }
43
44 Self {
45 input: inp_arr,
46 output: out_arr,
47 }
48 }
49
50 pub fn new_no_ref(b: Vec<LabeledEntry>) -> Self {
51 assert!( !b.is_empty() );
53
54 let mut inp_arr = Array2D::zeros( (b.len(), b.first().unwrap().input.shape()[0]) );
55 let mut out_arr = Array2D::zeros( (b.len(), b.first().unwrap().expected.shape()[0]) );
56
57 for (idx, it) in b.iter().enumerate() {
60 let mut inp_entry = inp_arr.index_axis_mut(Axis(0), idx);
61 inp_entry.assign(&it.input);
62
63 let mut out_entry = out_arr.index_axis_mut(Axis(0), idx);
64 out_entry.assign(&it.expected);
65 }
66
67 Self {
68 input: inp_arr,
69 output: out_arr,
70 }
71 }
72}