1use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
29
30use diskann_utils::{Reborrow, ReborrowMut};
31use thiserror::Error;
32
33use crate::utils;
34
35pub unsafe trait Repr: Copy {
58 type Row<'a>
60 where
61 Self: 'a;
62
63 fn nrows(&self) -> usize;
70
71 fn layout(&self) -> Result<Layout, LayoutError>;
79
80 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
92}
93
94pub unsafe trait ReprMut: Repr {
112 type RowMut<'a>
114 where
115 Self: 'a;
116
117 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
128}
129
130pub unsafe trait ReprOwned: ReprMut {
140 unsafe fn drop(self, ptr: NonNull<u8>);
148}
149
150#[derive(Debug, Clone, Copy)]
156#[non_exhaustive]
157pub struct LayoutError;
158
159impl LayoutError {
160 pub fn new() -> Self {
162 Self
163 }
164}
165
166impl Default for LayoutError {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl std::fmt::Display for LayoutError {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(f, "LayoutError")
175 }
176}
177
178impl std::error::Error for LayoutError {}
179
180impl From<std::alloc::LayoutError> for LayoutError {
181 fn from(_: std::alloc::LayoutError) -> Self {
182 LayoutError
183 }
184}
185
186pub unsafe trait NewRef<T>: Repr {
197 type Error;
199
200 fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
202}
203
204pub unsafe trait NewMut<T>: ReprMut {
211 type Error;
213
214 fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
216}
217
218pub unsafe trait NewOwned<T>: ReprOwned {
225 type Error;
227
228 fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
230}
231
232#[derive(Debug, Clone, Copy)]
243pub struct Defaulted;
244
245pub trait NewCloned: ReprOwned {
247 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
267pub struct Standard<T> {
268 nrows: usize,
269 ncols: usize,
270 _elem: PhantomData<T>,
271}
272
273impl<T: Copy> Standard<T> {
274 pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
283 Overflow::check::<T>(nrows, ncols)?;
284 Ok(Self {
285 nrows,
286 ncols,
287 _elem: PhantomData,
288 })
289 }
290
291 pub fn num_elements(&self) -> usize {
293 self.nrows() * self.ncols()
295 }
296
297 fn nrows(&self) -> usize {
299 self.nrows
300 }
301
302 fn ncols(&self) -> usize {
304 self.ncols
305 }
306
307 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
312 let len = self.num_elements();
313
314 if slice.len() != len {
315 Err(SliceError::LengthMismatch {
316 expected: len,
317 found: slice.len(),
318 })
319 } else {
320 Ok(())
321 }
322 }
323
324 unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
330 debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
331
332 let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
335
336 unsafe { Mat::from_raw_parts(self, ptr) }
340 }
341}
342
343#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346 nrows: usize,
347 ncols: usize,
348 elsize: usize,
349}
350
351impl Overflow {
352 fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
353 let elsize = std::mem::size_of::<T>();
354 let elements = nrows.checked_mul(ncols).ok_or(Self {
356 nrows,
357 ncols,
358 elsize,
359 })?;
360
361 let bytes = elsize.saturating_mul(elements);
362 if bytes <= isize::MAX as usize {
363 Ok(())
364 } else {
365 Err(Self {
366 nrows,
367 ncols,
368 elsize,
369 })
370 }
371 }
372}
373
374impl std::fmt::Display for Overflow {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 if self.elsize == 0 {
377 write!(
378 f,
379 "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
380 self.nrows, self.ncols,
381 )
382 } else {
383 write!(
384 f,
385 "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
386 self.nrows, self.ncols, self.elsize,
387 )
388 }
389 }
390}
391
392impl std::error::Error for Overflow {}
393
394#[derive(Debug, Clone, Copy, Error)]
396#[non_exhaustive]
397pub enum SliceError {
398 #[error("Length mismatch: expected {expected}, found {found}")]
399 LengthMismatch { expected: usize, found: usize },
400}
401
402unsafe impl<T: Copy> Repr for Standard<T> {
406 type Row<'a>
407 = &'a [T]
408 where
409 T: 'a;
410
411 fn nrows(&self) -> usize {
412 self.nrows
413 }
414
415 fn layout(&self) -> Result<Layout, LayoutError> {
416 Ok(Layout::array::<T>(self.num_elements())?)
417 }
418
419 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
420 debug_assert!(ptr.cast::<T>().is_aligned());
421 debug_assert!(i < self.nrows);
422
423 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
427
428 unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
430 }
431}
432
433unsafe impl<T: Copy> ReprMut for Standard<T> {
436 type RowMut<'a>
437 = &'a mut [T]
438 where
439 T: 'a;
440
441 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
442 debug_assert!(ptr.cast::<T>().is_aligned());
443 debug_assert!(i < self.nrows);
444
445 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
449
450 unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
453 }
454}
455
456unsafe impl<T: Copy> ReprOwned for Standard<T> {
460 unsafe fn drop(self, ptr: NonNull<u8>) {
461 unsafe {
467 let slice_ptr = std::ptr::slice_from_raw_parts_mut(
468 ptr.cast::<T>().as_ptr(),
469 self.nrows * self.ncols,
470 );
471 let _ = Box::from_raw(slice_ptr);
472 }
473 }
474}
475
476unsafe impl<T> NewOwned<T> for Standard<T>
479where
480 T: Copy,
481{
482 type Error = crate::error::Infallible;
483 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
484 let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
485
486 Ok(unsafe { self.box_to_mat(b) })
488 }
489}
490
491unsafe impl<T> NewOwned<Defaulted> for Standard<T>
493where
494 T: Copy + Default,
495{
496 type Error = crate::error::Infallible;
497 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
498 self.new_owned(T::default())
499 }
500}
501
502unsafe impl<T> NewRef<T> for Standard<T>
505where
506 T: Copy,
507{
508 type Error = SliceError;
509 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
510 self.check_slice(data)?;
511
512 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
517 }
518}
519
520unsafe impl<T> NewMut<T> for Standard<T>
523where
524 T: Copy,
525{
526 type Error = SliceError;
527 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
528 self.check_slice(data)?;
529
530 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
535 }
536}
537
538impl<T> NewCloned for Standard<T>
539where
540 T: Copy,
541{
542 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
543 let b: Box<[T]> = v.rows().flatten().copied().collect();
544
545 unsafe { v.repr().box_to_mat(b) }
547 }
548}
549
550#[derive(Debug)]
559pub struct Mat<T: ReprOwned> {
560 ptr: NonNull<u8>,
561 repr: T,
562 _invariant: PhantomData<fn(T) -> T>,
563}
564
565unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
567
568unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
570
571impl<T: ReprOwned> Mat<T> {
572 pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
574 where
575 T: NewOwned<U>,
576 {
577 repr.new_owned(init)
578 }
579
580 #[inline]
582 pub fn num_vectors(&self) -> usize {
583 self.repr.nrows()
584 }
585
586 pub fn repr(&self) -> &T {
588 &self.repr
589 }
590
591 #[must_use]
593 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
594 if i < self.num_vectors() {
595 let row = unsafe { self.get_row_unchecked(i) };
598 Some(row)
599 } else {
600 None
601 }
602 }
603
604 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
605 unsafe { self.repr.get_row(self.ptr, i) }
608 }
609
610 #[must_use]
612 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
613 if i < self.num_vectors() {
614 Some(unsafe { self.get_row_mut_unchecked(i) })
616 } else {
617 None
618 }
619 }
620
621 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
622 unsafe { self.repr.get_row_mut(self.ptr, i) }
625 }
626
627 #[inline]
629 pub fn as_view(&self) -> MatRef<'_, T> {
630 MatRef {
631 ptr: self.ptr,
632 repr: self.repr,
633 _lifetime: PhantomData,
634 }
635 }
636
637 #[inline]
639 pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
640 MatMut {
641 ptr: self.ptr,
642 repr: self.repr,
643 _lifetime: PhantomData,
644 }
645 }
646
647 pub fn rows(&self) -> Rows<'_, T> {
649 Rows::new(self.reborrow())
650 }
651
652 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
654 RowsMut::new(self.reborrow_mut())
655 }
656
657 pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
667 Self {
668 ptr,
669 repr,
670 _invariant: PhantomData,
671 }
672 }
673
674 pub fn as_raw_ptr(&self) -> *const u8 {
676 self.ptr.as_ptr()
677 }
678}
679
680impl<T: ReprOwned> Drop for Mat<T> {
681 fn drop(&mut self) {
682 unsafe { self.repr.drop(self.ptr) };
685 }
686}
687
688impl<T: NewCloned> Clone for Mat<T> {
689 fn clone(&self) -> Self {
690 T::new_cloned(self.as_view())
691 }
692}
693
694impl<T: Copy> Mat<Standard<T>> {
695 #[inline]
697 pub fn vector_dim(&self) -> usize {
698 self.repr.ncols()
699 }
700}
701
702#[derive(Debug, Clone, Copy)]
718pub struct MatRef<'a, T: Repr> {
719 pub(crate) ptr: NonNull<u8>,
720 pub(crate) repr: T,
721 pub(crate) _lifetime: PhantomData<&'a T>,
723}
724
725unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
727
728unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
730
731impl<'a, T: Repr> MatRef<'a, T> {
732 pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
734 where
735 T: NewRef<U>,
736 {
737 repr.new_ref(data)
738 }
739
740 #[inline]
742 pub fn num_vectors(&self) -> usize {
743 self.repr.nrows()
744 }
745
746 pub fn repr(&self) -> &T {
748 &self.repr
749 }
750
751 #[must_use]
753 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
754 if i < self.num_vectors() {
755 let row = unsafe { self.get_row_unchecked(i) };
758 Some(row)
759 } else {
760 None
761 }
762 }
763
764 #[inline]
770 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
771 unsafe { self.repr.get_row(self.ptr, i) }
773 }
774
775 pub fn rows(&self) -> Rows<'_, T> {
777 Rows::new(*self)
778 }
779
780 pub fn to_owned(&self) -> Mat<T>
782 where
783 T: NewCloned,
784 {
785 T::new_cloned(*self)
786 }
787
788 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
796 Self {
797 ptr,
798 repr,
799 _lifetime: PhantomData,
800 }
801 }
802
803 pub fn as_raw_ptr(&self) -> *const u8 {
805 self.ptr.as_ptr()
806 }
807}
808
809impl<'a, T: Copy> MatRef<'a, Standard<T>> {
810 #[inline]
812 pub fn vector_dim(&self) -> usize {
813 self.repr.ncols()
814 }
815}
816
817impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
819 type Target = MatRef<'this, T>;
820
821 fn reborrow(&'this self) -> Self::Target {
822 self.as_view()
823 }
824}
825
826impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
828 type Target = MatMut<'this, T>;
829
830 fn reborrow_mut(&'this mut self) -> Self::Target {
831 self.as_view_mut()
832 }
833}
834
835impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
837 type Target = MatRef<'this, T>;
838
839 fn reborrow(&'this self) -> Self::Target {
840 MatRef {
841 ptr: self.ptr,
842 repr: self.repr,
843 _lifetime: PhantomData,
844 }
845 }
846}
847
848#[derive(Debug)]
865pub struct MatMut<'a, T: ReprMut> {
866 pub(crate) ptr: NonNull<u8>,
867 pub(crate) repr: T,
868 pub(crate) _lifetime: PhantomData<&'a mut T>,
870}
871
872unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
874
875unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
877
878impl<'a, T: ReprMut> MatMut<'a, T> {
879 pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
881 where
882 T: NewMut<U>,
883 {
884 repr.new_mut(data)
885 }
886
887 #[inline]
889 pub fn num_vectors(&self) -> usize {
890 self.repr.nrows()
891 }
892
893 pub fn repr(&self) -> &T {
895 &self.repr
896 }
897
898 #[inline]
900 #[must_use]
901 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
902 if i < self.num_vectors() {
903 Some(unsafe { self.get_row_unchecked(i) })
905 } else {
906 None
907 }
908 }
909
910 #[inline]
916 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
917 unsafe { self.repr.get_row(self.ptr, i) }
919 }
920
921 #[inline]
923 #[must_use]
924 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
925 if i < self.num_vectors() {
926 Some(unsafe { self.get_row_mut_unchecked(i) })
928 } else {
929 None
930 }
931 }
932
933 #[inline]
939 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
940 unsafe { self.repr.get_row_mut(self.ptr, i) }
943 }
944
945 pub fn as_view(&self) -> MatRef<'_, T> {
947 MatRef {
948 ptr: self.ptr,
949 repr: self.repr,
950 _lifetime: PhantomData,
951 }
952 }
953
954 pub fn rows(&self) -> Rows<'_, T> {
956 Rows::new(self.reborrow())
957 }
958
959 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
961 RowsMut::new(self.reborrow_mut())
962 }
963
964 pub fn to_owned(&self) -> Mat<T>
966 where
967 T: NewCloned,
968 {
969 T::new_cloned(self.as_view())
970 }
971
972 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
979 Self {
980 ptr,
981 repr,
982 _lifetime: PhantomData,
983 }
984 }
985
986 pub fn as_raw_ptr(&self) -> *const u8 {
988 self.ptr.as_ptr()
989 }
990}
991
992impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
994 type Target = MatRef<'this, T>;
995
996 fn reborrow(&'this self) -> Self::Target {
997 self.as_view()
998 }
999}
1000
1001impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
1003 type Target = MatMut<'this, T>;
1004
1005 fn reborrow_mut(&'this mut self) -> Self::Target {
1006 MatMut {
1007 ptr: self.ptr,
1008 repr: self.repr,
1009 _lifetime: PhantomData,
1010 }
1011 }
1012}
1013
1014impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1015 #[inline]
1017 pub fn vector_dim(&self) -> usize {
1018 self.repr.ncols()
1019 }
1020}
1021
1022#[derive(Debug)]
1030pub struct Rows<'a, T: Repr> {
1031 matrix: MatRef<'a, T>,
1032 current: usize,
1033}
1034
1035impl<'a, T> Rows<'a, T>
1036where
1037 T: Repr,
1038{
1039 fn new(matrix: MatRef<'a, T>) -> Self {
1040 Self { matrix, current: 0 }
1041 }
1042}
1043
1044impl<'a, T> Iterator for Rows<'a, T>
1045where
1046 T: Repr + 'a,
1047{
1048 type Item = T::Row<'a>;
1049
1050 fn next(&mut self) -> Option<Self::Item> {
1051 let current = self.current;
1052 if current >= self.matrix.num_vectors() {
1053 None
1054 } else {
1055 self.current += 1;
1056 Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1062 }
1063 }
1064
1065 fn size_hint(&self) -> (usize, Option<usize>) {
1066 let remaining = self.matrix.num_vectors() - self.current;
1067 (remaining, Some(remaining))
1068 }
1069}
1070
1071impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1072impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1073
1074#[derive(Debug)]
1082pub struct RowsMut<'a, T: ReprMut> {
1083 matrix: MatMut<'a, T>,
1084 current: usize,
1085}
1086
1087impl<'a, T> RowsMut<'a, T>
1088where
1089 T: ReprMut,
1090{
1091 fn new(matrix: MatMut<'a, T>) -> Self {
1092 Self { matrix, current: 0 }
1093 }
1094}
1095
1096impl<'a, T> Iterator for RowsMut<'a, T>
1097where
1098 T: ReprMut + 'a,
1099{
1100 type Item = T::RowMut<'a>;
1101
1102 fn next(&mut self) -> Option<Self::Item> {
1103 let current = self.current;
1104 if current >= self.matrix.num_vectors() {
1105 None
1106 } else {
1107 self.current += 1;
1108 Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1117 }
1118 }
1119
1120 fn size_hint(&self) -> (usize, Option<usize>) {
1121 let remaining = self.matrix.num_vectors() - self.current;
1122 (remaining, Some(remaining))
1123 }
1124}
1125
1126impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1127impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1128
1129#[cfg(test)]
1134mod tests {
1135 use super::*;
1136
1137 use std::fmt::Display;
1138
1139 use diskann_utils::lazy_format;
1140
1141 fn assert_copy<T: Copy>(_: &T) {}
1143
1144 fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
1154 v: MatRef<'long, T>,
1155 ) -> MatRef<'short, T> {
1156 v
1157 }
1158
1159 fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
1161 v: MatRef<'a, Standard<&'long u8>>,
1162 ) -> MatRef<'a, Standard<&'short u8>> {
1163 v
1164 }
1165
1166 fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
1168 v: MatMut<'long, T>,
1169 ) -> MatMut<'short, T> {
1170 v
1171 }
1172
1173 fn edge_cases(nrows: usize) -> Vec<usize> {
1174 let max = usize::MAX;
1175
1176 vec![
1177 nrows,
1178 nrows + 1,
1179 nrows + 11,
1180 nrows + 20,
1181 max / 2,
1182 max.div_ceil(2),
1183 max - 1,
1184 max,
1185 ]
1186 }
1187
1188 fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1189 assert_eq!(x.repr(), &repr);
1190 assert_eq!(x.num_vectors(), repr.nrows());
1191 assert_eq!(x.vector_dim(), repr.ncols());
1192
1193 for i in 0..x.num_vectors() {
1194 let row = x.get_row_mut(i).unwrap();
1195 assert_eq!(row.len(), repr.ncols());
1196 row.iter_mut()
1197 .enumerate()
1198 .for_each(|(j, r)| *r = 10 * i + j);
1199 }
1200
1201 for i in edge_cases(repr.nrows()).into_iter() {
1202 assert!(x.get_row_mut(i).is_none());
1203 }
1204 }
1205
1206 fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1207 assert_eq!(x.repr(), &repr);
1208 assert_eq!(x.num_vectors(), repr.nrows());
1209 assert_eq!(x.vector_dim(), repr.ncols());
1210
1211 for i in 0..x.num_vectors() {
1212 let row = x.get_row_mut(i).unwrap();
1213 assert_eq!(row.len(), repr.ncols());
1214
1215 row.iter_mut()
1216 .enumerate()
1217 .for_each(|(j, r)| *r = 10 * i + j);
1218 }
1219
1220 for i in edge_cases(repr.nrows()).into_iter() {
1221 assert!(x.get_row_mut(i).is_none());
1222 }
1223 }
1224
1225 fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1226 assert_eq!(x.len(), repr.nrows());
1227 let mut all_rows: Vec<_> = x.collect();
1229 assert_eq!(all_rows.len(), repr.nrows());
1230 for (i, row) in all_rows.iter_mut().enumerate() {
1231 assert_eq!(row.len(), repr.ncols());
1232 row.iter_mut()
1233 .enumerate()
1234 .for_each(|(j, r)| *r = 10 * i + j);
1235 }
1236 }
1237
1238 fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1239 assert_eq!(x.repr(), &repr);
1240 assert_eq!(x.num_vectors(), repr.nrows());
1241 assert_eq!(x.vector_dim(), repr.ncols());
1242
1243 for i in 0..x.num_vectors() {
1244 let row = x.get_row(i).unwrap();
1245
1246 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1247 row.iter().enumerate().for_each(|(j, r)| {
1248 assert_eq!(
1249 *r,
1250 10 * i + j,
1251 "mismatched entry at row {}, col {} -- ctx: {}",
1252 i,
1253 j,
1254 ctx
1255 )
1256 });
1257 }
1258
1259 for i in edge_cases(repr.nrows()).into_iter() {
1260 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1261 }
1262 }
1263
1264 fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1265 assert_eq!(x.repr(), &repr);
1266 assert_eq!(x.num_vectors(), repr.nrows());
1267 assert_eq!(x.vector_dim(), repr.ncols());
1268
1269 assert_copy(&x);
1270 for i in 0..x.num_vectors() {
1271 let row = x.get_row(i).unwrap();
1272 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1273
1274 row.iter().enumerate().for_each(|(j, r)| {
1275 assert_eq!(
1276 *r,
1277 10 * i + j,
1278 "mismatched entry at row {}, col {} -- ctx: {}",
1279 i,
1280 j,
1281 ctx
1282 )
1283 });
1284 }
1285
1286 for i in edge_cases(repr.nrows()).into_iter() {
1287 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1288 }
1289 }
1290
1291 fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1292 assert_eq!(x.repr(), &repr);
1293 assert_eq!(x.num_vectors(), repr.nrows());
1294 assert_eq!(x.vector_dim(), repr.ncols());
1295
1296 for i in 0..x.num_vectors() {
1297 let row = x.get_row(i).unwrap();
1298 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1299
1300 row.iter().enumerate().for_each(|(j, r)| {
1301 assert_eq!(
1302 *r,
1303 10 * i + j,
1304 "mismatched entry at row {}, col {} -- ctx: {}",
1305 i,
1306 j,
1307 ctx
1308 )
1309 });
1310 }
1311
1312 for i in edge_cases(repr.nrows()).into_iter() {
1313 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1314 }
1315 }
1316
1317 fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1318 assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1319 let all_rows: Vec<_> = x.collect();
1320 assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1321 for (i, row) in all_rows.iter().enumerate() {
1322 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1323 row.iter().enumerate().for_each(|(j, r)| {
1324 assert_eq!(
1325 *r,
1326 10 * i + j,
1327 "mismatched entry at row {}, col {} -- ctx: {}",
1328 i,
1329 j,
1330 ctx
1331 )
1332 });
1333 }
1334 }
1335
1336 #[test]
1341 fn standard_representation() {
1342 let repr = Standard::<f32>::new(4, 3).unwrap();
1343 assert_eq!(repr.nrows(), 4);
1344 assert_eq!(repr.ncols(), 3);
1345
1346 let layout = repr.layout().unwrap();
1347 assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1348 assert_eq!(layout.align(), std::mem::align_of::<f32>());
1349 }
1350
1351 #[test]
1352 fn standard_zero_dimensions() {
1353 for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1354 let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1355 assert_eq!(repr.nrows(), nrows);
1356 assert_eq!(repr.ncols(), ncols);
1357 let layout = repr.layout().unwrap();
1358 assert_eq!(layout.size(), 0);
1359 }
1360 }
1361
1362 #[test]
1363 fn standard_check_slice() {
1364 let repr = Standard::<u32>::new(3, 4).unwrap();
1365
1366 let data = vec![0u32; 12];
1368 assert!(repr.check_slice(&data).is_ok());
1369
1370 let short = vec![0u32; 11];
1372 assert!(matches!(
1373 repr.check_slice(&short),
1374 Err(SliceError::LengthMismatch {
1375 expected: 12,
1376 found: 11
1377 })
1378 ));
1379
1380 let long = vec![0u32; 13];
1382 assert!(matches!(
1383 repr.check_slice(&long),
1384 Err(SliceError::LengthMismatch {
1385 expected: 12,
1386 found: 13
1387 })
1388 ));
1389
1390 let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1392 assert!(matches!(overflow_repr, Overflow { .. }));
1393 }
1394
1395 #[test]
1396 fn standard_new_rejects_element_count_overflow() {
1397 assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1399 assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1400 assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1401 }
1402
1403 #[test]
1404 fn standard_new_rejects_byte_count_exceeding_isize_max() {
1405 let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1407 assert!(Standard::<u64>::new(half, 1).is_err());
1408 assert!(Standard::<u64>::new(1, half).is_err());
1409 }
1410
1411 #[test]
1412 fn standard_new_accepts_boundary_below_isize_max() {
1413 let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1415 let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1416 assert_eq!(repr.num_elements(), max_elems);
1417 }
1418
1419 #[test]
1420 fn standard_new_zst_rejects_element_count_overflow() {
1421 assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1424 assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1425 }
1426
1427 #[test]
1428 fn standard_new_zst_accepts_large_non_overflowing() {
1429 let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1431 assert_eq!(repr.num_elements(), usize::MAX);
1432 assert_eq!(repr.layout().unwrap().size(), 0);
1433 }
1434
1435 #[test]
1436 fn standard_new_overflow_error_display() {
1437 let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1438 let msg = err.to_string();
1439 assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1440
1441 let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1442 let zst_msg = zst_err.to_string();
1443 assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1444 assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1445 }
1446
1447 #[test]
1452 fn mat_new_and_basic_accessors() {
1453 let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1454 let base: *const u8 = mat.as_raw_ptr();
1455
1456 assert_eq!(mat.num_vectors(), 3);
1457 assert_eq!(mat.vector_dim(), 4);
1458
1459 let repr = mat.repr();
1460 assert_eq!(repr.nrows(), 3);
1461 assert_eq!(repr.ncols(), 4);
1462
1463 for (i, r) in mat.rows().enumerate() {
1464 assert_eq!(r, &[42, 42, 42, 42]);
1465 let ptr = r.as_ptr().cast::<u8>();
1466 assert_eq!(
1467 ptr,
1468 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1469 );
1470 }
1471 }
1472
1473 #[test]
1474 fn mat_new_with_default() {
1475 let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1476 let base: *const u8 = mat.as_raw_ptr();
1477
1478 assert_eq!(mat.num_vectors(), 2);
1479 for (i, row) in mat.rows().enumerate() {
1480 assert!(row.iter().all(|&v| v == 0));
1481
1482 let ptr = row.as_ptr().cast::<u8>();
1483 assert_eq!(
1484 ptr,
1485 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1486 );
1487 }
1488 }
1489
1490 const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1491 const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1492
1493 #[test]
1494 fn test_mat() {
1495 for nrows in ROWS {
1496 for ncols in COLS {
1497 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1498 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1499
1500 {
1502 let ctx = &lazy_format!("{ctx} - direct");
1503 let mut mat = Mat::new(repr, Defaulted).unwrap();
1504
1505 assert_eq!(mat.num_vectors(), *nrows);
1506 assert_eq!(mat.vector_dim(), *ncols);
1507
1508 fill_mat(&mut mat, repr);
1509
1510 check_mat(&mat, repr, ctx);
1511 check_mat_ref(mat.reborrow(), repr, ctx);
1512 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1513 check_rows(mat.rows(), repr, ctx);
1514
1515 assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1517 assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1518 }
1519
1520 {
1522 let ctx = &lazy_format!("{ctx} - matmut");
1523 let mut mat = Mat::new(repr, Defaulted).unwrap();
1524 let matmut = mat.reborrow_mut();
1525
1526 assert_eq!(matmut.num_vectors(), *nrows);
1527 assert_eq!(matmut.vector_dim(), *ncols);
1528
1529 fill_mat_mut(matmut, repr);
1530
1531 check_mat(&mat, repr, ctx);
1532 check_mat_ref(mat.reborrow(), repr, ctx);
1533 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1534 check_rows(mat.rows(), repr, ctx);
1535 }
1536
1537 {
1539 let ctx = &lazy_format!("{ctx} - rows_mut");
1540 let mut mat = Mat::new(repr, Defaulted).unwrap();
1541 fill_rows_mut(mat.rows_mut(), repr);
1542
1543 check_mat(&mat, repr, ctx);
1544 check_mat_ref(mat.reborrow(), repr, ctx);
1545 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1546 check_rows(mat.rows(), repr, ctx);
1547 }
1548 }
1549 }
1550 }
1551
1552 #[test]
1553 fn test_mat_clone() {
1554 for nrows in ROWS {
1555 for ncols in COLS {
1556 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1557 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1558
1559 let mut mat = Mat::new(repr, Defaulted).unwrap();
1560 fill_mat(&mut mat, repr);
1561
1562 {
1564 let ctx = &lazy_format!("{ctx} - Mat::clone");
1565 let cloned = mat.clone();
1566
1567 assert_eq!(cloned.num_vectors(), *nrows);
1568 assert_eq!(cloned.vector_dim(), *ncols);
1569
1570 check_mat(&cloned, repr, ctx);
1571 check_mat_ref(cloned.reborrow(), repr, ctx);
1572 check_rows(cloned.rows(), repr, ctx);
1573
1574 if repr.num_elements() > 0 {
1576 assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1577 }
1578 }
1579
1580 {
1582 let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1583 let owned = mat.as_view().to_owned();
1584
1585 check_mat(&owned, repr, ctx);
1586 check_mat_ref(owned.reborrow(), repr, ctx);
1587 check_rows(owned.rows(), repr, ctx);
1588
1589 if repr.num_elements() > 0 {
1590 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1591 }
1592 }
1593
1594 {
1596 let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1597 let owned = mat.as_view_mut().to_owned();
1598
1599 check_mat(&owned, repr, ctx);
1600 check_mat_ref(owned.reborrow(), repr, ctx);
1601 check_rows(owned.rows(), repr, ctx);
1602
1603 if repr.num_elements() > 0 {
1604 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1605 }
1606 }
1607 }
1608 }
1609 }
1610
1611 #[test]
1612 fn test_mat_refmut() {
1613 for nrows in ROWS {
1614 for ncols in COLS {
1615 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1616 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1617
1618 {
1620 let ctx = &lazy_format!("{ctx} - by matmut");
1621 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1622 let ptr = b.as_ptr().cast::<u8>();
1623 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1624
1625 assert_eq!(
1626 ptr,
1627 matmut.as_raw_ptr(),
1628 "underlying memory should be preserved",
1629 );
1630
1631 fill_mat_mut(matmut.reborrow_mut(), repr);
1632
1633 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1634 check_mat_ref(matmut.reborrow(), repr, ctx);
1635 check_rows(matmut.rows(), repr, ctx);
1636 check_rows(matmut.reborrow().rows(), repr, ctx);
1637
1638 let matref = MatRef::new(repr, &b).unwrap();
1639 check_mat_ref(matref, repr, ctx);
1640 check_mat_ref(matref.reborrow(), repr, ctx);
1641 check_rows(matref.rows(), repr, ctx);
1642 }
1643
1644 {
1646 let ctx = &lazy_format!("{ctx} - by rows");
1647 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1648 let ptr = b.as_ptr().cast::<u8>();
1649 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1650
1651 assert_eq!(
1652 ptr,
1653 matmut.as_raw_ptr(),
1654 "underlying memory should be preserved",
1655 );
1656
1657 fill_rows_mut(matmut.rows_mut(), repr);
1658
1659 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1660 check_mat_ref(matmut.reborrow(), repr, ctx);
1661 check_rows(matmut.rows(), repr, ctx);
1662 check_rows(matmut.reborrow().rows(), repr, ctx);
1663
1664 let matref = MatRef::new(repr, &b).unwrap();
1665 check_mat_ref(matref, repr, ctx);
1666 check_mat_ref(matref.reborrow(), repr, ctx);
1667 check_rows(matref.rows(), repr, ctx);
1668 }
1669 }
1670 }
1671 }
1672
1673 #[test]
1678 fn test_standard_new_owned() {
1679 let rows = [0, 1, 2, 3, 5, 10];
1680 let cols = [0, 1, 2, 3, 5, 10];
1681
1682 for nrows in rows {
1683 for ncols in cols {
1684 let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1685 let rows_iter = m.rows();
1686 let len = <_ as ExactSizeIterator>::len(&rows_iter);
1687 assert_eq!(len, nrows);
1688 for r in rows_iter {
1689 assert_eq!(r.len(), ncols);
1690 assert!(r.iter().all(|i| *i == 1usize));
1691 }
1692 }
1693 }
1694 }
1695
1696 #[test]
1697 fn matref_new_slice_length_error() {
1698 let repr = Standard::<u32>::new(3, 4).unwrap();
1699
1700 let data = vec![0u32; 12];
1702 assert!(MatRef::new(repr, &data).is_ok());
1703
1704 let short = vec![0u32; 11];
1706 assert!(matches!(
1707 MatRef::new(repr, &short),
1708 Err(SliceError::LengthMismatch {
1709 expected: 12,
1710 found: 11
1711 })
1712 ));
1713
1714 let long = vec![0u32; 13];
1716 assert!(matches!(
1717 MatRef::new(repr, &long),
1718 Err(SliceError::LengthMismatch {
1719 expected: 12,
1720 found: 13
1721 })
1722 ));
1723 }
1724
1725 #[test]
1726 fn matmut_new_slice_length_error() {
1727 let repr = Standard::<u32>::new(3, 4).unwrap();
1728
1729 let mut data = vec![0u32; 12];
1731 assert!(MatMut::new(repr, &mut data).is_ok());
1732
1733 let mut short = vec![0u32; 11];
1735 assert!(matches!(
1736 MatMut::new(repr, &mut short),
1737 Err(SliceError::LengthMismatch {
1738 expected: 12,
1739 found: 11
1740 })
1741 ));
1742
1743 let mut long = vec![0u32; 13];
1745 assert!(matches!(
1746 MatMut::new(repr, &mut long),
1747 Err(SliceError::LengthMismatch {
1748 expected: 12,
1749 found: 13
1750 })
1751 ));
1752 }
1753}