easy_ml/interop/mod.rs
1/*!
2 * Interopability APIs between
3 * [Matrix](crate::matrices::Matrix)/[MatrixView](crate::matrices::views::MatrixView) and
4 * [Tensor](crate::tensors::Tensor)/[TensorView](crate::tensors::views::TensorView).
5 */
6
7use crate::matrices::views::DataLayout as MDataLayout;
8use crate::matrices::views::{MatrixMut, MatrixRef, NoInteriorMutability};
9use crate::matrices::{Column, Row};
10use crate::tensors::views::DataLayout as TDataLayout;
11use crate::tensors::views::{TensorMut, TensorRef};
12use crate::tensors::{Dimension, InvalidShapeError};
13
14use std::marker::PhantomData;
15
16/**
17 * A wrapper around a Matrix type that implements TensorRef and can thus be used in a
18 * [TensorView](crate::tensors::views::TensorView)
19 *
20 * ```
21 * use easy_ml::matrices::Matrix;
22 * use easy_ml::tensors::views::TensorView;
23 * use easy_ml::interop::TensorRefMatrix;
24 * let matrix = Matrix::from(vec![
25 * vec![ 1, 3, 5, 7 ],
26 * vec![ 2, 4, 6, 8 ]
27 * ]);
28 * // We can always unwrap here because we know a 2x4 matrix is a valid input
29 * let tensor_view = TensorView::from(TensorRefMatrix::from(&matrix).unwrap());
30 * assert_eq!(
31 * matrix.row_iter(1).eq(tensor_view.select([("row", 1)]).iter()),
32 * true
33 * );
34 * ```
35 */
36#[derive(Clone, Debug)]
37pub struct TensorRefMatrix<T, S, N> {
38 source: S,
39 names: N,
40 _type: PhantomData<T>,
41}
42
43/**
44 * The first and second dimension name a Matrix type wrapped in a [TensorRefMatrix] will report
45 * on its view shape. If you don't care what the dimension names are, [RowAndColumn] can be used
46 * which will hardcode the dimension names to "row" and "column" respectively.
47 */
48pub trait DimensionNames {
49 fn names(&self) -> [Dimension; 2];
50}
51
52/**
53 * A zero size DimensionNames type that always returns `["row", "column"]`.
54 */
55#[derive(Clone, Debug)]
56pub struct RowAndColumn;
57
58impl DimensionNames for RowAndColumn {
59 fn names(&self) -> [Dimension; 2] {
60 ["row", "column"]
61 }
62}
63
64/**
65 * Any array of two dimension names will implement DimensionNames returning those names in the
66 * same order.
67 */
68impl DimensionNames for [Dimension; 2] {
69 fn names(&self) -> [Dimension; 2] {
70 *self
71 }
72}
73
74impl<T, S> TensorRefMatrix<T, S, RowAndColumn>
75where
76 S: MatrixRef<T> + NoInteriorMutability,
77{
78 /**
79 * Creates a TensorRefMatrix wrapping a MatrixRef type and defaulting the dimension names
80 * to "row" and "column" respectively.
81 *
82 * Result::Err is returned if the matrix dimension lengths are not at least 1x1. This is
83 * allowed for MatrixRef but not for TensorRef and hence cannot be converted.
84 */
85 pub fn from(source: S) -> Result<TensorRefMatrix<T, S, RowAndColumn>, InvalidShapeError<2>> {
86 TensorRefMatrix::with_names(source, RowAndColumn)
87 }
88}
89
90impl<T, S, N> TensorRefMatrix<T, S, N>
91where
92 S: MatrixRef<T> + NoInteriorMutability,
93 N: DimensionNames,
94{
95 /**
96 * Creates a TensorRefMatrix wrapping a MatrixRef type and provided dimension names.
97 *
98 * Result::Err is returned if the provided dimension names are not unique, or the matrix
99 * dimension lengths are not at least 1x1. This is allowed for MatrixRef but not for
100 * TensorRef and hence cannot be converted.
101 *
102 * ```
103 * use easy_ml::matrices::Matrix;
104 * use easy_ml::tensors::views::TensorRef;
105 * use easy_ml::interop::TensorRefMatrix;
106 * assert_eq!(
107 * // We can always unwrap here because we know the input is 1x1 and "x" and "y" are unique
108 * // dimension names
109 * TensorRefMatrix::with_names(Matrix::from_scalar(1.0), ["x", "y"]).unwrap().view_shape(),
110 * [("x", 1), ("y", 1)]
111 * );
112 * ```
113 */
114 pub fn with_names(
115 source: S,
116 names: N,
117 ) -> Result<TensorRefMatrix<T, S, N>, InvalidShapeError<2>> {
118 let dimensions = names.names();
119 let shape = InvalidShapeError::new([
120 (dimensions[0], source.view_rows()),
121 (dimensions[1], source.view_columns()),
122 ]);
123 if shape.is_valid() {
124 Ok(TensorRefMatrix {
125 source,
126 names,
127 _type: PhantomData,
128 })
129 } else {
130 Err(shape)
131 }
132 }
133}
134
135// # Safety
136// The contract of MatrixRef<T> + NoInteriorMutability is essentially the ungeneralised version of
137// TensorRef, so we're good on no interior mutability and valid indexing behaviour. The
138// TensorRef only requirements are that "all dimension names in the view_shape must be unique"
139// and "all dimension lengths in the view_shape must be non zero". We enforce both of these during
140// construction, and the NoInteriorMutability bounds ensures these invariants remain valid.
141unsafe impl<T, S, N> TensorRef<T, 2> for TensorRefMatrix<T, S, N>
142where
143 S: MatrixRef<T> + NoInteriorMutability,
144 N: DimensionNames,
145{
146 fn get_reference(&self, indexes: [usize; 2]) -> Option<&T> {
147 self.source.try_get_reference(indexes[0], indexes[1])
148 }
149
150 fn view_shape(&self) -> [(Dimension, usize); 2] {
151 let (rows, columns) = (self.source.view_rows(), self.source.view_columns());
152 let [row, column] = self.names.names();
153 [(row, rows), (column, columns)]
154 }
155
156 unsafe fn get_reference_unchecked(&self, indexes: [usize; 2]) -> &T {
157 unsafe { self.source.get_reference_unchecked(indexes[0], indexes[1]) }
158 }
159
160 fn data_layout(&self) -> TDataLayout<2> {
161 let [rows_dimension, columns_dimension] = self.names.names();
162 // Row major and column major are the less generalised versions of
163 // a linear data layout. Since our view shape is hardcoded here to rows
164 // then columns, a row major matrix means the most significant dimension
165 // is the first, and the least significant dimension is the second. Similarly
166 // a column major matrix means the opposite.
167 match self.source.data_layout() {
168 MDataLayout::RowMajor => TDataLayout::Linear([rows_dimension, columns_dimension]),
169 MDataLayout::ColumnMajor => TDataLayout::Linear([columns_dimension, rows_dimension]),
170 MDataLayout::Other => TDataLayout::Other,
171 }
172 }
173}
174
175// # Safety
176// The contract of MatrixMut<T> + NoInteriorMutability is essentially the ungeneralised version of
177// TensorMut, so we're good on no interior mutability and valid indexing behaviour. The
178// TensorMut only requirements are that "all dimension names in the view_shape must be unique"
179// and "all dimension lengths in the view_shape must be non zero". We enforce both of these during
180// construction, and the NoInteriorMutability bounds ensures these invariants remain valid.
181unsafe impl<T, S, N> TensorMut<T, 2> for TensorRefMatrix<T, S, N>
182where
183 S: MatrixMut<T> + NoInteriorMutability,
184 N: DimensionNames,
185{
186 fn get_reference_mut(&mut self, indexes: [usize; 2]) -> Option<&mut T> {
187 self.source.try_get_reference_mut(indexes[0], indexes[1])
188 }
189
190 unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; 2]) -> &mut T {
191 unsafe {
192 self.source
193 .get_reference_unchecked_mut(indexes[0], indexes[1])
194 }
195 }
196}
197
198/**
199 * A wrapper around a Tensor<_, 2> type that implements MatrixRef and can thus be used in a
200 * [MatrixView](crate::matrices::views::MatrixView)
201 *
202 * ```
203 * use easy_ml::tensors::Tensor;
204 * use easy_ml::matrices::views::MatrixView;
205 * use easy_ml::interop::MatrixRefTensor;
206 * let tensor = Tensor::from([("row", 2), ("column", 4)], vec![
207 * 1, 3, 5, 7,
208 * 2, 4, 6, 8
209 * ]);
210 * let matrix_view = MatrixView::from(MatrixRefTensor::from(&tensor));
211 * assert_eq!(
212 * matrix_view.row_iter(1).eq(tensor.select([("row", 1)]).iter()),
213 * true
214 * );
215 * ```
216 */
217pub struct MatrixRefTensor<T, S> {
218 source: S,
219 _type: PhantomData<T>,
220}
221
222impl<T, S> MatrixRefTensor<T, S>
223where
224 S: TensorRef<T, 2>,
225{
226 /**
227 * Creates a MatrixRefTensor wrapping a TensorRef type and stripping the dimension names.
228 *
229 * The first dimension in the TensorRef type becomes the rows, and the second dimension the
230 * columns. If your tensor is the other way around,
231 * [reorder it first](crate::tensors::indexing::TensorAccess).
232 */
233 pub fn from(source: S) -> MatrixRefTensor<T, S> {
234 MatrixRefTensor {
235 source,
236 _type: PhantomData,
237 }
238 }
239}
240
241// # Safety
242// The contract of TensorRef<T, 2> covers everything the compiler can't check for MatrixRef<T>
243// so if we just delegate to the tensor source and hide the dimension names, the index based API
244// meets every requirement by default.
245unsafe impl<T, S> MatrixRef<T> for MatrixRefTensor<T, S>
246where
247 S: TensorRef<T, 2>,
248{
249 fn try_get_reference(&self, row: Row, column: Column) -> Option<&T> {
250 self.source.get_reference([row, column])
251 }
252
253 fn view_rows(&self) -> Row {
254 self.source.view_shape()[0].1
255 }
256
257 fn view_columns(&self) -> Column {
258 self.source.view_shape()[1].1
259 }
260
261 unsafe fn get_reference_unchecked(&self, row: Row, column: Column) -> &T {
262 unsafe { self.source.get_reference_unchecked([row, column]) }
263 }
264
265 fn data_layout(&self) -> MDataLayout {
266 let rows_dimension = self.source.view_shape()[0].0;
267 let columns_dimension = self.source.view_shape()[1].0;
268 // Row major and column major are the less generalised versions of
269 // a linear data layout. Since our view shape is always interpreted here as rows
270 // then columns, a row major matrix means the most significant dimension
271 // is the first, and the least significant dimension is the second. Similarly
272 // a column major matrix means the opposite.
273 let data_layout = self.source.data_layout();
274 if data_layout == TDataLayout::Linear([rows_dimension, columns_dimension]) {
275 MDataLayout::RowMajor
276 } else if data_layout == TDataLayout::Linear([columns_dimension, rows_dimension]) {
277 MDataLayout::ColumnMajor
278 } else {
279 match self.source.data_layout() {
280 TDataLayout::NonLinear => MDataLayout::Other,
281 TDataLayout::Other => MDataLayout::Other,
282 // This branch should never happen as no other Linear layouts are valid according
283 // to the docs the source implementation must follow but we need to keep the Rust
284 // compiler happy
285 TDataLayout::Linear([_, _]) => MDataLayout::Other,
286 }
287 }
288 }
289}
290
291// # Safety
292// The contract of TensorMut<T, 2> covers everything the compiler can't check for MatrixMut<T>
293// so if we just delegate to the tensor source and hide the dimension names, the index based API
294// meets every requirement by default.
295unsafe impl<T, S> MatrixMut<T> for MatrixRefTensor<T, S>
296where
297 S: TensorMut<T, 2>,
298{
299 fn try_get_reference_mut(&mut self, row: Row, column: Column) -> Option<&mut T> {
300 self.source.get_reference_mut([row, column])
301 }
302
303 unsafe fn get_reference_unchecked_mut(&mut self, row: Row, column: Column) -> &mut T {
304 unsafe { self.source.get_reference_unchecked_mut([row, column]) }
305 }
306}
307
308// # Safety
309// No interior mutability is implied by TensorRef
310unsafe impl<T, S> NoInteriorMutability for MatrixRefTensor<T, S> where S: TensorRef<T, 2> {}