1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 mem::{ManuallyDrop, MaybeUninit},
5 ops::Mul,
6};
7
8use derive_more::derive::Display;
9use serde::{Deserialize, Serialize};
10use typenum::Prod;
11use wincode::{
12 io::{Reader, Writer},
13 ReadResult,
14 SchemaRead,
15 SchemaWrite,
16 TypeMeta,
17 WriteResult,
18};
19
20use super::HeapArray;
21use crate::{
22 errors::PrimitiveError,
23 random::{CryptoRngCore, Random},
24 types::{heap_array::array::SliceDropGuard, Positive},
25};
26
27#[derive(Display, Default, Clone, Copy, PartialEq, Eq)]
29pub struct ColumnMajor;
30#[derive(Display, Default, Clone, Copy, PartialEq, Eq)]
32pub struct RowMajor;
33
34pub type RowMajorHeapMatrix<T, M, N> = HeapMatrix<T, M, N, RowMajor>;
35
36#[derive(Clone, PartialEq, Eq)]
40#[repr(C)]
41pub struct HeapMatrix<T: Sized, M: Positive, N: Positive, O = ColumnMajor> {
42 pub(super) data: Box<[T]>,
43
44 #[allow(clippy::type_complexity)]
48 pub(super) _len: PhantomData<fn() -> (M, N, O)>,
49}
50impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
51 fn new(data: Box<[T]>) -> Self {
52 Self {
53 data,
54 _len: PhantomData,
55 }
56 }
57
58 pub fn flat_iter(&self) -> impl ExactSizeIterator<Item = &T> {
60 self.data.iter()
61 }
62
63 pub fn flat_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut T> {
65 self.data.iter_mut()
66 }
67
68 pub fn into_flat_iter(self) -> impl ExactSizeIterator<Item = T> {
70 self.data.into_vec().into_iter()
71 }
72
73 pub const fn len(&self) -> usize {
75 M::USIZE * N::USIZE
76 }
77
78 pub const fn is_empty(&self) -> bool {
80 self.len() == 0
81 }
82
83 pub const fn rows(&self) -> usize {
85 M::USIZE
86 }
87
88 pub const fn cols(&self) -> usize {
90 N::USIZE
91 }
92}
93
94impl<T: Sized, M: Positive, N: Positive, O> HeapMatrix<T, M, N, O> {
95 pub fn map<F, U>(self, f: F) -> HeapMatrix<U, M, N, O>
96 where
97 F: FnMut(T) -> U,
98 {
99 HeapMatrix::new(
100 self.data
101 .into_vec()
102 .into_iter()
103 .map(f)
104 .collect::<Box<[U]>>(),
105 )
106 }
107}
108
109impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive, O> HeapMatrix<T, M, N, O> {
110 pub fn into_flat_array(self) -> HeapArray<T, Prod<M, N>> {
112 HeapArray::new(self.data)
113 }
114
115 pub fn from_flat_array(value: HeapArray<T, Prod<M, N>>) -> Self {
117 Self::new(value.data)
118 }
119}
120
121impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, ColumnMajor> {
124 pub fn col_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
126 self.data.chunks_exact(M::USIZE)
127 }
128
129 pub fn col_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
131 self.data.chunks_exact_mut(M::USIZE)
132 }
133
134 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
136 (row < M::USIZE && col < N::USIZE)
137 .then(|| unsafe { self.data.get_unchecked(col * M::USIZE + row) })
138 }
139
140 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
142 (row < M::USIZE && col < N::USIZE)
143 .then(|| unsafe { self.data.get_unchecked_mut(col * M::USIZE + row) })
144 }
145}
146
147impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
148 HeapMatrix<T, M, N, ColumnMajor>
149{
150 pub fn from_cols(val: HeapArray<HeapArray<T, M>, N>) -> Self {
152 Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
153 }
154}
155
156impl<T: Sized, M: Positive, N: Positive> HeapMatrix<T, M, N, RowMajor> {
159 pub fn row_iter(&self) -> impl ExactSizeIterator<Item = &[T]> {
161 self.data.chunks_exact(N::USIZE)
162 }
163
164 pub fn row_iter_mut(&mut self) -> impl ExactSizeIterator<Item = &mut [T]> {
166 self.data.chunks_exact_mut(N::USIZE)
167 }
168
169 pub fn get(&self, row: usize, col: usize) -> Option<&T> {
171 (row < M::USIZE && col < N::USIZE)
172 .then(|| unsafe { self.data.get_unchecked(row * N::USIZE + col) })
173 }
174
175 pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
177 (row < M::USIZE && col < N::USIZE)
178 .then(|| unsafe { self.data.get_unchecked_mut(row * N::USIZE + col) })
179 }
180}
181
182impl<T: Sized, M: Positive + Mul<N, Output: Positive>, N: Positive>
183 HeapMatrix<T, M, N, ColumnMajor>
184{
185 pub fn from_rows(val: HeapArray<HeapArray<T, N>, M>) -> Self {
187 Self::try_from(val.into_iter().flatten().collect::<Box<[T]>>()).unwrap()
188 }
189}
190
191impl<T: Sized + Default, M: Positive, N: Positive, O> Default for HeapMatrix<T, M, N, O> {
194 fn default() -> Self {
195 Self::new(
196 (0..M::USIZE * N::USIZE)
197 .map(|_| T::default())
198 .collect::<Box<[T]>>(),
199 )
200 }
201}
202
203impl<T: Sized + Debug, M: Positive, N: Positive, O: Display + Default> Debug
204 for HeapMatrix<T, M, N, O>
205{
206 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
207 f.debug_struct(format!("Matrix[{}]<{}, {}>", O::default(), M::USIZE, N::USIZE).as_str())
208 .field("data", &self.data)
209 .finish()
210 }
211}
212
213impl<T: Sized + Serialize, M: Positive, N: Positive, O> Serialize for HeapMatrix<T, M, N, O> {
214 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
215 self.data.serialize(serializer)
216 }
217}
218
219impl<T: Random + Sized, M: Positive, N: Positive, O> Random for HeapMatrix<T, M, N, O> {
220 fn random(mut rng: impl CryptoRngCore) -> Self {
221 Self::new(T::random_n(&mut rng, M::USIZE * N::USIZE))
222 }
223}
224
225impl<'de, T: Sized + Deserialize<'de>, M: Positive, N: Positive, O> Deserialize<'de>
226 for HeapMatrix<T, M, N, O>
227{
228 fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
229 let data = Box::<[T]>::deserialize(deserializer)?;
230 if data.len() != M::USIZE * N::USIZE {
231 return Err(serde::de::Error::custom(format!(
232 "Expected matrix of length {}, got {}",
233 M::USIZE * N::USIZE,
234 data.len()
235 )));
236 }
237 Ok(Self::new(data))
238 }
239}
240
241impl<T: Sized + SchemaWrite<Src = T>, M: Positive, N: Positive, O> SchemaWrite
242 for HeapMatrix<T, M, N, O>
243{
244 type Src = HeapMatrix<T::Src, M, N, O>;
245
246 const TYPE_META: wincode::TypeMeta = match <T as SchemaWrite>::TYPE_META {
247 TypeMeta::Static { size, zero_copy } => TypeMeta::Static {
248 size: size * M::USIZE * N::USIZE,
249 zero_copy,
250 },
251 TypeMeta::Dynamic => TypeMeta::Dynamic,
252 };
253
254 #[inline]
255 fn size_of(src: &Self::Src) -> WriteResult<usize> {
256 if let TypeMeta::Static { size, .. } = <Self as SchemaWrite>::TYPE_META {
257 return Ok(size);
258 }
259
260 src.data
262 .iter()
263 .map(T::size_of)
264 .try_fold(0usize, |acc, x| x.map(|x| acc + x))
265 }
266
267 #[inline]
268 fn write(writer: &mut impl Writer, src: &Self::Src) -> WriteResult<()> {
269 if let TypeMeta::Static {
270 size,
271 zero_copy: true,
272 } = <Self as SchemaWrite>::TYPE_META
273 {
274 let writer = &mut unsafe { writer.as_trusted_for(size) }?;
277 unsafe { writer.write_slice_t(&src.data)? };
280 writer.finish()?;
281 } else if let TypeMeta::Static { size, .. } = <Self as SchemaWrite>::TYPE_META {
282 #[allow(clippy::arithmetic_side_effects)]
283 let mut writer = unsafe { writer.as_trusted_for(size) }?;
287 for item in src.data.iter() {
288 T::write(&mut writer, item)?;
289 }
290 writer.finish()?;
291 } else {
292 for item in src.data.iter() {
293 T::write(writer, item)?;
294 }
295 }
296
297 Ok(())
298 }
299}
300
301impl<'de, T: Sized + SchemaRead<'de, Dst = T>, M: Positive, N: Positive, O> SchemaRead<'de>
302 for HeapMatrix<T, M, N, O>
303{
304 type Dst = HeapMatrix<T::Dst, M, N, O>;
305
306 const TYPE_META: TypeMeta = const {
307 match T::TYPE_META {
308 TypeMeta::Static { size, zero_copy } => TypeMeta::Static {
309 size: size * M::USIZE * N::USIZE,
310 zero_copy,
311 },
312 TypeMeta::Dynamic => TypeMeta::Dynamic,
313 }
314 };
315
316 #[inline]
317 fn read(reader: &mut impl Reader<'de>, dst: &mut MaybeUninit<Self::Dst>) -> ReadResult<()> {
318 struct DropGuardRawCopy<T>(*mut [MaybeUninit<T>]);
323 impl<T> Drop for DropGuardRawCopy<T> {
324 #[inline]
325 fn drop(&mut self) {
326 let container = unsafe { Box::from_raw(self.0) };
327 drop(container);
328 }
329 }
330 struct DropGuardElemCopy<T> {
335 inner: ManuallyDrop<SliceDropGuard<T>>,
336 fat: *mut [MaybeUninit<T>],
337 }
338 impl<T> DropGuardElemCopy<T> {
339 #[inline(always)]
340 fn new(fat: *mut [MaybeUninit<T>], raw: *mut MaybeUninit<T>) -> Self {
341 Self {
342 inner: ManuallyDrop::new(SliceDropGuard::new(raw)),
343 fat,
344 }
345 }
346 }
347 impl<T> Drop for DropGuardElemCopy<T> {
348 #[inline]
349 fn drop(&mut self) {
350 unsafe {
351 ManuallyDrop::drop(&mut self.inner);
352 }
353 let container = unsafe { Box::from_raw(self.fat) };
354 drop(container);
355 }
356 }
357 let mem = Box::<[T::Dst]>::new_uninit_slice(M::USIZE * N::USIZE);
358 let fat = Box::into_raw(mem);
359 match T::TYPE_META {
360 TypeMeta::Static {
361 zero_copy: true, ..
362 } => {
363 let guard = DropGuardRawCopy(fat);
364 let dst = unsafe { &mut *fat };
365 unsafe { reader.copy_into_slice_t(dst)? };
366 std::mem::forget(guard);
367 }
368 TypeMeta::Static {
369 size,
370 zero_copy: false,
371 } => {
372 let raw_base = unsafe { (*fat).as_mut_ptr() };
373 let mut guard: DropGuardElemCopy<T::Dst> = DropGuardElemCopy::new(fat, raw_base);
374 #[allow(clippy::arithmetic_side_effects)]
375 let reader = &mut unsafe { reader.as_trusted_for(size * M::USIZE * N::USIZE) }?;
376 for i in 0..M::USIZE * N::USIZE {
377 let slot = unsafe { &mut *raw_base.add(i) };
378 T::read(reader, slot)?;
379 guard.inner.inc_len();
380 }
381 std::mem::forget(guard);
382 }
383 TypeMeta::Dynamic => {
384 let raw_base = unsafe { (*fat).as_mut_ptr() };
385 let mut guard: DropGuardElemCopy<T::Dst> = DropGuardElemCopy::new(fat, raw_base);
386 for i in 0..M::USIZE * N::USIZE {
387 let slot = unsafe { &mut *raw_base.add(i) };
388 T::read(reader, slot)?;
389 guard.inner.inc_len();
390 }
391 std::mem::forget(guard);
392 }
393 }
394 let container = unsafe { Box::from_raw(fat) };
395 let container = unsafe { container.assume_init().try_into().unwrap() };
396 dst.write(container);
397 Ok(())
398 }
399}
400
401impl<T: Sized, M: Positive, N: Positive, O> AsRef<[T]> for HeapMatrix<T, M, N, O> {
402 fn as_ref(&self) -> &[T] {
403 &self.data
404 }
405}
406
407impl<T: Sized, M: Positive, N: Positive, O> AsMut<[T]> for HeapMatrix<T, M, N, O> {
408 fn as_mut(&mut self) -> &mut [T] {
409 &mut self.data
410 }
411}
412
413impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Vec<T> {
414 fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
415 matrix.data.into_vec()
416 }
417}
418
419impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Vec<T>> for HeapMatrix<T, M, N, O> {
420 type Error = PrimitiveError;
421
422 fn try_from(matrix: Vec<T>) -> Result<Self, Self::Error> {
423 if matrix.len() != M::USIZE * N::USIZE {
424 return Err(PrimitiveError::InvalidSize(
425 M::USIZE * N::USIZE,
426 matrix.len(),
427 ));
428 }
429 Ok(Self::new(matrix.into_boxed_slice()))
430 }
431}
432
433impl<T: Sized, M: Positive, N: Positive, O> From<HeapMatrix<T, M, N, O>> for Box<[T]> {
434 fn from(matrix: HeapMatrix<T, M, N, O>) -> Self {
435 matrix.data
436 }
437}
438
439impl<T: Sized, M: Positive, N: Positive, O> TryFrom<Box<[T]>> for HeapMatrix<T, M, N, O> {
440 type Error = PrimitiveError;
441
442 fn try_from(matrix: Box<[T]>) -> Result<Self, Self::Error> {
443 if matrix.len() != M::USIZE * N::USIZE {
444 return Err(PrimitiveError::InvalidSize(
445 M::USIZE * N::USIZE,
446 matrix.len(),
447 ));
448 }
449 Ok(Self::new(matrix))
450 }
451}
452
453#[cfg(test)]
454pub mod tests {
455 use itertools::Itertools;
456 use typenum::{U2, U3, U4, U6};
457
458 use super::{ColumnMajor, RowMajor};
459 use crate::types::{HeapArray, HeapMatrix};
460
461 #[test]
462 fn test_flat_operations() {
463 let data: Vec<u32> = (0..6).collect();
464 let matrix: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
465
466 assert_eq!(matrix.flat_iter().copied().collect_vec(), data);
468
469 let mut matrix3: HeapMatrix<u32, U2, U3> = HeapMatrix::new(data.clone().into_boxed_slice());
471 for x in matrix3.flat_iter_mut() {
472 *x *= 2;
473 }
474 assert_eq!(matrix3.get(0, 0), Some(&0));
475 assert_eq!(matrix3.get(1, 0), Some(&2));
476
477 let array: HeapArray<u32, U6> = matrix.into_flat_array();
479 assert_eq!(array.as_ref(), data.as_slice());
480 let matrix4: HeapMatrix<u32, U2, U3> = HeapMatrix::from_flat_array(array);
481 assert_eq!(matrix4.as_ref(), data.as_slice());
482 }
483
484 #[test]
485 fn test_column_first_operations() {
486 let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
488 let mut matrix: HeapMatrix<u32, U3, U2, ColumnMajor> =
489 HeapMatrix::new(data.into_boxed_slice());
490
491 assert_eq!(matrix.get(0, 0), Some(&0));
493 assert_eq!(matrix.get(1, 0), Some(&1));
494 assert_eq!(matrix.get(2, 0), Some(&2));
495 assert_eq!(matrix.get(0, 1), Some(&3));
496 assert_eq!(matrix.get(1, 1), Some(&4));
497 assert_eq!(matrix.get(2, 1), Some(&5));
498 assert_eq!(matrix.get(3, 0), None);
499 assert_eq!(matrix.get(0, 2), None);
500
501 *matrix.get_mut(1, 1).unwrap() = 42;
503 assert_eq!(matrix.get(1, 1), Some(&42));
504
505 let data2: Vec<u32> = (0..12).collect();
507 let matrix2: HeapMatrix<u32, U4, U3, ColumnMajor> =
508 HeapMatrix::new(data2.into_boxed_slice());
509 let cols: Vec<Vec<u32>> = matrix2.col_iter().map(|col| col.to_vec()).collect();
510 assert_eq!(cols.len(), 3);
511 assert_eq!(cols[0], vec![0, 1, 2, 3]);
512 assert_eq!(cols[1], vec![4, 5, 6, 7]);
513 assert_eq!(cols[2], vec![8, 9, 10, 11]);
514
515 let data3: Vec<u32> = (0..12).collect();
517 let mut matrix3: HeapMatrix<u32, U4, U3, ColumnMajor> =
518 HeapMatrix::new(data3.into_boxed_slice());
519 for col in matrix3.col_iter_mut() {
520 col[0] = 99;
521 }
522 assert_eq!(matrix3.get(0, 0), Some(&99));
523 assert_eq!(matrix3.get(0, 1), Some(&99));
524 assert_eq!(matrix3.get(0, 2), Some(&99));
525 }
526
527 #[test]
528 fn test_row_first_operations() {
529 let data: Vec<u32> = vec![0, 1, 2, 3, 4, 5];
531 let mut matrix: HeapMatrix<u32, U3, U2, RowMajor> =
532 HeapMatrix::new(data.into_boxed_slice());
533
534 assert_eq!(matrix.get(0, 0), Some(&0));
536 assert_eq!(matrix.get(0, 1), Some(&1));
537 assert_eq!(matrix.get(1, 0), Some(&2));
538 assert_eq!(matrix.get(1, 1), Some(&3));
539 assert_eq!(matrix.get(2, 0), Some(&4));
540 assert_eq!(matrix.get(2, 1), Some(&5));
541 assert_eq!(matrix.get(3, 0), None);
542 assert_eq!(matrix.get(0, 2), None);
543
544 *matrix.get_mut(1, 1).unwrap() = 42;
546 assert_eq!(matrix.get(1, 1), Some(&42));
547
548 let data2: Vec<u32> = (0..12).collect();
550 let matrix2: HeapMatrix<u32, U3, U4, RowMajor> = HeapMatrix::new(data2.into_boxed_slice());
551 let rows: Vec<Vec<u32>> = matrix2.row_iter().map(|row| row.to_vec()).collect();
552 assert_eq!(rows.len(), 3);
553 assert_eq!(rows[0], vec![0, 1, 2, 3]);
554 assert_eq!(rows[1], vec![4, 5, 6, 7]);
555 assert_eq!(rows[2], vec![8, 9, 10, 11]);
556
557 let data3: Vec<u32> = (0..12).collect();
559 let mut matrix3: HeapMatrix<u32, U3, U4, RowMajor> =
560 HeapMatrix::new(data3.into_boxed_slice());
561 for row in matrix3.row_iter_mut() {
562 row[0] = 99;
563 }
564 assert_eq!(matrix3.get(0, 0), Some(&99));
565 assert_eq!(matrix3.get(1, 0), Some(&99));
566 assert_eq!(matrix3.get(2, 0), Some(&99));
567 }
568
569 #[test]
570 fn test_try_from_vec() {
571 let data: Vec<u32> = (0..12).collect();
573 let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
574 assert!(result.is_ok());
575
576 let data: Vec<u32> = (0..10).collect();
578 let result: Result<HeapMatrix<u32, U3, U4>, _> = data.try_into();
579 assert!(result.is_err());
580 }
581
582 #[test]
583 fn test_heap_matrix_wincode_roundtrip_static_zerocopy() {
584 let matrix: HeapMatrix<u32, U2, U3, ColumnMajor> = HeapMatrix::new(
586 (0..6)
587 .map(|i| i as u32)
588 .collect::<Vec<_>>()
589 .into_boxed_slice(),
590 );
591
592 let serialized = wincode::serialize(&matrix).unwrap();
593 let deserialized: HeapMatrix<u32, U2, U3, ColumnMajor> =
594 wincode::deserialize(&serialized).unwrap();
595
596 assert_eq!(matrix, deserialized);
597 }
598
599 #[test]
600 fn test_heap_matrix_wincode_roundtrip_static_non_zerocopy() {
601 use serde::{Deserialize, Serialize};
602 use wincode::{SchemaRead, SchemaWrite};
603
604 #[derive(
606 Debug, Copy, Clone, PartialEq, Eq, SchemaRead, SchemaWrite, Serialize, Deserialize,
607 )]
608 struct NonZeroCopy {
609 a: u8,
610 b: u16,
611 }
612
613 let matrix: HeapMatrix<NonZeroCopy, U2, U2, RowMajor> = HeapMatrix::new(
614 vec![
615 NonZeroCopy { a: 1, b: 100 },
616 NonZeroCopy { a: 2, b: 200 },
617 NonZeroCopy { a: 3, b: 300 },
618 NonZeroCopy { a: 4, b: 400 },
619 ]
620 .into_boxed_slice(),
621 );
622
623 let serialized = wincode::serialize(&matrix).unwrap();
624 let deserialized: HeapMatrix<NonZeroCopy, U2, U2, RowMajor> =
625 wincode::deserialize(&serialized).unwrap();
626
627 assert_eq!(matrix, deserialized);
628 }
629
630 #[test]
631 fn test_heap_matrix_wincode_roundtrip_dynamic() {
632 let matrix: HeapMatrix<String, U3, U2, ColumnMajor> = HeapMatrix::new(
634 vec![
635 "a".to_string(),
636 "b".to_string(),
637 "c".to_string(),
638 "d".to_string(),
639 "e".to_string(),
640 "f".to_string(),
641 ]
642 .into_boxed_slice(),
643 );
644
645 let serialized = wincode::serialize(&matrix).unwrap();
646 let deserialized: HeapMatrix<String, U3, U2, ColumnMajor> =
647 wincode::deserialize(&serialized).unwrap();
648
649 assert_eq!(matrix, deserialized);
650 }
651}