mlinrust/
utils.rs

1use std::collections::HashSet;
2
3use crate::dataset::dataloader::{Dataloader, BatchIterator};
4use crate::dataset::{Dataset, TaskLabelType};
5use crate::model::Model;
6
7/// Trait for 
8/// * &Dataset<T>
9/// * &mut Dataloader<T, &Dataset<T>>
10/// 
11/// pass a dataloader will reduce the construction and enable batch prediction
12pub trait EvaluateArgTrait<'a, T: TaskLabelType + Copy> {
13    fn dataloader_iter(self, batch: usize) -> BatchIterator<T>;
14}
15
16impl<'a, T: TaskLabelType + Copy> EvaluateArgTrait<'a, T> for &Dataset<T> {
17    fn dataloader_iter(self, batch: usize) -> BatchIterator<T> {
18        let mut loader = Dataloader::new(self, batch, false, None);
19        loader.iter_mut()
20    }
21}
22
23impl<'a, T: TaskLabelType + Copy> EvaluateArgTrait<'a, T> for &mut Dataloader<T, &Dataset<T>> {
24    fn dataloader_iter(self, batch: usize) -> BatchIterator<T> {
25        if self.batch_size != batch {
26            self.batch_size = batch;
27        }
28        self.iter_mut()
29    }
30}
31
32impl<'a, T: TaskLabelType + Copy> EvaluateArgTrait<'a, T> for &mut Dataloader<T, Dataset<T>> {
33    fn dataloader_iter(self, batch: usize) -> BatchIterator<T> {
34        if self.batch_size != batch {
35            self.batch_size = batch;
36        }
37        self.iter_mut()
38    }
39}
40
41/// evaluate classification dataset<usize>
42/// * data: &Dataset<usize> or &mut Dataloader<usize, &Dataset<usize>>
43/// * return: (correct_num, accuracy)
44pub fn evaluate<'a, T: EvaluateArgTrait<'a, usize>>(data: T, model: &impl Model<usize>) -> (usize, f32)
45{
46    let mut correct = 0;
47    let mut total = 0;
48    for (feature, label) in data.dataloader_iter(128) {
49        model.predict_with_batch(&feature).iter().zip(label.iter()).for_each(|(p, l)| {
50            if p == l {
51                correct += 1;
52            }
53            total += 1;
54        })
55    }
56    (correct, correct as f32 / total as f32)
57}
58
59/// evaluate regression dataset<f32>
60/// * data: &Dataset<f32> or &mut Dataloader<f32, &Dataset<f32>>
61/// * return: mean absolute error
62pub fn evaluate_regression<'a, T: EvaluateArgTrait<'a, f32>>(dataset: T, model: &impl Model<f32>) -> f32 {
63    let mut error = 0.0;
64    let mut total = 0;
65
66    for (feature, label) in dataset.dataloader_iter(128) {
67        error += model.predict_with_batch(&feature).iter().zip(label.iter())
68        .fold(0.0, |s, (p, l)| {
69            total += 1;
70            s + (p - l).abs()
71        });
72    }
73    error / total as f32
74}
75
76
77
78/// pseudo random number generator
79/// 
80/// [Wiki linear congruential generator](https://en.wikipedia.org/wiki/Linear_congruential_generator)
81/// 
82/// Note that the max rand number is up to 2^31 - 1
83#[derive(Debug, Clone, Copy)]
84pub struct RandGenerator {
85    seed: usize,
86    a: usize,
87    c: usize,
88    m: usize,
89}
90
91pub trait RandRangeTrait {
92    fn to_f32(self) -> f32;
93
94    fn f32_to_self(n: f32) -> Self;
95}
96
97
98macro_rules! rand_range_trait_for {
99    ($name:tt) => {
100        impl RandRangeTrait for $name {
101            fn to_f32(self) -> f32 {
102                self as f32
103            }
104
105            fn f32_to_self(n: f32) -> Self {
106                n as Self
107            }
108        }
109    }
110}
111
112rand_range_trait_for!(i32);
113rand_range_trait_for!(f32);
114rand_range_trait_for!(usize);
115
116impl RandGenerator {
117    pub fn new(seed: usize) -> Self {
118        // to avoid overflow, so init seed is up to 2^31
119        Self { seed: seed & 0x7f_ff_ff_ff, a: 1103515245, c: 12345, m: 0x7f_ff_ff_ff }
120    }
121
122    /// Watch out the range, [0, 2^31 - 1] instead of uszie::max 2^64 - 1
123    pub fn gen_u32(&mut self) -> usize {
124        self.seed = (self.a * self.seed + self.c) % self.m;
125        self.seed
126    }
127
128    /// generate rand f32 from [0.0, 1.0)
129    pub fn gen_f32(&mut self) -> f32 {
130        self.gen_u32() as f32 / (self.m - 1) as f32
131    }
132
133    /// generate rand usize(u32)/f32/i32 from [lower, upper)
134    pub fn gen_range<T: RandRangeTrait>(&mut self, low: T, upper: T) -> T {
135        T::f32_to_self(self.gen_f32() * upper.to_f32() + low.to_f32())
136    }
137
138    /// provable evenly shuffle
139    /// 
140    /// each element is swapped with **equal probability** to any positions
141    /// 
142    /// ## proof:
143    /// 
144    /// for the last element, there is no probability for other elements swap with it before, each position is 1/n;
145    /// for the second last element, the probability for position [0, n-1] is 1/(n-1), and the probability for finally swapping at the last position is depending on the last element, i.e. 1/n; then for the first (n-1) positions, the probability is 1/(n-1) * (1 - 1/n) = 1/n; so as the (n-2), (n-3)...0th element
146    pub fn shuffle<T>(&mut self, arr: &mut Vec<T>) {
147        for i in 0..arr.len() {
148            if self.gen_f32() > 1.0 / (i+1) as f32 {
149                arr.swap(i, self.gen_range(0, i));
150            } // otherwise keep
151        }
152    }
153
154    /// randomly choose samples from the given array
155    /// * arr: the pool of candidates
156    /// * num: the number of you want; **note that the num should <= arr.len() if w/o replacement**
157    /// * replacement: whether allow repeated samples in the choice
158    ///     * true: allowed
159    ///     * false: not allowed
160    /// * return: Vector of randomly choosen samples
161    pub fn choice<T: Clone>(&mut self, arr: &Vec<T>, num: usize, replacement: bool) -> Vec<T> {
162        if replacement {
163            (0..num).map(|_| arr[self.gen_range(0, arr.len())].clone()).collect()
164        } else {
165            assert!(num <= arr.len());
166            let mut set = HashSet::new();
167            let mut samples = vec![];
168
169            // optimize
170            if arr.len() >= 500 && num as f32 / arr.len() as f32 > 0.85 {
171                let mut temp = arr.clone();
172                self.shuffle(&mut temp);
173                temp.into_iter().take(num).for_each(|item| samples.push(item));
174            } else {
175                while samples.len() < num {
176                    let i = self.gen_range(0, arr.len());
177                    if ! set.contains(&i) {
178                        set.insert(i);
179                        samples.push(arr[i].clone());
180                    }
181                }
182            }
183            samples
184        }
185    }
186}
187
188#[cfg(test)]
189mod test {
190    use super::RandGenerator;
191
192    #[test]
193    fn test_rand() {
194        let mut rng = RandGenerator::new(0);
195        // let mut p = 0.0;
196        let mut high = 0;
197        let mut low = usize::MAX;
198        for _ in 0..10000 {
199            println!("{} {} {}", rng.gen_f32(), rng.gen_range(-10.0, 11.0), rng.gen_range(0, 100));
200            high = high.max(rng.gen_range(0, 100));
201            low = low.min(rng.gen_range(0, 100));
202        }
203        assert!(low == 0);
204        assert!(high == 99);
205
206        let mut a: Vec<usize> = (0..10).collect();
207        rng.shuffle(&mut a);
208        println!("{a:?}");
209    }
210}