mlinrust/ndarray/
utils.rs

1use super::NdArray;
2
3/// check whether the i32 dim (can be either forward or reverse indexing) is legal for max_dim
4pub fn check_dim_is_legal(dim: i32, max_dim: usize) -> usize {
5    // check if dim is legal
6    // convert potential negative dimension index to usize
7    assert!((dim >= 0 && dim < max_dim as i32) || 
8    (dim < 0 && dim >= - (max_dim as i32)), "check your dim {} should < max {}", dim, max_dim);
9    if dim < 0 {
10        (max_dim as i32 + dim) as usize
11    } else {
12        dim as usize
13    }
14}
15
16/// the base worker function; it will be called by gather_by_a_specific_dim_and_do
17fn collect_by_recursive_then_gather_to<F>(pos: usize, max_dim: usize, dim_op_on: usize, src_idx: usize, tgt_idx: usize, src_base_sizes: &Vec<usize>, tgt_base_sizes: &Vec<usize>, shapes: &Vec<usize>, src_data: &Vec<f32>, tgt_data: &mut Vec<f32>, gather: &F) 
18where F: Fn(Vec<&f32>) -> f32 
19{
20    if pos == max_dim - 1 {
21        let last_size = if dim_op_on == pos {
22            1
23        } else {
24            shapes[pos]
25        };        
26        tgt_data.iter_mut().skip(tgt_idx).take(last_size).enumerate()
27        .for_each(|(i, dst)| {
28            let t: Vec<&f32> = src_data.iter().skip(src_idx + i).step_by(src_base_sizes[dim_op_on]).take(shapes[dim_op_on]).map(|i| i).collect();
29            *dst = gather(t);
30        });
31
32
33
34    } else if pos == dim_op_on {
35        collect_by_recursive_then_gather_to(pos + 1, max_dim, dim_op_on, src_idx, tgt_idx, src_base_sizes, tgt_base_sizes, shapes, src_data, tgt_data, gather);
36    } else {
37        for i in 0..shapes[pos] {
38            collect_by_recursive_then_gather_to(pos + 1, max_dim, dim_op_on, src_idx + i * src_base_sizes[pos], tgt_idx + i * tgt_base_sizes[pos], src_base_sizes, tgt_base_sizes, shapes, src_data, tgt_data, gather);
39        }
40    }
41}
42
43/// use for kind of squezzing operations for the specific dim; fix other dimensions, then traverse and collect the specific dim to a Vec
44/// * Example:
45///     - argmax: vector -> single element
46///     - mean: vector -> single element
47fn gather_by_a_specific_dim_and_do(x: &NdArray, dim: i32, gather: &dyn Fn(Vec<&f32>) -> f32) -> NdArray {
48    // squeeze this dim, and gather [....] to f32
49    // e.g., sum by dim, argmax by dim, etc.
50    let dim = check_dim_is_legal(dim, x.dim());
51    let tgt_shape = x.shape.iter().enumerate().fold(vec![], |mut s, (i, item)| {
52        if i == dim {
53            s.push(1);
54            s
55        } else {
56            s.push(*item);
57            s
58        }
59    });
60    let mut target = NdArray::new(tgt_shape);
61
62    collect_by_recursive_then_gather_to(0, x.dim(), dim, 0, 0, &NdArray::index_base_sizes(&x.shape), &NdArray::index_base_sizes(&target.shape), &x.shape, &x.data, &mut target.data, &gather);
63
64    target.squeeze(dim as i32);
65
66    target
67}
68
69/// inplace operation: softmax the given specific dim
70pub fn softmax(x: &mut NdArray, dim: i32) {
71    let dim = check_dim_is_legal(dim, x.dim());
72
73    let index_base_sizes = NdArray::index_base_sizes(&x.shape);
74
75    // retrieval from the specified dim 
76    fn retrieval_by_recursive(pos: usize, max_dim: usize, softmax_dim: usize, idx: usize, index_base_sizes: &Vec<usize>, shapes: &Vec<usize>, data: &mut Vec<f32>) {
77        if pos == max_dim - 1 {
78            let last_size = if softmax_dim == pos {
79                1
80            } else {
81                shapes[pos]
82            };
83            (0..last_size).for_each(|i| {
84                let idxs: Vec<usize> = (0..shapes[softmax_dim]).map(|sd_i| idx + i + index_base_sizes[softmax_dim] * sd_i).collect();
85                let mut softmax_data: Vec<f32> = idxs.iter().map(|ii| data[*ii]).collect();
86
87                let constant = softmax_data.iter().cloned().reduce(f32::max).unwrap();
88                softmax_data.iter_mut().for_each(|x| {
89                    *x = (*x - constant).exp();
90                });
91                let sum: f32 = softmax_data.iter().sum();
92                softmax_data.iter_mut().for_each(|x| *x /= sum);
93                idxs.into_iter().zip(softmax_data.into_iter()).for_each(|(sd_i, d)| {
94                    data[sd_i] = d;
95                })
96            });
97        } else if pos == softmax_dim {
98            retrieval_by_recursive(pos + 1, max_dim, softmax_dim, idx, index_base_sizes, shapes, data);
99        } else {
100            for i in 0..shapes[pos] {
101                retrieval_by_recursive(pos + 1, max_dim, softmax_dim, idx + i * index_base_sizes[pos], index_base_sizes, shapes, data);
102            }
103        }
104    }
105
106    // start
107    retrieval_by_recursive(0, x.dim(), dim, 0, &index_base_sizes, &x.shape, &mut x.data);
108}
109
110/// non-inplace operation: sum the given specific dim
111pub fn sum_ndarray(x: &NdArray, dim: i32) -> NdArray {
112    fn sum_value(src_data: Vec<&f32>) -> f32 {
113        src_data.iter().fold(0.0, |s, i| s + **i)
114    }
115    gather_by_a_specific_dim_and_do(x, dim, &sum_value)
116}
117
118/// non-inplace operation: argmax
119pub fn argmax(x: &NdArray, dim: i32) -> NdArray {
120    // since NdArray has not implemented the template for usize, so we have to return NdArray<f32> instead
121    // todo
122    fn get_arg_by_max(src_data: Vec<&f32>) -> f32 {
123        src_data.iter().enumerate().fold((0.0, f32::MIN), |s, i| {
124            if **i.1 > s.1 {
125                (i.0 as f32, **i.1)
126            } else {
127               s 
128            }
129        }).0
130    }
131    gather_by_a_specific_dim_and_do(x, dim, &get_arg_by_max)
132}
133
134/// non-inplace operation: mean
135pub fn mean(x: &NdArray, dim: i32) -> NdArray {
136    fn avg_value(src_data: Vec<&f32>) -> f32 {
137        src_data.iter().fold(0.0, |s, i| s + **i) / src_data.len() as f32
138    }
139    gather_by_a_specific_dim_and_do(x, dim, &avg_value)
140}
141
142/// non-inplace operation: min element of the specific dim
143pub fn min(x: &NdArray, dim: i32) -> NdArray {
144    fn min_value(src_data: Vec<&f32>) -> f32 {
145        src_data.iter().fold(f32::MAX, |s, i| s.min(**i))
146    }
147    gather_by_a_specific_dim_and_do(x, dim, &min_value)
148}
149
150/// non-inplace operation: max element of the specific dim
151pub fn max(x: &NdArray, dim: i32) -> NdArray {
152    fn max_value(src_data: Vec<&f32>) -> f32 {
153        src_data.iter().fold(f32::MIN, |s, i| s.max(**i))
154    }
155    gather_by_a_specific_dim_and_do(x, dim, &max_value)
156}
157
158/// following PyTorch, calculate the standard deviation with a specific dim
159/// 
160/// * unbiased = true: means doing an unbiased estimation, i.e., sum / (N-1) 
161/// * unbiased = false: i.e., sum / N
162pub fn std(x: &NdArray, dim: i32, unbiased: bool) -> NdArray {
163    fn std_value(src_data: Vec<&f32>) -> f32 {
164        let mean = src_data.iter().fold(0.0, |s, i| s + **i) / src_data.len() as f32;
165        (src_data.iter().fold(0.0, |s, i| s + (**i - mean).powf(2.0)) / f32::max((src_data.len() - 1) as f32, 1e-6)).sqrt()
166    }
167    fn std_value_biased(src_data: Vec<&f32>) -> f32 {
168        let mean = src_data.iter().fold(0.0, |s, i| s + **i) / src_data.len() as f32;
169        (src_data.iter().fold(0.0, |s, i| s + (**i - mean).powf(2.0)) / src_data.len() as f32).sqrt()
170    }
171    if unbiased {
172        gather_by_a_specific_dim_and_do(x, dim, &std_value)
173    } else {
174        gather_by_a_specific_dim_and_do(x, dim, &std_value_biased)
175    }
176    
177}
178
179
180/// 1 / (1 + exp(-x))
181/// 
182/// \[0.0, 1.0\]
183pub fn sigmoid(x: &NdArray) -> NdArray {
184    let mut out = x.clone();
185    out.data_as_mut_vector().iter_mut().for_each(|i| {
186        // to avoid overflow
187        if *i < 0.0 {
188            // exp(x) / (1 + exp(x))
189            let t = i.exp();
190            *i = t / (1.0 + t);
191        } else {
192            // 1 / (1 + exp(-x))
193            *i = 1.0 / (1.0 + (-*i).exp());
194        }
195    });
196    out
197}
198
199/// 2sigmoid(2x) - 1
200/// 
201/// \[-1.0, 1.0\]
202pub fn tanh(x: &NdArray) -> NdArray {
203    let mut out = x * 2.0;
204    out.data_as_mut_vector().iter_mut().for_each(|i| {
205        if *i < 0.0 {
206            let t = i.exp();
207            *i = t / (1.0 + t);
208        } else {
209            *i = 1.0 / (1.0 + (-*i).exp());
210        } // now out = sigmoid(2x)
211        // tanh(x) = 2 x sigmoid(2x) - 1
212        *i = 2.0 * *i - 1.0;
213    });
214    out
215}
216
217/// max(0.0, x)
218/// 
219/// \[0.0, +∞\]
220pub fn relu(x: &NdArray) -> NdArray {
221    let mut out = x.clone();
222    out.data_as_mut_vector().iter_mut().for_each(|i| *i = f32::max(*i, 0.0));
223    out
224}
225
226#[cfg(test)]
227mod test {
228    use super::{softmax, sum_ndarray, argmax, relu, sigmoid, tanh, mean, std, min, max};
229    use super::NdArray;
230
231    #[test]
232    fn test_activation_functions() {
233        let x = NdArray::new(vec![1.0, 232.0, -1.0, -22.0, 0.0]);
234        println!("{}", relu(&x));
235        println!("{}", sigmoid(&x));
236        println!("{}", tanh(&x));
237    }
238
239    #[test]
240    fn test_argmax() {
241        let mut x = NdArray::new(vec![1.0, -123.0, 5.8, 2.3, 11.3, 5.0]);
242        x.reshape(vec![2, 3]);
243        let t = argmax(&x, -1);
244        println!("argmax {x}");
245        let tt = NdArray::new(vec![2.0, 1.0]);
246        assert!(tt == t);
247        println!("{t}");
248    }
249
250
251    #[test]
252    fn test_sum_ndarray() {
253        let x = NdArray::new(vec![vec![1.0; 3]; 2]);
254        let t = sum_ndarray(&x, 0);
255        assert_eq!(NdArray::new(vec![2.0, 2.0, 2.0]), t);
256        println!("{t}");
257
258        let mut x = NdArray::new((0..12).map(|i| i as f32).collect::<Vec<f32>>());
259        x.reshape(vec![2, 3, 2]);
260        let t = sum_ndarray(&x, 1);
261        assert!(t.shape == vec![2, 2]);
262        println!("{t}");
263    }
264
265    #[test]
266    fn test_softmax() {
267        // example 1
268        let mut x = NdArray::new(vec![vec![1.0; 3]; 2]);
269        softmax(&mut x, -1);
270        println!("{x}");
271
272        // example 2
273        softmax(&mut x, -2);
274        println!("{x}");
275
276        // example 3
277        let mut x = NdArray::new(vec![vec![1.1, -3.7, 341.23, 46.6], vec![3.23, 6.2, 0.4, -2.87]]);
278        softmax(&mut x, -1);
279        let xx = NdArray::new(vec![
280            vec![0.0, 0.0, 1.0, 0.0],
281            vec![0.048654296, 0.94836545, 0.002871229, 0.000109125154]
282          ]);
283        println!("{x}");
284        assert_eq!(xx, x);
285    }
286
287    #[test]
288    fn test_mean_std_min_max() {
289        // let a = NdArray::default();
290        // mean(&a, 0); // assert error, since default has no data
291        let a = NdArray::random(vec![2, 3], None);
292        println!("{a}\nmean:{}\nstd:{}", mean(&a, 1), std(&a, 1, true));
293
294        println!("min{}\nmax{}", min(&a, 1), max(&a, 1));
295
296    }
297}