1use std::collections::HashSet;
2
3use crate::dataset::dataloader::{Dataloader, BatchIterator};
4use crate::dataset::{Dataset, TaskLabelType};
5use crate::model::Model;
6
7pub trait EvaluateArgTrait<'a, T: TaskLabelType + Copy> {
13 fn dataloader_iter(self, batch: usize) -> BatchIterator<T>;
14}
15
16impl<'a, T: TaskLabelType + Copy> EvaluateArgTrait<'a, T> for &Dataset<T> {
17 fn dataloader_iter(self, batch: usize) -> BatchIterator<T> {
18 let mut loader = Dataloader::new(self, batch, false, None);
19 loader.iter_mut()
20 }
21}
22
23impl<'a, T: TaskLabelType + Copy> EvaluateArgTrait<'a, T> for &mut Dataloader<T, &Dataset<T>> {
24 fn dataloader_iter(self, batch: usize) -> BatchIterator<T> {
25 if self.batch_size != batch {
26 self.batch_size = batch;
27 }
28 self.iter_mut()
29 }
30}
31
32impl<'a, T: TaskLabelType + Copy> EvaluateArgTrait<'a, T> for &mut Dataloader<T, Dataset<T>> {
33 fn dataloader_iter(self, batch: usize) -> BatchIterator<T> {
34 if self.batch_size != batch {
35 self.batch_size = batch;
36 }
37 self.iter_mut()
38 }
39}
40
41pub fn evaluate<'a, T: EvaluateArgTrait<'a, usize>>(data: T, model: &impl Model<usize>) -> (usize, f32)
45{
46 let mut correct = 0;
47 let mut total = 0;
48 for (feature, label) in data.dataloader_iter(128) {
49 model.predict_with_batch(&feature).iter().zip(label.iter()).for_each(|(p, l)| {
50 if p == l {
51 correct += 1;
52 }
53 total += 1;
54 })
55 }
56 (correct, correct as f32 / total as f32)
57}
58
59pub fn evaluate_regression<'a, T: EvaluateArgTrait<'a, f32>>(dataset: T, model: &impl Model<f32>) -> f32 {
63 let mut error = 0.0;
64 let mut total = 0;
65
66 for (feature, label) in dataset.dataloader_iter(128) {
67 error += model.predict_with_batch(&feature).iter().zip(label.iter())
68 .fold(0.0, |s, (p, l)| {
69 total += 1;
70 s + (p - l).abs()
71 });
72 }
73 error / total as f32
74}
75
76
77
78#[derive(Debug, Clone, Copy)]
84pub struct RandGenerator {
85 seed: usize,
86 a: usize,
87 c: usize,
88 m: usize,
89}
90
91pub trait RandRangeTrait {
92 fn to_f32(self) -> f32;
93
94 fn f32_to_self(n: f32) -> Self;
95}
96
97
98macro_rules! rand_range_trait_for {
99 ($name:tt) => {
100 impl RandRangeTrait for $name {
101 fn to_f32(self) -> f32 {
102 self as f32
103 }
104
105 fn f32_to_self(n: f32) -> Self {
106 n as Self
107 }
108 }
109 }
110}
111
112rand_range_trait_for!(i32);
113rand_range_trait_for!(f32);
114rand_range_trait_for!(usize);
115
116impl RandGenerator {
117 pub fn new(seed: usize) -> Self {
118 Self { seed: seed & 0x7f_ff_ff_ff, a: 1103515245, c: 12345, m: 0x7f_ff_ff_ff }
120 }
121
122 pub fn gen_u32(&mut self) -> usize {
124 self.seed = (self.a * self.seed + self.c) % self.m;
125 self.seed
126 }
127
128 pub fn gen_f32(&mut self) -> f32 {
130 self.gen_u32() as f32 / (self.m - 1) as f32
131 }
132
133 pub fn gen_range<T: RandRangeTrait>(&mut self, low: T, upper: T) -> T {
135 T::f32_to_self(self.gen_f32() * upper.to_f32() + low.to_f32())
136 }
137
138 pub fn shuffle<T>(&mut self, arr: &mut Vec<T>) {
147 for i in 0..arr.len() {
148 if self.gen_f32() > 1.0 / (i+1) as f32 {
149 arr.swap(i, self.gen_range(0, i));
150 } }
152 }
153
154 pub fn choice<T: Clone>(&mut self, arr: &Vec<T>, num: usize, replacement: bool) -> Vec<T> {
162 if replacement {
163 (0..num).map(|_| arr[self.gen_range(0, arr.len())].clone()).collect()
164 } else {
165 assert!(num <= arr.len());
166 let mut set = HashSet::new();
167 let mut samples = vec![];
168
169 if arr.len() >= 500 && num as f32 / arr.len() as f32 > 0.85 {
171 let mut temp = arr.clone();
172 self.shuffle(&mut temp);
173 temp.into_iter().take(num).for_each(|item| samples.push(item));
174 } else {
175 while samples.len() < num {
176 let i = self.gen_range(0, arr.len());
177 if ! set.contains(&i) {
178 set.insert(i);
179 samples.push(arr[i].clone());
180 }
181 }
182 }
183 samples
184 }
185 }
186}
187
188#[cfg(test)]
189mod test {
190 use super::RandGenerator;
191
192 #[test]
193 fn test_rand() {
194 let mut rng = RandGenerator::new(0);
195 let mut high = 0;
197 let mut low = usize::MAX;
198 for _ in 0..10000 {
199 println!("{} {} {}", rng.gen_f32(), rng.gen_range(-10.0, 11.0), rng.gen_range(0, 100));
200 high = high.max(rng.gen_range(0, 100));
201 low = low.min(rng.gen_range(0, 100));
202 }
203 assert!(low == 0);
204 assert!(high == 99);
205
206 let mut a: Vec<usize> = (0..10).collect();
207 rng.shuffle(&mut a);
208 println!("{a:?}");
209 }
210}