primitives/types/heap_array/
matrix.rs1use std::{fmt::Debug, marker::PhantomData};
2
3use itertools::Itertools;
4use serde::{Deserialize, Serialize};
5use typenum::Prod;
6
7use super::HeapArray;
8use crate::{types::Positive, utils::IntoExactSizeIterator};
9
10#[derive(Clone, PartialEq, Eq)]
14pub struct HeapMatrix<T: Sized, M: Positive, N: Positive> {
15 pub(super) data: Box<[T]>,
16
17 pub(super) _len: PhantomData<fn() -> (M, N)>,
21}
22
23impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N> {
24 fn new(data: Box<[T]>) -> Self {
25 Self {
26 data,
27 _len: PhantomData,
28 }
29 }
30
31 pub fn flat_iter(&self) -> impl ExactSizeIterator<Item = &T> {
33 self.data.iter()
34 }
35
36 pub fn flat_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> {
38 self.data.iter_mut()
39 }
40
41 pub fn into_flat_iter(self) -> impl ExactSizeIterator<Item = T> {
43 self.data.into_vec().into_iter()
44 }
45
46 pub fn col_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
48 self.data.chunks_exact(M::USIZE)
49 }
50
51 pub fn col_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
53 self.data.chunks_exact_mut(M::USIZE)
54 }
55
56 pub fn from_flat_iter(it: impl IntoExactSizeIterator<Item = T>) -> Self {
58 let it = it.into_iter();
59 assert_eq!(it.len(), M::USIZE * N::USIZE);
60 Self::new(it.collect_vec().into_boxed_slice())
61 }
62}
63
64impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N>
65where
66 M: std::ops::Mul<N, Output: Positive>,
67{
68 pub fn into_flat_array(self) -> HeapArray<T, Prod<M, N>> {
70 HeapArray {
71 data: self.data,
72 _len: PhantomData,
73 }
74 }
75
76 pub fn from_flat_array(value: HeapArray<T, Prod<M, N>>) -> Self {
78 Self::new(value.data)
79 }
80
81 pub fn from_cols_array(val: HeapArray<HeapArray<T, M>, N>) -> Self {
83 Self::new(val.into_iter().flatten().collect())
84 }
85}
86
87impl<T: Sized + Default, M: Positive, N: Positive> Default for HeapMatrix<T, M, N> {
88 fn default() -> Self {
89 Self::new((0..M::USIZE * N::USIZE).map(|_| T::default()).collect())
90 }
91}
92
93impl<T: Sized + Debug, M: Positive, N: Positive> Debug for HeapMatrix<T, M, N> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 f.debug_struct(format!("HeapMatrix<{}, {}>", M::USIZE, N::USIZE).as_str())
96 .field("data", &self.data)
97 .finish()
98 }
99}
100
101impl<T: Sized + Serialize, M: Positive, N: Positive> Serialize for HeapMatrix<T, M, N> {
102 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
103 self.data.serialize(serializer)
104 }
105}
106
107impl<'de, T: Sized + Deserialize<'de>, M: Positive, N: Positive> Deserialize<'de>
108 for HeapMatrix<T, M, N>
109{
110 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
111 let data = Box::<[T]>::deserialize(deserializer)?;
112
113 if data.len() != M::USIZE * N::USIZE {
114 return Err(serde::de::Error::custom(format!(
115 "Expected array of length {}, got {}",
116 M::USIZE,
117 data.len()
118 )));
119 }
120
121 Ok(Self {
122 data,
123 _len: PhantomData,
124 })
125 }
126}
127
128#[cfg(test)]
129pub mod tests {
130 use itertools::Itertools;
131 use typenum::{U15, U3, U5};
132
133 use crate::types::{heap_array::matrix::HeapMatrix, HeapArray};
134
135 #[test]
136 fn test_default() {
137 let mut m = HeapMatrix::<usize, U3, U5>::default();
138 assert_eq!(m.col_iter().len(), 5);
139 for col in m.col_iter() {
140 assert_eq!(col, &[0, 0, 0]);
141 }
142
143 {
145 let mut it = m.col_iter_mut();
146 let _ = it.next();
147 let c = it.next().unwrap();
148 c[0] = 3;
149 c[1] = 2;
150 c[2] = 1;
151 }
152
153 assert_eq!(
154 m.into_flat_iter().collect_vec(),
155 vec![0, 0, 0, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
156 );
157 }
158
159 #[test]
160 fn test_into_from_flat_array() {
161 let a = HeapArray::<usize, U15>::from_fn(|k| k);
162 let m: HeapMatrix<usize, U3, U5> = HeapMatrix::from_flat_iter(a.iter().copied());
163 let a1 = m.into_flat_array();
164 assert_eq!(a, a1);
165 }
166
167 #[test]
168 fn test_from_flat_array() {
169 let a = HeapArray::<usize, U15>::from_fn(|k| k);
170 let m: HeapMatrix<usize, U3, U5> = HeapMatrix::from_flat_array(a);
171 let mut it = m.col_iter();
172 assert_eq!(it.next().unwrap(), &[0, 1, 2]);
173 assert_eq!(it.next().unwrap(), &[3, 4, 5]);
174 assert_eq!(it.next().unwrap(), &[6, 7, 8]);
175 assert_eq!(it.next().unwrap(), &[9, 10, 11]);
176 assert_eq!(it.next().unwrap(), &[12, 13, 14]);
177 }
178}