nevermind_neu/dataloader/
simple.rs

1use std::{cell::RefCell, error::Error, fs::File};
2
3use crate::dataloader::{DataLoader, LabeledEntry, MiniBatch};
4
5pub struct SimpleDataLoader {
6    pub id: RefCell<usize>,
7    pub data: Vec<LabeledEntry>,
8}
9
10impl DataLoader for SimpleDataLoader {
11    fn next(&self) -> &LabeledEntry {
12        assert!(self.data.len() > 0);
13
14        let mut self_id = self.id.borrow_mut();
15
16        if *self_id < self.data.len() {
17            let ret = &self.data[*self_id];
18            *self_id += 1;
19            return ret;
20        } else {
21            *self_id = 0;
22            drop(self_id);
23
24            return self.next();
25        }
26    }
27
28    fn next_batch(&self, size: usize) -> MiniBatch {
29        let mut mb = Vec::with_capacity(size);
30
31        for _i in 0..size {
32            mb.push(self.next());
33        }
34
35        MiniBatch::new(mb)
36    }
37
38    fn reset(&mut self) {
39        *self.id.borrow_mut() = 0;
40    }
41
42    fn len(&self) -> Option< usize > {
43        Some(self.data.len())
44    }
45
46    fn pos(&self) -> Option< usize > {
47        Some(*self.id.borrow())
48    }
49}
50
51impl SimpleDataLoader {
52    pub fn new(data: Vec<LabeledEntry>) -> Self {
53        Self {
54            id: RefCell::new(0),
55            data,
56        }
57    }
58
59    pub fn from_csv_file(filepath: &str, lbl_col_count: usize) -> Result<Self, Box<dyn Error>> {
60        let file = File::open(filepath)?;
61        let mut rdr = csv::Reader::from_reader(file);
62
63        let records = rdr.records();
64
65        let mut data = Vec::new();
66
67        for row in records {
68            let row = row?;
69
70            let inp_len = row.len() - lbl_col_count;
71
72            let mut inp_vec = Vec::with_capacity(inp_len);
73            let mut out_vec = Vec::with_capacity(lbl_col_count);
74
75            for (idx, val) in row.iter().enumerate() {
76                if idx > inp_len {
77                    break;
78                }
79
80                inp_vec.push(val.parse::<f32>()?);
81            }
82
83            for val in row.iter().skip(inp_len) {
84                out_vec.push(val.parse::<f32>()?);
85            }
86
87            let lbl_entry = LabeledEntry::new(inp_vec, out_vec);
88            data.push(lbl_entry);
89        }
90
91        Ok(SimpleDataLoader::new(data))
92    }
93
94    pub fn empty() -> Self {
95        Self {
96            id: RefCell::new(0),
97            data: vec![],
98        }
99    }
100}