dendritic_ndarray/
ndarray.rs

1use serde::{Serialize, Deserialize};
2use crate::shape::*;
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5pub struct NDArray<T> {
6    pub shape: Shape,
7    pub size: usize,
8    pub rank: usize,
9    pub values: Vec<T>
10}
11
12
13impl<T: Default + Clone + std::fmt::Debug + std::cmp::PartialEq> NDArray<T> {
14
15    /// Gets the rank of the current array
16    pub fn rank(&self) -> usize {
17        self.rank
18    }
19
20    /// Returns the shape dimensions of the array
21    pub fn shape(&self) -> &Shape {
22        &self.shape
23    }
24
25    /// Get the generic values stored in the array
26    pub fn values(&self) -> &Vec<T> {
27        &self.values
28    }
29    
30    /// Get the current calculated size of the contigous array
31    pub fn size(&self) -> usize {
32        self.size
33    }
34
35    /// Get generic value from provided indices
36    pub fn get(&self, indices: Vec<usize>) -> &T {
37        &self.values[self.index(indices).unwrap()]
38    }
39
40    /// Get generic value from provided indices
41    pub fn idx(&self, index: usize) -> &T {
42        &self.values[index]
43    }
44
45    /// Lets you change the rank of the current ndarray
46    pub fn set_rank(&mut self, new_rank: usize) {
47        self.rank = new_rank;
48    }
49
50    /// Create instance of NDArray, provide shape dimensions as parameter
51    pub fn new(shape: Vec<usize>) -> Result<NDArray<T>, String> {
52
53        let calculated_rank = shape.len(); 
54        let mut calculated_size = 1; 
55        for item in &shape {
56            calculated_size *= item; 
57        }
58
59        Ok(Self {
60            shape: Shape::new(shape),
61            size: calculated_size,
62            rank: calculated_rank,
63            values: vec![T::default(); calculated_size],
64        })
65    }
66
67    
68    /// Create instance of NDArray, provide shape dimensions and array values as parameter
69    pub fn array(shape: Vec<usize>, values: Vec<T>) -> Result<NDArray<T>, String> {
70
71        let calculated_rank = shape.len(); 
72        let mut calculated_size = 1; 
73        for item in &shape {
74            calculated_size *= item; 
75        }
76
77        if values.len() != calculated_size {
78            return Err("Values don't match size based on dimensions".to_string()) 
79        }
80
81        Ok(Self {
82            shape: Shape::new(shape),
83            size: calculated_size,
84            rank: calculated_rank,
85            values: values,
86        })
87    }
88
89    /// Fill ndarray with values
90    pub fn fill(shape: Vec<usize>, value: T) -> Result<NDArray<T>, String> {
91        let calculated_rank = shape.len(); 
92        let mut calculated_size = 1; 
93        for item in &shape {
94            calculated_size *= item; 
95        }
96
97        let mut values = Vec::new(); 
98        for _item in 0..calculated_size {
99            values.push(value.clone());
100        }
101
102
103        Ok(Self {
104            shape: Shape::new(shape),
105            size: calculated_size,
106            rank: calculated_rank,
107            values: values,
108        })
109    }
110
111    /// Reshape dimensions of array to new shape. Shape must match current size
112    pub fn reshape(&mut self, shape_vals: Vec<usize>) -> Result<(), String> {
113
114        if shape_vals.len() != self.rank {
115            return Err("New Shape values don't match rank of array".to_string());
116        }
117
118        let mut size_validate = 1;
119        for item in &shape_vals {
120            size_validate *= item; 
121        }
122
123        if size_validate != self.size {
124            return Err("New Shape values don't match size of array".to_string());
125        }
126
127        self.shape = Shape::new(shape_vals);
128        Ok(())
129    }
130
131    /// Get contigous index of array using provided indices as parameter
132    pub fn index(&self, indices: Vec<usize>) -> Result<usize, String> {
133
134        if indices.len() != self.rank {
135            return Err("Indexing doesn't match rank of ndarray".to_string());
136        }
137
138        let mut stride = 1; 
139        let mut index = 0;
140        let mut counter = self.rank;  
141        for _n in 0..self.rank {
142            let temp = stride * indices[counter-1]; 
143            let curr_shape = self.shape.dim(counter-1);
144            stride *= curr_shape;
145            index += temp;  
146            counter -= 1; 
147        }
148
149        if index > self.size-1 {
150            return Err("Index out of bounds".to_string());
151        }
152
153        Ok(index)
154    }
155
156    /// Get indices from provided contigous index as parameter
157    pub fn indices(&self, index: usize) -> Result<Vec<usize>, String> {
158
159        if index > self.size-1 {
160            return Err("Index out of bounds".to_string());
161        }
162
163        let mut indexs = vec![0; self.rank]; 
164        let mut count = self.rank-1; 
165        let mut curr_index = index; 
166        for _n in 0..self.rank-1 {
167            let dim_size = self.shape.dim(count);
168            indexs[count] = curr_index % dim_size; 
169            curr_index /= dim_size; 
170            count -= 1;
171        }
172
173        indexs[0] = curr_index;
174        Ok(indexs)       
175    }
176
177    /// Set index and generic value, index must be within size of array
178    pub fn set_idx(&mut self, idx: usize, value: T) -> Result<(), String> {
179
180        if idx > self.size {
181            return Err("Index out of bounds".to_string());
182        }
183
184        self.values[idx] = value;
185        Ok(())
186    }
187
188    /// Set generic value using provided indices. Indices must match rank of array
189    pub fn set(&mut self, indices: Vec<usize>, value: T) -> Result<(), String> {
190
191        if indices.len() != self.rank {
192            return Err("Indices length don't match rank of ndarray".to_string());
193        }
194
195        let index = self.index(indices).unwrap();
196        self.values[index] = value;
197        Ok(())
198    }
199
200
201    /// Get rows dimension associated with multi dimensional array
202    pub fn rows(&self, index: usize) -> Result<Vec<T>, String> {
203
204        let dim_shape = self.shape.dim(0);
205        let result_length = self.size() / dim_shape;
206        let values = self.values();
207        let mut start_index = index * result_length;
208        let mut result = Vec::new();
209
210        for _i in 0..result_length {
211            let value = &values[start_index];
212            result.push(value.clone());
213            start_index += 1; 
214        }
215 
216        Ok(result)
217
218    }
219
220    /// Get column dimension associated with multi dimensional array
221    pub fn cols(&self, index: usize) -> Result<Vec<T>, String> {
222
223        let mut result = Vec::new();
224        let dim_shape = self.shape.dim(1);
225        let values = self.values();
226        let result_length = self.size() / dim_shape;
227        let stride = dim_shape;
228        let mut start = index; 
229
230        for _i in 0..result_length {
231            let value = &values[start];
232            result.push(value.clone());
233            start += stride; 
234        }
235 
236        Ok(result)
237    }
238
239    /// Get values from a specific axis/slice
240    pub fn axis(&self, axis: usize, index: usize) -> Result<NDArray<T>, String> {
241
242        if axis > self.rank() - 1 { 
243            return Err("Axis: Selected axis larger than rank".to_string());
244        }
245
246        if index > self.shape().dim(axis)-1 {
247            return Err("Axis: Index for value is too large".to_string()); 
248        }
249
250        let mut values: Vec<T> = Vec::new();
251        let mut new_shape = self.shape().clone();
252        new_shape.remove(axis);
253        let outer_size = new_shape.values().iter().product::<usize>();
254
255        for item in 0..outer_size {
256            let multi_index = new_shape.multi_index(item);
257            let mut full_index = multi_index.clone();
258            full_index.insert(axis, index); 
259            let flat_index = self.index(full_index).unwrap();
260            let val = &self.values()[flat_index];
261            values.push(val.clone());
262        }
263
264        if new_shape.values().len() == 1 {
265            new_shape.push(1);
266        }
267 
268        Ok(NDArray::array(new_shape.values(),values).unwrap()) 
269    }
270
271    /// Get mutiple axis values with provided indices
272    pub fn axis_indices(&self, axis: usize, indices: Vec<usize>) -> Result<NDArray<T>, String> {
273 
274        if axis > self.rank() - 1 { 
275            return Err("Axis Indices: Selected axis larger than rank".to_string());
276        }
277
278        let mut feature_vec: Vec<T> = Vec::new();
279
280        for idx in &indices {
281            let axis_call = self.axis(axis, *idx).unwrap();
282            let mut axis_values = axis_call.values().clone();
283            feature_vec.append(&mut axis_values);
284        }
285
286        let mut shape = self.shape().values().clone();
287        shape[axis] = indices.len();
288
289        Ok(NDArray::array(shape, feature_vec).unwrap()) 
290
291    }
292
293
294    /// Drop specified axis of ndarray
295    pub fn drop_axis(&self, axis: usize, index: usize) -> Result<NDArray<T>, String> {
296
297        if axis > self.rank() - 1 { 
298            let msg = "Drop Axis: Selected axis larger than rank";
299            return Err(msg.to_string());
300        }
301
302        if index > self.shape().dim(axis) { 
303            let msg = "Drop Axis: Selected indice too large for axis";
304            return Err(msg.to_string());
305        }
306
307        if self.rank() > 2 {
308            let msg = "Drop Axis: Only supported for rank 2 values";
309            return Err(msg.to_string()); 
310        }
311
312        let mut shape_vals = self.shape().values();
313        shape_vals[axis] -= 1;
314        let mut result: NDArray<T> = NDArray::new(shape_vals).unwrap();
315
316        let mut coords: Vec<usize> = vec![0, 0];
317        let coord_len = coords.len() - 1;
318        let axis_shape = self.shape().dim(axis); 
319        for item in 0..axis_shape {
320            let value = self.axis(axis, item).unwrap();
321            if item != index {
322                for val in value.values() {
323                    result.set(coords.clone(), val.clone()).unwrap();
324                    coords[coord_len - axis] += 1;
325                }
326                coords[coord_len - axis] = 0; 
327                coords[axis] += 1;
328            }
329        }
330
331        Ok(result)
332    }
333   
334
335    /// Batch ndarray in specified amount of chunks of rows, cols etc.
336    pub fn batch(&self, batch_size: usize) -> Result<Vec<NDArray<T>>, String> {
337       
338        if batch_size == 0 || batch_size >= self.size() {
339            return Err("Batch size out of bounds".to_string())
340        }
341
342        if self.rank() != 2 {
343            return Err("NDArray must be of rank 2".to_string())
344        }
345
346        let dim_size = batch_size * self.shape.dim(1);
347        let mut start_index = 0; 
348        let mut end_index = start_index + dim_size;
349
350        let mut batches: Vec<NDArray<T>> = Vec::new();
351        
352        for _item in 0..self.size() {
353
354            if end_index >= self.size()+1 {
355                break;
356            }
357
358            let temp_vec: Vec<T> = self.values()[start_index..end_index].to_vec(); 
359            let ndarray_batch: NDArray<T> = NDArray::array(
360                vec![batch_size, self.shape.dim(1)], 
361                temp_vec.clone()
362            ).unwrap();
363
364            batches.push(ndarray_batch); 
365            start_index += self.shape.dim(1); 
366            end_index += self.shape.dim(1); 
367             
368        }
369
370        Ok(batches) 
371    }
372
373
374    pub fn value_indices(&self, value: T) -> Vec<usize> {
375        self.values().iter()
376            .enumerate()
377            .filter_map(|(i, &ref x)| if *x == value { Some(i) } else { None })
378            .collect()
379    }
380
381
382    pub fn indice_query(&self, indices: Vec<usize>) -> Result<NDArray<T>, String> {
383
384        if indices.len() > self.size() {
385            let msg = "Indices length is greater than array size";
386            return Err(msg.to_string());
387        }
388
389        let mut values: Vec<T> = Vec::new();
390        for idx in &indices {
391        
392            if *idx > self.size() {
393                let msg = "Specified index greater than array size";
394                return Err(msg.to_string()); 
395            }
396
397            let val = self.idx(*idx);
398            values.push(val.clone());
399        }
400
401        Ok(NDArray::array(vec![values.len(), 1], values).unwrap())
402    }
403
404    pub fn split(
405        &self, 
406        axis: usize,
407        percentage: f64) -> Result<(NDArray<T>, NDArray<T>), String> {
408
409        if axis > self.shape().values().len() {
410            let msg = "AXIS greater than current NDArray shape";
411            return Err(msg.to_string());
412        } 
413        
414        let axis_shape = self.shape().dim(axis);
415        let split_dist = (percentage * axis_shape as f64).ceil();
416        let rem = (axis_shape as f64 - split_dist).ceil();
417
418        let mut x_shape: Vec<usize> = self.shape().values();
419        let mut y_shape: Vec<usize> = self.shape().values();
420        x_shape[axis] = split_dist as usize;
421        y_shape[axis] = rem as usize;
422
423        let mut x_vals: Vec<T> = Vec::new();
424        let mut y_vals: Vec<T> = Vec::new();
425
426        for axis_idx in 0..axis_shape {
427            let item = self.axis(axis, axis_idx).unwrap();
428            let mut x_item = item.values().clone();
429            if axis_idx < split_dist as usize { 
430                x_vals.append(&mut x_item);
431            } else {
432                y_vals.append(&mut x_item);
433            }
434        }
435
436        let x: NDArray<T> = NDArray::array(x_shape, x_vals).unwrap();
437        let y: NDArray<T> = NDArray::array(y_shape, y_vals).unwrap();
438        Ok((x, y))
439    }
440
441
442}