primitives/types/heap_array/
matrix.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ops::Mul,
5};
6
7use derive_more::derive::Display;
8use serde::{Deserialize, Serialize};
9use typenum::Prod;
10
11use super::HeapArray;
12use crate::{
13    errors::PrimitiveError,
14    random::{CryptoRngCore, Random},
15    types::Positive,
16};
17
18/// Indicates that the matrix is stored in column-major order (Fortran order).
19#[derive(Display, Default, Clone, Copy)]
20pub struct ColumnMajor;
21/// Indicates that the matrix is stored in row-major order (C order).
22#[derive(Display, Default, Clone, Copy)]
23pub struct RowMajor;
24
25pub type RowMajorHeapMatrix<T, M, N> = HeapMatrix<T, M, N, RowMajor>;
26
27/// A matrix with M rows and N columns on the heap that encodes its shape in the type system.
28///
29/// The matrix is stored as a contiguous memory chunk.
30#[derive(Clone, PartialEq, Eq)]
31pub struct HeapMatrix<T: Sized, M: Positive, N: Positive, O = ColumnMajor> {
32    pub(super) data: Box<[T]>,
33
34    // `fn() -> (M, N)` is used instead of `(M, N)` so `HeapMatrix<T, M, N>` doesn't need `(M, N)`
35    // to implement `Send + Sync` to be `Send + Sync` itself. This would be the case if `(M, N)`
36    // was used directly.
37    #[allow(clippy::type_complexity)]
38    pub(super) _len: PhantomData<fn() -> (M, N, O)>,
39}
40impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
41    fn new(data: Box<[T]>) -> Self {
42        Self {
43            data,
44            _len: PhantomData,
45        }
46    }
47
48    /// All matrix elements iterator in order.
49    pub fn flat_iter(&self) -> impl ExactSizeIterator<Item = &T> {
50        self.data.iter()
51    }
52
53    /// All matrix elements mutable iterator in order.
54    pub fn flat_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> {
55        self.data.iter_mut()
56    }
57
58    /// Convert matrix into an iterator over all elements in order.
59    pub fn into_flat_iter(self) -> impl ExactSizeIterator<Item = T> {
60        self.data.into_vec().into_iter()
61    }
62
63    /// Length of total number of elements in the matrix
64    pub const fn len(&self) -> usize {
65        M::USIZE * N::USIZE
66    }
67
68    /// Check if the matrix is empty
69    pub const fn is_empty(&self) -> bool {
70        self.len() == 0
71    }
72
73    /// Number of rows in the matrix
74    pub const fn rows(&self) -> usize {
75        M::USIZE
76    }
77
78    /// Number of columns in the matrix
79    pub const fn cols(&self) -> usize {
80        N::USIZE
81    }
82}
83
84impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
85    pub fn map<F, U>(self, f: F) -> HeapMatrix<U, M, N, O>
86    where
87        F: FnMut(T) -> U,
88    {
89        HeapMatrix::new(
90            self.data
91                .into_vec()
92                .into_iter()
93                .map(f)
94                .collect::<Box<[U]>>(),
95        )
96    }
97}
98
99impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive, O> HeapMatrix<T, M, N, O> {
100    /// Flatten matrix into an heap array.
101    pub fn into_flat_array(self) -> HeapArray<T, Prod<M, N>> {
102        HeapArray::new(self.data)
103    }
104
105    /// Build matrix from an heap array in column-major order
106    pub fn from_flat_array(value: HeapArray<T, Prod<M, N>>) -> Self {
107        Self::new(value.data)
108    }
109}
110
111// --------------------- Column Major (Fortran-style) ----------------------- //
112
113impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, ColumnMajor> {
114    /// Matrix column iterator
115    pub fn col_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
116        self.data.chunks_exact(M::USIZE)
117    }
118
119    /// Matrix column mutable iterator
120    pub fn col_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
121        self.data.chunks_exact_mut(M::USIZE)
122    }
123
124    /// Get a reference to an element at position (row, col)
125    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
126        (row < M::USIZE && col < N::USIZE)
127            .then(|| unsafe { self.data.get_unchecked(col * M::USIZE + row) })
128    }
129
130    /// Get a mutable reference to an element at position (row, col)
131    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
132        (row < M::USIZE && col < N::USIZE)
133            .then(|| unsafe { self.data.get_unchecked_mut(col * M::USIZE + row) })
134    }
135}
136
137impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
138    HeapMatrix<T, M, N, ColumnMajor>
139{
140    /// Build a matrix from an array of columns
141    pub fn from_cols(val: HeapArray<HeapArray<T, M>, N>) -> Self {
142        Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
143    }
144}
145
146// --------------------- Row Major (C-style) ----------------------- //
147
148impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, RowMajor> {
149    /// Matrix row iterator
150    pub fn row_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
151        self.data.chunks_exact(N::USIZE)
152    }
153
154    /// Matrix row mutable iterator
155    pub fn row_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
156        self.data.chunks_exact_mut(N::USIZE)
157    }
158
159    /// Get a reference to an element at position (row, col)
160    pub fn get(&self, row: usize, col: usize) -> Option<&T> {
161        (row < M::USIZE && col < N::USIZE)
162            .then(|| unsafe { self.data.get_unchecked(row * N::USIZE + col) })
163    }
164
165    /// Get a mutable reference to an element at position (row, col)
166    pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
167        (row < M::USIZE && col < N::USIZE)
168            .then(|| unsafe { self.data.get_unchecked_mut(row * N::USIZE + col) })
169    }
170}
171
172impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
173    HeapMatrix<T, M, N, ColumnMajor>
174{
175    /// Build a matrix from an array of rows
176    pub fn from_rows(val: HeapArray<HeapArray<T, N>, M>) -> Self {
177        Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
178    }
179}
180
181// ------------------------ Common Implementations -------------------------- //
182
183impl<T: Sized + Default, M: Positive, N: Positive, O> Default for HeapMatrix<T, M, N, O> {
184    fn default() -> Self {
185        Self::new(
186            (0..M::USIZE * N::USIZE)
187                .map(|_| T::default())
188                .collect::<Box<[T]>>(),
189        )
190    }
191}
192
193impl<T: Sized + Debug, M: Positive, N: Positive, O: Display + Default> Debug
194    for HeapMatrix<T, M, N, O>
195{
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        f.debug_struct(format!("Matrix[{}]<{}, {}>", O::default(), M::USIZE, N::USIZE).as_str())
198            .field("data", &self.data)
199            .finish()
200    }
201}
202
203impl<T: Sized + Serialize, M: Positive, N: Positive, O> Serialize for HeapMatrix<T, M, N, O> {
204    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
205        self.data.serialize(serializer)
206    }
207}
208
209impl<T: Random + Sized, M: Positive, N: Positive, O> Random for HeapMatrix<T, M, N, O> {
210    fn random(mut rng: impl CryptoRngCore) -> Self {
211        Self::new(
212            (0..M::USIZE * N::USIZE)
213                .map(|_| T::random(&mut rng))
214                .collect(),
215        )
216    }
217}
218
219impl<'de, T: Sized + Deserialize<'de>, M: Positive, N: Positive, O> Deserialize<'de>
220    for HeapMatrix<T, M, N, O>
221{
222    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
223        let data = Box::<[T]>::deserialize(deserializer)?;
224        if data.len() != M::USIZE * N::USIZE {
225            return Err(serde::de::Error::custom(format!(
226                "Expected matrix of length {}, got {}",
227                M::USIZE * N::USIZE,
228                data.len()
229            )));
230        }
231        Ok(Self::new(data))
232    }
233}
234
235impl<T: Sized, M: Positive, N: Positive, O> AsRef<[T]> for HeapMatrix<T, M, N, O> {
236    fn as_ref(&self) -> &[T] {
237        &self.data
238    }
239}
240
241impl<T: Sized, M: Positive, N: Positive, O> AsMut<[T]> for HeapMatrix<T, M, N, O> {
242    fn as_mut(&mut self) -> &mut [T] {
243        &mut self.data
244    }
245}
246
247impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Vec<T> {
248    fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
249        matrix.data.into_vec()
250    }
251}
252
253impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Vec<T>> for HeapMatrix<T, M, N, O> {
254    type Error = PrimitiveError;
255
256    fn try_from(matrix: Vec<T>) -> Result<Self, Self::Error> {
257        if matrix.len() != M::USIZE * N::USIZE {
258            return Err(PrimitiveError::InvalidSize(
259                M::USIZE * N::USIZE,
260                matrix.len(),
261            ));
262        }
263        Ok(Self::new(matrix.into_boxed_slice()))
264    }
265}
266
267impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Box<[T]> {
268    fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
269        matrix.data
270    }
271}
272
273impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Box<[T]>> for HeapMatrix<T, M, N, O> {
274    type Error = PrimitiveError;
275
276    fn try_from(matrix: Box<[T]>) -> Result<Self, Self::Error> {
277        if matrix.len() != M::USIZE * N::USIZE {
278            return Err(PrimitiveError::InvalidSize(
279                M::USIZE * N::USIZE,
280                matrix.len(),
281            ));
282        }
283        Ok(Self::new(matrix))
284    }
285}
286
287#[cfg(test)]
288pub mod tests {
289    use itertools::Itertools;
290    use typenum::{U2, U3, U4, U6};
291
292    use super::{ColumnMajor, RowMajor};
293    use crate::types::{HeapArray, HeapMatrix};
294
295    #[test]
296    fn test_flat_operations() {
297        let data: Vec<u32> = (0..6).collect();
298        let matrix: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
299
300        // flat_iter
301        assert_eq!(matrix.flat_iter().copied().collect_vec(), data);
302
303        // flat_iter_mut
304        let mut matrix3: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
305        for x in matrix3.flat_iter_mut() {
306            *x *= 2;
307        }
308        assert_eq!(matrix3.get(0, 0), Some(&0));
309        assert_eq!(matrix3.get(1, 0), Some(&2));
310
311        // into_flat_array / from_flat_array
312        let array: HeapArray<u32, U6> = matrix.into_flat_array();
313        assert_eq!(array.as_ref(), data.as_slice());
314        let matrix4: HeapMatrix<u32, U2, U3> = HeapMatrix::from_flat_array(array);
315        assert_eq!(matrix4.as_ref(), data.as_slice());
316    }
317
318    #[test]
319    fn test_column_first_operations() {
320        // 3x2 matrix in column-major: col0=[0,1,2], col1=[3,4,5]
321        let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
322        let mut matrix: HeapMatrix<u32, U3, U2, ColumnMajor> =
323            HeapMatrix::new(data.into_boxed_slice());
324
325        // get - column-major layout
326        assert_eq!(matrix.get(0, 0), Some(&0));
327        assert_eq!(matrix.get(1, 0), Some(&1));
328        assert_eq!(matrix.get(2, 0), Some(&2));
329        assert_eq!(matrix.get(0, 1), Some(&3));
330        assert_eq!(matrix.get(1, 1), Some(&4));
331        assert_eq!(matrix.get(2, 1), Some(&5));
332        assert_eq!(matrix.get(3, 0), None);
333        assert_eq!(matrix.get(0, 2), None);
334
335        // get_mut
336        *matrix.get_mut(1, 1).unwrap() = 42;
337        assert_eq!(matrix.get(1, 1), Some(&42));
338
339        // col_iter
340        let data2: Vec<u32> = (0..12).collect();
341        let matrix2: HeapMatrix<u32, U4, U3, ColumnMajor> =
342            HeapMatrix::new(data2.into_boxed_slice());
343        let cols: Vec<Vec<u32>> = matrix2.col_iter().map(|col| col.to_vec()).collect();
344        assert_eq!(cols.len(), 3);
345        assert_eq!(cols[0], vec![0, 1, 2, 3]);
346        assert_eq!(cols[1], vec![4, 5, 6, 7]);
347        assert_eq!(cols[2], vec![8, 9, 10, 11]);
348
349        // col_iter_mut
350        let data3: Vec<u32> = (0..12).collect();
351        let mut matrix3: HeapMatrix<u32, U4, U3, ColumnMajor> =
352            HeapMatrix::new(data3.into_boxed_slice());
353        for col in matrix3.col_iter_mut() {
354            col[0] = 99;
355        }
356        assert_eq!(matrix3.get(0, 0), Some(&99));
357        assert_eq!(matrix3.get(0, 1), Some(&99));
358        assert_eq!(matrix3.get(0, 2), Some(&99));
359    }
360
361    #[test]
362    fn test_row_first_operations() {
363        // 3x2 matrix in row-major: row0=[0,1], row1=[2,3], row2=[4,5]
364        let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
365        let mut matrix: HeapMatrix<u32, U3, U2, RowMajor> =
366            HeapMatrix::new(data.into_boxed_slice());
367
368        // get - row-major layout
369        assert_eq!(matrix.get(0, 0), Some(&0));
370        assert_eq!(matrix.get(0, 1), Some(&1));
371        assert_eq!(matrix.get(1, 0), Some(&2));
372        assert_eq!(matrix.get(1, 1), Some(&3));
373        assert_eq!(matrix.get(2, 0), Some(&4));
374        assert_eq!(matrix.get(2, 1), Some(&5));
375        assert_eq!(matrix.get(3, 0), None);
376        assert_eq!(matrix.get(0, 2), None);
377
378        // get_mut
379        *matrix.get_mut(1, 1).unwrap() = 42;
380        assert_eq!(matrix.get(1, 1), Some(&42));
381
382        // row_iter
383        let data2: Vec<u32> = (0..12).collect();
384        let matrix2: HeapMatrix<u32, U3, U4, RowMajor> = HeapMatrix::new(data2.into_boxed_slice());
385        let rows: Vec<Vec<u32>> = matrix2.row_iter().map(|row| row.to_vec()).collect();
386        assert_eq!(rows.len(), 3);
387        assert_eq!(rows[0], vec![0, 1, 2, 3]);
388        assert_eq!(rows[1], vec![4, 5, 6, 7]);
389        assert_eq!(rows[2], vec![8, 9, 10, 11]);
390
391        // row_iter_mut
392        let data3: Vec<u32> = (0..12).collect();
393        let mut matrix3: HeapMatrix<u32, U3, U4, RowMajor> =
394            HeapMatrix::new(data3.into_boxed_slice());
395        for row in matrix3.row_iter_mut() {
396            row[0] = 99;
397        }
398        assert_eq!(matrix3.get(0, 0), Some(&99));
399        assert_eq!(matrix3.get(1, 0), Some(&99));
400        assert_eq!(matrix3.get(2, 0), Some(&99));
401    }
402
403    #[test]
404    fn test_try_from_vec() {
405        // Success case
406        let data: Vec<u32> = (0..12).collect();
407        let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
408        assert!(result.is_ok());
409
410        // Failure case - wrong size
411        let data: Vec<u32> = (0..10).collect();
412        let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
413        assert!(result.is_err());
414    }
415}