dendritic_ndarray/ops/
unary.rs

1use crate::ndarray::NDArray;
2
3
4pub trait UnaryOps {
5    fn transpose(self) -> Result<NDArray<f64>, String>;
6    fn permute(self, indice_order: Vec<usize>) -> Result<NDArray<f64>, String>; 
7    fn norm(&self, p: usize) -> Result<NDArray<f64>, String>;
8    fn signum(&self) -> Result<NDArray<f64>, String>;
9    fn sum_axis(&self, axis: usize) -> Result<NDArray<f64>, String>;
10    fn select_axis(&self, axis: usize, indices: Vec<usize>) -> Result<NDArray<f64>, String>;
11    fn apply(&self, loss_func: fn(value: f64) -> f64) -> Result<NDArray<f64>, String>;
12    fn argmax(&self, axis: usize) -> NDArray<f64>;
13    fn argmin(&self, axis: usize) -> Result<NDArray<f64>, String>;
14    fn nonzero(&self) -> NDArray<f64>;
15}
16
17
18impl UnaryOps for NDArray<f64> {
19
20    /// Tranpose current NDArray instance, works only on rank 2 values
21    fn transpose(self) -> Result<NDArray<f64>, String> {
22
23        if self.rank() != 2 {
24            return Err("Transpose must contain on rank 2 values".to_string());
25        }
26
27        let mut index = 0;
28        let mut result = NDArray::new(self.shape().reverse()).unwrap();
29
30        for _item in self.values() {
31
32            let indices = self.indices(index).unwrap();
33            let mut reversed_indices = indices.clone();
34            reversed_indices.reverse();
35
36            let idx = self.index(indices).unwrap();
37            let val = self.values()[idx]; 
38
39            /* set value from reversed */ 
40            let _ = result.set(reversed_indices ,val);
41            index += 1; 
42        }
43
44        Ok(result)
45
46    }
47
48    /// Permute indices of NDArray. Can be used to peform transposes/contraction on rank 3 or higher values.
49    fn permute(self, indice_order: Vec<usize>) -> Result<NDArray<f64>, String> {
50
51        if indice_order.len() != self.rank() {
52            return Err("Indice order must be same length as rank".to_string());
53        }
54
55        let mut index = 0;
56        let permuted_shape = self.shape().permute(indice_order.clone());
57        let mut result = NDArray::new(permuted_shape).unwrap();
58        for _item in self.values() {
59
60            let indices = self.indices(index).unwrap();
61            let mut new_indice_order = Vec::new();
62            for item in &indice_order {
63                new_indice_order.push(indices[*item])
64            }
65
66            let idx = self.index(indices.clone()).unwrap();
67            let val = self.values()[idx]; 
68
69            /* set value from reversed */ 
70            let _ = result.set(new_indice_order ,val);
71            index += 1; 
72        }
73
74        Ok(result)
75    }
76
77
78    /// L2 norm can also be  x^t x
79    fn norm(&self, p: usize) -> Result<NDArray<f64>, String> {
80
81        let mut result = NDArray::new(self.shape().values()).unwrap();
82        for index in 0..self.size() {
83            let value = self.values()[index]; 
84            let raised = value.powf(p as f64); 
85            let _ = result.set_idx(index, raised);
86        }
87        Ok(result)
88    }
89
90    
91    /// Adds values based on x < 0 < 1
92    fn signum(&self) -> Result<NDArray<f64>, String> {
93
94        let mut result = NDArray::new(self.shape().values()).unwrap();
95        for index in 0..self.size() {
96            let value = self.values()[index]; 
97            if value < 0.0 {
98                let _ = result.set_idx(index, -1.0);
99            } else if value > 0.0 {
100                let _ = result.set_idx(index, 1.0);
101            } else { 
102                let _ = result.set_idx(index, 0.0);
103            }
104        }
105
106        Ok(result)
107    }
108
109
110    /// Sum values along a specified axis
111    fn sum_axis(&self, axis: usize) -> Result<NDArray<f64>, String> {
112
113        if axis > self.rank()-1 {
114            return Err("Sum Axis: Axis greater than rank".to_string())
115        }
116
117        if self.rank() > 2 {
118            return Err("Sum Axis: Not supported for rank 2 or higher values yet".to_string());
119        }
120
121        if axis == 0 {
122            let mut result = NDArray::new(vec![1,1]).unwrap();
123            let sum: f64 = self.values().iter().sum();
124            let _ = result.set_idx(0, sum);
125            return Ok(result);
126        }
127
128
129        let sum_stride = self.size() / self.shape().dim(axis);
130        let axis_stride = self.shape().dim(axis.clone());
131        let result_shape: Vec<usize> = vec![axis, axis_stride];
132        let mut result = NDArray::new(result_shape.clone()).unwrap();
133
134        let mut idx = 0; 
135        let mut sum: f64 = 0.0; 
136        let mut stride_counter = 0; 
137        for item in self.values() {
138
139            if stride_counter == sum_stride {
140                let _ = result.set_idx(idx, sum);  
141                stride_counter = 0;
142                sum = 0.0;
143                idx += 1;  
144            }
145
146            sum += item;
147            stride_counter += 1;
148        }
149
150        let _ = result.set_idx(idx, sum); 
151        Ok(result)
152    }
153
154
155    /// Select specific indices from an axis
156    fn select_axis(&self, axis: usize, indices: Vec<usize>) -> Result<NDArray<f64>, String> {
157 
158        if axis > self.rank() - 1 { 
159            return Err(
160                "Axis Indices: Selected axis larger than rank".to_string()
161            );
162        }
163
164        if self.rank() > 2 {
165            return Err(
166                "Select Axis: Only works on rank 2 values and lower".to_string()
167            );
168
169        }
170
171        let mut curr_shape = self.shape().values(); 
172        curr_shape[axis] = indices.len();
173
174        let mut result: NDArray<f64> = NDArray::new(
175            curr_shape.clone()
176        ).unwrap();
177
178        for (index, indice) in indices.iter().enumerate() {
179            let axis_vals = self.axis(axis, *indice).unwrap();
180            for (idx, val) in axis_vals.values().iter().enumerate() {
181                let remainder_idx = self.rank() - 1 - axis;
182                let mut indices: Vec<usize> = vec![0; self.rank()];
183                indices[axis] = index; 
184                indices[remainder_idx] = idx; 
185                result.set(indices, *val).unwrap(); 
186            }
187        }
188
189        Ok(result)
190    }
191
192
193    /// Apply loss function on values in ndarray
194    fn apply(&self, loss_func: fn(value: f64) -> f64) -> Result<NDArray<f64>, String> {   
195        let mut index = 0; 
196        let mut result = NDArray::new(self.shape().values()).unwrap(); 
197        for x in self.values() {
198            let loss_val = loss_func(*x); 
199            let _ = result.set_idx(index, loss_val);
200            index += 1;  
201        }
202        Ok(result)
203    }
204
205    /// Get's the maximum values index along a specified axis
206    fn argmax(&self, axis: usize) -> NDArray<f64> {
207
208        // this only works for a row (for now)
209        let mut results = Vec::new();
210        let shape = self.shape().dim(axis);
211        for idx in 0..shape {
212            let axis_value = self.axis(axis, idx).unwrap();
213
214            let mut curr_max = 0.0; 
215            let mut index = 0; 
216            let mut final_index = 0;
217            for item in axis_value.values() {
218                if item > &curr_max {
219                    curr_max = *item;
220                    final_index = index; 
221                }
222                index += 1;
223            }
224
225            results.push(final_index as f64);
226        }
227
228        let result = NDArray::array(
229            vec![shape, 1],
230            results
231        ).unwrap();
232        result
233
234    }
235
236
237    /// Get's the minimum values index along a specified axis
238    fn argmin(&self, axis: usize) -> Result<NDArray<f64>, String> {
239
240        if axis > self.rank() - 1 {
241            let msg = "Argmin: Selected axis larger than rank";
242            return Err(msg.to_string());
243        }
244
245        // this only works for a row (for now)
246        let mut results: Vec<f64> = Vec::new();
247        let shape = self.shape().dim(axis);
248        for idx in 0..shape {
249            let axis_value = self.axis(axis, idx).unwrap();
250            let mut min_value = f64::INFINITY;
251            let mut min_idx = 0;
252            for (idx, val) in axis_value.values().iter().enumerate() {
253                if *val < min_value {
254                    min_value = *val; 
255                    min_idx = idx;
256                }
257            }
258            results.push(min_idx as f64);
259        }
260
261        Ok(NDArray::array(
262            vec![shape, 1],
263            results
264        ).unwrap())
265
266    }
267
268    /// Retrieve all non zero elements in an ndarray
269    fn nonzero(&self) -> NDArray<f64> {
270        let mut vals: Vec<f64> = Vec::new();
271        for item in self.values() {
272            if *item != 0.0 {
273                vals.push(*item);
274            }
275        }
276
277        NDArray::array(vec![vals.len(), 1], vals).unwrap()
278    }
279
280}