easy_ml/matrices/views/reverse.rs
1use crate::matrices::views::{DataLayout, MatrixMut, MatrixRef, NoInteriorMutability};
2use crate::matrices::{Column, Row};
3use crate::tensors::views::reverse_indexes;
4
5use std::marker::PhantomData;
6
7/**
8 * A view over a matrix where some or all of the rows and columns are iterated in reverse order.
9 *
10 * ```
11 * use easy_ml::matrices::Matrix;
12 * use easy_ml::matrices::views::{MatrixView, MatrixReverse, Reverse};
13 * let ab = Matrix::from(vec![
14 * vec![ 0, 1, 2 ],
15 * vec![ 3, 4, 5 ]
16 * ]);
17 * let reversed = ab.reverse(Reverse { rows: true, ..Default::default() });
18 * let also_reversed = MatrixView::from(
19 * MatrixReverse::from(&ab, Reverse { rows: true, columns: false })
20 * );
21 * assert_eq!(reversed, also_reversed);
22 * assert_eq!(
23 * reversed,
24 * Matrix::from(vec![
25 * vec![ 3, 4, 5 ],
26 * vec![ 0, 1, 2 ]
27 * ])
28 * );
29 * ```
30 */
31#[derive(Clone, Debug)]
32pub struct MatrixReverse<T, S> {
33 source: S,
34 rows: bool,
35 columns: bool,
36 _type: PhantomData<T>,
37}
38
39/**
40 * Helper struct for declaring which of `rows` and `columns` should be reversed for iteration.
41 *
42 * If a dimension is set to `false` it will iterate in its normal order. If a dimension is
43 * set to `true` the iteration order will be reversed, so the first index 0 becomes the last
44 * length-1, and the last index length-1 becomes 0
45 */
46// NB: Default impl for bool is false, which is what we want here
47#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
48pub struct Reverse {
49 pub rows: bool,
50 pub columns: bool,
51}
52
53impl<T, S> MatrixReverse<T, S>
54where
55 S: MatrixRef<T>,
56{
57 /**
58 * Creates a MatrixReverse from a source and a struct for which dimensions to reverse the
59 * order of iteration for. If either or both of rows and columns in [Reverse] are set to false
60 * the iteration order for that dimension will continue to iterate in its normal order.
61 */
62 pub fn from(source: S, reverse: Reverse) -> MatrixReverse<T, S> {
63 MatrixReverse {
64 source,
65 rows: reverse.rows,
66 columns: reverse.columns,
67 _type: PhantomData,
68 }
69 }
70
71 /**
72 * Consumes the MatrixReverse, yielding the source it was created from.
73 */
74 #[allow(dead_code)]
75 pub fn source(self) -> S {
76 self.source
77 }
78
79 /**
80 * Gives a reference to the MatrixReverse's source (in which the data is not reversed).
81 */
82 #[allow(dead_code)]
83 pub fn source_ref(&self) -> &S {
84 &self.source
85 }
86
87 /**
88 * Gives a mutable reference to the MatrixReverse's source (in which the data is not reversed).
89 */
90 #[allow(dead_code)]
91 pub fn source_ref_mut(&mut self) -> &mut S {
92 &mut self.source
93 }
94}
95
96// # Safety
97//
98// Since our source implements NoInteriorMutability correctly, so do we by delegating to it, as
99// we don't introduce any interior mutability.
100/**
101 * A MatrixReverse of a NoInteriorMutability type implements NoInteriorMutability.
102 */
103unsafe impl<T, S> NoInteriorMutability for MatrixReverse<T, S> where S: NoInteriorMutability {}
104
105// # Safety
106//
107// The type implementing MatrixRef must implement it correctly, so by delegating to it
108// by only reversing some indexes and not introducing interior mutability, we implement
109// MatrixRef correctly as well.
110/**
111 * A MatrixReverse implements MatrixRef, with the dimension names the MatrixReverse was created
112 * with iterating in reverse order compared to the dimension names in the original source.
113 */
114unsafe impl<T, S> MatrixRef<T> for MatrixReverse<T, S>
115where
116 S: MatrixRef<T>,
117{
118 fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
119 // If the Matrix has 0 length rows or columns the Tensor reverse_indexes function
120 // would reach out of bounds as it does not need to handle this case for tensors.
121 // Since the caller can expect to be able to query a 0x0 matrix and get None for
122 // any index, we must ensure this out of bounds calculation doesn't happen.
123 if self.source.view_rows() == 0 || self.source.view_columns() == 0 {
124 return None;
125 }
126 let [row, column] = reverse_indexes(
127 &[row, column],
128 &[
129 ("row", self.source.view_rows()),
130 ("column", self.source.view_columns()),
131 ],
132 &[self.rows, self.columns],
133 );
134 self.source.try_get_reference(row, column)
135 }
136
137 fn view_rows(&self) -> Row {
138 self.source.view_rows()
139 }
140
141 fn view_columns(&self) -> Column {
142 self.source.view_columns()
143 }
144
145 unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
146 unsafe {
147 // It is the caller's responsibiltiy to call this unsafe function with only valid
148 // indexes. If the source matrix is not at least 1x1, there are no valid indexes and hence
149 // the caller must not call this function.
150 // Given we can assume the matrix is at least 1x1 if we're being called, this calculation
151 // will return a new index which is also in range if the input was, so we won't
152 // introduce any out of bounds reads.
153 let [row, column] = reverse_indexes(
154 &[row, column],
155 &[
156 ("row", self.source.view_rows()),
157 ("column", self.source.view_columns()),
158 ],
159 &[self.rows, self.columns],
160 );
161 self.source.get_reference_unchecked(row, column)
162 }
163 }
164
165 fn data_layout(&self) -> DataLayout {
166 // There might be some specific cases where reversing maintains a linear order but
167 // in general I think reversing only some indexes is going to mean any attempt at being
168 // able to take a slice that matches up with our view_shape is gone.
169 DataLayout::Other
170 }
171}
172
173// # Safety
174//
175// The type implementing MatrixMut must implement it correctly, so by delegating to it
176// by only reversing some indexes and not introducing interior mutability, we implement
177// MatrixMut correctly as well.
178/**
179 * A MatrixReverse implements MatrixMut, with the dimension names the MatrixReverse was created
180 * with iterating in reverse order compared to the dimension names in the original source.
181 */
182unsafe impl<T, S> MatrixMut<T> for MatrixReverse<T, S>
183where
184 S: MatrixMut<T>,
185{
186 fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
187 // If the Matrix has 0 length rows or columns the Tensor reverse_indexes function
188 // would reach out of bounds as it does not need to handle this case for tensors.
189 // Since the caller can expect to be able to query a 0x0 matrix and get None for
190 // any index, we must ensure this out of bounds calculation doesn't happen.
191 if self.source.view_rows() == 0 || self.source.view_columns() == 0 {
192 return None;
193 }
194 let [row, column] = reverse_indexes(
195 &[row, column],
196 &[
197 ("row", self.source.view_rows()),
198 ("column", self.source.view_columns()),
199 ],
200 &[self.rows, self.columns],
201 );
202 self.source.try_get_reference_mut(row, column)
203 }
204
205 unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
206 unsafe {
207 // It is the caller's responsibiltiy to call this unsafe function with only valid
208 // indexes. If the source matrix is not at least 1x1, there are no valid indexes and hence
209 // the caller must not call this function.
210 // Given we can assume the matrix is at least 1x1 if we're being called, this calculation
211 // will return a new index which is also in range if the input was, so we won't
212 // introduce any out of bounds reads.
213 let [row, column] = reverse_indexes(
214 &[row, column],
215 &[
216 ("row", self.source.view_rows()),
217 ("column", self.source.view_columns()),
218 ],
219 &[self.rows, self.columns],
220 );
221 self.source.get_reference_unchecked_mut(row, column)
222 }
223 }
224}