1use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
31
32use diskann_utils::{Reborrow, ReborrowMut};
33use thiserror::Error;
34
35use crate::utils;
36
37pub unsafe trait Repr: Copy {
60 type Row<'a>
62 where
63 Self: 'a;
64
65 fn nrows(&self) -> usize;
72
73 fn layout(&self) -> Result<Layout, LayoutError>;
81
82 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
94}
95
96pub unsafe trait ReprMut: Repr {
114 type RowMut<'a>
116 where
117 Self: 'a;
118
119 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
130}
131
132pub unsafe trait ReprOwned: ReprMut {
142 unsafe fn drop(self, ptr: NonNull<u8>);
150}
151
152#[derive(Debug, Clone, Copy)]
158#[non_exhaustive]
159pub struct LayoutError;
160
161impl LayoutError {
162 pub fn new() -> Self {
164 Self
165 }
166}
167
168impl Default for LayoutError {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl std::fmt::Display for LayoutError {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 write!(f, "LayoutError")
177 }
178}
179
180impl std::error::Error for LayoutError {}
181
182impl From<std::alloc::LayoutError> for LayoutError {
183 fn from(_: std::alloc::LayoutError) -> Self {
184 LayoutError
185 }
186}
187
188pub unsafe trait NewRef<T>: Repr {
199 type Error;
201
202 fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
204}
205
206pub unsafe trait NewMut<T>: ReprMut {
213 type Error;
215
216 fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
218}
219
220pub unsafe trait NewOwned<T>: ReprOwned {
227 type Error;
229
230 fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
232}
233
234#[derive(Debug, Clone, Copy)]
245pub struct Defaulted;
246
247pub trait NewCloned: ReprOwned {
249 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct Standard<T> {
270 nrows: usize,
271 ncols: usize,
272 _elem: PhantomData<T>,
273}
274
275impl<T: Copy> Standard<T> {
276 pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
285 Overflow::check::<T>(nrows, ncols)?;
286 Ok(Self {
287 nrows,
288 ncols,
289 _elem: PhantomData,
290 })
291 }
292
293 pub fn num_elements(&self) -> usize {
295 self.nrows() * self.ncols()
297 }
298
299 fn nrows(&self) -> usize {
301 self.nrows
302 }
303
304 fn ncols(&self) -> usize {
306 self.ncols
307 }
308
309 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
314 let len = self.num_elements();
315
316 if slice.len() != len {
317 Err(SliceError::LengthMismatch {
318 expected: len,
319 found: slice.len(),
320 })
321 } else {
322 Ok(())
323 }
324 }
325
326 unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
332 debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
333
334 let ptr = utils::box_into_nonnull(b).cast::<u8>();
335
336 unsafe { Mat::from_raw_parts(self, ptr) }
340 }
341}
342
343#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346 nrows: usize,
347 ncols: usize,
348 elsize: usize,
349}
350
351impl Overflow {
352 pub(crate) fn for_type<T>(nrows: usize, ncols: usize) -> Self {
354 Self {
355 nrows,
356 ncols,
357 elsize: std::mem::size_of::<T>(),
358 }
359 }
360
361 pub(crate) fn check_byte_budget<T>(
367 capacity: usize,
368 nrows: usize,
369 ncols: usize,
370 ) -> Result<(), Self> {
371 let bytes = std::mem::size_of::<T>().saturating_mul(capacity);
372 if bytes <= isize::MAX as usize {
373 Ok(())
374 } else {
375 Err(Self::for_type::<T>(nrows, ncols))
376 }
377 }
378
379 pub(crate) fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
380 let capacity = nrows
382 .checked_mul(ncols)
383 .ok_or_else(|| Self::for_type::<T>(nrows, ncols))?;
384
385 Self::check_byte_budget::<T>(capacity, nrows, ncols)
386 }
387}
388
389impl std::fmt::Display for Overflow {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 if self.elsize == 0 {
392 write!(
393 f,
394 "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
395 self.nrows, self.ncols,
396 )
397 } else {
398 write!(
399 f,
400 "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
401 self.nrows, self.ncols, self.elsize,
402 )
403 }
404 }
405}
406
407impl std::error::Error for Overflow {}
408
409#[derive(Debug, Clone, Copy, Error)]
411#[non_exhaustive]
412pub enum SliceError {
413 #[error("Length mismatch: expected {expected}, found {found}")]
414 LengthMismatch { expected: usize, found: usize },
415}
416
417unsafe impl<T: Copy> Repr for Standard<T> {
421 type Row<'a>
422 = &'a [T]
423 where
424 T: 'a;
425
426 fn nrows(&self) -> usize {
427 self.nrows
428 }
429
430 fn layout(&self) -> Result<Layout, LayoutError> {
431 Ok(Layout::array::<T>(self.num_elements())?)
432 }
433
434 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
435 debug_assert!(ptr.cast::<T>().is_aligned());
436 debug_assert!(i < self.nrows);
437
438 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
442
443 unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
445 }
446}
447
448unsafe impl<T: Copy> ReprMut for Standard<T> {
451 type RowMut<'a>
452 = &'a mut [T]
453 where
454 T: 'a;
455
456 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
457 debug_assert!(ptr.cast::<T>().is_aligned());
458 debug_assert!(i < self.nrows);
459
460 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
464
465 unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
468 }
469}
470
471unsafe impl<T: Copy> ReprOwned for Standard<T> {
475 unsafe fn drop(self, ptr: NonNull<u8>) {
476 unsafe {
482 let slice_ptr = std::ptr::slice_from_raw_parts_mut(
483 ptr.cast::<T>().as_ptr(),
484 self.nrows * self.ncols,
485 );
486 let _ = Box::from_raw(slice_ptr);
487 }
488 }
489}
490
491unsafe impl<T> NewOwned<T> for Standard<T>
494where
495 T: Copy,
496{
497 type Error = crate::error::Infallible;
498 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
499 let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
500
501 Ok(unsafe { self.box_to_mat(b) })
503 }
504}
505
506unsafe impl<T> NewOwned<Defaulted> for Standard<T>
508where
509 T: Copy + Default,
510{
511 type Error = crate::error::Infallible;
512 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
513 self.new_owned(T::default())
514 }
515}
516
517unsafe impl<T> NewRef<T> for Standard<T>
520where
521 T: Copy,
522{
523 type Error = SliceError;
524 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
525 self.check_slice(data)?;
526
527 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
532 }
533}
534
535unsafe impl<T> NewMut<T> for Standard<T>
538where
539 T: Copy,
540{
541 type Error = SliceError;
542 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
543 self.check_slice(data)?;
544
545 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
550 }
551}
552
553impl<T> NewCloned for Standard<T>
554where
555 T: Copy,
556{
557 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
558 let b: Box<[T]> = v.rows().flatten().copied().collect();
559
560 unsafe { v.repr().box_to_mat(b) }
562 }
563}
564
565#[derive(Debug)]
574pub struct Mat<T: ReprOwned> {
575 ptr: NonNull<u8>,
576 repr: T,
577 _invariant: PhantomData<fn(T) -> T>,
578}
579
580unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
582
583unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
585
586impl<T: ReprOwned> Mat<T> {
587 pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
589 where
590 T: NewOwned<U>,
591 {
592 repr.new_owned(init)
593 }
594
595 #[inline]
597 pub fn num_vectors(&self) -> usize {
598 self.repr.nrows()
599 }
600
601 pub fn repr(&self) -> &T {
603 &self.repr
604 }
605
606 #[must_use]
608 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
609 if i < self.num_vectors() {
610 let row = unsafe { self.get_row_unchecked(i) };
613 Some(row)
614 } else {
615 None
616 }
617 }
618
619 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
620 unsafe { self.repr.get_row(self.ptr, i) }
623 }
624
625 #[must_use]
627 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
628 if i < self.num_vectors() {
629 Some(unsafe { self.get_row_mut_unchecked(i) })
631 } else {
632 None
633 }
634 }
635
636 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
637 unsafe { self.repr.get_row_mut(self.ptr, i) }
640 }
641
642 #[inline]
644 pub fn as_view(&self) -> MatRef<'_, T> {
645 MatRef {
646 ptr: self.ptr,
647 repr: self.repr,
648 _lifetime: PhantomData,
649 }
650 }
651
652 #[inline]
654 pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
655 MatMut {
656 ptr: self.ptr,
657 repr: self.repr,
658 _lifetime: PhantomData,
659 }
660 }
661
662 pub fn rows(&self) -> Rows<'_, T> {
664 Rows::new(self.reborrow())
665 }
666
667 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
669 RowsMut::new(self.reborrow_mut())
670 }
671
672 pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
682 Self {
683 ptr,
684 repr,
685 _invariant: PhantomData,
686 }
687 }
688
689 pub fn as_raw_ptr(&self) -> *const u8 {
691 self.ptr.as_ptr()
692 }
693
694 pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
696 self.ptr.as_ptr()
697 }
698}
699
700impl<T: ReprOwned> Drop for Mat<T> {
701 fn drop(&mut self) {
702 unsafe { self.repr.drop(self.ptr) };
705 }
706}
707
708impl<T: NewCloned> Clone for Mat<T> {
709 fn clone(&self) -> Self {
710 T::new_cloned(self.as_view())
711 }
712}
713
714impl<T: Copy> Mat<Standard<T>> {
715 #[inline]
717 pub fn vector_dim(&self) -> usize {
718 self.repr.ncols()
719 }
720}
721
722#[derive(Debug, Clone, Copy)]
738pub struct MatRef<'a, T: Repr> {
739 ptr: NonNull<u8>,
740 repr: T,
741 _lifetime: PhantomData<&'a T>,
743}
744
745unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
747
748unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
750
751impl<'a, T: Repr> MatRef<'a, T> {
752 pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
754 where
755 T: NewRef<U>,
756 {
757 repr.new_ref(data)
758 }
759
760 #[inline]
762 pub fn num_vectors(&self) -> usize {
763 self.repr.nrows()
764 }
765
766 pub fn repr(&self) -> &T {
768 &self.repr
769 }
770
771 #[must_use]
773 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
774 if i < self.num_vectors() {
775 let row = unsafe { self.get_row_unchecked(i) };
778 Some(row)
779 } else {
780 None
781 }
782 }
783
784 #[inline]
790 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
791 unsafe { self.repr.get_row(self.ptr, i) }
793 }
794
795 pub fn rows(&self) -> Rows<'_, T> {
797 Rows::new(*self)
798 }
799
800 pub fn to_owned(&self) -> Mat<T>
802 where
803 T: NewCloned,
804 {
805 T::new_cloned(*self)
806 }
807
808 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
816 Self {
817 ptr,
818 repr,
819 _lifetime: PhantomData,
820 }
821 }
822
823 pub fn as_raw_ptr(&self) -> *const u8 {
825 self.ptr.as_ptr()
826 }
827}
828
829impl<'a, T: Copy> MatRef<'a, Standard<T>> {
830 #[inline]
832 pub fn vector_dim(&self) -> usize {
833 self.repr.ncols()
834 }
835}
836
837impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
839 type Target = MatRef<'this, T>;
840
841 fn reborrow(&'this self) -> Self::Target {
842 self.as_view()
843 }
844}
845
846impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
848 type Target = MatMut<'this, T>;
849
850 fn reborrow_mut(&'this mut self) -> Self::Target {
851 self.as_view_mut()
852 }
853}
854
855impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
857 type Target = MatRef<'this, T>;
858
859 fn reborrow(&'this self) -> Self::Target {
860 MatRef {
861 ptr: self.ptr,
862 repr: self.repr,
863 _lifetime: PhantomData,
864 }
865 }
866}
867
868#[derive(Debug)]
885pub struct MatMut<'a, T: ReprMut> {
886 ptr: NonNull<u8>,
887 repr: T,
888 _lifetime: PhantomData<&'a mut T>,
890}
891
892unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
894
895unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
897
898impl<'a, T: ReprMut> MatMut<'a, T> {
899 pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
901 where
902 T: NewMut<U>,
903 {
904 repr.new_mut(data)
905 }
906
907 #[inline]
909 pub fn num_vectors(&self) -> usize {
910 self.repr.nrows()
911 }
912
913 pub fn repr(&self) -> &T {
915 &self.repr
916 }
917
918 #[inline]
920 #[must_use]
921 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
922 if i < self.num_vectors() {
923 Some(unsafe { self.get_row_unchecked(i) })
925 } else {
926 None
927 }
928 }
929
930 #[inline]
936 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
937 unsafe { self.repr.get_row(self.ptr, i) }
939 }
940
941 #[inline]
943 #[must_use]
944 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
945 if i < self.num_vectors() {
946 Some(unsafe { self.get_row_mut_unchecked(i) })
948 } else {
949 None
950 }
951 }
952
953 #[inline]
959 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
960 unsafe { self.repr.get_row_mut(self.ptr, i) }
963 }
964
965 pub fn as_view(&self) -> MatRef<'_, T> {
967 MatRef {
968 ptr: self.ptr,
969 repr: self.repr,
970 _lifetime: PhantomData,
971 }
972 }
973
974 pub fn rows(&self) -> Rows<'_, T> {
976 Rows::new(self.reborrow())
977 }
978
979 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
981 RowsMut::new(self.reborrow_mut())
982 }
983
984 pub fn to_owned(&self) -> Mat<T>
986 where
987 T: NewCloned,
988 {
989 T::new_cloned(self.as_view())
990 }
991
992 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
999 Self {
1000 ptr,
1001 repr,
1002 _lifetime: PhantomData,
1003 }
1004 }
1005
1006 pub fn as_raw_ptr(&self) -> *const u8 {
1008 self.ptr.as_ptr()
1009 }
1010
1011 pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
1013 self.ptr.as_ptr()
1014 }
1015}
1016
1017impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
1019 type Target = MatRef<'this, T>;
1020
1021 fn reborrow(&'this self) -> Self::Target {
1022 self.as_view()
1023 }
1024}
1025
1026impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
1028 type Target = MatMut<'this, T>;
1029
1030 fn reborrow_mut(&'this mut self) -> Self::Target {
1031 MatMut {
1032 ptr: self.ptr,
1033 repr: self.repr,
1034 _lifetime: PhantomData,
1035 }
1036 }
1037}
1038
1039impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1040 #[inline]
1042 pub fn vector_dim(&self) -> usize {
1043 self.repr.ncols()
1044 }
1045}
1046
1047#[derive(Debug)]
1055pub struct Rows<'a, T: Repr> {
1056 matrix: MatRef<'a, T>,
1057 current: usize,
1058}
1059
1060impl<'a, T> Rows<'a, T>
1061where
1062 T: Repr,
1063{
1064 fn new(matrix: MatRef<'a, T>) -> Self {
1065 Self { matrix, current: 0 }
1066 }
1067}
1068
1069impl<'a, T> Iterator for Rows<'a, T>
1070where
1071 T: Repr + 'a,
1072{
1073 type Item = T::Row<'a>;
1074
1075 fn next(&mut self) -> Option<Self::Item> {
1076 let current = self.current;
1077 if current >= self.matrix.num_vectors() {
1078 None
1079 } else {
1080 self.current += 1;
1081 Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1087 }
1088 }
1089
1090 fn size_hint(&self) -> (usize, Option<usize>) {
1091 let remaining = self.matrix.num_vectors() - self.current;
1092 (remaining, Some(remaining))
1093 }
1094}
1095
1096impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1097impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1098
1099#[derive(Debug)]
1107pub struct RowsMut<'a, T: ReprMut> {
1108 matrix: MatMut<'a, T>,
1109 current: usize,
1110}
1111
1112impl<'a, T> RowsMut<'a, T>
1113where
1114 T: ReprMut,
1115{
1116 fn new(matrix: MatMut<'a, T>) -> Self {
1117 Self { matrix, current: 0 }
1118 }
1119}
1120
1121impl<'a, T> Iterator for RowsMut<'a, T>
1122where
1123 T: ReprMut + 'a,
1124{
1125 type Item = T::RowMut<'a>;
1126
1127 fn next(&mut self) -> Option<Self::Item> {
1128 let current = self.current;
1129 if current >= self.matrix.num_vectors() {
1130 None
1131 } else {
1132 self.current += 1;
1133 Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1142 }
1143 }
1144
1145 fn size_hint(&self) -> (usize, Option<usize>) {
1146 let remaining = self.matrix.num_vectors() - self.current;
1147 (remaining, Some(remaining))
1148 }
1149}
1150
1151impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1152impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1153
1154#[cfg(test)]
1159mod tests {
1160 use super::*;
1161
1162 use std::fmt::Display;
1163
1164 use diskann_utils::lazy_format;
1165
1166 fn assert_copy<T: Copy>(_: &T) {}
1168
1169 fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
1179 v: MatRef<'long, T>,
1180 ) -> MatRef<'short, T> {
1181 v
1182 }
1183
1184 fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
1186 v: MatRef<'a, Standard<&'long u8>>,
1187 ) -> MatRef<'a, Standard<&'short u8>> {
1188 v
1189 }
1190
1191 fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
1193 v: MatMut<'long, T>,
1194 ) -> MatMut<'short, T> {
1195 v
1196 }
1197
1198 fn edge_cases(nrows: usize) -> Vec<usize> {
1199 let max = usize::MAX;
1200
1201 vec![
1202 nrows,
1203 nrows + 1,
1204 nrows + 11,
1205 nrows + 20,
1206 max / 2,
1207 max.div_ceil(2),
1208 max - 1,
1209 max,
1210 ]
1211 }
1212
1213 fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1214 assert_eq!(x.repr(), &repr);
1215 assert_eq!(x.num_vectors(), repr.nrows());
1216 assert_eq!(x.vector_dim(), repr.ncols());
1217
1218 for i in 0..x.num_vectors() {
1219 let row = x.get_row_mut(i).unwrap();
1220 assert_eq!(row.len(), repr.ncols());
1221 row.iter_mut()
1222 .enumerate()
1223 .for_each(|(j, r)| *r = 10 * i + j);
1224 }
1225
1226 for i in edge_cases(repr.nrows()).into_iter() {
1227 assert!(x.get_row_mut(i).is_none());
1228 }
1229 }
1230
1231 fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1232 assert_eq!(x.repr(), &repr);
1233 assert_eq!(x.num_vectors(), repr.nrows());
1234 assert_eq!(x.vector_dim(), repr.ncols());
1235
1236 for i in 0..x.num_vectors() {
1237 let row = x.get_row_mut(i).unwrap();
1238 assert_eq!(row.len(), repr.ncols());
1239
1240 row.iter_mut()
1241 .enumerate()
1242 .for_each(|(j, r)| *r = 10 * i + j);
1243 }
1244
1245 for i in edge_cases(repr.nrows()).into_iter() {
1246 assert!(x.get_row_mut(i).is_none());
1247 }
1248 }
1249
1250 fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1251 assert_eq!(x.len(), repr.nrows());
1252 let mut all_rows: Vec<_> = x.collect();
1254 assert_eq!(all_rows.len(), repr.nrows());
1255 for (i, row) in all_rows.iter_mut().enumerate() {
1256 assert_eq!(row.len(), repr.ncols());
1257 row.iter_mut()
1258 .enumerate()
1259 .for_each(|(j, r)| *r = 10 * i + j);
1260 }
1261 }
1262
1263 fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1264 assert_eq!(x.repr(), &repr);
1265 assert_eq!(x.num_vectors(), repr.nrows());
1266 assert_eq!(x.vector_dim(), repr.ncols());
1267
1268 for i in 0..x.num_vectors() {
1269 let row = x.get_row(i).unwrap();
1270
1271 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1272 row.iter().enumerate().for_each(|(j, r)| {
1273 assert_eq!(
1274 *r,
1275 10 * i + j,
1276 "mismatched entry at row {}, col {} -- ctx: {}",
1277 i,
1278 j,
1279 ctx
1280 )
1281 });
1282 }
1283
1284 for i in edge_cases(repr.nrows()).into_iter() {
1285 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1286 }
1287 }
1288
1289 fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1290 assert_eq!(x.repr(), &repr);
1291 assert_eq!(x.num_vectors(), repr.nrows());
1292 assert_eq!(x.vector_dim(), repr.ncols());
1293
1294 assert_copy(&x);
1295 for i in 0..x.num_vectors() {
1296 let row = x.get_row(i).unwrap();
1297 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1298
1299 row.iter().enumerate().for_each(|(j, r)| {
1300 assert_eq!(
1301 *r,
1302 10 * i + j,
1303 "mismatched entry at row {}, col {} -- ctx: {}",
1304 i,
1305 j,
1306 ctx
1307 )
1308 });
1309 }
1310
1311 for i in edge_cases(repr.nrows()).into_iter() {
1312 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1313 }
1314 }
1315
1316 fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1317 assert_eq!(x.repr(), &repr);
1318 assert_eq!(x.num_vectors(), repr.nrows());
1319 assert_eq!(x.vector_dim(), repr.ncols());
1320
1321 for i in 0..x.num_vectors() {
1322 let row = x.get_row(i).unwrap();
1323 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1324
1325 row.iter().enumerate().for_each(|(j, r)| {
1326 assert_eq!(
1327 *r,
1328 10 * i + j,
1329 "mismatched entry at row {}, col {} -- ctx: {}",
1330 i,
1331 j,
1332 ctx
1333 )
1334 });
1335 }
1336
1337 for i in edge_cases(repr.nrows()).into_iter() {
1338 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1339 }
1340 }
1341
1342 fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1343 assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1344 let all_rows: Vec<_> = x.collect();
1345 assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1346 for (i, row) in all_rows.iter().enumerate() {
1347 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1348 row.iter().enumerate().for_each(|(j, r)| {
1349 assert_eq!(
1350 *r,
1351 10 * i + j,
1352 "mismatched entry at row {}, col {} -- ctx: {}",
1353 i,
1354 j,
1355 ctx
1356 )
1357 });
1358 }
1359 }
1360
1361 #[test]
1366 fn standard_representation() {
1367 let repr = Standard::<f32>::new(4, 3).unwrap();
1368 assert_eq!(repr.nrows(), 4);
1369 assert_eq!(repr.ncols(), 3);
1370
1371 let layout = repr.layout().unwrap();
1372 assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1373 assert_eq!(layout.align(), std::mem::align_of::<f32>());
1374 }
1375
1376 #[test]
1377 fn standard_zero_dimensions() {
1378 for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1379 let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1380 assert_eq!(repr.nrows(), nrows);
1381 assert_eq!(repr.ncols(), ncols);
1382 let layout = repr.layout().unwrap();
1383 assert_eq!(layout.size(), 0);
1384 }
1385 }
1386
1387 #[test]
1388 fn standard_check_slice() {
1389 let repr = Standard::<u32>::new(3, 4).unwrap();
1390
1391 let data = vec![0u32; 12];
1393 assert!(repr.check_slice(&data).is_ok());
1394
1395 let short = vec![0u32; 11];
1397 assert!(matches!(
1398 repr.check_slice(&short),
1399 Err(SliceError::LengthMismatch {
1400 expected: 12,
1401 found: 11
1402 })
1403 ));
1404
1405 let long = vec![0u32; 13];
1407 assert!(matches!(
1408 repr.check_slice(&long),
1409 Err(SliceError::LengthMismatch {
1410 expected: 12,
1411 found: 13
1412 })
1413 ));
1414
1415 let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1417 assert!(matches!(overflow_repr, Overflow { .. }));
1418 }
1419
1420 #[test]
1421 fn standard_new_rejects_element_count_overflow() {
1422 assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1424 assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1425 assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1426 }
1427
1428 #[test]
1429 fn standard_new_rejects_byte_count_exceeding_isize_max() {
1430 let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1432 assert!(Standard::<u64>::new(half, 1).is_err());
1433 assert!(Standard::<u64>::new(1, half).is_err());
1434 }
1435
1436 #[test]
1437 fn standard_new_accepts_boundary_below_isize_max() {
1438 let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1440 let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1441 assert_eq!(repr.num_elements(), max_elems);
1442 }
1443
1444 #[test]
1445 fn standard_new_zst_rejects_element_count_overflow() {
1446 assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1449 assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1450 }
1451
1452 #[test]
1453 fn standard_new_zst_accepts_large_non_overflowing() {
1454 let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1456 assert_eq!(repr.num_elements(), usize::MAX);
1457 assert_eq!(repr.layout().unwrap().size(), 0);
1458 }
1459
1460 #[test]
1461 fn standard_new_overflow_error_display() {
1462 let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1463 let msg = err.to_string();
1464 assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1465
1466 let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1467 let zst_msg = zst_err.to_string();
1468 assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1469 assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1470 }
1471
1472 #[test]
1477 fn mat_new_and_basic_accessors() {
1478 let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1479 let base: *const u8 = mat.as_raw_ptr();
1480
1481 assert_eq!(mat.num_vectors(), 3);
1482 assert_eq!(mat.vector_dim(), 4);
1483
1484 let repr = mat.repr();
1485 assert_eq!(repr.nrows(), 3);
1486 assert_eq!(repr.ncols(), 4);
1487
1488 for (i, r) in mat.rows().enumerate() {
1489 assert_eq!(r, &[42, 42, 42, 42]);
1490 let ptr = r.as_ptr().cast::<u8>();
1491 assert_eq!(
1492 ptr,
1493 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1494 );
1495 }
1496 }
1497
1498 #[test]
1499 fn mat_new_with_default() {
1500 let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1501 let base: *const u8 = mat.as_raw_ptr();
1502
1503 assert_eq!(mat.num_vectors(), 2);
1504 for (i, row) in mat.rows().enumerate() {
1505 assert!(row.iter().all(|&v| v == 0));
1506
1507 let ptr = row.as_ptr().cast::<u8>();
1508 assert_eq!(
1509 ptr,
1510 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1511 );
1512 }
1513 }
1514
1515 const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1516 const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1517
1518 #[test]
1519 fn test_mat() {
1520 for nrows in ROWS {
1521 for ncols in COLS {
1522 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1523 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1524
1525 {
1527 let ctx = &lazy_format!("{ctx} - direct");
1528 let mut mat = Mat::new(repr, Defaulted).unwrap();
1529
1530 assert_eq!(mat.num_vectors(), *nrows);
1531 assert_eq!(mat.vector_dim(), *ncols);
1532
1533 fill_mat(&mut mat, repr);
1534
1535 check_mat(&mat, repr, ctx);
1536 check_mat_ref(mat.reborrow(), repr, ctx);
1537 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1538 check_rows(mat.rows(), repr, ctx);
1539
1540 assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1542 assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1543 }
1544
1545 {
1547 let ctx = &lazy_format!("{ctx} - matmut");
1548 let mut mat = Mat::new(repr, Defaulted).unwrap();
1549 let matmut = mat.reborrow_mut();
1550
1551 assert_eq!(matmut.num_vectors(), *nrows);
1552 assert_eq!(matmut.vector_dim(), *ncols);
1553
1554 fill_mat_mut(matmut, repr);
1555
1556 check_mat(&mat, repr, ctx);
1557 check_mat_ref(mat.reborrow(), repr, ctx);
1558 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1559 check_rows(mat.rows(), repr, ctx);
1560 }
1561
1562 {
1564 let ctx = &lazy_format!("{ctx} - rows_mut");
1565 let mut mat = Mat::new(repr, Defaulted).unwrap();
1566 fill_rows_mut(mat.rows_mut(), repr);
1567
1568 check_mat(&mat, repr, ctx);
1569 check_mat_ref(mat.reborrow(), repr, ctx);
1570 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1571 check_rows(mat.rows(), repr, ctx);
1572 }
1573 }
1574 }
1575 }
1576
1577 #[test]
1578 fn test_mat_clone() {
1579 for nrows in ROWS {
1580 for ncols in COLS {
1581 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1582 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1583
1584 let mut mat = Mat::new(repr, Defaulted).unwrap();
1585 fill_mat(&mut mat, repr);
1586
1587 {
1589 let ctx = &lazy_format!("{ctx} - Mat::clone");
1590 let cloned = mat.clone();
1591
1592 assert_eq!(cloned.num_vectors(), *nrows);
1593 assert_eq!(cloned.vector_dim(), *ncols);
1594
1595 check_mat(&cloned, repr, ctx);
1596 check_mat_ref(cloned.reborrow(), repr, ctx);
1597 check_rows(cloned.rows(), repr, ctx);
1598
1599 if repr.num_elements() > 0 {
1601 assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1602 }
1603 }
1604
1605 {
1607 let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1608 let owned = mat.as_view().to_owned();
1609
1610 check_mat(&owned, repr, ctx);
1611 check_mat_ref(owned.reborrow(), repr, ctx);
1612 check_rows(owned.rows(), repr, ctx);
1613
1614 if repr.num_elements() > 0 {
1615 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1616 }
1617 }
1618
1619 {
1621 let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1622 let owned = mat.as_view_mut().to_owned();
1623
1624 check_mat(&owned, repr, ctx);
1625 check_mat_ref(owned.reborrow(), repr, ctx);
1626 check_rows(owned.rows(), repr, ctx);
1627
1628 if repr.num_elements() > 0 {
1629 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1630 }
1631 }
1632 }
1633 }
1634 }
1635
1636 #[test]
1637 fn test_mat_refmut() {
1638 for nrows in ROWS {
1639 for ncols in COLS {
1640 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1641 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1642
1643 {
1645 let ctx = &lazy_format!("{ctx} - by matmut");
1646 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1647 let ptr = b.as_ptr().cast::<u8>();
1648 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1649
1650 assert_eq!(
1651 ptr,
1652 matmut.as_raw_ptr(),
1653 "underlying memory should be preserved",
1654 );
1655
1656 fill_mat_mut(matmut.reborrow_mut(), repr);
1657
1658 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1659 check_mat_ref(matmut.reborrow(), repr, ctx);
1660 check_rows(matmut.rows(), repr, ctx);
1661 check_rows(matmut.reborrow().rows(), repr, ctx);
1662
1663 let matref = MatRef::new(repr, &b).unwrap();
1664 check_mat_ref(matref, repr, ctx);
1665 check_mat_ref(matref.reborrow(), repr, ctx);
1666 check_rows(matref.rows(), repr, ctx);
1667 }
1668
1669 {
1671 let ctx = &lazy_format!("{ctx} - by rows");
1672 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1673 let ptr = b.as_ptr().cast::<u8>();
1674 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1675
1676 assert_eq!(
1677 ptr,
1678 matmut.as_raw_ptr(),
1679 "underlying memory should be preserved",
1680 );
1681
1682 fill_rows_mut(matmut.rows_mut(), repr);
1683
1684 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1685 check_mat_ref(matmut.reborrow(), repr, ctx);
1686 check_rows(matmut.rows(), repr, ctx);
1687 check_rows(matmut.reborrow().rows(), repr, ctx);
1688
1689 let matref = MatRef::new(repr, &b).unwrap();
1690 check_mat_ref(matref, repr, ctx);
1691 check_mat_ref(matref.reborrow(), repr, ctx);
1692 check_rows(matref.rows(), repr, ctx);
1693 }
1694 }
1695 }
1696 }
1697
1698 #[test]
1703 fn test_standard_new_owned() {
1704 let rows = [0, 1, 2, 3, 5, 10];
1705 let cols = [0, 1, 2, 3, 5, 10];
1706
1707 for nrows in rows {
1708 for ncols in cols {
1709 let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1710 let rows_iter = m.rows();
1711 let len = <_ as ExactSizeIterator>::len(&rows_iter);
1712 assert_eq!(len, nrows);
1713 for r in rows_iter {
1714 assert_eq!(r.len(), ncols);
1715 assert!(r.iter().all(|i| *i == 1usize));
1716 }
1717 }
1718 }
1719 }
1720
1721 #[test]
1722 fn matref_new_slice_length_error() {
1723 let repr = Standard::<u32>::new(3, 4).unwrap();
1724
1725 let data = vec![0u32; 12];
1727 assert!(MatRef::new(repr, &data).is_ok());
1728
1729 let short = vec![0u32; 11];
1731 assert!(matches!(
1732 MatRef::new(repr, &short),
1733 Err(SliceError::LengthMismatch {
1734 expected: 12,
1735 found: 11
1736 })
1737 ));
1738
1739 let long = vec![0u32; 13];
1741 assert!(matches!(
1742 MatRef::new(repr, &long),
1743 Err(SliceError::LengthMismatch {
1744 expected: 12,
1745 found: 13
1746 })
1747 ));
1748 }
1749
1750 #[test]
1751 fn matmut_new_slice_length_error() {
1752 let repr = Standard::<u32>::new(3, 4).unwrap();
1753
1754 let mut data = vec![0u32; 12];
1756 assert!(MatMut::new(repr, &mut data).is_ok());
1757
1758 let mut short = vec![0u32; 11];
1760 assert!(matches!(
1761 MatMut::new(repr, &mut short),
1762 Err(SliceError::LengthMismatch {
1763 expected: 12,
1764 found: 11
1765 })
1766 ));
1767
1768 let mut long = vec![0u32; 13];
1770 assert!(matches!(
1771 MatMut::new(repr, &mut long),
1772 Err(SliceError::LengthMismatch {
1773 expected: 12,
1774 found: 13
1775 })
1776 ));
1777 }
1778}