1use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
31
32use diskann_utils::{Reborrow, ReborrowMut, views::MatrixView};
33use thiserror::Error;
34
35use crate::utils;
36
37pub unsafe trait Repr: Copy {
60 type Row<'a>
62 where
63 Self: 'a;
64
65 fn nrows(&self) -> usize;
72
73 fn layout(&self) -> Result<Layout, LayoutError>;
81
82 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
94}
95
96pub unsafe trait ReprMut: Repr {
114 type RowMut<'a>
116 where
117 Self: 'a;
118
119 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
130}
131
132pub unsafe trait ReprOwned: ReprMut {
142 unsafe fn drop(self, ptr: NonNull<u8>);
150}
151
152#[derive(Debug, Clone, Copy)]
158#[non_exhaustive]
159pub struct LayoutError;
160
161impl LayoutError {
162 pub fn new() -> Self {
164 Self
165 }
166}
167
168impl Default for LayoutError {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl std::fmt::Display for LayoutError {
175 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176 write!(f, "LayoutError")
177 }
178}
179
180impl std::error::Error for LayoutError {}
181
182impl From<std::alloc::LayoutError> for LayoutError {
183 fn from(_: std::alloc::LayoutError) -> Self {
184 LayoutError
185 }
186}
187
188pub unsafe trait NewRef<T>: Repr {
199 type Error;
201
202 fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
204}
205
206pub unsafe trait NewMut<T>: ReprMut {
213 type Error;
215
216 fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
218}
219
220pub unsafe trait NewOwned<T>: ReprOwned {
227 type Error;
229
230 fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
232}
233
234#[derive(Debug, Clone, Copy)]
245pub struct Defaulted;
246
247pub trait NewCloned: ReprOwned {
249 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct Standard<T> {
270 nrows: usize,
271 ncols: usize,
272 _elem: PhantomData<T>,
273}
274
275impl<T: Copy> Standard<T> {
276 pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
285 Overflow::check::<T>(nrows, ncols)?;
286 Ok(Self {
287 nrows,
288 ncols,
289 _elem: PhantomData,
290 })
291 }
292
293 pub fn num_elements(&self) -> usize {
295 self.nrows() * self.ncols()
297 }
298
299 fn nrows(&self) -> usize {
301 self.nrows
302 }
303
304 fn ncols(&self) -> usize {
306 self.ncols
307 }
308
309 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
314 let len = self.num_elements();
315
316 if slice.len() != len {
317 Err(SliceError::LengthMismatch {
318 expected: len,
319 found: slice.len(),
320 })
321 } else {
322 Ok(())
323 }
324 }
325
326 unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
332 debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
333
334 let ptr = utils::box_into_nonnull(b).cast::<u8>();
335
336 unsafe { Mat::from_raw_parts(self, ptr) }
340 }
341}
342
343#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346 nrows: usize,
347 ncols: usize,
348 elsize: usize,
349}
350
351impl Overflow {
352 pub(crate) fn for_type<T>(nrows: usize, ncols: usize) -> Self {
354 Self {
355 nrows,
356 ncols,
357 elsize: std::mem::size_of::<T>(),
358 }
359 }
360
361 pub(crate) fn check_byte_budget<T>(
367 capacity: usize,
368 nrows: usize,
369 ncols: usize,
370 ) -> Result<(), Self> {
371 let bytes = std::mem::size_of::<T>().saturating_mul(capacity);
372 if bytes <= isize::MAX as usize {
373 Ok(())
374 } else {
375 Err(Self::for_type::<T>(nrows, ncols))
376 }
377 }
378
379 pub(crate) fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
380 let capacity = nrows
382 .checked_mul(ncols)
383 .ok_or_else(|| Self::for_type::<T>(nrows, ncols))?;
384
385 Self::check_byte_budget::<T>(capacity, nrows, ncols)
386 }
387}
388
389impl std::fmt::Display for Overflow {
390 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391 if self.elsize == 0 {
392 write!(
393 f,
394 "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
395 self.nrows, self.ncols,
396 )
397 } else {
398 write!(
399 f,
400 "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
401 self.nrows, self.ncols, self.elsize,
402 )
403 }
404 }
405}
406
407impl std::error::Error for Overflow {}
408
409#[derive(Debug, Clone, Copy, Error)]
411#[non_exhaustive]
412pub enum SliceError {
413 #[error("Length mismatch: expected {expected}, found {found}")]
414 LengthMismatch { expected: usize, found: usize },
415}
416
417unsafe impl<T: Copy> Repr for Standard<T> {
421 type Row<'a>
422 = &'a [T]
423 where
424 T: 'a;
425
426 fn nrows(&self) -> usize {
427 self.nrows
428 }
429
430 fn layout(&self) -> Result<Layout, LayoutError> {
431 Ok(Layout::array::<T>(self.num_elements())?)
432 }
433
434 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
435 debug_assert!(ptr.cast::<T>().is_aligned());
436 debug_assert!(i < self.nrows);
437
438 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
442
443 unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
445 }
446}
447
448unsafe impl<T: Copy> ReprMut for Standard<T> {
451 type RowMut<'a>
452 = &'a mut [T]
453 where
454 T: 'a;
455
456 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
457 debug_assert!(ptr.cast::<T>().is_aligned());
458 debug_assert!(i < self.nrows);
459
460 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
464
465 unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
468 }
469}
470
471unsafe impl<T: Copy> ReprOwned for Standard<T> {
475 unsafe fn drop(self, ptr: NonNull<u8>) {
476 unsafe {
482 let slice_ptr = std::ptr::slice_from_raw_parts_mut(
483 ptr.cast::<T>().as_ptr(),
484 self.nrows * self.ncols,
485 );
486 let _ = Box::from_raw(slice_ptr);
487 }
488 }
489}
490
491unsafe impl<T> NewOwned<T> for Standard<T>
494where
495 T: Copy,
496{
497 type Error = crate::error::Infallible;
498 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
499 let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
500
501 Ok(unsafe { self.box_to_mat(b) })
503 }
504}
505
506unsafe impl<T> NewOwned<Defaulted> for Standard<T>
508where
509 T: Copy + Default,
510{
511 type Error = crate::error::Infallible;
512 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
513 self.new_owned(T::default())
514 }
515}
516
517unsafe impl<T> NewRef<T> for Standard<T>
520where
521 T: Copy,
522{
523 type Error = SliceError;
524 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
525 self.check_slice(data)?;
526
527 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
532 }
533}
534
535unsafe impl<T> NewMut<T> for Standard<T>
538where
539 T: Copy,
540{
541 type Error = SliceError;
542 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
543 self.check_slice(data)?;
544
545 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
550 }
551}
552
553impl<T> NewCloned for Standard<T>
554where
555 T: Copy,
556{
557 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
558 let b: Box<[T]> = v.rows().flatten().copied().collect();
559
560 unsafe { v.repr().box_to_mat(b) }
562 }
563}
564
565#[derive(Debug)]
574pub struct Mat<T: ReprOwned> {
575 ptr: NonNull<u8>,
576 repr: T,
577 _invariant: PhantomData<fn(T) -> T>,
578}
579
580unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
582
583unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
585
586impl<T: ReprOwned> Mat<T> {
587 pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
589 where
590 T: NewOwned<U>,
591 {
592 repr.new_owned(init)
593 }
594
595 #[inline]
597 pub fn num_vectors(&self) -> usize {
598 self.repr.nrows()
599 }
600
601 pub fn repr(&self) -> &T {
603 &self.repr
604 }
605
606 #[must_use]
608 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
609 if i < self.num_vectors() {
610 let row = unsafe { self.get_row_unchecked(i) };
613 Some(row)
614 } else {
615 None
616 }
617 }
618
619 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
620 unsafe { self.repr.get_row(self.ptr, i) }
623 }
624
625 #[must_use]
627 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
628 if i < self.num_vectors() {
629 Some(unsafe { self.get_row_mut_unchecked(i) })
631 } else {
632 None
633 }
634 }
635
636 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
637 unsafe { self.repr.get_row_mut(self.ptr, i) }
640 }
641
642 #[inline]
644 pub fn as_view(&self) -> MatRef<'_, T> {
645 MatRef {
646 ptr: self.ptr,
647 repr: self.repr,
648 _lifetime: PhantomData,
649 }
650 }
651
652 #[inline]
654 pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
655 MatMut {
656 ptr: self.ptr,
657 repr: self.repr,
658 _lifetime: PhantomData,
659 }
660 }
661
662 pub fn rows(&self) -> Rows<'_, T> {
664 Rows::new(self.reborrow())
665 }
666
667 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
669 RowsMut::new(self.reborrow_mut())
670 }
671
672 pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
682 Self {
683 ptr,
684 repr,
685 _invariant: PhantomData,
686 }
687 }
688
689 pub fn as_raw_ptr(&self) -> *const u8 {
691 self.ptr.as_ptr()
692 }
693
694 pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
696 self.ptr.as_ptr()
697 }
698}
699
700impl<T: ReprOwned> Drop for Mat<T> {
701 fn drop(&mut self) {
702 unsafe { self.repr.drop(self.ptr) };
705 }
706}
707
708impl<T: NewCloned> Clone for Mat<T> {
709 fn clone(&self) -> Self {
710 T::new_cloned(self.as_view())
711 }
712}
713
714impl<T: Copy> Mat<Standard<T>> {
715 #[inline]
717 pub fn vector_dim(&self) -> usize {
718 self.repr.ncols()
719 }
720
721 #[inline]
725 pub fn as_slice(&self) -> &[T] {
726 self.as_view().as_slice()
727 }
728
729 #[inline]
731 pub fn as_matrix_view(&self) -> MatrixView<'_, T> {
732 self.as_view().as_matrix_view()
733 }
734}
735
736#[derive(Debug, Clone, Copy)]
752pub struct MatRef<'a, T: Repr> {
753 ptr: NonNull<u8>,
754 repr: T,
755 _lifetime: PhantomData<&'a T>,
757}
758
759unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
761
762unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
764
765impl<'a, T: Repr> MatRef<'a, T> {
766 pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
768 where
769 T: NewRef<U>,
770 {
771 repr.new_ref(data)
772 }
773
774 #[inline]
776 pub fn num_vectors(&self) -> usize {
777 self.repr.nrows()
778 }
779
780 pub fn repr(&self) -> &T {
782 &self.repr
783 }
784
785 #[must_use]
787 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
788 if i < self.num_vectors() {
789 let row = unsafe { self.get_row_unchecked(i) };
792 Some(row)
793 } else {
794 None
795 }
796 }
797
798 #[inline]
804 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
805 unsafe { self.repr.get_row(self.ptr, i) }
807 }
808
809 pub fn rows(&self) -> Rows<'_, T> {
811 Rows::new(*self)
812 }
813
814 pub fn to_owned(&self) -> Mat<T>
816 where
817 T: NewCloned,
818 {
819 T::new_cloned(*self)
820 }
821
822 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
830 Self {
831 ptr,
832 repr,
833 _lifetime: PhantomData,
834 }
835 }
836
837 pub fn as_raw_ptr(&self) -> *const u8 {
839 self.ptr.as_ptr()
840 }
841}
842
843impl<'a, T: Copy> MatRef<'a, Standard<T>> {
844 #[inline]
846 pub fn vector_dim(&self) -> usize {
847 self.repr.ncols()
848 }
849
850 #[inline]
854 pub fn as_slice(&self) -> &'a [T] {
855 let len = self.repr.num_elements();
856 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr().cast::<T>(), len) }
859 }
860
861 #[allow(clippy::expect_used)]
863 #[inline]
864 pub fn as_matrix_view(&self) -> MatrixView<'a, T> {
865 MatrixView::try_from(self.as_slice(), self.num_vectors(), self.vector_dim())
868 .expect("Standard<T> has valid dimensions")
869 }
870}
871
872impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
874 type Target = MatRef<'this, T>;
875
876 fn reborrow(&'this self) -> Self::Target {
877 self.as_view()
878 }
879}
880
881impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
883 type Target = MatMut<'this, T>;
884
885 fn reborrow_mut(&'this mut self) -> Self::Target {
886 self.as_view_mut()
887 }
888}
889
890impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
892 type Target = MatRef<'this, T>;
893
894 fn reborrow(&'this self) -> Self::Target {
895 MatRef {
896 ptr: self.ptr,
897 repr: self.repr,
898 _lifetime: PhantomData,
899 }
900 }
901}
902
903#[derive(Debug)]
920pub struct MatMut<'a, T: ReprMut> {
921 ptr: NonNull<u8>,
922 repr: T,
923 _lifetime: PhantomData<&'a mut T>,
925}
926
927unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
929
930unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
932
933impl<'a, T: ReprMut> MatMut<'a, T> {
934 pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
936 where
937 T: NewMut<U>,
938 {
939 repr.new_mut(data)
940 }
941
942 #[inline]
944 pub fn num_vectors(&self) -> usize {
945 self.repr.nrows()
946 }
947
948 pub fn repr(&self) -> &T {
950 &self.repr
951 }
952
953 #[inline]
955 #[must_use]
956 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
957 if i < self.num_vectors() {
958 Some(unsafe { self.get_row_unchecked(i) })
960 } else {
961 None
962 }
963 }
964
965 #[inline]
971 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
972 unsafe { self.repr.get_row(self.ptr, i) }
974 }
975
976 #[inline]
978 #[must_use]
979 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
980 if i < self.num_vectors() {
981 Some(unsafe { self.get_row_mut_unchecked(i) })
983 } else {
984 None
985 }
986 }
987
988 #[inline]
994 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
995 unsafe { self.repr.get_row_mut(self.ptr, i) }
998 }
999
1000 pub fn as_view(&self) -> MatRef<'_, T> {
1002 MatRef {
1003 ptr: self.ptr,
1004 repr: self.repr,
1005 _lifetime: PhantomData,
1006 }
1007 }
1008
1009 pub fn rows(&self) -> Rows<'_, T> {
1011 Rows::new(self.reborrow())
1012 }
1013
1014 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
1016 RowsMut::new(self.reborrow_mut())
1017 }
1018
1019 pub fn to_owned(&self) -> Mat<T>
1021 where
1022 T: NewCloned,
1023 {
1024 T::new_cloned(self.as_view())
1025 }
1026
1027 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
1034 Self {
1035 ptr,
1036 repr,
1037 _lifetime: PhantomData,
1038 }
1039 }
1040
1041 pub fn as_raw_ptr(&self) -> *const u8 {
1043 self.ptr.as_ptr()
1044 }
1045
1046 pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
1048 self.ptr.as_ptr()
1049 }
1050}
1051
1052impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
1054 type Target = MatRef<'this, T>;
1055
1056 fn reborrow(&'this self) -> Self::Target {
1057 self.as_view()
1058 }
1059}
1060
1061impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
1063 type Target = MatMut<'this, T>;
1064
1065 fn reborrow_mut(&'this mut self) -> Self::Target {
1066 MatMut {
1067 ptr: self.ptr,
1068 repr: self.repr,
1069 _lifetime: PhantomData,
1070 }
1071 }
1072}
1073
1074impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1075 #[inline]
1077 pub fn vector_dim(&self) -> usize {
1078 self.repr.ncols()
1079 }
1080
1081 #[inline]
1085 pub fn as_slice(&self) -> &[T] {
1086 self.as_view().as_slice()
1087 }
1088
1089 #[inline]
1091 pub fn as_matrix_view(&self) -> MatrixView<'_, T> {
1092 self.as_view().as_matrix_view()
1093 }
1094}
1095
1096#[derive(Debug)]
1104pub struct Rows<'a, T: Repr> {
1105 matrix: MatRef<'a, T>,
1106 current: usize,
1107}
1108
1109impl<'a, T> Rows<'a, T>
1110where
1111 T: Repr,
1112{
1113 fn new(matrix: MatRef<'a, T>) -> Self {
1114 Self { matrix, current: 0 }
1115 }
1116}
1117
1118impl<'a, T> Iterator for Rows<'a, T>
1119where
1120 T: Repr + 'a,
1121{
1122 type Item = T::Row<'a>;
1123
1124 fn next(&mut self) -> Option<Self::Item> {
1125 let current = self.current;
1126 if current >= self.matrix.num_vectors() {
1127 None
1128 } else {
1129 self.current += 1;
1130 Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1136 }
1137 }
1138
1139 fn size_hint(&self) -> (usize, Option<usize>) {
1140 let remaining = self.matrix.num_vectors() - self.current;
1141 (remaining, Some(remaining))
1142 }
1143}
1144
1145impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1146impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1147
1148#[derive(Debug)]
1156pub struct RowsMut<'a, T: ReprMut> {
1157 matrix: MatMut<'a, T>,
1158 current: usize,
1159}
1160
1161impl<'a, T> RowsMut<'a, T>
1162where
1163 T: ReprMut,
1164{
1165 fn new(matrix: MatMut<'a, T>) -> Self {
1166 Self { matrix, current: 0 }
1167 }
1168}
1169
1170impl<'a, T> Iterator for RowsMut<'a, T>
1171where
1172 T: ReprMut + 'a,
1173{
1174 type Item = T::RowMut<'a>;
1175
1176 fn next(&mut self) -> Option<Self::Item> {
1177 let current = self.current;
1178 if current >= self.matrix.num_vectors() {
1179 None
1180 } else {
1181 self.current += 1;
1182 Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1191 }
1192 }
1193
1194 fn size_hint(&self) -> (usize, Option<usize>) {
1195 let remaining = self.matrix.num_vectors() - self.current;
1196 (remaining, Some(remaining))
1197 }
1198}
1199
1200impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1201impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1202
1203#[cfg(test)]
1208mod tests {
1209 use super::*;
1210
1211 use std::fmt::Display;
1212
1213 use diskann_utils::lazy_format;
1214
1215 fn assert_copy<T: Copy>(_: &T) {}
1217
1218 fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
1228 v: MatRef<'long, T>,
1229 ) -> MatRef<'short, T> {
1230 v
1231 }
1232
1233 fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
1235 v: MatRef<'a, Standard<&'long u8>>,
1236 ) -> MatRef<'a, Standard<&'short u8>> {
1237 v
1238 }
1239
1240 fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
1242 v: MatMut<'long, T>,
1243 ) -> MatMut<'short, T> {
1244 v
1245 }
1246
1247 fn edge_cases(nrows: usize) -> Vec<usize> {
1248 let max = usize::MAX;
1249
1250 vec![
1251 nrows,
1252 nrows + 1,
1253 nrows + 11,
1254 nrows + 20,
1255 max / 2,
1256 max.div_ceil(2),
1257 max - 1,
1258 max,
1259 ]
1260 }
1261
1262 fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1263 assert_eq!(x.repr(), &repr);
1264 assert_eq!(x.num_vectors(), repr.nrows());
1265 assert_eq!(x.vector_dim(), repr.ncols());
1266
1267 for i in 0..x.num_vectors() {
1268 let row = x.get_row_mut(i).unwrap();
1269 assert_eq!(row.len(), repr.ncols());
1270 row.iter_mut()
1271 .enumerate()
1272 .for_each(|(j, r)| *r = 10 * i + j);
1273 }
1274
1275 for i in edge_cases(repr.nrows()).into_iter() {
1276 assert!(x.get_row_mut(i).is_none());
1277 }
1278 }
1279
1280 fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1281 assert_eq!(x.repr(), &repr);
1282 assert_eq!(x.num_vectors(), repr.nrows());
1283 assert_eq!(x.vector_dim(), repr.ncols());
1284
1285 for i in 0..x.num_vectors() {
1286 let row = x.get_row_mut(i).unwrap();
1287 assert_eq!(row.len(), repr.ncols());
1288
1289 row.iter_mut()
1290 .enumerate()
1291 .for_each(|(j, r)| *r = 10 * i + j);
1292 }
1293
1294 for i in edge_cases(repr.nrows()).into_iter() {
1295 assert!(x.get_row_mut(i).is_none());
1296 }
1297 }
1298
1299 fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1300 assert_eq!(x.len(), repr.nrows());
1301 let mut all_rows: Vec<_> = x.collect();
1303 assert_eq!(all_rows.len(), repr.nrows());
1304 for (i, row) in all_rows.iter_mut().enumerate() {
1305 assert_eq!(row.len(), repr.ncols());
1306 row.iter_mut()
1307 .enumerate()
1308 .for_each(|(j, r)| *r = 10 * i + j);
1309 }
1310 }
1311
1312 fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1313 assert_eq!(x.repr(), &repr);
1314 assert_eq!(x.num_vectors(), repr.nrows());
1315 assert_eq!(x.vector_dim(), repr.ncols());
1316
1317 for i in 0..x.num_vectors() {
1318 let row = x.get_row(i).unwrap();
1319
1320 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1321 row.iter().enumerate().for_each(|(j, r)| {
1322 assert_eq!(
1323 *r,
1324 10 * i + j,
1325 "mismatched entry at row {}, col {} -- ctx: {}",
1326 i,
1327 j,
1328 ctx
1329 )
1330 });
1331 }
1332
1333 for i in edge_cases(repr.nrows()).into_iter() {
1334 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1335 }
1336 }
1337
1338 fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1339 assert_eq!(x.repr(), &repr);
1340 assert_eq!(x.num_vectors(), repr.nrows());
1341 assert_eq!(x.vector_dim(), repr.ncols());
1342
1343 assert_copy(&x);
1344 for i in 0..x.num_vectors() {
1345 let row = x.get_row(i).unwrap();
1346 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1347
1348 row.iter().enumerate().for_each(|(j, r)| {
1349 assert_eq!(
1350 *r,
1351 10 * i + j,
1352 "mismatched entry at row {}, col {} -- ctx: {}",
1353 i,
1354 j,
1355 ctx
1356 )
1357 });
1358 }
1359
1360 for i in edge_cases(repr.nrows()).into_iter() {
1361 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1362 }
1363 }
1364
1365 fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1366 assert_eq!(x.repr(), &repr);
1367 assert_eq!(x.num_vectors(), repr.nrows());
1368 assert_eq!(x.vector_dim(), repr.ncols());
1369
1370 for i in 0..x.num_vectors() {
1371 let row = x.get_row(i).unwrap();
1372 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1373
1374 row.iter().enumerate().for_each(|(j, r)| {
1375 assert_eq!(
1376 *r,
1377 10 * i + j,
1378 "mismatched entry at row {}, col {} -- ctx: {}",
1379 i,
1380 j,
1381 ctx
1382 )
1383 });
1384 }
1385
1386 for i in edge_cases(repr.nrows()).into_iter() {
1387 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1388 }
1389 }
1390
1391 fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1392 assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1393 let all_rows: Vec<_> = x.collect();
1394 assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1395 for (i, row) in all_rows.iter().enumerate() {
1396 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1397 row.iter().enumerate().for_each(|(j, r)| {
1398 assert_eq!(
1399 *r,
1400 10 * i + j,
1401 "mismatched entry at row {}, col {} -- ctx: {}",
1402 i,
1403 j,
1404 ctx
1405 )
1406 });
1407 }
1408 }
1409
1410 #[test]
1415 fn standard_representation() {
1416 let repr = Standard::<f32>::new(4, 3).unwrap();
1417 assert_eq!(repr.nrows(), 4);
1418 assert_eq!(repr.ncols(), 3);
1419
1420 let layout = repr.layout().unwrap();
1421 assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1422 assert_eq!(layout.align(), std::mem::align_of::<f32>());
1423 }
1424
1425 #[test]
1426 fn standard_zero_dimensions() {
1427 for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1428 let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1429 assert_eq!(repr.nrows(), nrows);
1430 assert_eq!(repr.ncols(), ncols);
1431 let layout = repr.layout().unwrap();
1432 assert_eq!(layout.size(), 0);
1433 }
1434 }
1435
1436 #[test]
1437 fn standard_check_slice() {
1438 let repr = Standard::<u32>::new(3, 4).unwrap();
1439
1440 let data = vec![0u32; 12];
1442 assert!(repr.check_slice(&data).is_ok());
1443
1444 let short = vec![0u32; 11];
1446 assert!(matches!(
1447 repr.check_slice(&short),
1448 Err(SliceError::LengthMismatch {
1449 expected: 12,
1450 found: 11
1451 })
1452 ));
1453
1454 let long = vec![0u32; 13];
1456 assert!(matches!(
1457 repr.check_slice(&long),
1458 Err(SliceError::LengthMismatch {
1459 expected: 12,
1460 found: 13
1461 })
1462 ));
1463
1464 let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1466 assert!(matches!(overflow_repr, Overflow { .. }));
1467 }
1468
1469 #[test]
1470 fn standard_new_rejects_element_count_overflow() {
1471 assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1473 assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1474 assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1475 }
1476
1477 #[test]
1478 fn standard_new_rejects_byte_count_exceeding_isize_max() {
1479 let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1481 assert!(Standard::<u64>::new(half, 1).is_err());
1482 assert!(Standard::<u64>::new(1, half).is_err());
1483 }
1484
1485 #[test]
1486 fn standard_new_accepts_boundary_below_isize_max() {
1487 let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1489 let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1490 assert_eq!(repr.num_elements(), max_elems);
1491 }
1492
1493 #[test]
1494 fn standard_new_zst_rejects_element_count_overflow() {
1495 assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1498 assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1499 }
1500
1501 #[test]
1502 fn standard_new_zst_accepts_large_non_overflowing() {
1503 let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1505 assert_eq!(repr.num_elements(), usize::MAX);
1506 assert_eq!(repr.layout().unwrap().size(), 0);
1507 }
1508
1509 #[test]
1510 fn standard_new_overflow_error_display() {
1511 let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1512 let msg = err.to_string();
1513 assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1514
1515 let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1516 let zst_msg = zst_err.to_string();
1517 assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1518 assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1519 }
1520
1521 #[test]
1526 fn mat_new_and_basic_accessors() {
1527 let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1528 let base: *const u8 = mat.as_raw_ptr();
1529
1530 assert_eq!(mat.num_vectors(), 3);
1531 assert_eq!(mat.vector_dim(), 4);
1532
1533 let repr = mat.repr();
1534 assert_eq!(repr.nrows(), 3);
1535 assert_eq!(repr.ncols(), 4);
1536
1537 for (i, r) in mat.rows().enumerate() {
1538 assert_eq!(r, &[42, 42, 42, 42]);
1539 let ptr = r.as_ptr().cast::<u8>();
1540 assert_eq!(
1541 ptr,
1542 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1543 );
1544 }
1545 }
1546
1547 #[test]
1548 fn mat_new_with_default() {
1549 let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1550 let base: *const u8 = mat.as_raw_ptr();
1551
1552 assert_eq!(mat.num_vectors(), 2);
1553 for (i, row) in mat.rows().enumerate() {
1554 assert!(row.iter().all(|&v| v == 0));
1555
1556 let ptr = row.as_ptr().cast::<u8>();
1557 assert_eq!(
1558 ptr,
1559 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1560 );
1561 }
1562 }
1563
1564 const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1565 const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1566
1567 #[test]
1568 fn test_mat() {
1569 for nrows in ROWS {
1570 for ncols in COLS {
1571 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1572 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1573
1574 {
1576 let ctx = &lazy_format!("{ctx} - direct");
1577 let mut mat = Mat::new(repr, Defaulted).unwrap();
1578
1579 assert_eq!(mat.num_vectors(), *nrows);
1580 assert_eq!(mat.vector_dim(), *ncols);
1581
1582 fill_mat(&mut mat, repr);
1583
1584 check_mat(&mat, repr, ctx);
1585 check_mat_ref(mat.reborrow(), repr, ctx);
1586 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1587 check_rows(mat.rows(), repr, ctx);
1588
1589 assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1591 assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1592 }
1593
1594 {
1596 let ctx = &lazy_format!("{ctx} - matmut");
1597 let mut mat = Mat::new(repr, Defaulted).unwrap();
1598 let matmut = mat.reborrow_mut();
1599
1600 assert_eq!(matmut.num_vectors(), *nrows);
1601 assert_eq!(matmut.vector_dim(), *ncols);
1602
1603 fill_mat_mut(matmut, repr);
1604
1605 check_mat(&mat, repr, ctx);
1606 check_mat_ref(mat.reborrow(), repr, ctx);
1607 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1608 check_rows(mat.rows(), repr, ctx);
1609 }
1610
1611 {
1613 let ctx = &lazy_format!("{ctx} - rows_mut");
1614 let mut mat = Mat::new(repr, Defaulted).unwrap();
1615 fill_rows_mut(mat.rows_mut(), repr);
1616
1617 check_mat(&mat, repr, ctx);
1618 check_mat_ref(mat.reborrow(), repr, ctx);
1619 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1620 check_rows(mat.rows(), repr, ctx);
1621 }
1622 }
1623 }
1624 }
1625
1626 #[test]
1627 fn test_mat_clone() {
1628 for nrows in ROWS {
1629 for ncols in COLS {
1630 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1631 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1632
1633 let mut mat = Mat::new(repr, Defaulted).unwrap();
1634 fill_mat(&mut mat, repr);
1635
1636 {
1638 let ctx = &lazy_format!("{ctx} - Mat::clone");
1639 let cloned = mat.clone();
1640
1641 assert_eq!(cloned.num_vectors(), *nrows);
1642 assert_eq!(cloned.vector_dim(), *ncols);
1643
1644 check_mat(&cloned, repr, ctx);
1645 check_mat_ref(cloned.reborrow(), repr, ctx);
1646 check_rows(cloned.rows(), repr, ctx);
1647
1648 if repr.num_elements() > 0 {
1650 assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1651 }
1652 }
1653
1654 {
1656 let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1657 let owned = mat.as_view().to_owned();
1658
1659 check_mat(&owned, repr, ctx);
1660 check_mat_ref(owned.reborrow(), repr, ctx);
1661 check_rows(owned.rows(), repr, ctx);
1662
1663 if repr.num_elements() > 0 {
1664 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1665 }
1666 }
1667
1668 {
1670 let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1671 let owned = mat.as_view_mut().to_owned();
1672
1673 check_mat(&owned, repr, ctx);
1674 check_mat_ref(owned.reborrow(), repr, ctx);
1675 check_rows(owned.rows(), repr, ctx);
1676
1677 if repr.num_elements() > 0 {
1678 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1679 }
1680 }
1681 }
1682 }
1683 }
1684
1685 #[test]
1686 fn test_mat_refmut() {
1687 for nrows in ROWS {
1688 for ncols in COLS {
1689 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1690 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1691
1692 {
1694 let ctx = &lazy_format!("{ctx} - by matmut");
1695 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1696 let ptr = b.as_ptr().cast::<u8>();
1697 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1698
1699 assert_eq!(
1700 ptr,
1701 matmut.as_raw_ptr(),
1702 "underlying memory should be preserved",
1703 );
1704
1705 fill_mat_mut(matmut.reborrow_mut(), repr);
1706
1707 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1708 check_mat_ref(matmut.reborrow(), repr, ctx);
1709 check_rows(matmut.rows(), repr, ctx);
1710 check_rows(matmut.reborrow().rows(), repr, ctx);
1711
1712 let matref = MatRef::new(repr, &b).unwrap();
1713 check_mat_ref(matref, repr, ctx);
1714 check_mat_ref(matref.reborrow(), repr, ctx);
1715 check_rows(matref.rows(), repr, ctx);
1716 }
1717
1718 {
1720 let ctx = &lazy_format!("{ctx} - by rows");
1721 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1722 let ptr = b.as_ptr().cast::<u8>();
1723 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1724
1725 assert_eq!(
1726 ptr,
1727 matmut.as_raw_ptr(),
1728 "underlying memory should be preserved",
1729 );
1730
1731 fill_rows_mut(matmut.rows_mut(), repr);
1732
1733 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1734 check_mat_ref(matmut.reborrow(), repr, ctx);
1735 check_rows(matmut.rows(), repr, ctx);
1736 check_rows(matmut.reborrow().rows(), repr, ctx);
1737
1738 let matref = MatRef::new(repr, &b).unwrap();
1739 check_mat_ref(matref, repr, ctx);
1740 check_mat_ref(matref.reborrow(), repr, ctx);
1741 check_rows(matref.rows(), repr, ctx);
1742 }
1743 }
1744 }
1745 }
1746
1747 #[test]
1752 fn test_standard_new_owned() {
1753 let rows = [0, 1, 2, 3, 5, 10];
1754 let cols = [0, 1, 2, 3, 5, 10];
1755
1756 for nrows in rows {
1757 for ncols in cols {
1758 let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1759 let rows_iter = m.rows();
1760 let len = <_ as ExactSizeIterator>::len(&rows_iter);
1761 assert_eq!(len, nrows);
1762 for r in rows_iter {
1763 assert_eq!(r.len(), ncols);
1764 assert!(r.iter().all(|i| *i == 1usize));
1765 }
1766 }
1767 }
1768 }
1769
1770 #[test]
1771 fn matref_new_slice_length_error() {
1772 let repr = Standard::<u32>::new(3, 4).unwrap();
1773
1774 let data = vec![0u32; 12];
1776 assert!(MatRef::new(repr, &data).is_ok());
1777
1778 let short = vec![0u32; 11];
1780 assert!(matches!(
1781 MatRef::new(repr, &short),
1782 Err(SliceError::LengthMismatch {
1783 expected: 12,
1784 found: 11
1785 })
1786 ));
1787
1788 let long = vec![0u32; 13];
1790 assert!(matches!(
1791 MatRef::new(repr, &long),
1792 Err(SliceError::LengthMismatch {
1793 expected: 12,
1794 found: 13
1795 })
1796 ));
1797 }
1798
1799 #[test]
1800 fn matmut_new_slice_length_error() {
1801 let repr = Standard::<u32>::new(3, 4).unwrap();
1802
1803 let mut data = vec![0u32; 12];
1805 assert!(MatMut::new(repr, &mut data).is_ok());
1806
1807 let mut short = vec![0u32; 11];
1809 assert!(matches!(
1810 MatMut::new(repr, &mut short),
1811 Err(SliceError::LengthMismatch {
1812 expected: 12,
1813 found: 11
1814 })
1815 ));
1816
1817 let mut long = vec![0u32; 13];
1819 assert!(matches!(
1820 MatMut::new(repr, &mut long),
1821 Err(SliceError::LengthMismatch {
1822 expected: 12,
1823 found: 13
1824 })
1825 ));
1826 }
1827
1828 #[test]
1829 fn as_matrix_view_roundtrip() {
1830 let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1831
1832 let matref = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap();
1834 let view = matref.as_matrix_view();
1835 assert_eq!(view.nrows(), 2);
1836 assert_eq!(view.ncols(), 3);
1837 for row in 0..2 {
1838 for col in 0..3 {
1839 assert_eq!(view[(row, col)], data[row * 3 + col]);
1840 }
1841 }
1842 assert_eq!(matref.as_slice(), &data);
1843
1844 let mut mat = Mat::new(Standard::<f32>::new(2, 3).unwrap(), 0.0f32).unwrap();
1846 for i in 0..2 {
1847 let r = mat.get_row_mut(i).unwrap();
1848 for j in 0..3 {
1849 r[j] = data[i * 3 + j];
1850 }
1851 }
1852 let view = mat.as_matrix_view();
1853 assert_eq!(view.nrows(), 2);
1854 assert_eq!(view.ncols(), 3);
1855 for row in 0..2 {
1856 for col in 0..3 {
1857 assert_eq!(view[(row, col)], data[row * 3 + col]);
1858 }
1859 }
1860 assert_eq!(mat.as_slice(), &data);
1861
1862 let mut buf = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1864 let matmut = MatMut::new(Standard::new(2, 3).unwrap(), &mut buf).unwrap();
1865 let view = matmut.as_matrix_view();
1866 assert_eq!(view.nrows(), 2);
1867 assert_eq!(view.ncols(), 3);
1868 for row in 0..2 {
1869 for col in 0..3 {
1870 assert_eq!(view[(row, col)], data[row * 3 + col]);
1871 }
1872 }
1873 assert_eq!(matmut.as_slice(), &data);
1874 }
1875}