rds_tensors/
tindex.rs

1/// A trait which adds indexing functions to simplify index manipulation.
2pub trait TIndex {
3
4    /// Increment an index in the row-major order.
5    fn inc_ro(&mut self, shape : &[usize]);
6
7    /// Decrement an index in the row-major order.
8    fn dec_ro(&mut self, shape : &[usize]);
9
10    /// Increment an index in the column-major order.
11    fn inc_co(&mut self, shape : &[usize]);
12
13    /// Decrement an index in the column-major order.
14    fn dec_co(&mut self, shape : &[usize]);
15
16    /// Return true if the index is zero for all dimensions.
17    fn is_zero(&mut self) -> bool;
18    
19    /// Compute the resulting position in a linear storage array.
20    fn to_pos(&self, shape : &[usize], strides : &[usize]) -> usize;
21}
22
23impl TIndex for [usize] {
24
25    fn inc_ro(&mut self, shape : &[usize]) {
26        let mut i = self.len();
27        while i > 0 {
28            self[i-1] += 1;
29            if self[i-1] >= shape[i-1] {
30                self[i-1] = 0;
31                i -= 1;
32            }
33            else {
34                break;
35            }
36        }
37    }
38
39    fn dec_ro(&mut self, shape : &[usize]) {
40        let mut i = self.len();
41        while i > 0 {
42            if self[i-1] == 0 {
43                self[i-1] = shape[i-1] - 1;
44                i -= 1;
45            }
46            else {
47                self[i-1] -= 1;
48                break;
49            }
50        }
51    }
52
53    fn inc_co(&mut self, shape : &[usize]) {
54        let mut i = 0;
55        while i < self.len() {
56            self[i] += 1;
57            if self[i] >= shape[i] {
58                self[i] = 0;
59                i += 1;
60            }
61            else {
62                break;
63            }
64        }
65    }
66
67    fn dec_co(&mut self, shape : &[usize]) {
68        let mut i = 0;
69        while i < self.len() {
70            if self[i] ==  0 {
71                self[i] = shape[i] - 1;
72                i += 1;
73            }
74            else {
75                self[i] -= 1;
76                break;
77            }
78        }
79    }
80
81    fn is_zero(&mut self) -> bool {
82        for i in 0..self.len() {
83            if self[i] != 0 {
84                return false;
85            }
86        }
87        return true;
88    }
89
90    fn to_pos(&self, shape : &[usize], strides : &[usize]) -> usize {
91        let mut pos = 0usize;
92        for i in 0..self.len() {
93            assert!(self[i] < shape[i], "TIndex::to_pos(): idx is out of bound.");
94            pos += self[i] * strides[i];
95        }
96        return pos;
97    }
98}