dendritic_ndarray/
ndarray.rs1use serde::{Serialize, Deserialize};
2use crate::shape::*;
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5pub struct NDArray<T> {
6 pub shape: Shape,
7 pub size: usize,
8 pub rank: usize,
9 pub values: Vec<T>
10}
11
12
13impl<T: Default + Clone + std::fmt::Debug + std::cmp::PartialEq> NDArray<T> {
14
15 pub fn rank(&self) -> usize {
17 self.rank
18 }
19
20 pub fn shape(&self) -> &Shape {
22 &self.shape
23 }
24
25 pub fn values(&self) -> &Vec<T> {
27 &self.values
28 }
29
30 pub fn size(&self) -> usize {
32 self.size
33 }
34
35 pub fn get(&self, indices: Vec<usize>) -> &T {
37 &self.values[self.index(indices).unwrap()]
38 }
39
40 pub fn idx(&self, index: usize) -> &T {
42 &self.values[index]
43 }
44
45 pub fn set_rank(&mut self, new_rank: usize) {
47 self.rank = new_rank;
48 }
49
50 pub fn new(shape: Vec<usize>) -> Result<NDArray<T>, String> {
52
53 let calculated_rank = shape.len();
54 let mut calculated_size = 1;
55 for item in &shape {
56 calculated_size *= item;
57 }
58
59 Ok(Self {
60 shape: Shape::new(shape),
61 size: calculated_size,
62 rank: calculated_rank,
63 values: vec![T::default(); calculated_size],
64 })
65 }
66
67
68 pub fn array(shape: Vec<usize>, values: Vec<T>) -> Result<NDArray<T>, String> {
70
71 let calculated_rank = shape.len();
72 let mut calculated_size = 1;
73 for item in &shape {
74 calculated_size *= item;
75 }
76
77 if values.len() != calculated_size {
78 return Err("Values don't match size based on dimensions".to_string())
79 }
80
81 Ok(Self {
82 shape: Shape::new(shape),
83 size: calculated_size,
84 rank: calculated_rank,
85 values: values,
86 })
87 }
88
89 pub fn fill(shape: Vec<usize>, value: T) -> Result<NDArray<T>, String> {
91 let calculated_rank = shape.len();
92 let mut calculated_size = 1;
93 for item in &shape {
94 calculated_size *= item;
95 }
96
97 let mut values = Vec::new();
98 for _item in 0..calculated_size {
99 values.push(value.clone());
100 }
101
102
103 Ok(Self {
104 shape: Shape::new(shape),
105 size: calculated_size,
106 rank: calculated_rank,
107 values: values,
108 })
109 }
110
111 pub fn reshape(&mut self, shape_vals: Vec<usize>) -> Result<(), String> {
113
114 if shape_vals.len() != self.rank {
115 return Err("New Shape values don't match rank of array".to_string());
116 }
117
118 let mut size_validate = 1;
119 for item in &shape_vals {
120 size_validate *= item;
121 }
122
123 if size_validate != self.size {
124 return Err("New Shape values don't match size of array".to_string());
125 }
126
127 self.shape = Shape::new(shape_vals);
128 Ok(())
129 }
130
131 pub fn index(&self, indices: Vec<usize>) -> Result<usize, String> {
133
134 if indices.len() != self.rank {
135 return Err("Indexing doesn't match rank of ndarray".to_string());
136 }
137
138 let mut stride = 1;
139 let mut index = 0;
140 let mut counter = self.rank;
141 for _n in 0..self.rank {
142 let temp = stride * indices[counter-1];
143 let curr_shape = self.shape.dim(counter-1);
144 stride *= curr_shape;
145 index += temp;
146 counter -= 1;
147 }
148
149 if index > self.size-1 {
150 return Err("Index out of bounds".to_string());
151 }
152
153 Ok(index)
154 }
155
156 pub fn indices(&self, index: usize) -> Result<Vec<usize>, String> {
158
159 if index > self.size-1 {
160 return Err("Index out of bounds".to_string());
161 }
162
163 let mut indexs = vec![0; self.rank];
164 let mut count = self.rank-1;
165 let mut curr_index = index;
166 for _n in 0..self.rank-1 {
167 let dim_size = self.shape.dim(count);
168 indexs[count] = curr_index % dim_size;
169 curr_index /= dim_size;
170 count -= 1;
171 }
172
173 indexs[0] = curr_index;
174 Ok(indexs)
175 }
176
177 pub fn set_idx(&mut self, idx: usize, value: T) -> Result<(), String> {
179
180 if idx > self.size {
181 return Err("Index out of bounds".to_string());
182 }
183
184 self.values[idx] = value;
185 Ok(())
186 }
187
188 pub fn set(&mut self, indices: Vec<usize>, value: T) -> Result<(), String> {
190
191 if indices.len() != self.rank {
192 return Err("Indices length don't match rank of ndarray".to_string());
193 }
194
195 let index = self.index(indices).unwrap();
196 self.values[index] = value;
197 Ok(())
198 }
199
200
201 pub fn rows(&self, index: usize) -> Result<Vec<T>, String> {
203
204 let dim_shape = self.shape.dim(0);
205 let result_length = self.size() / dim_shape;
206 let values = self.values();
207 let mut start_index = index * result_length;
208 let mut result = Vec::new();
209
210 for _i in 0..result_length {
211 let value = &values[start_index];
212 result.push(value.clone());
213 start_index += 1;
214 }
215
216 Ok(result)
217
218 }
219
220 pub fn cols(&self, index: usize) -> Result<Vec<T>, String> {
222
223 let mut result = Vec::new();
224 let dim_shape = self.shape.dim(1);
225 let values = self.values();
226 let result_length = self.size() / dim_shape;
227 let stride = dim_shape;
228 let mut start = index;
229
230 for _i in 0..result_length {
231 let value = &values[start];
232 result.push(value.clone());
233 start += stride;
234 }
235
236 Ok(result)
237 }
238
239 pub fn axis(&self, axis: usize, index: usize) -> Result<NDArray<T>, String> {
241
242 if axis > self.rank() - 1 {
243 return Err("Axis: Selected axis larger than rank".to_string());
244 }
245
246 if index > self.shape().dim(axis)-1 {
247 return Err("Axis: Index for value is too large".to_string());
248 }
249
250 let mut values: Vec<T> = Vec::new();
251 let mut new_shape = self.shape().clone();
252 new_shape.remove(axis);
253 let outer_size = new_shape.values().iter().product::<usize>();
254
255 for item in 0..outer_size {
256 let multi_index = new_shape.multi_index(item);
257 let mut full_index = multi_index.clone();
258 full_index.insert(axis, index);
259 let flat_index = self.index(full_index).unwrap();
260 let val = &self.values()[flat_index];
261 values.push(val.clone());
262 }
263
264 if new_shape.values().len() == 1 {
265 new_shape.push(1);
266 }
267
268 Ok(NDArray::array(new_shape.values(),values).unwrap())
269 }
270
271 pub fn axis_indices(&self, axis: usize, indices: Vec<usize>) -> Result<NDArray<T>, String> {
273
274 if axis > self.rank() - 1 {
275 return Err("Axis Indices: Selected axis larger than rank".to_string());
276 }
277
278 let mut feature_vec: Vec<T> = Vec::new();
279
280 for idx in &indices {
281 let axis_call = self.axis(axis, *idx).unwrap();
282 let mut axis_values = axis_call.values().clone();
283 feature_vec.append(&mut axis_values);
284 }
285
286 let mut shape = self.shape().values().clone();
287 shape[axis] = indices.len();
288
289 Ok(NDArray::array(shape, feature_vec).unwrap())
290
291 }
292
293
294 pub fn drop_axis(&self, axis: usize, index: usize) -> Result<NDArray<T>, String> {
296
297 if axis > self.rank() - 1 {
298 let msg = "Drop Axis: Selected axis larger than rank";
299 return Err(msg.to_string());
300 }
301
302 if index > self.shape().dim(axis) {
303 let msg = "Drop Axis: Selected indice too large for axis";
304 return Err(msg.to_string());
305 }
306
307 if self.rank() > 2 {
308 let msg = "Drop Axis: Only supported for rank 2 values";
309 return Err(msg.to_string());
310 }
311
312 let mut shape_vals = self.shape().values();
313 shape_vals[axis] -= 1;
314 let mut result: NDArray<T> = NDArray::new(shape_vals).unwrap();
315
316 let mut coords: Vec<usize> = vec![0, 0];
317 let coord_len = coords.len() - 1;
318 let axis_shape = self.shape().dim(axis);
319 for item in 0..axis_shape {
320 let value = self.axis(axis, item).unwrap();
321 if item != index {
322 for val in value.values() {
323 result.set(coords.clone(), val.clone()).unwrap();
324 coords[coord_len - axis] += 1;
325 }
326 coords[coord_len - axis] = 0;
327 coords[axis] += 1;
328 }
329 }
330
331 Ok(result)
332 }
333
334
335 pub fn batch(&self, batch_size: usize) -> Result<Vec<NDArray<T>>, String> {
337
338 if batch_size == 0 || batch_size >= self.size() {
339 return Err("Batch size out of bounds".to_string())
340 }
341
342 if self.rank() != 2 {
343 return Err("NDArray must be of rank 2".to_string())
344 }
345
346 let dim_size = batch_size * self.shape.dim(1);
347 let mut start_index = 0;
348 let mut end_index = start_index + dim_size;
349
350 let mut batches: Vec<NDArray<T>> = Vec::new();
351
352 for _item in 0..self.size() {
353
354 if end_index >= self.size()+1 {
355 break;
356 }
357
358 let temp_vec: Vec<T> = self.values()[start_index..end_index].to_vec();
359 let ndarray_batch: NDArray<T> = NDArray::array(
360 vec![batch_size, self.shape.dim(1)],
361 temp_vec.clone()
362 ).unwrap();
363
364 batches.push(ndarray_batch);
365 start_index += self.shape.dim(1);
366 end_index += self.shape.dim(1);
367
368 }
369
370 Ok(batches)
371 }
372
373
374 pub fn value_indices(&self, value: T) -> Vec<usize> {
375 self.values().iter()
376 .enumerate()
377 .filter_map(|(i, &ref x)| if *x == value { Some(i) } else { None })
378 .collect()
379 }
380
381
382 pub fn indice_query(&self, indices: Vec<usize>) -> Result<NDArray<T>, String> {
383
384 if indices.len() > self.size() {
385 let msg = "Indices length is greater than array size";
386 return Err(msg.to_string());
387 }
388
389 let mut values: Vec<T> = Vec::new();
390 for idx in &indices {
391
392 if *idx > self.size() {
393 let msg = "Specified index greater than array size";
394 return Err(msg.to_string());
395 }
396
397 let val = self.idx(*idx);
398 values.push(val.clone());
399 }
400
401 Ok(NDArray::array(vec![values.len(), 1], values).unwrap())
402 }
403
404 pub fn split(
405 &self,
406 axis: usize,
407 percentage: f64) -> Result<(NDArray<T>, NDArray<T>), String> {
408
409 if axis > self.shape().values().len() {
410 let msg = "AXIS greater than current NDArray shape";
411 return Err(msg.to_string());
412 }
413
414 let axis_shape = self.shape().dim(axis);
415 let split_dist = (percentage * axis_shape as f64).ceil();
416 let rem = (axis_shape as f64 - split_dist).ceil();
417
418 let mut x_shape: Vec<usize> = self.shape().values();
419 let mut y_shape: Vec<usize> = self.shape().values();
420 x_shape[axis] = split_dist as usize;
421 y_shape[axis] = rem as usize;
422
423 let mut x_vals: Vec<T> = Vec::new();
424 let mut y_vals: Vec<T> = Vec::new();
425
426 for axis_idx in 0..axis_shape {
427 let item = self.axis(axis, axis_idx).unwrap();
428 let mut x_item = item.values().clone();
429 if axis_idx < split_dist as usize {
430 x_vals.append(&mut x_item);
431 } else {
432 y_vals.append(&mut x_item);
433 }
434 }
435
436 let x: NDArray<T> = NDArray::array(x_shape, x_vals).unwrap();
437 let y: NDArray<T> = NDArray::array(y_shape, y_vals).unwrap();
438 Ok((x, y))
439 }
440
441
442}