1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 mem::{ManuallyDrop, MaybeUninit},
5 ops::Mul,
6};
7
8use derive_more::derive::Display;
9use serde::{Deserialize, Serialize};
10use typenum::Prod;
11use wincode::{
12 io::{Reader, Writer},
13 ReadResult,
14 SchemaRead,
15 SchemaWrite,
16 TypeMeta,
17 WriteResult,
18};
19
20use super::HeapArray;
21use crate::{
22 errors::PrimitiveError,
23 random::{CryptoRngCore, Random},
24 types::{heap_array::array::SliceDropGuard, Positive},
25};
26
27#[derive(Display, Default, Clone, Copy, PartialEq, Eq)]
29pub struct ColumnMajor;
30#[derive(Display, Default, Clone, Copy, PartialEq, Eq)]
32pub struct RowMajor;
33
34pub type RowMajorHeapMatrix<T, M, N> = HeapMatrix<T, M, N, RowMajor>;
35
36#[derive(Clone, PartialEq, Eq)]
40#[repr(C)]
41pub struct HeapMatrix<T: Sized, M: Positive, N: Positive, O = ColumnMajor> {
42 pub(super) data: Box<[T]>,
43
44 #[allow(clippy::type_complexity)]
48 pub(super) _len: PhantomData<fn() -> (M, N, O)>,
49}
50impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
51 fn new(data: Box<[T]>) -> Self {
52 Self {
53 data,
54 _len: PhantomData,
55 }
56 }
57
58 pub fn flat_iter(&self) -> impl ExactSizeIterator<Item = &T> {
60 self.data.iter()
61 }
62
63 pub fn flat_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> {
65 self.data.iter_mut()
66 }
67
68 pub fn into_flat_iter(self) -> impl ExactSizeIterator<Item = T> {
70 self.data.into_vec().into_iter()
71 }
72
73 pub const fn len(&self) -> usize {
75 M::USIZE * N::USIZE
76 }
77
78 pub const fn is_empty(&self) -> bool {
80 self.len() == 0
81 }
82
83 pub const fn rows(&self) -> usize {
85 M::USIZE
86 }
87
88 pub const fn cols(&self) -> usize {
90 N::USIZE
91 }
92}
93
94impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
95 pub fn map<F, U>(self, f: F) -> HeapMatrix<U, M, N, O>
96 where
97 F: FnMut(T) -> U,
98 {
99 HeapMatrix::new(
100 self.data
101 .into_vec()
102 .into_iter()
103 .map(f)
104 .collect::<Box<[U]>>(),
105 )
106 }
107}
108
109impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive, O> HeapMatrix<T, M, N, O> {
110 pub fn into_flat_array(self) -> HeapArray<T, Prod<M, N>> {
112 HeapArray::new(self.data)
113 }
114
115 pub fn from_flat_array(value: HeapArray<T, Prod<M, N>>) -> Self {
117 Self::new(value.data)
118 }
119}
120
121impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, ColumnMajor> {
124 pub fn col_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
126 self.data.chunks_exact(M::USIZE)
127 }
128
129 pub fn col_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
131 self.data.chunks_exact_mut(M::USIZE)
132 }
133
134 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
136 (row < M::USIZE && col < N::USIZE)
137 .then(|| unsafe { self.data.get_unchecked(col * M::USIZE + row) })
138 }
139
140 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
142 (row < M::USIZE && col < N::USIZE)
143 .then(|| unsafe { self.data.get_unchecked_mut(col * M::USIZE + row) })
144 }
145}
146
147impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
148 HeapMatrix<T, M, N, ColumnMajor>
149{
150 pub fn from_cols(val: HeapArray<HeapArray<T, M>, N>) -> Self {
152 Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
153 }
154}
155
156impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, RowMajor> {
159 pub fn row_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
161 self.data.chunks_exact(N::USIZE)
162 }
163
164 pub fn row_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
166 self.data.chunks_exact_mut(N::USIZE)
167 }
168
169 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
171 (row < M::USIZE && col < N::USIZE)
172 .then(|| unsafe { self.data.get_unchecked(row * N::USIZE + col) })
173 }
174
175 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
177 (row < M::USIZE && col < N::USIZE)
178 .then(|| unsafe { self.data.get_unchecked_mut(row * N::USIZE + col) })
179 }
180}
181
182impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
183 HeapMatrix<T, M, N, ColumnMajor>
184{
185 pub fn from_rows(val: HeapArray<HeapArray<T, N>, M>) -> Self {
187 Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
188 }
189}
190
191impl<T: Sized + Default, M: Positive, N: Positive, O> Default for HeapMatrix<T, M, N, O> {
194 fn default() -> Self {
195 Self::new(
196 (0..M::USIZE * N::USIZE)
197 .map(|_| T::default())
198 .collect::<Box<[T]>>(),
199 )
200 }
201}
202
203impl<T: Sized + Debug, M: Positive, N: Positive, O: Display + Default> Debug
204 for HeapMatrix<T, M, N, O>
205{
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 f.debug_struct(format!("Matrix[{}]<{}, {}>", O::default(), M::USIZE, N::USIZE).as_str())
208 .field("data", &self.data)
209 .finish()
210 }
211}
212
213impl<T: Sized + Serialize, M: Positive, N: Positive, O> Serialize for HeapMatrix<T, M, N, O> {
214 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
215 self.data.serialize(serializer)
216 }
217}
218
219impl<T: Random + Sized, M: Positive, N: Positive, O> Random for HeapMatrix<T, M, N, O> {
220 fn random(mut rng: impl CryptoRngCore) -> Self {
221 Self::new(
222 (0..M::USIZE * N::USIZE)
223 .map(|_| T::random(&mut rng))
224 .collect(),
225 )
226 }
227}
228
229impl<'de, T: Sized + Deserialize<'de>, M: Positive, N: Positive, O> Deserialize<'de>
230 for HeapMatrix<T, M, N, O>
231{
232 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
233 let data = Box::<[T]>::deserialize(deserializer)?;
234 if data.len() != M::USIZE * N::USIZE {
235 return Err(serde::de::Error::custom(format!(
236 "Expected matrix of length {}, got {}",
237 M::USIZE * N::USIZE,
238 data.len()
239 )));
240 }
241 Ok(Self::new(data))
242 }
243}
244
245impl<T: Sized + SchemaWrite<Src = T>, M: Positive, N: Positive, O> SchemaWrite
246 for HeapMatrix<T, M, N, O>
247{
248 type Src = HeapMatrix<T::Src, M, N, O>;
249
250 const TYPE_META: wincode::TypeMeta = match <T as SchemaWrite>::TYPE_META {
251 TypeMeta::Static { size, zero_copy } => TypeMeta::Static {
252 size: size * M::USIZE * N::USIZE,
253 zero_copy,
254 },
255 TypeMeta::Dynamic => TypeMeta::Dynamic,
256 };
257
258 #[inline]
259 fn size_of(src: &Self::Src) -> WriteResult<usize> {
260 if let TypeMeta::Static { size, .. } = <Self as SchemaWrite>::TYPE_META {
261 return Ok(size);
262 }
263
264 src.data
266 .iter()
267 .map(T::size_of)
268 .try_fold(0usize, |acc, x| x.map(|x| acc + x))
269 }
270
271 #[inline]
272 fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
273 if let TypeMeta::Static {
274 size,
275 zero_copy: true,
276 } = <Self as SchemaWrite>::TYPE_META
277 {
278 let writer = &mut unsafe { writer.as_trusted_for(size) }?;
281 unsafe { writer.write_slice_t(&src.data)? };
284 writer.finish()?;
285 } else if let TypeMeta::Static { size, .. } = <Self as SchemaWrite>::TYPE_META {
286 #[allow(clippy::arithmetic_side_effects)]
287 let mut writer = unsafe { writer.as_trusted_for(size) }?;
291 for item in src.data.iter() {
292 T::write(&mut writer, item)?;
293 }
294 writer.finish()?;
295 } else {
296 for item in src.data.iter() {
297 T::write(writer, item)?;
298 }
299 }
300
301 Ok(())
302 }
303}
304
305impl<'de, T: Sized + SchemaRead<'de, Dst = T>, M: Positive, N: Positive, O> SchemaRead<'de>
306 for HeapMatrix<T, M, N, O>
307{
308 type Dst = HeapMatrix<T::Dst, M, N, O>;
309
310 const TYPE_META: TypeMeta = const {
311 match T::TYPE_META {
312 TypeMeta::Static { size, zero_copy } => TypeMeta::Static {
313 size: size * M::USIZE * N::USIZE,
314 zero_copy,
315 },
316 TypeMeta::Dynamic => TypeMeta::Dynamic,
317 }
318 };
319
320 #[inline]
321 fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
322 struct DropGuardRawCopy<T>(*mut [MaybeUninit<T>]);
327 impl<T> Drop for DropGuardRawCopy<T> {
328 #[inline]
329 fn drop(&mut self) {
330 let container = unsafe { Box::from_raw(self.0) };
331 drop(container);
332 }
333 }
334 struct DropGuardElemCopy<T> {
339 inner: ManuallyDrop<SliceDropGuard<T>>,
340 fat: *mut [MaybeUninit<T>],
341 }
342 impl<T> DropGuardElemCopy<T> {
343 #[inline(always)]
344 fn new(fat: *mut [MaybeUninit<T>], raw: *mut MaybeUninit<T>) -> Self {
345 Self {
346 inner: ManuallyDrop::new(SliceDropGuard::new(raw)),
347 fat,
348 }
349 }
350 }
351 impl<T> Drop for DropGuardElemCopy<T> {
352 #[inline]
353 fn drop(&mut self) {
354 unsafe {
355 ManuallyDrop::drop(&mut self.inner);
356 }
357 let container = unsafe { Box::from_raw(self.fat) };
358 drop(container);
359 }
360 }
361 let mem = Box::<[T::Dst]>::new_uninit_slice(M::USIZE * N::USIZE);
362 let fat = Box::into_raw(mem);
363 match T::TYPE_META {
364 TypeMeta::Static {
365 zero_copy: true, ..
366 } => {
367 let guard = DropGuardRawCopy(fat);
368 let dst = unsafe { &mut *fat };
369 unsafe { reader.copy_into_slice_t(dst)? };
370 std::mem::forget(guard);
371 }
372 TypeMeta::Static {
373 size,
374 zero_copy: false,
375 } => {
376 let raw_base = unsafe { (*fat).as_mut_ptr() };
377 let mut guard: DropGuardElemCopy<T::Dst> = DropGuardElemCopy::new(fat, raw_base);
378 #[allow(clippy::arithmetic_side_effects)]
379 let reader = &mut unsafe { reader.as_trusted_for(size * M::USIZE * N::USIZE) }?;
380 for i in 0..M::USIZE * N::USIZE {
381 let slot = unsafe { &mut *raw_base.add(i) };
382 T::read(reader, slot)?;
383 guard.inner.inc_len();
384 }
385 std::mem::forget(guard);
386 }
387 TypeMeta::Dynamic => {
388 let raw_base = unsafe { (*fat).as_mut_ptr() };
389 let mut guard: DropGuardElemCopy<T::Dst> = DropGuardElemCopy::new(fat, raw_base);
390 for i in 0..M::USIZE * N::USIZE {
391 let slot = unsafe { &mut *raw_base.add(i) };
392 T::read(reader, slot)?;
393 guard.inner.inc_len();
394 }
395 std::mem::forget(guard);
396 }
397 }
398 let container = unsafe { Box::from_raw(fat) };
399 let container = unsafe { container.assume_init().try_into().unwrap() };
400 dst.write(container);
401 Ok(())
402 }
403}
404
405impl<T: Sized, M: Positive, N: Positive, O> AsRef<[T]> for HeapMatrix<T, M, N, O> {
406 fn as_ref(&self) -> &[T] {
407 &self.data
408 }
409}
410
411impl<T: Sized, M: Positive, N: Positive, O> AsMut<[T]> for HeapMatrix<T, M, N, O> {
412 fn as_mut(&mut self) -> &mut [T] {
413 &mut self.data
414 }
415}
416
417impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Vec<T> {
418 fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
419 matrix.data.into_vec()
420 }
421}
422
423impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Vec<T>> for HeapMatrix<T, M, N, O> {
424 type Error = PrimitiveError;
425
426 fn try_from(matrix: Vec<T>) -> Result<Self, Self::Error> {
427 if matrix.len() != M::USIZE * N::USIZE {
428 return Err(PrimitiveError::InvalidSize(
429 M::USIZE * N::USIZE,
430 matrix.len(),
431 ));
432 }
433 Ok(Self::new(matrix.into_boxed_slice()))
434 }
435}
436
437impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Box<[T]> {
438 fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
439 matrix.data
440 }
441}
442
443impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Box<[T]>> for HeapMatrix<T, M, N, O> {
444 type Error = PrimitiveError;
445
446 fn try_from(matrix: Box<[T]>) -> Result<Self, Self::Error> {
447 if matrix.len() != M::USIZE * N::USIZE {
448 return Err(PrimitiveError::InvalidSize(
449 M::USIZE * N::USIZE,
450 matrix.len(),
451 ));
452 }
453 Ok(Self::new(matrix))
454 }
455}
456
457#[cfg(test)]
458pub mod tests {
459 use itertools::Itertools;
460 use typenum::{U2, U3, U4, U6};
461
462 use super::{ColumnMajor, RowMajor};
463 use crate::types::{HeapArray, HeapMatrix};
464
465 #[test]
466 fn test_flat_operations() {
467 let data: Vec<u32> = (0..6).collect();
468 let matrix: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
469
470 assert_eq!(matrix.flat_iter().copied().collect_vec(), data);
472
473 let mut matrix3: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
475 for x in matrix3.flat_iter_mut() {
476 *x *= 2;
477 }
478 assert_eq!(matrix3.get(0, 0), Some(&0));
479 assert_eq!(matrix3.get(1, 0), Some(&2));
480
481 let array: HeapArray<u32, U6> = matrix.into_flat_array();
483 assert_eq!(array.as_ref(), data.as_slice());
484 let matrix4: HeapMatrix<u32, U2, U3> = HeapMatrix::from_flat_array(array);
485 assert_eq!(matrix4.as_ref(), data.as_slice());
486 }
487
488 #[test]
489 fn test_column_first_operations() {
490 let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
492 let mut matrix: HeapMatrix<u32, U3, U2, ColumnMajor> =
493 HeapMatrix::new(data.into_boxed_slice());
494
495 assert_eq!(matrix.get(0, 0), Some(&0));
497 assert_eq!(matrix.get(1, 0), Some(&1));
498 assert_eq!(matrix.get(2, 0), Some(&2));
499 assert_eq!(matrix.get(0, 1), Some(&3));
500 assert_eq!(matrix.get(1, 1), Some(&4));
501 assert_eq!(matrix.get(2, 1), Some(&5));
502 assert_eq!(matrix.get(3, 0), None);
503 assert_eq!(matrix.get(0, 2), None);
504
505 *matrix.get_mut(1, 1).unwrap() = 42;
507 assert_eq!(matrix.get(1, 1), Some(&42));
508
509 let data2: Vec<u32> = (0..12).collect();
511 let matrix2: HeapMatrix<u32, U4, U3, ColumnMajor> =
512 HeapMatrix::new(data2.into_boxed_slice());
513 let cols: Vec<Vec<u32>> = matrix2.col_iter().map(|col| col.to_vec()).collect();
514 assert_eq!(cols.len(), 3);
515 assert_eq!(cols[0], vec![0, 1, 2, 3]);
516 assert_eq!(cols[1], vec![4, 5, 6, 7]);
517 assert_eq!(cols[2], vec![8, 9, 10, 11]);
518
519 let data3: Vec<u32> = (0..12).collect();
521 let mut matrix3: HeapMatrix<u32, U4, U3, ColumnMajor> =
522 HeapMatrix::new(data3.into_boxed_slice());
523 for col in matrix3.col_iter_mut() {
524 col[0] = 99;
525 }
526 assert_eq!(matrix3.get(0, 0), Some(&99));
527 assert_eq!(matrix3.get(0, 1), Some(&99));
528 assert_eq!(matrix3.get(0, 2), Some(&99));
529 }
530
531 #[test]
532 fn test_row_first_operations() {
533 let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
535 let mut matrix: HeapMatrix<u32, U3, U2, RowMajor> =
536 HeapMatrix::new(data.into_boxed_slice());
537
538 assert_eq!(matrix.get(0, 0), Some(&0));
540 assert_eq!(matrix.get(0, 1), Some(&1));
541 assert_eq!(matrix.get(1, 0), Some(&2));
542 assert_eq!(matrix.get(1, 1), Some(&3));
543 assert_eq!(matrix.get(2, 0), Some(&4));
544 assert_eq!(matrix.get(2, 1), Some(&5));
545 assert_eq!(matrix.get(3, 0), None);
546 assert_eq!(matrix.get(0, 2), None);
547
548 *matrix.get_mut(1, 1).unwrap() = 42;
550 assert_eq!(matrix.get(1, 1), Some(&42));
551
552 let data2: Vec<u32> = (0..12).collect();
554 let matrix2: HeapMatrix<u32, U3, U4, RowMajor> = HeapMatrix::new(data2.into_boxed_slice());
555 let rows: Vec<Vec<u32>> = matrix2.row_iter().map(|row| row.to_vec()).collect();
556 assert_eq!(rows.len(), 3);
557 assert_eq!(rows[0], vec![0, 1, 2, 3]);
558 assert_eq!(rows[1], vec![4, 5, 6, 7]);
559 assert_eq!(rows[2], vec![8, 9, 10, 11]);
560
561 let data3: Vec<u32> = (0..12).collect();
563 let mut matrix3: HeapMatrix<u32, U3, U4, RowMajor> =
564 HeapMatrix::new(data3.into_boxed_slice());
565 for row in matrix3.row_iter_mut() {
566 row[0] = 99;
567 }
568 assert_eq!(matrix3.get(0, 0), Some(&99));
569 assert_eq!(matrix3.get(1, 0), Some(&99));
570 assert_eq!(matrix3.get(2, 0), Some(&99));
571 }
572
573 #[test]
574 fn test_try_from_vec() {
575 let data: Vec<u32> = (0..12).collect();
577 let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
578 assert!(result.is_ok());
579
580 let data: Vec<u32> = (0..10).collect();
582 let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn test_heap_matrix_wincode_roundtrip_static_zerocopy() {
588 let matrix: HeapMatrix<u32, U2, U3, ColumnMajor> = HeapMatrix::new(
590 (0..6)
591 .map(|i| i as u32)
592 .collect::<Vec<_>>()
593 .into_boxed_slice(),
594 );
595
596 let serialized = wincode::serialize(&matrix).unwrap();
597 let deserialized: HeapMatrix<u32, U2, U3, ColumnMajor> =
598 wincode::deserialize(&serialized).unwrap();
599
600 assert_eq!(matrix, deserialized);
601 }
602
603 #[test]
604 fn test_heap_matrix_wincode_roundtrip_static_non_zerocopy() {
605 use serde::{Deserialize, Serialize};
606 use wincode::{SchemaRead, SchemaWrite};
607
608 #[derive(
610 Debug, Copy, Clone, PartialEq, Eq, SchemaRead, SchemaWrite, Serialize, Deserialize,
611 )]
612 struct NonZeroCopy {
613 a: u8,
614 b: u16,
615 }
616
617 let matrix: HeapMatrix<NonZeroCopy, U2, U2, RowMajor> = HeapMatrix::new(
618 vec![
619 NonZeroCopy { a: 1, b: 100 },
620 NonZeroCopy { a: 2, b: 200 },
621 NonZeroCopy { a: 3, b: 300 },
622 NonZeroCopy { a: 4, b: 400 },
623 ]
624 .into_boxed_slice(),
625 );
626
627 let serialized = wincode::serialize(&matrix).unwrap();
628 let deserialized: HeapMatrix<NonZeroCopy, U2, U2, RowMajor> =
629 wincode::deserialize(&serialized).unwrap();
630
631 assert_eq!(matrix, deserialized);
632 }
633
634 #[test]
635 fn test_heap_matrix_wincode_roundtrip_dynamic() {
636 let matrix: HeapMatrix<String, U3, U2, ColumnMajor> = HeapMatrix::new(
638 vec![
639 "a".to_string(),
640 "b".to_string(),
641 "c".to_string(),
642 "d".to_string(),
643 "e".to_string(),
644 "f".to_string(),
645 ]
646 .into_boxed_slice(),
647 );
648
649 let serialized = wincode::serialize(&matrix).unwrap();
650 let deserialized: HeapMatrix<String, U3, U2, ColumnMajor> =
651 wincode::deserialize(&serialized).unwrap();
652
653 assert_eq!(matrix, deserialized);
654 }
655}