1use super::NdArray;
2
3pub fn check_dim_is_legal(dim: i32, max_dim: usize) -> usize {
5 assert!((dim >= 0 && dim < max_dim as i32) ||
8 (dim < 0 && dim >= - (max_dim as i32)), "check your dim {} should < max {}", dim, max_dim);
9 if dim < 0 {
10 (max_dim as i32 + dim) as usize
11 } else {
12 dim as usize
13 }
14}
15
16fn collect_by_recursive_then_gather_to<F>(pos: usize, max_dim: usize, dim_op_on: usize, src_idx: usize, tgt_idx: usize, src_base_sizes: &Vec<usize>, tgt_base_sizes: &Vec<usize>, shapes: &Vec<usize>, src_data: &Vec<f32>, tgt_data: &mut Vec<f32>, gather: &F)
18where F: Fn(Vec<&f32>) -> f32
19{
20 if pos == max_dim - 1 {
21 let last_size = if dim_op_on == pos {
22 1
23 } else {
24 shapes[pos]
25 };
26 tgt_data.iter_mut().skip(tgt_idx).take(last_size).enumerate()
27 .for_each(|(i, dst)| {
28 let t: Vec<&f32> = src_data.iter().skip(src_idx + i).step_by(src_base_sizes[dim_op_on]).take(shapes[dim_op_on]).map(|i| i).collect();
29 *dst = gather(t);
30 });
31
32
33
34 } else if pos == dim_op_on {
35 collect_by_recursive_then_gather_to(pos + 1, max_dim, dim_op_on, src_idx, tgt_idx, src_base_sizes, tgt_base_sizes, shapes, src_data, tgt_data, gather);
36 } else {
37 for i in 0..shapes[pos] {
38 collect_by_recursive_then_gather_to(pos + 1, max_dim, dim_op_on, src_idx + i * src_base_sizes[pos], tgt_idx + i * tgt_base_sizes[pos], src_base_sizes, tgt_base_sizes, shapes, src_data, tgt_data, gather);
39 }
40 }
41}
42
43fn gather_by_a_specific_dim_and_do(x: &NdArray, dim: i32, gather: &dyn Fn(Vec<&f32>) -> f32) -> NdArray {
48 let dim = check_dim_is_legal(dim, x.dim());
51 let tgt_shape = x.shape.iter().enumerate().fold(vec![], |mut s, (i, item)| {
52 if i == dim {
53 s.push(1);
54 s
55 } else {
56 s.push(*item);
57 s
58 }
59 });
60 let mut target = NdArray::new(tgt_shape);
61
62 collect_by_recursive_then_gather_to(0, x.dim(), dim, 0, 0, &NdArray::index_base_sizes(&x.shape), &NdArray::index_base_sizes(&target.shape), &x.shape, &x.data, &mut target.data, &gather);
63
64 target.squeeze(dim as i32);
65
66 target
67}
68
69pub fn softmax(x: &mut NdArray, dim: i32) {
71 let dim = check_dim_is_legal(dim, x.dim());
72
73 let index_base_sizes = NdArray::index_base_sizes(&x.shape);
74
75 fn retrieval_by_recursive(pos: usize, max_dim: usize, softmax_dim: usize, idx: usize, index_base_sizes: &Vec<usize>, shapes: &Vec<usize>, data: &mut Vec<f32>) {
77 if pos == max_dim - 1 {
78 let last_size = if softmax_dim == pos {
79 1
80 } else {
81 shapes[pos]
82 };
83 (0..last_size).for_each(|i| {
84 let idxs: Vec<usize> = (0..shapes[softmax_dim]).map(|sd_i| idx + i + index_base_sizes[softmax_dim] * sd_i).collect();
85 let mut softmax_data: Vec<f32> = idxs.iter().map(|ii| data[*ii]).collect();
86
87 let constant = softmax_data.iter().cloned().reduce(f32::max).unwrap();
88 softmax_data.iter_mut().for_each(|x| {
89 *x = (*x - constant).exp();
90 });
91 let sum: f32 = softmax_data.iter().sum();
92 softmax_data.iter_mut().for_each(|x| *x /= sum);
93 idxs.into_iter().zip(softmax_data.into_iter()).for_each(|(sd_i, d)| {
94 data[sd_i] = d;
95 })
96 });
97 } else if pos == softmax_dim {
98 retrieval_by_recursive(pos + 1, max_dim, softmax_dim, idx, index_base_sizes, shapes, data);
99 } else {
100 for i in 0..shapes[pos] {
101 retrieval_by_recursive(pos + 1, max_dim, softmax_dim, idx + i * index_base_sizes[pos], index_base_sizes, shapes, data);
102 }
103 }
104 }
105
106 retrieval_by_recursive(0, x.dim(), dim, 0, &index_base_sizes, &x.shape, &mut x.data);
108}
109
110pub fn sum_ndarray(x: &NdArray, dim: i32) -> NdArray {
112 fn sum_value(src_data: Vec<&f32>) -> f32 {
113 src_data.iter().fold(0.0, |s, i| s + **i)
114 }
115 gather_by_a_specific_dim_and_do(x, dim, &sum_value)
116}
117
118pub fn argmax(x: &NdArray, dim: i32) -> NdArray {
120 fn get_arg_by_max(src_data: Vec<&f32>) -> f32 {
123 src_data.iter().enumerate().fold((0.0, f32::MIN), |s, i| {
124 if **i.1 > s.1 {
125 (i.0 as f32, **i.1)
126 } else {
127 s
128 }
129 }).0
130 }
131 gather_by_a_specific_dim_and_do(x, dim, &get_arg_by_max)
132}
133
134pub fn mean(x: &NdArray, dim: i32) -> NdArray {
136 fn avg_value(src_data: Vec<&f32>) -> f32 {
137 src_data.iter().fold(0.0, |s, i| s + **i) / src_data.len() as f32
138 }
139 gather_by_a_specific_dim_and_do(x, dim, &avg_value)
140}
141
142pub fn min(x: &NdArray, dim: i32) -> NdArray {
144 fn min_value(src_data: Vec<&f32>) -> f32 {
145 src_data.iter().fold(f32::MAX, |s, i| s.min(**i))
146 }
147 gather_by_a_specific_dim_and_do(x, dim, &min_value)
148}
149
150pub fn max(x: &NdArray, dim: i32) -> NdArray {
152 fn max_value(src_data: Vec<&f32>) -> f32 {
153 src_data.iter().fold(f32::MIN, |s, i| s.max(**i))
154 }
155 gather_by_a_specific_dim_and_do(x, dim, &max_value)
156}
157
158pub fn std(x: &NdArray, dim: i32, unbiased: bool) -> NdArray {
163 fn std_value(src_data: Vec<&f32>) -> f32 {
164 let mean = src_data.iter().fold(0.0, |s, i| s + **i) / src_data.len() as f32;
165 (src_data.iter().fold(0.0, |s, i| s + (**i - mean).powf(2.0)) / f32::max((src_data.len() - 1) as f32, 1e-6)).sqrt()
166 }
167 fn std_value_biased(src_data: Vec<&f32>) -> f32 {
168 let mean = src_data.iter().fold(0.0, |s, i| s + **i) / src_data.len() as f32;
169 (src_data.iter().fold(0.0, |s, i| s + (**i - mean).powf(2.0)) / src_data.len() as f32).sqrt()
170 }
171 if unbiased {
172 gather_by_a_specific_dim_and_do(x, dim, &std_value)
173 } else {
174 gather_by_a_specific_dim_and_do(x, dim, &std_value_biased)
175 }
176
177}
178
179
180pub fn sigmoid(x: &NdArray) -> NdArray {
184 let mut out = x.clone();
185 out.data_as_mut_vector().iter_mut().for_each(|i| {
186 if *i < 0.0 {
188 let t = i.exp();
190 *i = t / (1.0 + t);
191 } else {
192 *i = 1.0 / (1.0 + (-*i).exp());
194 }
195 });
196 out
197}
198
199pub fn tanh(x: &NdArray) -> NdArray {
203 let mut out = x * 2.0;
204 out.data_as_mut_vector().iter_mut().for_each(|i| {
205 if *i < 0.0 {
206 let t = i.exp();
207 *i = t / (1.0 + t);
208 } else {
209 *i = 1.0 / (1.0 + (-*i).exp());
210 } *i = 2.0 * *i - 1.0;
213 });
214 out
215}
216
217pub fn relu(x: &NdArray) -> NdArray {
221 let mut out = x.clone();
222 out.data_as_mut_vector().iter_mut().for_each(|i| *i = f32::max(*i, 0.0));
223 out
224}
225
226#[cfg(test)]
227mod test {
228 use super::{softmax, sum_ndarray, argmax, relu, sigmoid, tanh, mean, std, min, max};
229 use super::NdArray;
230
231 #[test]
232 fn test_activation_functions() {
233 let x = NdArray::new(vec![1.0, 232.0, -1.0, -22.0, 0.0]);
234 println!("{}", relu(&x));
235 println!("{}", sigmoid(&x));
236 println!("{}", tanh(&x));
237 }
238
239 #[test]
240 fn test_argmax() {
241 let mut x = NdArray::new(vec![1.0, -123.0, 5.8, 2.3, 11.3, 5.0]);
242 x.reshape(vec![2, 3]);
243 let t = argmax(&x, -1);
244 println!("argmax {x}");
245 let tt = NdArray::new(vec![2.0, 1.0]);
246 assert!(tt == t);
247 println!("{t}");
248 }
249
250
251 #[test]
252 fn test_sum_ndarray() {
253 let x = NdArray::new(vec![vec![1.0; 3]; 2]);
254 let t = sum_ndarray(&x, 0);
255 assert_eq!(NdArray::new(vec![2.0, 2.0, 2.0]), t);
256 println!("{t}");
257
258 let mut x = NdArray::new((0..12).map(|i| i as f32).collect::<Vec<f32>>());
259 x.reshape(vec![2, 3, 2]);
260 let t = sum_ndarray(&x, 1);
261 assert!(t.shape == vec![2, 2]);
262 println!("{t}");
263 }
264
265 #[test]
266 fn test_softmax() {
267 let mut x = NdArray::new(vec![vec![1.0; 3]; 2]);
269 softmax(&mut x, -1);
270 println!("{x}");
271
272 softmax(&mut x, -2);
274 println!("{x}");
275
276 let mut x = NdArray::new(vec![vec![1.1, -3.7, 341.23, 46.6], vec![3.23, 6.2, 0.4, -2.87]]);
278 softmax(&mut x, -1);
279 let xx = NdArray::new(vec![
280 vec![0.0, 0.0, 1.0, 0.0],
281 vec![0.048654296, 0.94836545, 0.002871229, 0.000109125154]
282 ]);
283 println!("{x}");
284 assert_eq!(xx, x);
285 }
286
287 #[test]
288 fn test_mean_std_min_max() {
289 let a = NdArray::random(vec![2, 3], None);
292 println!("{a}\nmean:{}\nstd:{}", mean(&a, 1), std(&a, 1, true));
293
294 println!("min{}\nmax{}", min(&a, 1), max(&a, 1));
295
296 }
297}