dendritic_ndarray/
shape.rs1use serde::{Serialize, Deserialize};
2
3
4#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
5pub struct Shape {
6 values: Vec<usize>
7}
8
9impl Shape {
10
11 pub fn new(shape_vals: Vec<usize>) -> Self {
13 Self { values: shape_vals }
14 }
15
16 pub fn dim(&self, index: usize) -> usize {
18 self.values[index]
19 }
20
21 pub fn values(&self) -> Vec<usize> {
23 self.values.clone()
24 }
25
26 pub fn reverse(&self) -> Vec<usize> {
28 let mut cloned_shape = self.values.clone();
29 cloned_shape.reverse();
30 cloned_shape
31 }
32
33 pub fn remove(&mut self, index: usize) {
35 self.values.remove(index);
36 }
37
38 pub fn push(&mut self, value: usize) {
40 self.values.push(value);
41 }
42
43 pub fn permute(&self, indice_order: Vec<usize>) -> Vec<usize> {
45 let mut new_shape: Vec<usize> = Vec::new();
46 for item in &indice_order {
47 new_shape.push(self.values[*item]);
48 }
49 new_shape
50 }
51
52 pub fn idx(&self, indices: Vec<usize>) -> usize {
54 let mut stride = 1;
55 let mut index = 0;
56 let mut counter = indices.len();
57 for _n in 0..indices.len() {
58 let temp = stride * indices[counter-1];
59 let curr_shape = self.values[counter-1];
60 stride *= curr_shape;
61 index += temp;
62 counter -= 1;
63 }
64 index
65 }
66
67 pub fn indices(&self, index: usize, rank: usize) -> Vec<usize> {
69 let mut indexs = vec![0; rank];
70 let mut count = rank-1;
71 let mut curr_index = index;
72 for _n in 0..rank-1 {
73 let dim_size = self.values[count];
74 indexs[count] = curr_index % dim_size;
75 curr_index /= dim_size;
76 count -= 1;
77 }
78 indexs[0] = curr_index;
79 indexs
80 }
81
82 pub fn multi_index(&self, flat_index: usize) -> Vec<usize> {
84 let mut indices = Vec::new();
85 let mut flat_index = flat_index;
86 for dim in self.values.iter().rev() {
87 indices.push(flat_index % dim);
88 flat_index /= dim;
89 }
90 indices.reverse();
91 indices
92 }
93
94 pub fn strides(&self) -> Vec<usize> {
96 let mut counter = self.values().len();
97 let mut stride = 1;
98 let mut strides: Vec<usize> = Vec::new();
99 for _n in 0..self.values().len() {
100 let curr_shape = self.values[counter-1];
101 strides.push(stride);
102 stride *= curr_shape;
103 counter -= 1;
104 }
105 strides
106 }
107
108}