dendritic_ndarray/
shape.rs

1use serde::{Serialize, Deserialize};
2
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5pub struct Shape {
6    values: Vec<usize>
7}
8
9impl Shape {
10
11    /// Create instance of shape object for stride calculation
12    pub fn new(shape_vals: Vec<usize>) -> Self {
13        Self { values: shape_vals }
14    }
15
16    /// Get's the corresponding dimension of a shape vector
17    pub fn dim(&self, index: usize) -> usize {
18        self.values[index]
19    }
20
21    /// Get the vector values associated with shape object 
22    pub fn values(&self) -> Vec<usize> {
23        self.values.clone()
24    }
25
26    /// Reverse the shape indices
27    pub fn reverse(&self) -> Vec<usize> {
28        let mut cloned_shape = self.values.clone();
29        cloned_shape.reverse();
30        cloned_shape
31    }
32
33    /// Remove axis from shape
34    pub fn remove(&mut self, index: usize) {
35        self.values.remove(index);
36    }
37
38    /// Add axis to shape
39    pub fn push(&mut self, value: usize) {
40        self.values.push(value);
41    }
42
43    /// Permute indices in shape vector
44    pub fn permute(&self, indice_order: Vec<usize>) -> Vec<usize> {
45        let mut new_shape: Vec<usize> = Vec::new();
46        for item in &indice_order {
47            new_shape.push(self.values[*item]);
48        }
49        new_shape
50    }
51
52    /// Produce 1d index from ndarray using higher rank index coordinates
53    pub fn idx(&self, indices: Vec<usize>) -> usize { 
54        let mut stride = 1; 
55        let mut index = 0;
56        let mut counter = indices.len();  
57        for _n in 0..indices.len() {
58            let temp = stride * indices[counter-1]; 
59            let curr_shape = self.values[counter-1];
60            stride *= curr_shape;
61            index += temp;  
62            counter -= 1; 
63        }
64        index
65    }
66
67    /// Produce multi index coordinate with 1d index supplied
68    pub fn indices(&self, index: usize, rank: usize) -> Vec<usize> {
69        let mut indexs = vec![0; rank]; 
70        let mut count = rank-1; 
71        let mut curr_index = index; 
72        for _n in 0..rank-1 {
73            let dim_size = self.values[count];
74            indexs[count] = curr_index % dim_size; 
75            curr_index /= dim_size; 
76            count -= 1;
77        }
78        indexs[0] = curr_index;
79        indexs
80    }
81
82    /// Get associated multi dimensional index with a single index
83    pub fn multi_index(&self, flat_index: usize) -> Vec<usize> {
84        let mut indices = Vec::new(); 
85        let mut flat_index = flat_index; 
86        for dim in self.values.iter().rev() {
87            indices.push(flat_index % dim); 
88            flat_index /= dim; 
89        }
90        indices.reverse();
91        indices
92    }
93
94    /// Get stride for provided axis (dimension)
95    pub fn strides(&self) -> Vec<usize> {
96        let mut counter = self.values().len();  
97        let mut stride = 1; 
98        let mut strides: Vec<usize> = Vec::new();
99        for _n in 0..self.values().len() {
100            let curr_shape = self.values[counter-1];
101            strides.push(stride);
102            stride *= curr_shape;
103            counter -= 1;
104        }        
105        strides
106    }
107
108}