1use std::fmt::{Debug, Display, Formatter};
2use std::mem;
3use std::ops::IndexMut;
4
5use super::dummy_vector::DummyIndex;
6
7const DYN_DIMENSION: usize = usize::MAX;
8
9#[derive(Clone, Copy)]
10pub struct DimensionMismatchingError {
11 pub dimension: usize,
12 pub vector_dimension: usize,
13}
14
15impl Display for DimensionMismatchingError {
16 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
17 write!(
18 f,
19 "Dimension should be {}, not {}.",
20 self.dimension, self.vector_dimension
21 )
22 }
23}
24
25impl Debug for DimensionMismatchingError {
26 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
27 write!(
28 f,
29 "Dimension should be {}, not {}.",
30 self.dimension, self.vector_dimension
31 )
32 }
33}
34
35#[derive(Clone, Copy)]
36pub struct OutOfShapeError {
37 pub dimension: usize,
38 pub len: usize,
39 pub vector_index: isize,
40}
41
42impl Display for OutOfShapeError {
43 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
44 write!(
45 f,
46 "Length of dimension {} is {}, but it get {}.",
47 self.dimension, self.len, self.vector_index
48 )
49 }
50}
51
52impl Debug for OutOfShapeError {
53 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
54 write!(
55 f,
56 "Length of dimension {} is {}, but it get {}.",
57 self.dimension, self.len, self.vector_index
58 )
59 }
60}
61
62#[derive(Debug, Clone, Copy)]
63pub enum IndexCalculationError {
64 DimensionMismatching(DimensionMismatchingError),
65 OutOfShape(OutOfShapeError),
66}
67
68impl Display for IndexCalculationError {
69 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70 match self {
71 IndexCalculationError::DimensionMismatching(err) => {
72 write!(f, "{}", err)
73 }
74 IndexCalculationError::OutOfShape(err) => {
75 write!(f, "{}", err)
76 }
77 }
78 }
79}
80
81pub trait Shape {
82 const DIMENSION: usize;
83 type VectorType: IndexMut<usize, Output = usize>;
84 type DummyVectorType: IndexMut<usize, Output = DummyIndex>;
85
86 fn zero(&self) -> Self::VectorType;
87
88 fn len(&self) -> usize;
89 fn dimension(&self) -> usize {
90 Self::DIMENSION
91 }
92 fn dimension_of(_: &Self::VectorType) -> usize {
93 Self::DIMENSION
94 }
95
96 fn shape(&self) -> &[usize];
97 fn offset(&self) -> &[usize];
98
99 fn len_of_dimension(&self, dimension: usize) -> Result<usize, DimensionMismatchingError> {
100 if dimension > Self::DIMENSION {
101 Err(DimensionMismatchingError {
102 dimension: Self::DIMENSION,
103 vector_dimension: dimension,
104 })
105 } else {
106 Ok(self.shape()[dimension])
107 }
108 }
109
110 fn offset_of_dimension(&self, dimension: usize) -> Result<usize, DimensionMismatchingError> {
111 if dimension > Self::DIMENSION {
112 Err(DimensionMismatchingError {
113 dimension: Self::DIMENSION,
114 vector_dimension: dimension,
115 })
116 } else {
117 Ok(self.offset()[dimension])
118 }
119 }
120
121 fn index(&self, vector: &Self::VectorType) -> Result<usize, IndexCalculationError> {
122 if Self::dimension_of(vector) > self.dimension() {
123 Err(IndexCalculationError::DimensionMismatching(
124 DimensionMismatchingError {
125 dimension: self.dimension(),
126 vector_dimension: Self::dimension_of(vector),
127 },
128 ))
129 } else {
130 let mut index = 0;
131 for i in 0..self.dimension() {
132 if vector[i] > self.len_of_dimension(i).unwrap() {
133 return Err(IndexCalculationError::OutOfShape(OutOfShapeError {
134 dimension: i,
135 len: self.len_of_dimension(i).unwrap(),
136 vector_index: vector[i] as isize,
137 }));
138 }
139 index += vector[i] * self.offset_of_dimension(i).unwrap();
140 }
141 Ok(index)
142 }
143 }
144
145 fn vector(&self, mut index: usize) -> Self::VectorType {
146 let mut vector = self.zero();
147 for i in 0..self.dimension() {
148 let offset = self.offset_of_dimension(i).unwrap();
149 vector[i] = index / offset;
150 index = index % offset;
151 }
152 vector
153 }
154
155 fn next_vector(&self, vector: &mut Self::VectorType) -> bool {
156 let mut carry = false;
157 vector[self.dimension() - 1] += 1;
158
159 for i in (0..self.dimension()).rev() {
160 if carry {
161 vector[i] += 1;
162 carry = false;
163 }
164 if vector[i] == self.len_of_dimension(i).unwrap() {
165 vector[i] = 0;
166 carry = true;
167 }
168 }
169 !carry
170 }
171
172 fn actual_index(&self, dimension: usize, index: isize) -> Option<usize> {
173 let len = self.len_of_dimension(dimension).unwrap();
174 if index >= (len as isize) || index < -(len as isize) {
175 None
176 } else {
177 Some((index % (len as isize)) as usize)
178 }
179 }
180}
181
182pub(self) fn offset<const DIMENSION: usize>(
183 shape: &[usize; DIMENSION],
184) -> ([usize; DIMENSION], usize) {
185 let mut offset: [usize; DIMENSION] = unsafe { mem::zeroed() };
186
187 offset[shape.len() - 1] = 1;
188 let mut len = 1;
189 for i in (0..(shape.len() - 1)).rev() {
190 offset[i] = offset[i + 1] * shape[i + 1];
191 len *= shape[i + 1];
192 }
193 len *= shape[0];
194 (offset, len)
195}
196
197#[derive(Clone, Copy)]
198pub struct Shape1 {
199 pub(self) shape: [usize; 1],
200}
201
202impl Shape1 {
203 pub fn new(shape: [usize; 1]) -> Self {
204 Self { shape }
205 }
206}
207
208impl Shape for Shape1 {
209 const DIMENSION: usize = 1;
210 type VectorType = [usize; 1];
211 type DummyVectorType = [DummyIndex; 1];
212
213 fn zero(&self) -> Self::VectorType {
214 [0]
215 }
216
217 fn len(&self) -> usize {
218 self.shape[0]
219 }
220
221 fn shape(&self) -> &[usize] {
222 &self.shape
223 }
224
225 fn offset(&self) -> &[usize] {
226 &self.shape
227 }
228}
229
230macro_rules! shape {
231 ($type:ident, $dim:expr) => {
232 #[derive(Clone, Copy)]
233 pub struct $type {
234 pub(self) shape: [usize; $dim],
235 pub(self) offset: [usize; $dim],
236 pub(self) len: usize,
237 }
238
239 impl $type {
240 pub fn new(shape: [usize; $dim]) -> Self {
241 let (offset, len) = offset(&shape);
242 Self {
243 shape: shape,
244 offset: offset,
245 len: len,
246 }
247 }
248 }
249
250 impl Shape for $type {
251 const DIMENSION: usize = $dim;
252 type VectorType = [usize; $dim];
253 type DummyVectorType = [DummyIndex; $dim];
254
255 fn zero(&self) -> Self::VectorType {
256 unsafe { mem::zeroed() }
257 }
258
259 fn len(&self) -> usize {
260 self.len
261 }
262
263 fn shape(&self) -> &[usize] {
264 &self.shape
265 }
266
267 fn offset(&self) -> &[usize] {
268 &self.offset
269 }
270 }
271 };
272}
273
274shape!(Shape2, 2);
275shape!(Shape3, 3);
276shape!(Shape4, 4);
277shape!(Shape5, 5);
278shape!(Shape6, 6);
279shape!(Shape7, 7);
280shape!(Shape8, 8);
281shape!(Shape9, 9);
282shape!(Shape10, 10);
283shape!(Shape11, 11);
284shape!(Shape12, 12);
285shape!(Shape13, 13);
286shape!(Shape14, 14);
287shape!(Shape15, 15);
288shape!(Shape16, 16);
289shape!(Shape17, 17);
290shape!(Shape18, 18);
291shape!(Shape19, 19);
292shape!(Shape20, 20);
293
294pub struct DynShape {
295 pub(self) shape: Vec<usize>,
296 pub(self) offset: Vec<usize>,
297 pub(self) len: usize,
298}
299
300impl DynShape {
301 pub fn new(shape: Vec<usize>) -> Self {
302 let (offset, len) = Self::offset(&shape);
303 Self { shape, offset, len }
304 }
305
306 pub(self) fn offset(shape: &Vec<usize>) -> (Vec<usize>, usize) {
307 let mut offset: Vec<usize> = (0..shape.len()).map(|_| 0).collect();
308 offset[shape.len() - 1] = 1;
309 let mut len = 1;
310 for i in (0..(shape.len() - 1)).rev() {
311 offset[i] = offset[i + 1] * shape[i + 1];
312 len *= shape[i + 1];
313 }
314 len *= shape[0];
315 (offset, len)
316 }
317}
318
319impl Shape for DynShape {
320 const DIMENSION: usize = DYN_DIMENSION;
321 type VectorType = Vec<usize>;
322 type DummyVectorType = Vec<DummyIndex>;
323
324 fn zero(&self) -> Self::VectorType {
325 (0..self.shape.len()).map(|_| 0).collect()
326 }
327
328 fn len(&self) -> usize {
329 self.len
330 }
331
332 fn dimension(&self) -> usize {
333 self.shape.len()
334 }
335
336 fn dimension_of(vector: &Self::VectorType) -> usize {
337 vector.len()
338 }
339
340 fn shape(&self) -> &[usize] {
341 &self.shape
342 }
343
344 fn offset(&self) -> &[usize] {
345 &self.offset
346 }
347}