easy_ml/matrices/views/
reverse.rs

1use crate::matrices::views::{DataLayout, MatrixMut, MatrixRef, NoInteriorMutability};
2use crate::matrices::{Column, Row};
3use crate::tensors::views::reverse_indexes;
4
5use std::marker::PhantomData;
6
7/**
8 * A view over a matrix where some or all of the rows and columns are iterated in reverse order.
9 *
10 * ```
11 * use easy_ml::matrices::Matrix;
12 * use easy_ml::matrices::views::{MatrixView, MatrixReverse, Reverse};
13 * let ab = Matrix::from(vec![
14 *     vec![ 0, 1, 2 ],
15 *     vec![ 3, 4, 5 ]
16 * ]);
17 * let reversed = ab.reverse(Reverse { rows: true, ..Default::default() });
18 * let also_reversed = MatrixView::from(
19 *     MatrixReverse::from(&ab, Reverse { rows: true, columns: false })
20 * );
21 * assert_eq!(reversed, also_reversed);
22 * assert_eq!(
23 *     reversed,
24 *     Matrix::from(vec![
25 *         vec![ 3, 4, 5 ],
26 *         vec![ 0, 1, 2 ]
27 *     ])
28 * );
29 * ```
30 */
31#[derive(Clone, Debug)]
32pub struct MatrixReverse<T, S> {
33    source: S,
34    rows: bool,
35    columns: bool,
36    _type: PhantomData<T>,
37}
38
39/**
40 * Helper struct for declaring which of `rows` and `columns` should be reversed for iteration.
41 *
42 * If a dimension is set to `false` it will iterate in its normal order. If a dimension is
43 * set to `true` the iteration order will be reversed, so the first index 0 becomes the last
44 * length-1, and the last index length-1 becomes 0
45 */
46// NB: Default impl for bool is false, which is what we want here
47#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
48pub struct Reverse {
49    pub rows: bool,
50    pub columns: bool,
51}
52
53impl<T, S> MatrixReverse<T, S>
54where
55    S: MatrixRef<T>,
56{
57    /**
58     * Creates a MatrixReverse from a source and a struct for which dimensions to reverse the
59     * order of iteration for. If either or both of rows and columns in [Reverse] are set to false
60     * the iteration order for that dimension will continue to iterate in its normal order.
61     */
62    pub fn from(source: S, reverse: Reverse) -> MatrixReverse<T, S> {
63        MatrixReverse {
64            source,
65            rows: reverse.rows,
66            columns: reverse.columns,
67            _type: PhantomData,
68        }
69    }
70
71    /**
72     * Consumes the MatrixReverse, yielding the source it was created from.
73     */
74    #[allow(dead_code)]
75    pub fn source(self) -> S {
76        self.source
77    }
78
79    /**
80     * Gives a reference to the MatrixReverse's source (in which the data is not reversed).
81     */
82    #[allow(dead_code)]
83    pub fn source_ref(&self) -> &S {
84        &self.source
85    }
86
87    /**
88     * Gives a mutable reference to the MatrixReverse's source (in which the data is not reversed).
89     */
90    #[allow(dead_code)]
91    pub fn source_ref_mut(&mut self) -> &mut S {
92        &mut self.source
93    }
94}
95
96// # Safety
97//
98// Since our source implements NoInteriorMutability correctly, so do we by delegating to it, as
99// we don't introduce any interior mutability.
100/**
101 * A MatrixReverse of a NoInteriorMutability type implements NoInteriorMutability.
102 */
103unsafe impl<T, S> NoInteriorMutability for MatrixReverse<T, S> where S: NoInteriorMutability {}
104
105// # Safety
106//
107// The type implementing MatrixRef must implement it correctly, so by delegating to it
108// by only reversing some indexes and not introducing interior mutability, we implement
109// MatrixRef correctly as well.
110/**
111 * A MatrixReverse implements MatrixRef, with the dimension names the MatrixReverse was created
112 * with iterating in reverse order compared to the dimension names in the original source.
113 */
114unsafe impl<T, S> MatrixRef<T> for MatrixReverse<T, S>
115where
116    S: MatrixRef<T>,
117{
118    fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
119        // If the Matrix has 0 length rows or columns the Tensor reverse_indexes function
120        // would reach out of bounds as it does not need to handle this case for tensors.
121        // Since the caller can expect to be able to query a 0x0 matrix and get None for
122        // any index, we must ensure this out of bounds calculation doesn't happen.
123        if self.source.view_rows() == 0 || self.source.view_columns() == 0 {
124            return None;
125        }
126        let [row, column] = reverse_indexes(
127            &[row, column],
128            &[
129                ("row", self.source.view_rows()),
130                ("column", self.source.view_columns()),
131            ],
132            &[self.rows, self.columns],
133        );
134        self.source.try_get_reference(row, column)
135    }
136
137    fn view_rows(&self) -> Row {
138        self.source.view_rows()
139    }
140
141    fn view_columns(&self) -> Column {
142        self.source.view_columns()
143    }
144
145    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
146        unsafe {
147            // It is the caller's responsibiltiy to call this unsafe function with only valid
148            // indexes. If the source matrix is not at least 1x1, there are no valid indexes and hence
149            // the caller must not call this function.
150            // Given we can assume the matrix is at least 1x1 if we're being called, this calculation
151            // will return a new index which is also in range if the input was, so we won't
152            // introduce any out of bounds reads.
153            let [row, column] = reverse_indexes(
154                &[row, column],
155                &[
156                    ("row", self.source.view_rows()),
157                    ("column", self.source.view_columns()),
158                ],
159                &[self.rows, self.columns],
160            );
161            self.source.get_reference_unchecked(row, column)
162        }
163    }
164
165    fn data_layout(&self) -> DataLayout {
166        // There might be some specific cases where reversing maintains a linear order but
167        // in general I think reversing only some indexes is going to mean any attempt at being
168        // able to take a slice that matches up with our view_shape is gone.
169        DataLayout::Other
170    }
171}
172
173// # Safety
174//
175// The type implementing MatrixMut must implement it correctly, so by delegating to it
176// by only reversing some indexes and not introducing interior mutability, we implement
177// MatrixMut correctly as well.
178/**
179 * A MatrixReverse implements MatrixMut, with the dimension names the MatrixReverse was created
180 * with iterating in reverse order compared to the dimension names in the original source.
181 */
182unsafe impl<T, S> MatrixMut<T> for MatrixReverse<T, S>
183where
184    S: MatrixMut<T>,
185{
186    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
187        // If the Matrix has 0 length rows or columns the Tensor reverse_indexes function
188        // would reach out of bounds as it does not need to handle this case for tensors.
189        // Since the caller can expect to be able to query a 0x0 matrix and get None for
190        // any index, we must ensure this out of bounds calculation doesn't happen.
191        if self.source.view_rows() == 0 || self.source.view_columns() == 0 {
192            return None;
193        }
194        let [row, column] = reverse_indexes(
195            &[row, column],
196            &[
197                ("row", self.source.view_rows()),
198                ("column", self.source.view_columns()),
199            ],
200            &[self.rows, self.columns],
201        );
202        self.source.try_get_reference_mut(row, column)
203    }
204
205    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
206        unsafe {
207            // It is the caller's responsibiltiy to call this unsafe function with only valid
208            // indexes. If the source matrix is not at least 1x1, there are no valid indexes and hence
209            // the caller must not call this function.
210            // Given we can assume the matrix is at least 1x1 if we're being called, this calculation
211            // will return a new index which is also in range if the input was, so we won't
212            // introduce any out of bounds reads.
213            let [row, column] = reverse_indexes(
214                &[row, column],
215                &[
216                    ("row", self.source.view_rows()),
217                    ("column", self.source.view_columns()),
218                ],
219                &[self.rows, self.columns],
220            );
221            self.source.get_reference_unchecked_mut(row, column)
222        }
223    }
224}