primitives/types/heap_array/
matrix.rs

1use std::{fmt::Debug, marker::PhantomData};
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5use typenum::Prod;
6
7use super::HeapArray;
8use crate::{types::Positive, utils::IntoExactSizeIterator};
9
10/// A matrix with M rows and N columns on the heap that encodes its shape in the type system.
11///
12/// The matrix is stored as a contiguous memory chunk in column-major order.
13#[derive(Clone, PartialEq, Eq)]
14pub struct HeapMatrix<T: Sized, M: Positive, N: Positive> {
15    pub(super) data: Box<[T]>,
16
17    // `fn() -> (M, N)` is used instead of `(M, N)` so `HeapMatrix<T, M, N>` doesn't need `(M, N)`
18    // to implement `Send + Sync` to be `Send + Sync` itself. This would be the case if `(M, N)`
19    // was used directly.
20    pub(super) _len: PhantomData<fn() -> (M, N)>,
21}
22
23impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N> {
24    fn new(data: Box<[T]>) -> Self {
25        Self {
26            data,
27            _len: PhantomData,
28        }
29    }
30
31    /// All matrix elements iterator in column-major order
32    pub fn flat_iter(&self) -> impl ExactSizeIterator<Item = &T> {
33        self.data.iter()
34    }
35
36    /// All matrix elements mutable iterator in column-major order
37    pub fn flat_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> {
38        self.data.iter_mut()
39    }
40
41    /// Transform all matrix elements into owned-values iterator in column-major order
42    pub fn into_flat_iter(self) -> impl ExactSizeIterator<Item = T> {
43        self.data.into_vec().into_iter()
44    }
45
46    /// Matrix column iterator
47    pub fn col_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
48        self.data.chunks_exact(M::USIZE)
49    }
50
51    /// Matrix column mutable iterator
52    pub fn col_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
53        self.data.chunks_exact_mut(M::USIZE)
54    }
55
56    /// Build a matrix from a elements in column-major order
57    pub fn from_flat_iter(it: impl IntoExactSizeIterator<Item = T>) -> Self {
58        let it = it.into_iter();
59        assert_eq!(it.len(), M::USIZE * N::USIZE);
60        Self::new(it.collect_vec().into_boxed_slice())
61    }
62}
63
64impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N>
65where
66    M: std::ops::Mul<N, Output: Positive>,
67{
68    /// Flatten matrix into an heap array in column-major order
69    pub fn into_flat_array(self) -> HeapArray<T, Prod<M, N>> {
70        HeapArray {
71            data: self.data,
72            _len: PhantomData,
73        }
74    }
75
76    /// Build matrix from an heap array in column-major order
77    pub fn from_flat_array(value: HeapArray<T, Prod<M, N>>) -> Self {
78        Self::new(value.data)
79    }
80
81    /// Build a matrix from an array of columns
82    pub fn from_cols_array(val: HeapArray<HeapArray<T, M>, N>) -> Self {
83        Self::new(val.into_iter().flatten().collect())
84    }
85}
86
87impl<T: Sized + Default, M: Positive, N: Positive> Default for HeapMatrix<T, M, N> {
88    fn default() -> Self {
89        Self::new((0..M::USIZE * N::USIZE).map(|_| T::default()).collect())
90    }
91}
92
93impl<T: Sized + Debug, M: Positive, N: Positive> Debug for HeapMatrix<T, M, N> {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        f.debug_struct(format!("HeapMatrix<{}, {}>", M::USIZE, N::USIZE).as_str())
96            .field("data", &self.data)
97            .finish()
98    }
99}
100
101impl<T: Sized + Serialize, M: Positive, N: Positive> Serialize for HeapMatrix<T, M, N> {
102    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
103        self.data.serialize(serializer)
104    }
105}
106
107impl<'de, T: Sized + Deserialize<'de>, M: Positive, N: Positive> Deserialize<'de>
108    for HeapMatrix<T, M, N>
109{
110    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
111        let data = Box::<[T]>::deserialize(deserializer)?;
112
113        if data.len() != M::USIZE * N::USIZE {
114            return Err(serde::de::Error::custom(format!(
115                "Expected array of length {}, got {}",
116                M::USIZE,
117                data.len()
118            )));
119        }
120
121        Ok(Self {
122            data,
123            _len: PhantomData,
124        })
125    }
126}
127
128#[cfg(test)]
129pub mod tests {
130    use itertools::Itertools;
131    use typenum::{U15, U3, U5};
132
133    use crate::types::{heap_array::matrix::HeapMatrix, HeapArray};
134
135    #[test]
136    fn test_default() {
137        let mut m = HeapMatrix::<usize, U3, U5>::default();
138        assert_eq!(m.col_iter().len(), 5);
139        for col in m.col_iter() {
140            assert_eq!(col, &[0, 0, 0]);
141        }
142
143        // Change second column
144        {
145            let mut it = m.col_iter_mut();
146            let _ = it.next();
147            let c = it.next().unwrap();
148            c[0] = 3;
149            c[1] = 2;
150            c[2] = 1;
151        }
152
153        assert_eq!(
154            m.into_flat_iter().collect_vec(),
155            vec![0, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
156        );
157    }
158
159    #[test]
160    fn test_into_from_flat_array() {
161        let a = HeapArray::<usize, U15>::from_fn(|k| k);
162        let m: HeapMatrix<usize, U3, U5> = HeapMatrix::from_flat_iter(a.iter().copied());
163        let a1 = m.into_flat_array();
164        assert_eq!(a, a1);
165    }
166
167    #[test]
168    fn test_from_flat_array() {
169        let a = HeapArray::<usize, U15>::from_fn(|k| k);
170        let m: HeapMatrix<usize, U3, U5> = HeapMatrix::from_flat_array(a);
171        let mut it = m.col_iter();
172        assert_eq!(it.next().unwrap(), &[0, 1, 2]);
173        assert_eq!(it.next().unwrap(), &[3, 4, 5]);
174        assert_eq!(it.next().unwrap(), &[6, 7, 8]);
175        assert_eq!(it.next().unwrap(), &[9, 10, 11]);
176        assert_eq!(it.next().unwrap(), &[12, 13, 14]);
177    }
178}