nevermind_neu/dataloader/
simple.rs1use std::{cell::RefCell, error::Error, fs::File};
2
3use crate::dataloader::{DataLoader, LabeledEntry, MiniBatch};
4
5pub struct SimpleDataLoader {
6 pub id: RefCell<usize>,
7 pub data: Vec<LabeledEntry>,
8}
9
10impl DataLoader for SimpleDataLoader {
11 fn next(&self) -> &LabeledEntry {
12 assert!(self.data.len() > 0);
13
14 let mut self_id = self.id.borrow_mut();
15
16 if *self_id < self.data.len() {
17 let ret = &self.data[*self_id];
18 *self_id += 1;
19 return ret;
20 } else {
21 *self_id = 0;
22 drop(self_id);
23
24 return self.next();
25 }
26 }
27
28 fn next_batch(&self, size: usize) -> MiniBatch {
29 let mut mb = Vec::with_capacity(size);
30
31 for _i in 0..size {
32 mb.push(self.next());
33 }
34
35 MiniBatch::new(mb)
36 }
37
38 fn reset(&mut self) {
39 *self.id.borrow_mut() = 0;
40 }
41
42 fn len(&self) -> Option< usize > {
43 Some(self.data.len())
44 }
45
46 fn pos(&self) -> Option< usize > {
47 Some(*self.id.borrow())
48 }
49}
50
51impl SimpleDataLoader {
52 pub fn new(data: Vec<LabeledEntry>) -> Self {
53 Self {
54 id: RefCell::new(0),
55 data,
56 }
57 }
58
59 pub fn from_csv_file(filepath: &str, lbl_col_count: usize) -> Result<Self, Box<dyn Error>> {
60 let file = File::open(filepath)?;
61 let mut rdr = csv::Reader::from_reader(file);
62
63 let records = rdr.records();
64
65 let mut data = Vec::new();
66
67 for row in records {
68 let row = row?;
69
70 let inp_len = row.len() - lbl_col_count;
71
72 let mut inp_vec = Vec::with_capacity(inp_len);
73 let mut out_vec = Vec::with_capacity(lbl_col_count);
74
75 for (idx, val) in row.iter().enumerate() {
76 if idx > inp_len {
77 break;
78 }
79
80 inp_vec.push(val.parse::<f32>()?);
81 }
82
83 for val in row.iter().skip(inp_len) {
84 out_vec.push(val.parse::<f32>()?);
85 }
86
87 let lbl_entry = LabeledEntry::new(inp_vec, out_vec);
88 data.push(lbl_entry);
89 }
90
91 Ok(SimpleDataLoader::new(data))
92 }
93
94 pub fn empty() -> Self {
95 Self {
96 id: RefCell::new(0),
97 data: vec![],
98 }
99 }
100}