dendritic_ndarray/ops/
unary.rs1use crate::ndarray::NDArray;
2
3
4pub trait UnaryOps {
5 fn transpose(self) -> Result<NDArray<f64>, String>;
6 fn permute(self, indice_order: Vec<usize>) -> Result<NDArray<f64>, String>;
7 fn norm(&self, p: usize) -> Result<NDArray<f64>, String>;
8 fn signum(&self) -> Result<NDArray<f64>, String>;
9 fn sum_axis(&self, axis: usize) -> Result<NDArray<f64>, String>;
10 fn select_axis(&self, axis: usize, indices: Vec<usize>) -> Result<NDArray<f64>, String>;
11 fn apply(&self, loss_func: fn(value: f64) -> f64) -> Result<NDArray<f64>, String>;
12 fn argmax(&self, axis: usize) -> NDArray<f64>;
13 fn argmin(&self, axis: usize) -> Result<NDArray<f64>, String>;
14 fn nonzero(&self) -> NDArray<f64>;
15}
16
17
18impl UnaryOps for NDArray<f64> {
19
20 fn transpose(self) -> Result<NDArray<f64>, String> {
22
23 if self.rank() != 2 {
24 return Err("Transpose must contain on rank 2 values".to_string());
25 }
26
27 let mut index = 0;
28 let mut result = NDArray::new(self.shape().reverse()).unwrap();
29
30 for _item in self.values() {
31
32 let indices = self.indices(index).unwrap();
33 let mut reversed_indices = indices.clone();
34 reversed_indices.reverse();
35
36 let idx = self.index(indices).unwrap();
37 let val = self.values()[idx];
38
39 let _ = result.set(reversed_indices ,val);
41 index += 1;
42 }
43
44 Ok(result)
45
46 }
47
48 fn permute(self, indice_order: Vec<usize>) -> Result<NDArray<f64>, String> {
50
51 if indice_order.len() != self.rank() {
52 return Err("Indice order must be same length as rank".to_string());
53 }
54
55 let mut index = 0;
56 let permuted_shape = self.shape().permute(indice_order.clone());
57 let mut result = NDArray::new(permuted_shape).unwrap();
58 for _item in self.values() {
59
60 let indices = self.indices(index).unwrap();
61 let mut new_indice_order = Vec::new();
62 for item in &indice_order {
63 new_indice_order.push(indices[*item])
64 }
65
66 let idx = self.index(indices.clone()).unwrap();
67 let val = self.values()[idx];
68
69 let _ = result.set(new_indice_order ,val);
71 index += 1;
72 }
73
74 Ok(result)
75 }
76
77
78 fn norm(&self, p: usize) -> Result<NDArray<f64>, String> {
80
81 let mut result = NDArray::new(self.shape().values()).unwrap();
82 for index in 0..self.size() {
83 let value = self.values()[index];
84 let raised = value.powf(p as f64);
85 let _ = result.set_idx(index, raised);
86 }
87 Ok(result)
88 }
89
90
91 fn signum(&self) -> Result<NDArray<f64>, String> {
93
94 let mut result = NDArray::new(self.shape().values()).unwrap();
95 for index in 0..self.size() {
96 let value = self.values()[index];
97 if value < 0.0 {
98 let _ = result.set_idx(index, -1.0);
99 } else if value > 0.0 {
100 let _ = result.set_idx(index, 1.0);
101 } else {
102 let _ = result.set_idx(index, 0.0);
103 }
104 }
105
106 Ok(result)
107 }
108
109
110 fn sum_axis(&self, axis: usize) -> Result<NDArray<f64>, String> {
112
113 if axis > self.rank()-1 {
114 return Err("Sum Axis: Axis greater than rank".to_string())
115 }
116
117 if self.rank() > 2 {
118 return Err("Sum Axis: Not supported for rank 2 or higher values yet".to_string());
119 }
120
121 if axis == 0 {
122 let mut result = NDArray::new(vec![1,1]).unwrap();
123 let sum: f64 = self.values().iter().sum();
124 let _ = result.set_idx(0, sum);
125 return Ok(result);
126 }
127
128
129 let sum_stride = self.size() / self.shape().dim(axis);
130 let axis_stride = self.shape().dim(axis.clone());
131 let result_shape: Vec<usize> = vec![axis, axis_stride];
132 let mut result = NDArray::new(result_shape.clone()).unwrap();
133
134 let mut idx = 0;
135 let mut sum: f64 = 0.0;
136 let mut stride_counter = 0;
137 for item in self.values() {
138
139 if stride_counter == sum_stride {
140 let _ = result.set_idx(idx, sum);
141 stride_counter = 0;
142 sum = 0.0;
143 idx += 1;
144 }
145
146 sum += item;
147 stride_counter += 1;
148 }
149
150 let _ = result.set_idx(idx, sum);
151 Ok(result)
152 }
153
154
155 fn select_axis(&self, axis: usize, indices: Vec<usize>) -> Result<NDArray<f64>, String> {
157
158 if axis > self.rank() - 1 {
159 return Err(
160 "Axis Indices: Selected axis larger than rank".to_string()
161 );
162 }
163
164 if self.rank() > 2 {
165 return Err(
166 "Select Axis: Only works on rank 2 values and lower".to_string()
167 );
168
169 }
170
171 let mut curr_shape = self.shape().values();
172 curr_shape[axis] = indices.len();
173
174 let mut result: NDArray<f64> = NDArray::new(
175 curr_shape.clone()
176 ).unwrap();
177
178 for (index, indice) in indices.iter().enumerate() {
179 let axis_vals = self.axis(axis, *indice).unwrap();
180 for (idx, val) in axis_vals.values().iter().enumerate() {
181 let remainder_idx = self.rank() - 1 - axis;
182 let mut indices: Vec<usize> = vec![0; self.rank()];
183 indices[axis] = index;
184 indices[remainder_idx] = idx;
185 result.set(indices, *val).unwrap();
186 }
187 }
188
189 Ok(result)
190 }
191
192
193 fn apply(&self, loss_func: fn(value: f64) -> f64) -> Result<NDArray<f64>, String> {
195 let mut index = 0;
196 let mut result = NDArray::new(self.shape().values()).unwrap();
197 for x in self.values() {
198 let loss_val = loss_func(*x);
199 let _ = result.set_idx(index, loss_val);
200 index += 1;
201 }
202 Ok(result)
203 }
204
205 fn argmax(&self, axis: usize) -> NDArray<f64> {
207
208 let mut results = Vec::new();
210 let shape = self.shape().dim(axis);
211 for idx in 0..shape {
212 let axis_value = self.axis(axis, idx).unwrap();
213
214 let mut curr_max = 0.0;
215 let mut index = 0;
216 let mut final_index = 0;
217 for item in axis_value.values() {
218 if item > &curr_max {
219 curr_max = *item;
220 final_index = index;
221 }
222 index += 1;
223 }
224
225 results.push(final_index as f64);
226 }
227
228 let result = NDArray::array(
229 vec![shape, 1],
230 results
231 ).unwrap();
232 result
233
234 }
235
236
237 fn argmin(&self, axis: usize) -> Result<NDArray<f64>, String> {
239
240 if axis > self.rank() - 1 {
241 let msg = "Argmin: Selected axis larger than rank";
242 return Err(msg.to_string());
243 }
244
245 let mut results: Vec<f64> = Vec::new();
247 let shape = self.shape().dim(axis);
248 for idx in 0..shape {
249 let axis_value = self.axis(axis, idx).unwrap();
250 let mut min_value = f64::INFINITY;
251 let mut min_idx = 0;
252 for (idx, val) in axis_value.values().iter().enumerate() {
253 if *val < min_value {
254 min_value = *val;
255 min_idx = idx;
256 }
257 }
258 results.push(min_idx as f64);
259 }
260
261 Ok(NDArray::array(
262 vec![shape, 1],
263 results
264 ).unwrap())
265
266 }
267
268 fn nonzero(&self) -> NDArray<f64> {
270 let mut vals: Vec<f64> = Vec::new();
271 for item in self.values() {
272 if *item != 0.0 {
273 vals.push(*item);
274 }
275 }
276
277 NDArray::array(vec![vals.len(), 1], vals).unwrap()
278 }
279
280}