1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 ops::Mul,
5};
6
7use derive_more::derive::Display;
8use serde::{Deserialize, Serialize};
9use typenum::Prod;
10
11use super::HeapArray;
12use crate::{
13 errors::PrimitiveError,
14 random::{CryptoRngCore, Random},
15 types::Positive,
16};
17
18#[derive(Display, Default, Clone, Copy)]
20pub struct ColumnMajor;
21#[derive(Display, Default, Clone, Copy)]
23pub struct RowMajor;
24
25pub type RowMajorHeapMatrix<T, M, N> = HeapMatrix<T, M, N, RowMajor>;
26
27#[derive(Clone, PartialEq, Eq)]
31pub struct HeapMatrix<T: Sized, M: Positive, N: Positive, O = ColumnMajor> {
32 pub(super) data: Box<[T]>,
33
34 #[allow(clippy::type_complexity)]
38 pub(super) _len: PhantomData<fn() -> (M, N, O)>,
39}
40impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
41 fn new(data: Box<[T]>) -> Self {
42 Self {
43 data,
44 _len: PhantomData,
45 }
46 }
47
48 pub fn flat_iter(&self) -> impl ExactSizeIterator<Item = &T> {
50 self.data.iter()
51 }
52
53 pub fn flat_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> {
55 self.data.iter_mut()
56 }
57
58 pub fn into_flat_iter(self) -> impl ExactSizeIterator<Item = T> {
60 self.data.into_vec().into_iter()
61 }
62
63 pub const fn len(&self) -> usize {
65 M::USIZE * N::USIZE
66 }
67
68 pub const fn is_empty(&self) -> bool {
70 self.len() == 0
71 }
72
73 pub const fn rows(&self) -> usize {
75 M::USIZE
76 }
77
78 pub const fn cols(&self) -> usize {
80 N::USIZE
81 }
82}
83
84impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
85 pub fn map<F, U>(self, f: F) -> HeapMatrix<U, M, N, O>
86 where
87 F: FnMut(T) -> U,
88 {
89 HeapMatrix::new(
90 self.data
91 .into_vec()
92 .into_iter()
93 .map(f)
94 .collect::<Box<[U]>>(),
95 )
96 }
97}
98
99impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive, O> HeapMatrix<T, M, N, O> {
100 pub fn into_flat_array(self) -> HeapArray<T, Prod<M, N>> {
102 HeapArray::new(self.data)
103 }
104
105 pub fn from_flat_array(value: HeapArray<T, Prod<M, N>>) -> Self {
107 Self::new(value.data)
108 }
109}
110
111impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, ColumnMajor> {
114 pub fn col_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
116 self.data.chunks_exact(M::USIZE)
117 }
118
119 pub fn col_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
121 self.data.chunks_exact_mut(M::USIZE)
122 }
123
124 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
126 (row < M::USIZE && col < N::USIZE)
127 .then(|| unsafe { self.data.get_unchecked(col * M::USIZE + row) })
128 }
129
130 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
132 (row < M::USIZE && col < N::USIZE)
133 .then(|| unsafe { self.data.get_unchecked_mut(col * M::USIZE + row) })
134 }
135}
136
137impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
138 HeapMatrix<T, M, N, ColumnMajor>
139{
140 pub fn from_cols(val: HeapArray<HeapArray<T, M>, N>) -> Self {
142 Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
143 }
144}
145
146impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, RowMajor> {
149 pub fn row_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
151 self.data.chunks_exact(N::USIZE)
152 }
153
154 pub fn row_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
156 self.data.chunks_exact_mut(N::USIZE)
157 }
158
159 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
161 (row < M::USIZE && col < N::USIZE)
162 .then(|| unsafe { self.data.get_unchecked(row * N::USIZE + col) })
163 }
164
165 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
167 (row < M::USIZE && col < N::USIZE)
168 .then(|| unsafe { self.data.get_unchecked_mut(row * N::USIZE + col) })
169 }
170}
171
172impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
173 HeapMatrix<T, M, N, ColumnMajor>
174{
175 pub fn from_rows(val: HeapArray<HeapArray<T, N>, M>) -> Self {
177 Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
178 }
179}
180
181impl<T: Sized + Default, M: Positive, N: Positive, O> Default for HeapMatrix<T, M, N, O> {
184 fn default() -> Self {
185 Self::new(
186 (0..M::USIZE * N::USIZE)
187 .map(|_| T::default())
188 .collect::<Box<[T]>>(),
189 )
190 }
191}
192
193impl<T: Sized + Debug, M: Positive, N: Positive, O: Display + Default> Debug
194 for HeapMatrix<T, M, N, O>
195{
196 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197 f.debug_struct(format!("Matrix[{}]<{}, {}>", O::default(), M::USIZE, N::USIZE).as_str())
198 .field("data", &self.data)
199 .finish()
200 }
201}
202
203impl<T: Sized + Serialize, M: Positive, N: Positive, O> Serialize for HeapMatrix<T, M, N, O> {
204 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
205 self.data.serialize(serializer)
206 }
207}
208
209impl<T: Random + Sized, M: Positive, N: Positive, O> Random for HeapMatrix<T, M, N, O> {
210 fn random(mut rng: impl CryptoRngCore) -> Self {
211 Self::new(
212 (0..M::USIZE * N::USIZE)
213 .map(|_| T::random(&mut rng))
214 .collect(),
215 )
216 }
217}
218
219impl<'de, T: Sized + Deserialize<'de>, M: Positive, N: Positive, O> Deserialize<'de>
220 for HeapMatrix<T, M, N, O>
221{
222 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
223 let data = Box::<[T]>::deserialize(deserializer)?;
224 if data.len() != M::USIZE * N::USIZE {
225 return Err(serde::de::Error::custom(format!(
226 "Expected matrix of length {}, got {}",
227 M::USIZE * N::USIZE,
228 data.len()
229 )));
230 }
231 Ok(Self::new(data))
232 }
233}
234
235impl<T: Sized, M: Positive, N: Positive, O> AsRef<[T]> for HeapMatrix<T, M, N, O> {
236 fn as_ref(&self) -> &[T] {
237 &self.data
238 }
239}
240
241impl<T: Sized, M: Positive, N: Positive, O> AsMut<[T]> for HeapMatrix<T, M, N, O> {
242 fn as_mut(&mut self) -> &mut [T] {
243 &mut self.data
244 }
245}
246
247impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Vec<T> {
248 fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
249 matrix.data.into_vec()
250 }
251}
252
253impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Vec<T>> for HeapMatrix<T, M, N, O> {
254 type Error = PrimitiveError;
255
256 fn try_from(matrix: Vec<T>) -> Result<Self, Self::Error> {
257 if matrix.len() != M::USIZE * N::USIZE {
258 return Err(PrimitiveError::InvalidSize(
259 M::USIZE * N::USIZE,
260 matrix.len(),
261 ));
262 }
263 Ok(Self::new(matrix.into_boxed_slice()))
264 }
265}
266
267impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Box<[T]> {
268 fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
269 matrix.data
270 }
271}
272
273impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Box<[T]>> for HeapMatrix<T, M, N, O> {
274 type Error = PrimitiveError;
275
276 fn try_from(matrix: Box<[T]>) -> Result<Self, Self::Error> {
277 if matrix.len() != M::USIZE * N::USIZE {
278 return Err(PrimitiveError::InvalidSize(
279 M::USIZE * N::USIZE,
280 matrix.len(),
281 ));
282 }
283 Ok(Self::new(matrix))
284 }
285}
286
287#[cfg(test)]
288pub mod tests {
289 use itertools::Itertools;
290 use typenum::{U2, U3, U4, U6};
291
292 use super::{ColumnMajor, RowMajor};
293 use crate::types::{HeapArray, HeapMatrix};
294
295 #[test]
296 fn test_flat_operations() {
297 let data: Vec<u32> = (0..6).collect();
298 let matrix: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
299
300 assert_eq!(matrix.flat_iter().copied().collect_vec(), data);
302
303 let mut matrix3: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
305 for x in matrix3.flat_iter_mut() {
306 *x *= 2;
307 }
308 assert_eq!(matrix3.get(0, 0), Some(&0));
309 assert_eq!(matrix3.get(1, 0), Some(&2));
310
311 let array: HeapArray<u32, U6> = matrix.into_flat_array();
313 assert_eq!(array.as_ref(), data.as_slice());
314 let matrix4: HeapMatrix<u32, U2, U3> = HeapMatrix::from_flat_array(array);
315 assert_eq!(matrix4.as_ref(), data.as_slice());
316 }
317
318 #[test]
319 fn test_column_first_operations() {
320 let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
322 let mut matrix: HeapMatrix<u32, U3, U2, ColumnMajor> =
323 HeapMatrix::new(data.into_boxed_slice());
324
325 assert_eq!(matrix.get(0, 0), Some(&0));
327 assert_eq!(matrix.get(1, 0), Some(&1));
328 assert_eq!(matrix.get(2, 0), Some(&2));
329 assert_eq!(matrix.get(0, 1), Some(&3));
330 assert_eq!(matrix.get(1, 1), Some(&4));
331 assert_eq!(matrix.get(2, 1), Some(&5));
332 assert_eq!(matrix.get(3, 0), None);
333 assert_eq!(matrix.get(0, 2), None);
334
335 *matrix.get_mut(1, 1).unwrap() = 42;
337 assert_eq!(matrix.get(1, 1), Some(&42));
338
339 let data2: Vec<u32> = (0..12).collect();
341 let matrix2: HeapMatrix<u32, U4, U3, ColumnMajor> =
342 HeapMatrix::new(data2.into_boxed_slice());
343 let cols: Vec<Vec<u32>> = matrix2.col_iter().map(|col| col.to_vec()).collect();
344 assert_eq!(cols.len(), 3);
345 assert_eq!(cols[0], vec![0, 1, 2, 3]);
346 assert_eq!(cols[1], vec![4, 5, 6, 7]);
347 assert_eq!(cols[2], vec![8, 9, 10, 11]);
348
349 let data3: Vec<u32> = (0..12).collect();
351 let mut matrix3: HeapMatrix<u32, U4, U3, ColumnMajor> =
352 HeapMatrix::new(data3.into_boxed_slice());
353 for col in matrix3.col_iter_mut() {
354 col[0] = 99;
355 }
356 assert_eq!(matrix3.get(0, 0), Some(&99));
357 assert_eq!(matrix3.get(0, 1), Some(&99));
358 assert_eq!(matrix3.get(0, 2), Some(&99));
359 }
360
361 #[test]
362 fn test_row_first_operations() {
363 let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
365 let mut matrix: HeapMatrix<u32, U3, U2, RowMajor> =
366 HeapMatrix::new(data.into_boxed_slice());
367
368 assert_eq!(matrix.get(0, 0), Some(&0));
370 assert_eq!(matrix.get(0, 1), Some(&1));
371 assert_eq!(matrix.get(1, 0), Some(&2));
372 assert_eq!(matrix.get(1, 1), Some(&3));
373 assert_eq!(matrix.get(2, 0), Some(&4));
374 assert_eq!(matrix.get(2, 1), Some(&5));
375 assert_eq!(matrix.get(3, 0), None);
376 assert_eq!(matrix.get(0, 2), None);
377
378 *matrix.get_mut(1, 1).unwrap() = 42;
380 assert_eq!(matrix.get(1, 1), Some(&42));
381
382 let data2: Vec<u32> = (0..12).collect();
384 let matrix2: HeapMatrix<u32, U3, U4, RowMajor> = HeapMatrix::new(data2.into_boxed_slice());
385 let rows: Vec<Vec<u32>> = matrix2.row_iter().map(|row| row.to_vec()).collect();
386 assert_eq!(rows.len(), 3);
387 assert_eq!(rows[0], vec![0, 1, 2, 3]);
388 assert_eq!(rows[1], vec![4, 5, 6, 7]);
389 assert_eq!(rows[2], vec![8, 9, 10, 11]);
390
391 let data3: Vec<u32> = (0..12).collect();
393 let mut matrix3: HeapMatrix<u32, U3, U4, RowMajor> =
394 HeapMatrix::new(data3.into_boxed_slice());
395 for row in matrix3.row_iter_mut() {
396 row[0] = 99;
397 }
398 assert_eq!(matrix3.get(0, 0), Some(&99));
399 assert_eq!(matrix3.get(1, 0), Some(&99));
400 assert_eq!(matrix3.get(2, 0), Some(&99));
401 }
402
403 #[test]
404 fn test_try_from_vec() {
405 let data: Vec<u32> = (0..12).collect();
407 let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
408 assert!(result.is_ok());
409
410 let data: Vec<u32> = (0..10).collect();
412 let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
413 assert!(result.is_err());
414 }
415}