1use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
29
30use diskann_utils::{Reborrow, ReborrowMut};
31use thiserror::Error;
32
33use crate::utils;
34
35pub unsafe trait Repr: Copy {
58 type Row<'a>
60 where
61 Self: 'a;
62
63 fn nrows(&self) -> usize;
70
71 fn layout(&self) -> Result<Layout, LayoutError>;
79
80 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
92}
93
94pub unsafe trait ReprMut: Repr {
112 type RowMut<'a>
114 where
115 Self: 'a;
116
117 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
128}
129
130pub unsafe trait ReprOwned: ReprMut {
140 unsafe fn drop(self, ptr: NonNull<u8>);
148}
149
150#[derive(Debug, Clone, Copy)]
156#[non_exhaustive]
157pub struct LayoutError;
158
159impl LayoutError {
160 pub fn new() -> Self {
162 Self
163 }
164}
165
166impl Default for LayoutError {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl std::fmt::Display for LayoutError {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(f, "LayoutError")
175 }
176}
177
178impl std::error::Error for LayoutError {}
179
180impl From<std::alloc::LayoutError> for LayoutError {
181 fn from(_: std::alloc::LayoutError) -> Self {
182 LayoutError
183 }
184}
185
186pub unsafe trait NewRef<T>: Repr {
197 type Error;
199
200 fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
202}
203
204pub unsafe trait NewMut<T>: ReprMut {
211 type Error;
213
214 fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
216}
217
218pub unsafe trait NewOwned<T>: ReprOwned {
225 type Error;
227
228 fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
230}
231
232#[derive(Debug, Clone, Copy)]
243pub struct Defaulted;
244
245#[derive(Debug, Clone, Copy, PartialEq, Eq)]
259pub struct Standard<T> {
260 nrows: usize,
261 ncols: usize,
262 _elem: PhantomData<T>,
263}
264
265impl<T: Copy> Standard<T> {
266 pub fn new(nrows: usize, ncols: usize) -> Self {
268 Self {
269 nrows,
270 ncols,
271 _elem: PhantomData,
272 }
273 }
274
275 pub fn num_elements(&self) -> Option<usize> {
278 self.nrows.checked_mul(self.ncols())
279 }
280
281 fn ncols(&self) -> usize {
283 self.ncols
284 }
285
286 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
291 let len = self.num_elements().ok_or(SliceError::Overflow)?;
292
293 if slice.len() != len {
294 Err(SliceError::LengthMismatch {
295 expected: len,
296 found: slice.len(),
297 })
298 } else {
299 Ok(())
300 }
301 }
302}
303
304#[derive(Debug, Clone, Copy, Error)]
306#[non_exhaustive]
307pub enum SliceError {
308 #[error("Length mismatch: expected {expected}, found {found}")]
309 LengthMismatch { expected: usize, found: usize },
310 #[error("Computing slice length overflowed.")]
311 Overflow,
312}
313
314unsafe impl<T: Copy> Repr for Standard<T> {
318 type Row<'a>
319 = &'a [T]
320 where
321 T: 'a;
322
323 fn nrows(&self) -> usize {
324 self.nrows
325 }
326
327 fn layout(&self) -> Result<Layout, LayoutError> {
328 let elements = self.num_elements().ok_or(LayoutError::new())?;
329 Ok(Layout::array::<T>(elements)?)
330 }
331
332 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
333 debug_assert!(ptr.cast::<T>().is_aligned());
334 debug_assert!(i < self.nrows);
335
336 let row_ptr = ptr.as_ptr().cast::<T>().add(i * self.ncols);
337 std::slice::from_raw_parts(row_ptr, self.ncols)
338 }
339}
340
341unsafe impl<T: Copy> ReprMut for Standard<T> {
344 type RowMut<'a>
345 = &'a mut [T]
346 where
347 T: 'a;
348
349 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
350 debug_assert!(ptr.cast::<T>().is_aligned());
351 debug_assert!(i < self.nrows);
352
353 let row_ptr = ptr.as_ptr().cast::<T>().add(i * self.ncols);
354 std::slice::from_raw_parts_mut(row_ptr, self.ncols)
355 }
356}
357
358unsafe impl<T: Copy> ReprOwned for Standard<T> {
362 unsafe fn drop(self, ptr: NonNull<u8>) {
363 unsafe {
369 let slice_ptr = std::ptr::slice_from_raw_parts_mut(
370 ptr.cast::<T>().as_ptr(),
371 self.nrows * self.ncols,
372 );
373 let _ = Box::from_raw(slice_ptr);
374 }
375 }
376}
377
378unsafe impl<T> NewOwned<T> for Standard<T>
381where
382 T: Copy,
383{
384 type Error = crate::error::Infallible;
385 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
386 let b: Box<[T]> = (0..self.nrows() * self.ncols()).map(|_| value).collect();
387 let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
390
391 Ok(unsafe { Mat::from_raw_parts(self, ptr) })
395 }
396}
397
398unsafe impl<T> NewOwned<Defaulted> for Standard<T>
400where
401 T: Copy + Default,
402{
403 type Error = crate::error::Infallible;
404 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
405 self.new_owned(T::default())
406 }
407}
408
409unsafe impl<T> NewRef<T> for Standard<T>
412where
413 T: Copy,
414{
415 type Error = SliceError;
416 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
417 self.check_slice(data)?;
418
419 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
424 }
425}
426
427unsafe impl<T> NewMut<T> for Standard<T>
430where
431 T: Copy,
432{
433 type Error = SliceError;
434 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
435 self.check_slice(data)?;
436
437 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
442 }
443}
444
445#[derive(Debug)]
454pub struct Mat<T: ReprOwned> {
455 ptr: NonNull<u8>,
456 repr: T,
457}
458
459unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
461
462unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
464
465impl<T: ReprOwned> Mat<T> {
466 pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
468 where
469 T: NewOwned<U>,
470 {
471 repr.new_owned(init)
472 }
473
474 #[inline]
476 pub fn num_vectors(&self) -> usize {
477 self.repr.nrows()
478 }
479
480 pub fn repr(&self) -> &T {
482 &self.repr
483 }
484
485 #[must_use]
487 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
488 if i < self.num_vectors() {
489 let row = unsafe { self.get_row_unchecked(i) };
492 Some(row)
493 } else {
494 None
495 }
496 }
497
498 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
499 unsafe { self.repr.get_row(self.ptr, i) }
502 }
503
504 #[must_use]
506 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
507 if i < self.num_vectors() {
508 Some(unsafe { self.get_row_mut_unchecked(i) })
510 } else {
511 None
512 }
513 }
514
515 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
516 unsafe { self.repr.get_row_mut(self.ptr, i) }
519 }
520
521 #[inline]
523 pub fn as_view(&self) -> MatRef<'_, T> {
524 MatRef {
525 ptr: self.ptr,
526 repr: self.repr,
527 _lifetime: PhantomData,
528 }
529 }
530
531 #[inline]
533 pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
534 MatMut {
535 ptr: self.ptr,
536 repr: self.repr,
537 _lifetime: PhantomData,
538 }
539 }
540
541 pub fn rows(&self) -> Rows<'_, T> {
543 Rows::new(self.reborrow())
544 }
545
546 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
548 RowsMut::new(self.reborrow_mut())
549 }
550
551 pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
561 Self { ptr, repr }
562 }
563
564 #[cfg(test)]
565 fn as_ptr(&self) -> NonNull<u8> {
566 self.ptr
567 }
568}
569
570impl<T: ReprOwned> Drop for Mat<T> {
571 fn drop(&mut self) {
572 unsafe { self.repr.drop(self.ptr) };
575 }
576}
577
578impl<T: Copy> Mat<Standard<T>> {
579 #[inline]
581 pub fn vector_dim(&self) -> usize {
582 self.repr.ncols()
583 }
584}
585
586#[derive(Debug, Clone, Copy)]
602pub struct MatRef<'a, T: Repr> {
603 pub(crate) ptr: NonNull<u8>,
604 pub(crate) repr: T,
605 pub(crate) _lifetime: PhantomData<&'a [u8]>,
607}
608
609unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
611
612unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
614
615impl<'a, T: Repr> MatRef<'a, T> {
616 pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
618 where
619 T: NewRef<U>,
620 {
621 repr.new_ref(data)
622 }
623
624 #[inline]
626 pub fn num_vectors(&self) -> usize {
627 self.repr.nrows()
628 }
629
630 pub fn repr(&self) -> &T {
632 &self.repr
633 }
634
635 #[must_use]
637 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
638 if i < self.num_vectors() {
639 let row = unsafe { self.get_row_unchecked(i) };
642 Some(row)
643 } else {
644 None
645 }
646 }
647
648 #[inline]
654 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
655 unsafe { self.repr.get_row(self.ptr, i) }
657 }
658
659 pub fn rows(&self) -> Rows<'_, T> {
661 Rows::new(*self)
662 }
663
664 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
672 Self {
673 ptr,
674 repr,
675 _lifetime: PhantomData,
676 }
677 }
678}
679
680impl<'a, T: Copy> MatRef<'a, Standard<T>> {
681 #[inline]
683 pub fn vector_dim(&self) -> usize {
684 self.repr.ncols()
685 }
686}
687
688impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
690 type Target = MatRef<'this, T>;
691
692 fn reborrow(&'this self) -> Self::Target {
693 self.as_view()
694 }
695}
696
697impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
699 type Target = MatMut<'this, T>;
700
701 fn reborrow_mut(&'this mut self) -> Self::Target {
702 self.as_view_mut()
703 }
704}
705
706impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
708 type Target = MatRef<'this, T>;
709
710 fn reborrow(&'this self) -> Self::Target {
711 MatRef {
712 ptr: self.ptr,
713 repr: self.repr,
714 _lifetime: PhantomData,
715 }
716 }
717}
718
719#[derive(Debug)]
736pub struct MatMut<'a, T: ReprMut> {
737 pub(crate) ptr: NonNull<u8>,
738 pub(crate) repr: T,
739 pub(crate) _lifetime: PhantomData<&'a mut [u8]>,
741}
742
743unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
745
746unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
748
749impl<'a, T: ReprMut> MatMut<'a, T> {
750 pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
752 where
753 T: NewMut<U>,
754 {
755 repr.new_mut(data)
756 }
757
758 #[inline]
760 pub fn num_vectors(&self) -> usize {
761 self.repr.nrows()
762 }
763
764 pub fn repr(&self) -> &T {
766 &self.repr
767 }
768
769 #[inline]
771 #[must_use]
772 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
773 if i < self.num_vectors() {
774 Some(unsafe { self.get_row_unchecked(i) })
776 } else {
777 None
778 }
779 }
780
781 #[inline]
787 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
788 unsafe { self.repr.get_row(self.ptr, i) }
790 }
791
792 #[inline]
794 #[must_use]
795 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
796 if i < self.num_vectors() {
797 Some(unsafe { self.get_row_mut_unchecked(i) })
799 } else {
800 None
801 }
802 }
803
804 #[inline]
810 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
811 unsafe { self.repr.get_row_mut(self.ptr, i) }
814 }
815
816 pub fn as_view(&self) -> MatRef<'_, T> {
818 MatRef {
819 ptr: self.ptr,
820 repr: self.repr,
821 _lifetime: PhantomData,
822 }
823 }
824
825 pub fn rows(&self) -> Rows<'_, T> {
827 Rows::new(self.reborrow())
828 }
829
830 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
832 RowsMut::new(self.reborrow_mut())
833 }
834
835 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
842 Self {
843 ptr,
844 repr,
845 _lifetime: PhantomData,
846 }
847 }
848}
849
850impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
852 type Target = MatRef<'this, T>;
853
854 fn reborrow(&'this self) -> Self::Target {
855 self.as_view()
856 }
857}
858
859impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
861 type Target = MatMut<'this, T>;
862
863 fn reborrow_mut(&'this mut self) -> Self::Target {
864 MatMut {
865 ptr: self.ptr,
866 repr: self.repr,
867 _lifetime: PhantomData,
868 }
869 }
870}
871
872impl<'a, T: Copy> MatMut<'a, Standard<T>> {
873 #[inline]
875 pub fn vector_dim(&self) -> usize {
876 self.repr.ncols()
877 }
878}
879
880#[derive(Debug)]
888pub struct Rows<'a, T: Repr> {
889 matrix: MatRef<'a, T>,
890 current: usize,
891}
892
893impl<'a, T> Rows<'a, T>
894where
895 T: Repr,
896{
897 fn new(matrix: MatRef<'a, T>) -> Self {
898 Self { matrix, current: 0 }
899 }
900}
901
902impl<'a, T> Iterator for Rows<'a, T>
903where
904 T: Repr + 'a,
905{
906 type Item = T::Row<'a>;
907
908 fn next(&mut self) -> Option<Self::Item> {
909 let current = self.current;
910 if current >= self.matrix.num_vectors() {
911 None
912 } else {
913 self.current += 1;
914 Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
920 }
921 }
922
923 fn size_hint(&self) -> (usize, Option<usize>) {
924 let remaining = self.matrix.num_vectors() - self.current;
925 (remaining, Some(remaining))
926 }
927}
928
929impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
930impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
931
932#[derive(Debug)]
940pub struct RowsMut<'a, T: ReprMut> {
941 matrix: MatMut<'a, T>,
942 current: usize,
943}
944
945impl<'a, T> RowsMut<'a, T>
946where
947 T: ReprMut,
948{
949 fn new(matrix: MatMut<'a, T>) -> Self {
950 Self { matrix, current: 0 }
951 }
952}
953
954impl<'a, T> Iterator for RowsMut<'a, T>
955where
956 T: ReprMut + 'a,
957{
958 type Item = T::RowMut<'a>;
959
960 fn next(&mut self) -> Option<Self::Item> {
961 let current = self.current;
962 if current >= self.matrix.num_vectors() {
963 None
964 } else {
965 self.current += 1;
966 Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
975 }
976 }
977
978 fn size_hint(&self) -> (usize, Option<usize>) {
979 let remaining = self.matrix.num_vectors() - self.current;
980 (remaining, Some(remaining))
981 }
982}
983
984impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
985impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
986
987#[cfg(test)]
992mod tests {
993 use super::*;
994
995 use std::fmt::Display;
996
997 use diskann_utils::lazy_format;
998
999 fn assert_copy<T: Copy>(_: &T) {}
1001
1002 fn edge_cases(nrows: usize) -> Vec<usize> {
1003 let max = usize::MAX;
1004
1005 vec![
1006 nrows,
1007 nrows + 1,
1008 nrows + 11,
1009 nrows + 20,
1010 max / 2,
1011 max.div_ceil(2),
1012 max - 1,
1013 max,
1014 ]
1015 }
1016
1017 fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1018 assert_eq!(x.repr(), &repr);
1019 assert_eq!(x.num_vectors(), repr.nrows());
1020 assert_eq!(x.vector_dim(), repr.ncols());
1021
1022 for i in 0..x.num_vectors() {
1023 let row = x.get_row_mut(i).unwrap();
1024 assert_eq!(row.len(), repr.ncols());
1025 row.iter_mut()
1026 .enumerate()
1027 .for_each(|(j, r)| *r = 10 * i + j);
1028 }
1029
1030 for i in edge_cases(repr.nrows()).into_iter() {
1031 assert!(x.get_row_mut(i).is_none());
1032 }
1033 }
1034
1035 fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1036 assert_eq!(x.repr(), &repr);
1037 assert_eq!(x.num_vectors(), repr.nrows());
1038 assert_eq!(x.vector_dim(), repr.ncols());
1039
1040 for i in 0..x.num_vectors() {
1041 let row = x.get_row_mut(i).unwrap();
1042 assert_eq!(row.len(), repr.ncols());
1043
1044 row.iter_mut()
1045 .enumerate()
1046 .for_each(|(j, r)| *r = 10 * i + j);
1047 }
1048
1049 for i in edge_cases(repr.nrows()).into_iter() {
1050 assert!(x.get_row_mut(i).is_none());
1051 }
1052 }
1053
1054 fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1055 assert_eq!(x.len(), repr.nrows());
1056 let mut all_rows: Vec<_> = x.collect();
1058 assert_eq!(all_rows.len(), repr.nrows());
1059 for (i, row) in all_rows.iter_mut().enumerate() {
1060 assert_eq!(row.len(), repr.ncols());
1061 row.iter_mut()
1062 .enumerate()
1063 .for_each(|(j, r)| *r = 10 * i + j);
1064 }
1065 }
1066
1067 fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1068 assert_eq!(x.repr(), &repr);
1069 assert_eq!(x.num_vectors(), repr.nrows());
1070 assert_eq!(x.vector_dim(), repr.ncols());
1071
1072 for i in 0..x.num_vectors() {
1073 let row = x.get_row(i).unwrap();
1074
1075 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1076 row.iter().enumerate().for_each(|(j, r)| {
1077 assert_eq!(
1078 *r,
1079 10 * i + j,
1080 "mismatched entry at row {}, col {} -- ctx: {}",
1081 i,
1082 j,
1083 ctx
1084 )
1085 });
1086 }
1087
1088 for i in edge_cases(repr.nrows()).into_iter() {
1089 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1090 }
1091 }
1092
1093 fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1094 assert_eq!(x.repr(), &repr);
1095 assert_eq!(x.num_vectors(), repr.nrows());
1096 assert_eq!(x.vector_dim(), repr.ncols());
1097
1098 assert_copy(&x);
1099 for i in 0..x.num_vectors() {
1100 let row = x.get_row(i).unwrap();
1101 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1102
1103 row.iter().enumerate().for_each(|(j, r)| {
1104 assert_eq!(
1105 *r,
1106 10 * i + j,
1107 "mismatched entry at row {}, col {} -- ctx: {}",
1108 i,
1109 j,
1110 ctx
1111 )
1112 });
1113 }
1114
1115 for i in edge_cases(repr.nrows()).into_iter() {
1116 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1117 }
1118 }
1119
1120 fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1121 assert_eq!(x.repr(), &repr);
1122 assert_eq!(x.num_vectors(), repr.nrows());
1123 assert_eq!(x.vector_dim(), repr.ncols());
1124
1125 for i in 0..x.num_vectors() {
1126 let row = x.get_row(i).unwrap();
1127 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1128
1129 row.iter().enumerate().for_each(|(j, r)| {
1130 assert_eq!(
1131 *r,
1132 10 * i + j,
1133 "mismatched entry at row {}, col {} -- ctx: {}",
1134 i,
1135 j,
1136 ctx
1137 )
1138 });
1139 }
1140
1141 for i in edge_cases(repr.nrows()).into_iter() {
1142 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1143 }
1144 }
1145
1146 fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1147 assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1148 let all_rows: Vec<_> = x.collect();
1149 assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1150 for (i, row) in all_rows.iter().enumerate() {
1151 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1152 row.iter().enumerate().for_each(|(j, r)| {
1153 assert_eq!(
1154 *r,
1155 10 * i + j,
1156 "mismatched entry at row {}, col {} -- ctx: {}",
1157 i,
1158 j,
1159 ctx
1160 )
1161 });
1162 }
1163 }
1164
1165 #[test]
1170 fn standard_representation() {
1171 let repr = Standard::<f32>::new(4, 3);
1172 assert_eq!(repr.nrows(), 4);
1173 assert_eq!(repr.ncols(), 3);
1174
1175 let layout = repr.layout().unwrap();
1176 assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1177 assert_eq!(layout.align(), std::mem::align_of::<f32>());
1178 }
1179
1180 #[test]
1181 fn standard_zero_dimensions() {
1182 for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1183 let repr = Standard::<u8>::new(nrows, ncols);
1184 assert_eq!(repr.nrows(), nrows);
1185 assert_eq!(repr.ncols(), ncols);
1186 let layout = repr.layout().unwrap();
1187 assert_eq!(layout.size(), 0);
1188 }
1189 }
1190
1191 #[test]
1192 fn standard_check_slice() {
1193 let repr = Standard::<u32>::new(3, 4);
1194
1195 let data = vec![0u32; 12];
1197 assert!(repr.check_slice(&data).is_ok());
1198
1199 let short = vec![0u32; 11];
1201 assert!(matches!(
1202 repr.check_slice(&short),
1203 Err(SliceError::LengthMismatch {
1204 expected: 12,
1205 found: 11
1206 })
1207 ));
1208
1209 let long = vec![0u32; 13];
1211 assert!(matches!(
1212 repr.check_slice(&long),
1213 Err(SliceError::LengthMismatch {
1214 expected: 12,
1215 found: 13
1216 })
1217 ));
1218
1219 let overflow_repr = Standard::<u8>::new(usize::MAX, 2);
1221 assert!(matches!(
1222 overflow_repr.check_slice(&[]),
1223 Err(SliceError::Overflow)
1224 ));
1225 }
1226
1227 #[test]
1228 fn standard_layout_errors() {
1229 let overflow_repr = Standard::<u8>::new(usize::MAX, 2);
1231 assert!(overflow_repr.layout().is_err());
1232
1233 let large_repr = Standard::<u64>::new(isize::MAX as usize / 4, 2);
1237 assert!(large_repr.layout().is_err());
1238 }
1239
1240 #[test]
1245 fn mat_new_and_basic_accessors() {
1246 let mat = Mat::new(Standard::<usize>::new(3, 4), 42usize).unwrap();
1247 let base: *const u8 = mat.as_ptr().as_ptr();
1248
1249 assert_eq!(mat.num_vectors(), 3);
1250 assert_eq!(mat.vector_dim(), 4);
1251
1252 let repr = mat.repr();
1253 assert_eq!(repr.nrows(), 3);
1254 assert_eq!(repr.ncols(), 4);
1255
1256 for (i, r) in mat.rows().enumerate() {
1257 assert_eq!(r, &[42, 42, 42, 42]);
1258 let ptr = r.as_ptr().cast::<u8>();
1259 assert_eq!(
1260 ptr,
1261 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1262 );
1263 }
1264 }
1265
1266 #[test]
1267 fn mat_new_with_default() {
1268 let mat = Mat::new(Standard::<usize>::new(2, 3), Defaulted).unwrap();
1269 let base: *const u8 = mat.as_ptr().as_ptr();
1270
1271 assert_eq!(mat.num_vectors(), 2);
1272 for (i, row) in mat.rows().enumerate() {
1273 assert!(row.iter().all(|&v| v == 0));
1274
1275 let ptr = row.as_ptr().cast::<u8>();
1276 assert_eq!(
1277 ptr,
1278 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1279 );
1280 }
1281 }
1282
1283 const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1284 const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1285
1286 #[test]
1287 fn test_mat() {
1288 for nrows in ROWS {
1289 for ncols in COLS {
1290 let repr = Standard::<usize>::new(*nrows, *ncols);
1291 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1292
1293 {
1295 let ctx = &lazy_format!("{ctx} - direct");
1296 let mut mat = Mat::new(repr, Defaulted).unwrap();
1297
1298 assert_eq!(mat.num_vectors(), *nrows);
1299 assert_eq!(mat.vector_dim(), *ncols);
1300
1301 fill_mat(&mut mat, repr);
1302
1303 check_mat(&mat, repr, ctx);
1304 check_mat_ref(mat.reborrow(), repr, ctx);
1305 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1306 check_rows(mat.rows(), repr, ctx);
1307 }
1308
1309 {
1311 let ctx = &lazy_format!("{ctx} - matmut");
1312 let mut mat = Mat::new(repr, Defaulted).unwrap();
1313 let matmut = mat.reborrow_mut();
1314
1315 assert_eq!(matmut.num_vectors(), *nrows);
1316 assert_eq!(matmut.vector_dim(), *ncols);
1317
1318 fill_mat_mut(matmut, repr);
1319
1320 check_mat(&mat, repr, ctx);
1321 check_mat_ref(mat.reborrow(), repr, ctx);
1322 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1323 check_rows(mat.rows(), repr, ctx);
1324 }
1325
1326 {
1328 let ctx = &lazy_format!("{ctx} - rows_mut");
1329 let mut mat = Mat::new(repr, Defaulted).unwrap();
1330 fill_rows_mut(mat.rows_mut(), repr);
1331
1332 check_mat(&mat, repr, ctx);
1333 check_mat_ref(mat.reborrow(), repr, ctx);
1334 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1335 check_rows(mat.rows(), repr, ctx);
1336 }
1337 }
1338 }
1339 }
1340
1341 #[test]
1342 fn test_mat_refmut() {
1343 for nrows in ROWS {
1344 for ncols in COLS {
1345 let repr = Standard::<usize>::new(*nrows, *ncols);
1346 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1347
1348 {
1350 let ctx = &lazy_format!("{ctx} - by matmut");
1351 let mut b: Box<[_]> =
1352 (0..repr.num_elements().unwrap()).map(|_| 0usize).collect();
1353 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1354
1355 fill_mat_mut(matmut.reborrow_mut(), repr);
1356
1357 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1358 check_mat_ref(matmut.reborrow(), repr, ctx);
1359 check_rows(matmut.rows(), repr, ctx);
1360 check_rows(matmut.reborrow().rows(), repr, ctx);
1361
1362 let matref = MatRef::new(repr, &b).unwrap();
1363 check_mat_ref(matref, repr, ctx);
1364 check_mat_ref(matref.reborrow(), repr, ctx);
1365 check_rows(matref.rows(), repr, ctx);
1366 }
1367
1368 {
1370 let ctx = &lazy_format!("{ctx} - by rows");
1371 let mut b: Box<[_]> =
1372 (0..repr.num_elements().unwrap()).map(|_| 0usize).collect();
1373 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1374
1375 fill_rows_mut(matmut.rows_mut(), repr);
1376
1377 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1378 check_mat_ref(matmut.reborrow(), repr, ctx);
1379 check_rows(matmut.rows(), repr, ctx);
1380 check_rows(matmut.reborrow().rows(), repr, ctx);
1381
1382 let matref = MatRef::new(repr, &b).unwrap();
1383 check_mat_ref(matref, repr, ctx);
1384 check_mat_ref(matref.reborrow(), repr, ctx);
1385 check_rows(matref.rows(), repr, ctx);
1386 }
1387 }
1388 }
1389 }
1390
1391 #[test]
1396 fn test_standard_new_owned() {
1397 let rows = [0, 1, 2, 3, 5, 10];
1398 let cols = [0, 1, 2, 3, 5, 10];
1399
1400 for nrows in rows {
1401 for ncols in cols {
1402 let m = Mat::new(Standard::new(nrows, ncols), 1usize).unwrap();
1403 let rows_iter = m.rows();
1404 let len = <_ as ExactSizeIterator>::len(&rows_iter);
1405 assert_eq!(len, nrows);
1406 for r in rows_iter {
1407 assert_eq!(r.len(), ncols);
1408 assert!(r.iter().all(|i| *i == 1usize));
1409 }
1410 }
1411 }
1412 }
1413
1414 #[test]
1415 fn matref_new_slice_length_error() {
1416 let repr = Standard::<u32>::new(3, 4);
1417
1418 let data = vec![0u32; 12];
1420 assert!(MatRef::new(repr, &data).is_ok());
1421
1422 let short = vec![0u32; 11];
1424 assert!(matches!(
1425 MatRef::new(repr, &short),
1426 Err(SliceError::LengthMismatch {
1427 expected: 12,
1428 found: 11
1429 })
1430 ));
1431
1432 let long = vec![0u32; 13];
1434 assert!(matches!(
1435 MatRef::new(repr, &long),
1436 Err(SliceError::LengthMismatch {
1437 expected: 12,
1438 found: 13
1439 })
1440 ));
1441 }
1442
1443 #[test]
1444 fn matmut_new_slice_length_error() {
1445 let repr = Standard::<u32>::new(3, 4);
1446
1447 let mut data = vec![0u32; 12];
1449 assert!(MatMut::new(repr, &mut data).is_ok());
1450
1451 let mut short = vec![0u32; 11];
1453 assert!(matches!(
1454 MatMut::new(repr, &mut short),
1455 Err(SliceError::LengthMismatch {
1456 expected: 12,
1457 found: 11
1458 })
1459 ));
1460
1461 let mut long = vec![0u32; 13];
1463 assert!(matches!(
1464 MatMut::new(repr, &mut long),
1465 Err(SliceError::LengthMismatch {
1466 expected: 12,
1467 found: 13
1468 })
1469 ));
1470 }
1471}