easy_ml/interop/
mod.rs

1/*!
2 * Interopability APIs between
3 * [Matrix](crate::matrices::Matrix)/[MatrixView](crate::matrices::views::MatrixView) and
4 * [Tensor](crate::tensors::Tensor)/[TensorView](crate::tensors::views::TensorView).
5 */
6
7use crate::matrices::views::DataLayout as MDataLayout;
8use crate::matrices::views::{MatrixMut, MatrixRef, NoInteriorMutability};
9use crate::matrices::{Column, Row};
10use crate::tensors::views::DataLayout as TDataLayout;
11use crate::tensors::views::{TensorMut, TensorRef};
12use crate::tensors::{Dimension, InvalidShapeError};
13
14use std::marker::PhantomData;
15
16/**
17 * A wrapper around a Matrix type that implements TensorRef and can thus be used in a
18 * [TensorView](crate::tensors::views::TensorView)
19 *
20 * ```
21 * use easy_ml::matrices::Matrix;
22 * use easy_ml::tensors::views::TensorView;
23 * use easy_ml::interop::TensorRefMatrix;
24 * let matrix = Matrix::from(vec![
25 *     vec![ 1, 3, 5, 7 ],
26 *     vec![ 2, 4, 6, 8 ]
27 * ]);
28 * // We can always unwrap here because we know a 2x4 matrix is a valid input
29 * let tensor_view = TensorView::from(TensorRefMatrix::from(&matrix).unwrap());
30 * assert_eq!(
31 *     matrix.row_iter(1).eq(tensor_view.select([("row", 1)]).iter()),
32 *     true
33 * );
34 * ```
35 */
36#[derive(Clone, Debug)]
37pub struct TensorRefMatrix<T, S, N> {
38    source: S,
39    names: N,
40    _type: PhantomData<T>,
41}
42
43/**
44 * The first and second dimension name a Matrix type wrapped in a [TensorRefMatrix] will report
45 * on its view shape. If you don't care what the dimension names are, [RowAndColumn] can be used
46 * which will hardcode the dimension names to "row" and "column" respectively.
47 */
48pub trait DimensionNames {
49    fn names(&self) -> [Dimension; 2];
50}
51
52/**
53 * A zero size DimensionNames type that always returns `["row", "column"]`.
54 */
55#[derive(Clone, Debug)]
56pub struct RowAndColumn;
57
58impl DimensionNames for RowAndColumn {
59    fn names(&self) -> [Dimension; 2] {
60        ["row", "column"]
61    }
62}
63
64/**
65 * Any array of two dimension names will implement DimensionNames returning those names in the
66 * same order.
67 */
68impl DimensionNames for [Dimension; 2] {
69    fn names(&self) -> [Dimension; 2] {
70        *self
71    }
72}
73
74impl<T, S> TensorRefMatrix<T, S, RowAndColumn>
75where
76    S: MatrixRef<T> + NoInteriorMutability,
77{
78    /**
79     * Creates a TensorRefMatrix wrapping a MatrixRef type and defaulting the dimension names
80     * to "row" and "column" respectively.
81     *
82     * Result::Err is returned if the matrix dimension lengths are not at least 1x1. This is
83     * allowed for MatrixRef but not for TensorRef and hence cannot be converted.
84     */
85    pub fn from(source: S) -> Result<TensorRefMatrix<T, S, RowAndColumn>, InvalidShapeError<2>> {
86        TensorRefMatrix::with_names(source, RowAndColumn)
87    }
88}
89
90impl<T, S, N> TensorRefMatrix<T, S, N>
91where
92    S: MatrixRef<T> + NoInteriorMutability,
93    N: DimensionNames,
94{
95    /**
96     * Creates a TensorRefMatrix wrapping a MatrixRef type and provided dimension names.
97     *
98     * Result::Err is returned if the provided dimension names are not unique, or the matrix
99     * dimension lengths are not at least 1x1. This is allowed for MatrixRef but not for
100     * TensorRef and hence cannot be converted.
101     *
102     * ```
103     * use easy_ml::matrices::Matrix;
104     * use easy_ml::tensors::views::TensorRef;
105     * use easy_ml::interop::TensorRefMatrix;
106     * assert_eq!(
107     *     // We can always unwrap here because we know the input is 1x1 and "x" and "y" are unique
108     *     // dimension names
109     *     TensorRefMatrix::with_names(Matrix::from_scalar(1.0), ["x", "y"]).unwrap().view_shape(),
110     *     [("x", 1), ("y", 1)]
111     * );
112     * ```
113     */
114    pub fn with_names(
115        source: S,
116        names: N,
117    ) -> Result<TensorRefMatrix<T, S, N>, InvalidShapeError<2>> {
118        let dimensions = names.names();
119        let shape = InvalidShapeError::new([
120            (dimensions[0], source.view_rows()),
121            (dimensions[1], source.view_columns()),
122        ]);
123        if shape.is_valid() {
124            Ok(TensorRefMatrix {
125                source,
126                names,
127                _type: PhantomData,
128            })
129        } else {
130            Err(shape)
131        }
132    }
133}
134
135// # Safety
136// The contract of MatrixRef<T> + NoInteriorMutability is essentially the ungeneralised version of
137// TensorRef, so we're good on no interior mutability and valid indexing behaviour. The
138// TensorRef only requirements are that "all dimension names in the view_shape must be unique"
139// and "all dimension lengths in the view_shape must be non zero". We enforce both of these during
140// construction, and the NoInteriorMutability bounds ensures these invariants remain valid.
141unsafe impl<T, S, N> TensorRef<T, 2> for TensorRefMatrix<T, S, N>
142where
143    S: MatrixRef<T> + NoInteriorMutability,
144    N: DimensionNames,
145{
146    fn get_reference(&self, indexes: [usize; 2]) -> Option<&T> {
147        self.source.try_get_reference(indexes[0], indexes[1])
148    }
149
150    fn view_shape(&self) -> [(Dimension, usize); 2] {
151        let (rows, columns) = (self.source.view_rows(), self.source.view_columns());
152        let [row, column] = self.names.names();
153        [(row, rows), (column, columns)]
154    }
155
156    unsafe fn get_reference_unchecked(&self, indexes: [usize; 2]) -> &T {
157        unsafe { self.source.get_reference_unchecked(indexes[0], indexes[1]) }
158    }
159
160    fn data_layout(&self) -> TDataLayout<2> {
161        let [rows_dimension, columns_dimension] = self.names.names();
162        // Row major and column major are the less generalised versions of
163        // a linear data layout. Since our view shape is hardcoded here to rows
164        // then columns, a row major matrix means the most significant dimension
165        // is the first, and the least significant dimension is the second. Similarly
166        // a column major matrix means the opposite.
167        match self.source.data_layout() {
168            MDataLayout::RowMajor => TDataLayout::Linear([rows_dimension, columns_dimension]),
169            MDataLayout::ColumnMajor => TDataLayout::Linear([columns_dimension, rows_dimension]),
170            MDataLayout::Other => TDataLayout::Other,
171        }
172    }
173}
174
175// # Safety
176// The contract of MatrixMut<T> + NoInteriorMutability is essentially the ungeneralised version of
177// TensorMut, so we're good on no interior mutability and valid indexing behaviour. The
178// TensorMut only requirements are that "all dimension names in the view_shape must be unique"
179// and "all dimension lengths in the view_shape must be non zero". We enforce both of these during
180// construction, and the NoInteriorMutability bounds ensures these invariants remain valid.
181unsafe impl<T, S, N> TensorMut<T, 2> for TensorRefMatrix<T, S, N>
182where
183    S: MatrixMut<T> + NoInteriorMutability,
184    N: DimensionNames,
185{
186    fn get_reference_mut(&mut self, indexes: [usize; 2]) -> Option<&mut T> {
187        self.source.try_get_reference_mut(indexes[0], indexes[1])
188    }
189
190    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; 2]) -> &mut T {
191        unsafe {
192            self.source
193                .get_reference_unchecked_mut(indexes[0], indexes[1])
194        }
195    }
196}
197
198/**
199 * A wrapper around a Tensor<_, 2> type that implements MatrixRef and can thus be used in a
200 * [MatrixView](crate::matrices::views::MatrixView)
201 *
202 * ```
203 * use easy_ml::tensors::Tensor;
204 * use easy_ml::matrices::views::MatrixView;
205 * use easy_ml::interop::MatrixRefTensor;
206 * let tensor = Tensor::from([("row", 2), ("column", 4)], vec![
207 *     1, 3, 5, 7,
208 *     2, 4, 6, 8
209 * ]);
210 * let matrix_view = MatrixView::from(MatrixRefTensor::from(&tensor));
211 * assert_eq!(
212 *     matrix_view.row_iter(1).eq(tensor.select([("row", 1)]).iter()),
213 *     true
214 * );
215 * ```
216 */
217pub struct MatrixRefTensor<T, S> {
218    source: S,
219    _type: PhantomData<T>,
220}
221
222impl<T, S> MatrixRefTensor<T, S>
223where
224    S: TensorRef<T, 2>,
225{
226    /**
227     * Creates a MatrixRefTensor wrapping a TensorRef type and stripping the dimension names.
228     *
229     * The first dimension in the TensorRef type becomes the rows, and the second dimension the
230     * columns. If your tensor is the other way around,
231     * [reorder it first](crate::tensors::indexing::TensorAccess).
232     */
233    pub fn from(source: S) -> MatrixRefTensor<T, S> {
234        MatrixRefTensor {
235            source,
236            _type: PhantomData,
237        }
238    }
239}
240
241// # Safety
242// The contract of TensorRef<T, 2> covers everything the compiler can't check for MatrixRef<T>
243// so if we just delegate to the tensor source and hide the dimension names, the index based API
244// meets every requirement by default.
245unsafe impl<T, S> MatrixRef<T> for MatrixRefTensor<T, S>
246where
247    S: TensorRef<T, 2>,
248{
249    fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
250        self.source.get_reference([row, column])
251    }
252
253    fn view_rows(&self) -> Row {
254        self.source.view_shape()[0].1
255    }
256
257    fn view_columns(&self) -> Column {
258        self.source.view_shape()[1].1
259    }
260
261    unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
262        unsafe { self.source.get_reference_unchecked([row, column]) }
263    }
264
265    fn data_layout(&self) -> MDataLayout {
266        let rows_dimension = self.source.view_shape()[0].0;
267        let columns_dimension = self.source.view_shape()[1].0;
268        // Row major and column major are the less generalised versions of
269        // a linear data layout. Since our view shape is always interpreted here as rows
270        // then columns, a row major matrix means the most significant dimension
271        // is the first, and the least significant dimension is the second. Similarly
272        // a column major matrix means the opposite.
273        let data_layout = self.source.data_layout();
274        if data_layout == TDataLayout::Linear([rows_dimension, columns_dimension]) {
275            MDataLayout::RowMajor
276        } else if data_layout == TDataLayout::Linear([columns_dimension, rows_dimension]) {
277            MDataLayout::ColumnMajor
278        } else {
279            match self.source.data_layout() {
280                TDataLayout::NonLinear => MDataLayout::Other,
281                TDataLayout::Other => MDataLayout::Other,
282                // This branch should never happen as no other Linear layouts are valid according
283                // to the docs the source implementation must follow but we need to keep the Rust
284                // compiler happy
285                TDataLayout::Linear([_, _]) => MDataLayout::Other,
286            }
287        }
288    }
289}
290
291// # Safety
292// The contract of TensorMut<T, 2> covers everything the compiler can't check for MatrixMut<T>
293// so if we just delegate to the tensor source and hide the dimension names, the index based API
294// meets every requirement by default.
295unsafe impl<T, S> MatrixMut<T> for MatrixRefTensor<T, S>
296where
297    S: TensorMut<T, 2>,
298{
299    fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
300        self.source.get_reference_mut([row, column])
301    }
302
303    unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
304        unsafe { self.source.get_reference_unchecked_mut([row, column]) }
305    }
306}
307
308// # Safety
309// No interior mutability is implied by TensorRef
310unsafe impl<T, S> NoInteriorMutability for MatrixRefTensor<T, S> where S: TensorRef<T, 2> {}