1use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
29
30use diskann_utils::{Reborrow, ReborrowMut};
31use thiserror::Error;
32
33use crate::utils;
34
35pub unsafe trait Repr: Copy {
58 type Row<'a>
60 where
61 Self: 'a;
62
63 fn nrows(&self) -> usize;
70
71 fn layout(&self) -> Result<Layout, LayoutError>;
79
80 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
92}
93
94pub unsafe trait ReprMut: Repr {
112 type RowMut<'a>
114 where
115 Self: 'a;
116
117 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
128}
129
130pub unsafe trait ReprOwned: ReprMut {
140 unsafe fn drop(self, ptr: NonNull<u8>);
148}
149
150#[derive(Debug, Clone, Copy)]
156#[non_exhaustive]
157pub struct LayoutError;
158
159impl LayoutError {
160 pub fn new() -> Self {
162 Self
163 }
164}
165
166impl Default for LayoutError {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl std::fmt::Display for LayoutError {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(f, "LayoutError")
175 }
176}
177
178impl std::error::Error for LayoutError {}
179
180impl From<std::alloc::LayoutError> for LayoutError {
181 fn from(_: std::alloc::LayoutError) -> Self {
182 LayoutError
183 }
184}
185
186pub unsafe trait NewRef<T>: Repr {
197 type Error;
199
200 fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
202}
203
204pub unsafe trait NewMut<T>: ReprMut {
211 type Error;
213
214 fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
216}
217
218pub unsafe trait NewOwned<T>: ReprOwned {
225 type Error;
227
228 fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
230}
231
232#[derive(Debug, Clone, Copy)]
243pub struct Defaulted;
244
245pub trait NewCloned: ReprOwned {
247 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
251}
252
253#[derive(Debug, Clone, Copy, PartialEq, Eq)]
267pub struct Standard<T> {
268 nrows: usize,
269 ncols: usize,
270 _elem: PhantomData<T>,
271}
272
273impl<T: Copy> Standard<T> {
274 pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
283 Overflow::check::<T>(nrows, ncols)?;
284 Ok(Self {
285 nrows,
286 ncols,
287 _elem: PhantomData,
288 })
289 }
290
291 pub fn num_elements(&self) -> usize {
293 self.nrows() * self.ncols()
295 }
296
297 fn nrows(&self) -> usize {
299 self.nrows
300 }
301
302 fn ncols(&self) -> usize {
304 self.ncols
305 }
306
307 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
312 let len = self.num_elements();
313
314 if slice.len() != len {
315 Err(SliceError::LengthMismatch {
316 expected: len,
317 found: slice.len(),
318 })
319 } else {
320 Ok(())
321 }
322 }
323
324 unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
330 debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
331
332 let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
335
336 unsafe { Mat::from_raw_parts(self, ptr) }
340 }
341}
342
343#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346 nrows: usize,
347 ncols: usize,
348 elsize: usize,
349}
350
351impl Overflow {
352 fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
353 let elsize = std::mem::size_of::<T>();
354 let elements = nrows.checked_mul(ncols).ok_or(Self {
356 nrows,
357 ncols,
358 elsize,
359 })?;
360
361 let bytes = elsize.saturating_mul(elements);
362 if bytes <= isize::MAX as usize {
363 Ok(())
364 } else {
365 Err(Self {
366 nrows,
367 ncols,
368 elsize,
369 })
370 }
371 }
372}
373
374impl std::fmt::Display for Overflow {
375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376 if self.elsize == 0 {
377 write!(
378 f,
379 "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
380 self.nrows, self.ncols,
381 )
382 } else {
383 write!(
384 f,
385 "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
386 self.nrows, self.ncols, self.elsize,
387 )
388 }
389 }
390}
391
392impl std::error::Error for Overflow {}
393
394#[derive(Debug, Clone, Copy, Error)]
396#[non_exhaustive]
397pub enum SliceError {
398 #[error("Length mismatch: expected {expected}, found {found}")]
399 LengthMismatch { expected: usize, found: usize },
400}
401
402unsafe impl<T: Copy> Repr for Standard<T> {
406 type Row<'a>
407 = &'a [T]
408 where
409 T: 'a;
410
411 fn nrows(&self) -> usize {
412 self.nrows
413 }
414
415 fn layout(&self) -> Result<Layout, LayoutError> {
416 Ok(Layout::array::<T>(self.num_elements())?)
417 }
418
419 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
420 debug_assert!(ptr.cast::<T>().is_aligned());
421 debug_assert!(i < self.nrows);
422
423 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
427
428 unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
430 }
431}
432
433unsafe impl<T: Copy> ReprMut for Standard<T> {
436 type RowMut<'a>
437 = &'a mut [T]
438 where
439 T: 'a;
440
441 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
442 debug_assert!(ptr.cast::<T>().is_aligned());
443 debug_assert!(i < self.nrows);
444
445 let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
449
450 unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
453 }
454}
455
456unsafe impl<T: Copy> ReprOwned for Standard<T> {
460 unsafe fn drop(self, ptr: NonNull<u8>) {
461 unsafe {
467 let slice_ptr = std::ptr::slice_from_raw_parts_mut(
468 ptr.cast::<T>().as_ptr(),
469 self.nrows * self.ncols,
470 );
471 let _ = Box::from_raw(slice_ptr);
472 }
473 }
474}
475
476unsafe impl<T> NewOwned<T> for Standard<T>
479where
480 T: Copy,
481{
482 type Error = crate::error::Infallible;
483 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
484 let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
485
486 Ok(unsafe { self.box_to_mat(b) })
488 }
489}
490
491unsafe impl<T> NewOwned<Defaulted> for Standard<T>
493where
494 T: Copy + Default,
495{
496 type Error = crate::error::Infallible;
497 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
498 self.new_owned(T::default())
499 }
500}
501
502unsafe impl<T> NewRef<T> for Standard<T>
505where
506 T: Copy,
507{
508 type Error = SliceError;
509 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
510 self.check_slice(data)?;
511
512 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
517 }
518}
519
520unsafe impl<T> NewMut<T> for Standard<T>
523where
524 T: Copy,
525{
526 type Error = SliceError;
527 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
528 self.check_slice(data)?;
529
530 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
535 }
536}
537
538impl<T> NewCloned for Standard<T>
539where
540 T: Copy,
541{
542 fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
543 let b: Box<[T]> = v.rows().flatten().copied().collect();
544
545 unsafe { v.repr().box_to_mat(b) }
547 }
548}
549
550#[derive(Debug)]
559pub struct Mat<T: ReprOwned> {
560 ptr: NonNull<u8>,
561 repr: T,
562}
563
564unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
566
567unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
569
570impl<T: ReprOwned> Mat<T> {
571 pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
573 where
574 T: NewOwned<U>,
575 {
576 repr.new_owned(init)
577 }
578
579 #[inline]
581 pub fn num_vectors(&self) -> usize {
582 self.repr.nrows()
583 }
584
585 pub fn repr(&self) -> &T {
587 &self.repr
588 }
589
590 #[must_use]
592 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
593 if i < self.num_vectors() {
594 let row = unsafe { self.get_row_unchecked(i) };
597 Some(row)
598 } else {
599 None
600 }
601 }
602
603 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
604 unsafe { self.repr.get_row(self.ptr, i) }
607 }
608
609 #[must_use]
611 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
612 if i < self.num_vectors() {
613 Some(unsafe { self.get_row_mut_unchecked(i) })
615 } else {
616 None
617 }
618 }
619
620 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
621 unsafe { self.repr.get_row_mut(self.ptr, i) }
624 }
625
626 #[inline]
628 pub fn as_view(&self) -> MatRef<'_, T> {
629 MatRef {
630 ptr: self.ptr,
631 repr: self.repr,
632 _lifetime: PhantomData,
633 }
634 }
635
636 #[inline]
638 pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
639 MatMut {
640 ptr: self.ptr,
641 repr: self.repr,
642 _lifetime: PhantomData,
643 }
644 }
645
646 pub fn rows(&self) -> Rows<'_, T> {
648 Rows::new(self.reborrow())
649 }
650
651 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
653 RowsMut::new(self.reborrow_mut())
654 }
655
656 pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
666 Self { ptr, repr }
667 }
668
669 pub fn as_raw_ptr(&self) -> *const u8 {
671 self.ptr.as_ptr()
672 }
673}
674
675impl<T: ReprOwned> Drop for Mat<T> {
676 fn drop(&mut self) {
677 unsafe { self.repr.drop(self.ptr) };
680 }
681}
682
683impl<T: NewCloned> Clone for Mat<T> {
684 fn clone(&self) -> Self {
685 T::new_cloned(self.as_view())
686 }
687}
688
689impl<T: Copy> Mat<Standard<T>> {
690 #[inline]
692 pub fn vector_dim(&self) -> usize {
693 self.repr.ncols()
694 }
695}
696
697#[derive(Debug, Clone, Copy)]
713pub struct MatRef<'a, T: Repr> {
714 pub(crate) ptr: NonNull<u8>,
715 pub(crate) repr: T,
716 pub(crate) _lifetime: PhantomData<&'a [u8]>,
718}
719
720unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
722
723unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
725
726impl<'a, T: Repr> MatRef<'a, T> {
727 pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
729 where
730 T: NewRef<U>,
731 {
732 repr.new_ref(data)
733 }
734
735 #[inline]
737 pub fn num_vectors(&self) -> usize {
738 self.repr.nrows()
739 }
740
741 pub fn repr(&self) -> &T {
743 &self.repr
744 }
745
746 #[must_use]
748 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
749 if i < self.num_vectors() {
750 let row = unsafe { self.get_row_unchecked(i) };
753 Some(row)
754 } else {
755 None
756 }
757 }
758
759 #[inline]
765 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
766 unsafe { self.repr.get_row(self.ptr, i) }
768 }
769
770 pub fn rows(&self) -> Rows<'_, T> {
772 Rows::new(*self)
773 }
774
775 pub fn to_owned(&self) -> Mat<T>
777 where
778 T: NewCloned,
779 {
780 T::new_cloned(*self)
781 }
782
783 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
791 Self {
792 ptr,
793 repr,
794 _lifetime: PhantomData,
795 }
796 }
797
798 pub fn as_raw_ptr(&self) -> *const u8 {
800 self.ptr.as_ptr()
801 }
802}
803
804impl<'a, T: Copy> MatRef<'a, Standard<T>> {
805 #[inline]
807 pub fn vector_dim(&self) -> usize {
808 self.repr.ncols()
809 }
810}
811
812impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
814 type Target = MatRef<'this, T>;
815
816 fn reborrow(&'this self) -> Self::Target {
817 self.as_view()
818 }
819}
820
821impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
823 type Target = MatMut<'this, T>;
824
825 fn reborrow_mut(&'this mut self) -> Self::Target {
826 self.as_view_mut()
827 }
828}
829
830impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
832 type Target = MatRef<'this, T>;
833
834 fn reborrow(&'this self) -> Self::Target {
835 MatRef {
836 ptr: self.ptr,
837 repr: self.repr,
838 _lifetime: PhantomData,
839 }
840 }
841}
842
843#[derive(Debug)]
860pub struct MatMut<'a, T: ReprMut> {
861 pub(crate) ptr: NonNull<u8>,
862 pub(crate) repr: T,
863 pub(crate) _lifetime: PhantomData<&'a mut [u8]>,
865}
866
867unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
869
870unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
872
873impl<'a, T: ReprMut> MatMut<'a, T> {
874 pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
876 where
877 T: NewMut<U>,
878 {
879 repr.new_mut(data)
880 }
881
882 #[inline]
884 pub fn num_vectors(&self) -> usize {
885 self.repr.nrows()
886 }
887
888 pub fn repr(&self) -> &T {
890 &self.repr
891 }
892
893 #[inline]
895 #[must_use]
896 pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
897 if i < self.num_vectors() {
898 Some(unsafe { self.get_row_unchecked(i) })
900 } else {
901 None
902 }
903 }
904
905 #[inline]
911 pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
912 unsafe { self.repr.get_row(self.ptr, i) }
914 }
915
916 #[inline]
918 #[must_use]
919 pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
920 if i < self.num_vectors() {
921 Some(unsafe { self.get_row_mut_unchecked(i) })
923 } else {
924 None
925 }
926 }
927
928 #[inline]
934 pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
935 unsafe { self.repr.get_row_mut(self.ptr, i) }
938 }
939
940 pub fn as_view(&self) -> MatRef<'_, T> {
942 MatRef {
943 ptr: self.ptr,
944 repr: self.repr,
945 _lifetime: PhantomData,
946 }
947 }
948
949 pub fn rows(&self) -> Rows<'_, T> {
951 Rows::new(self.reborrow())
952 }
953
954 pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
956 RowsMut::new(self.reborrow_mut())
957 }
958
959 pub fn to_owned(&self) -> Mat<T>
961 where
962 T: NewCloned,
963 {
964 T::new_cloned(self.as_view())
965 }
966
967 pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
974 Self {
975 ptr,
976 repr,
977 _lifetime: PhantomData,
978 }
979 }
980
981 pub fn as_raw_ptr(&self) -> *const u8 {
983 self.ptr.as_ptr()
984 }
985}
986
987impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
989 type Target = MatRef<'this, T>;
990
991 fn reborrow(&'this self) -> Self::Target {
992 self.as_view()
993 }
994}
995
996impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
998 type Target = MatMut<'this, T>;
999
1000 fn reborrow_mut(&'this mut self) -> Self::Target {
1001 MatMut {
1002 ptr: self.ptr,
1003 repr: self.repr,
1004 _lifetime: PhantomData,
1005 }
1006 }
1007}
1008
1009impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1010 #[inline]
1012 pub fn vector_dim(&self) -> usize {
1013 self.repr.ncols()
1014 }
1015}
1016
1017#[derive(Debug)]
1025pub struct Rows<'a, T: Repr> {
1026 matrix: MatRef<'a, T>,
1027 current: usize,
1028}
1029
1030impl<'a, T> Rows<'a, T>
1031where
1032 T: Repr,
1033{
1034 fn new(matrix: MatRef<'a, T>) -> Self {
1035 Self { matrix, current: 0 }
1036 }
1037}
1038
1039impl<'a, T> Iterator for Rows<'a, T>
1040where
1041 T: Repr + 'a,
1042{
1043 type Item = T::Row<'a>;
1044
1045 fn next(&mut self) -> Option<Self::Item> {
1046 let current = self.current;
1047 if current >= self.matrix.num_vectors() {
1048 None
1049 } else {
1050 self.current += 1;
1051 Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1057 }
1058 }
1059
1060 fn size_hint(&self) -> (usize, Option<usize>) {
1061 let remaining = self.matrix.num_vectors() - self.current;
1062 (remaining, Some(remaining))
1063 }
1064}
1065
1066impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1067impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1068
1069#[derive(Debug)]
1077pub struct RowsMut<'a, T: ReprMut> {
1078 matrix: MatMut<'a, T>,
1079 current: usize,
1080}
1081
1082impl<'a, T> RowsMut<'a, T>
1083where
1084 T: ReprMut,
1085{
1086 fn new(matrix: MatMut<'a, T>) -> Self {
1087 Self { matrix, current: 0 }
1088 }
1089}
1090
1091impl<'a, T> Iterator for RowsMut<'a, T>
1092where
1093 T: ReprMut + 'a,
1094{
1095 type Item = T::RowMut<'a>;
1096
1097 fn next(&mut self) -> Option<Self::Item> {
1098 let current = self.current;
1099 if current >= self.matrix.num_vectors() {
1100 None
1101 } else {
1102 self.current += 1;
1103 Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1112 }
1113 }
1114
1115 fn size_hint(&self) -> (usize, Option<usize>) {
1116 let remaining = self.matrix.num_vectors() - self.current;
1117 (remaining, Some(remaining))
1118 }
1119}
1120
1121impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1122impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1123
1124#[cfg(test)]
1129mod tests {
1130 use super::*;
1131
1132 use std::fmt::Display;
1133
1134 use diskann_utils::lazy_format;
1135
1136 fn assert_copy<T: Copy>(_: &T) {}
1138
1139 fn edge_cases(nrows: usize) -> Vec<usize> {
1140 let max = usize::MAX;
1141
1142 vec![
1143 nrows,
1144 nrows + 1,
1145 nrows + 11,
1146 nrows + 20,
1147 max / 2,
1148 max.div_ceil(2),
1149 max - 1,
1150 max,
1151 ]
1152 }
1153
1154 fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1155 assert_eq!(x.repr(), &repr);
1156 assert_eq!(x.num_vectors(), repr.nrows());
1157 assert_eq!(x.vector_dim(), repr.ncols());
1158
1159 for i in 0..x.num_vectors() {
1160 let row = x.get_row_mut(i).unwrap();
1161 assert_eq!(row.len(), repr.ncols());
1162 row.iter_mut()
1163 .enumerate()
1164 .for_each(|(j, r)| *r = 10 * i + j);
1165 }
1166
1167 for i in edge_cases(repr.nrows()).into_iter() {
1168 assert!(x.get_row_mut(i).is_none());
1169 }
1170 }
1171
1172 fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1173 assert_eq!(x.repr(), &repr);
1174 assert_eq!(x.num_vectors(), repr.nrows());
1175 assert_eq!(x.vector_dim(), repr.ncols());
1176
1177 for i in 0..x.num_vectors() {
1178 let row = x.get_row_mut(i).unwrap();
1179 assert_eq!(row.len(), repr.ncols());
1180
1181 row.iter_mut()
1182 .enumerate()
1183 .for_each(|(j, r)| *r = 10 * i + j);
1184 }
1185
1186 for i in edge_cases(repr.nrows()).into_iter() {
1187 assert!(x.get_row_mut(i).is_none());
1188 }
1189 }
1190
1191 fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1192 assert_eq!(x.len(), repr.nrows());
1193 let mut all_rows: Vec<_> = x.collect();
1195 assert_eq!(all_rows.len(), repr.nrows());
1196 for (i, row) in all_rows.iter_mut().enumerate() {
1197 assert_eq!(row.len(), repr.ncols());
1198 row.iter_mut()
1199 .enumerate()
1200 .for_each(|(j, r)| *r = 10 * i + j);
1201 }
1202 }
1203
1204 fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1205 assert_eq!(x.repr(), &repr);
1206 assert_eq!(x.num_vectors(), repr.nrows());
1207 assert_eq!(x.vector_dim(), repr.ncols());
1208
1209 for i in 0..x.num_vectors() {
1210 let row = x.get_row(i).unwrap();
1211
1212 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1213 row.iter().enumerate().for_each(|(j, r)| {
1214 assert_eq!(
1215 *r,
1216 10 * i + j,
1217 "mismatched entry at row {}, col {} -- ctx: {}",
1218 i,
1219 j,
1220 ctx
1221 )
1222 });
1223 }
1224
1225 for i in edge_cases(repr.nrows()).into_iter() {
1226 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1227 }
1228 }
1229
1230 fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1231 assert_eq!(x.repr(), &repr);
1232 assert_eq!(x.num_vectors(), repr.nrows());
1233 assert_eq!(x.vector_dim(), repr.ncols());
1234
1235 assert_copy(&x);
1236 for i in 0..x.num_vectors() {
1237 let row = x.get_row(i).unwrap();
1238 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1239
1240 row.iter().enumerate().for_each(|(j, r)| {
1241 assert_eq!(
1242 *r,
1243 10 * i + j,
1244 "mismatched entry at row {}, col {} -- ctx: {}",
1245 i,
1246 j,
1247 ctx
1248 )
1249 });
1250 }
1251
1252 for i in edge_cases(repr.nrows()).into_iter() {
1253 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1254 }
1255 }
1256
1257 fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1258 assert_eq!(x.repr(), &repr);
1259 assert_eq!(x.num_vectors(), repr.nrows());
1260 assert_eq!(x.vector_dim(), repr.ncols());
1261
1262 for i in 0..x.num_vectors() {
1263 let row = x.get_row(i).unwrap();
1264 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1265
1266 row.iter().enumerate().for_each(|(j, r)| {
1267 assert_eq!(
1268 *r,
1269 10 * i + j,
1270 "mismatched entry at row {}, col {} -- ctx: {}",
1271 i,
1272 j,
1273 ctx
1274 )
1275 });
1276 }
1277
1278 for i in edge_cases(repr.nrows()).into_iter() {
1279 assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1280 }
1281 }
1282
1283 fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1284 assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1285 let all_rows: Vec<_> = x.collect();
1286 assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1287 for (i, row) in all_rows.iter().enumerate() {
1288 assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1289 row.iter().enumerate().for_each(|(j, r)| {
1290 assert_eq!(
1291 *r,
1292 10 * i + j,
1293 "mismatched entry at row {}, col {} -- ctx: {}",
1294 i,
1295 j,
1296 ctx
1297 )
1298 });
1299 }
1300 }
1301
1302 #[test]
1307 fn standard_representation() {
1308 let repr = Standard::<f32>::new(4, 3).unwrap();
1309 assert_eq!(repr.nrows(), 4);
1310 assert_eq!(repr.ncols(), 3);
1311
1312 let layout = repr.layout().unwrap();
1313 assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1314 assert_eq!(layout.align(), std::mem::align_of::<f32>());
1315 }
1316
1317 #[test]
1318 fn standard_zero_dimensions() {
1319 for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1320 let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1321 assert_eq!(repr.nrows(), nrows);
1322 assert_eq!(repr.ncols(), ncols);
1323 let layout = repr.layout().unwrap();
1324 assert_eq!(layout.size(), 0);
1325 }
1326 }
1327
1328 #[test]
1329 fn standard_check_slice() {
1330 let repr = Standard::<u32>::new(3, 4).unwrap();
1331
1332 let data = vec![0u32; 12];
1334 assert!(repr.check_slice(&data).is_ok());
1335
1336 let short = vec![0u32; 11];
1338 assert!(matches!(
1339 repr.check_slice(&short),
1340 Err(SliceError::LengthMismatch {
1341 expected: 12,
1342 found: 11
1343 })
1344 ));
1345
1346 let long = vec![0u32; 13];
1348 assert!(matches!(
1349 repr.check_slice(&long),
1350 Err(SliceError::LengthMismatch {
1351 expected: 12,
1352 found: 13
1353 })
1354 ));
1355
1356 let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1358 assert!(matches!(overflow_repr, Overflow { .. }));
1359 }
1360
1361 #[test]
1362 fn standard_new_rejects_element_count_overflow() {
1363 assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1365 assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1366 assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1367 }
1368
1369 #[test]
1370 fn standard_new_rejects_byte_count_exceeding_isize_max() {
1371 let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1373 assert!(Standard::<u64>::new(half, 1).is_err());
1374 assert!(Standard::<u64>::new(1, half).is_err());
1375 }
1376
1377 #[test]
1378 fn standard_new_accepts_boundary_below_isize_max() {
1379 let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1381 let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1382 assert_eq!(repr.num_elements(), max_elems);
1383 }
1384
1385 #[test]
1386 fn standard_new_zst_rejects_element_count_overflow() {
1387 assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1390 assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1391 }
1392
1393 #[test]
1394 fn standard_new_zst_accepts_large_non_overflowing() {
1395 let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1397 assert_eq!(repr.num_elements(), usize::MAX);
1398 assert_eq!(repr.layout().unwrap().size(), 0);
1399 }
1400
1401 #[test]
1402 fn standard_new_overflow_error_display() {
1403 let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1404 let msg = err.to_string();
1405 assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1406
1407 let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1408 let zst_msg = zst_err.to_string();
1409 assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1410 assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1411 }
1412
1413 #[test]
1418 fn mat_new_and_basic_accessors() {
1419 let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1420 let base: *const u8 = mat.as_raw_ptr();
1421
1422 assert_eq!(mat.num_vectors(), 3);
1423 assert_eq!(mat.vector_dim(), 4);
1424
1425 let repr = mat.repr();
1426 assert_eq!(repr.nrows(), 3);
1427 assert_eq!(repr.ncols(), 4);
1428
1429 for (i, r) in mat.rows().enumerate() {
1430 assert_eq!(r, &[42, 42, 42, 42]);
1431 let ptr = r.as_ptr().cast::<u8>();
1432 assert_eq!(
1433 ptr,
1434 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1435 );
1436 }
1437 }
1438
1439 #[test]
1440 fn mat_new_with_default() {
1441 let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1442 let base: *const u8 = mat.as_raw_ptr();
1443
1444 assert_eq!(mat.num_vectors(), 2);
1445 for (i, row) in mat.rows().enumerate() {
1446 assert!(row.iter().all(|&v| v == 0));
1447
1448 let ptr = row.as_ptr().cast::<u8>();
1449 assert_eq!(
1450 ptr,
1451 base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1452 );
1453 }
1454 }
1455
1456 const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1457 const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1458
1459 #[test]
1460 fn test_mat() {
1461 for nrows in ROWS {
1462 for ncols in COLS {
1463 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1464 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1465
1466 {
1468 let ctx = &lazy_format!("{ctx} - direct");
1469 let mut mat = Mat::new(repr, Defaulted).unwrap();
1470
1471 assert_eq!(mat.num_vectors(), *nrows);
1472 assert_eq!(mat.vector_dim(), *ncols);
1473
1474 fill_mat(&mut mat, repr);
1475
1476 check_mat(&mat, repr, ctx);
1477 check_mat_ref(mat.reborrow(), repr, ctx);
1478 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1479 check_rows(mat.rows(), repr, ctx);
1480
1481 assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1483 assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1484 }
1485
1486 {
1488 let ctx = &lazy_format!("{ctx} - matmut");
1489 let mut mat = Mat::new(repr, Defaulted).unwrap();
1490 let matmut = mat.reborrow_mut();
1491
1492 assert_eq!(matmut.num_vectors(), *nrows);
1493 assert_eq!(matmut.vector_dim(), *ncols);
1494
1495 fill_mat_mut(matmut, repr);
1496
1497 check_mat(&mat, repr, ctx);
1498 check_mat_ref(mat.reborrow(), repr, ctx);
1499 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1500 check_rows(mat.rows(), repr, ctx);
1501 }
1502
1503 {
1505 let ctx = &lazy_format!("{ctx} - rows_mut");
1506 let mut mat = Mat::new(repr, Defaulted).unwrap();
1507 fill_rows_mut(mat.rows_mut(), repr);
1508
1509 check_mat(&mat, repr, ctx);
1510 check_mat_ref(mat.reborrow(), repr, ctx);
1511 check_mat_mut(mat.reborrow_mut(), repr, ctx);
1512 check_rows(mat.rows(), repr, ctx);
1513 }
1514 }
1515 }
1516 }
1517
1518 #[test]
1519 fn test_mat_clone() {
1520 for nrows in ROWS {
1521 for ncols in COLS {
1522 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1523 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1524
1525 let mut mat = Mat::new(repr, Defaulted).unwrap();
1526 fill_mat(&mut mat, repr);
1527
1528 {
1530 let ctx = &lazy_format!("{ctx} - Mat::clone");
1531 let cloned = mat.clone();
1532
1533 assert_eq!(cloned.num_vectors(), *nrows);
1534 assert_eq!(cloned.vector_dim(), *ncols);
1535
1536 check_mat(&cloned, repr, ctx);
1537 check_mat_ref(cloned.reborrow(), repr, ctx);
1538 check_rows(cloned.rows(), repr, ctx);
1539
1540 if repr.num_elements() > 0 {
1542 assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1543 }
1544 }
1545
1546 {
1548 let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1549 let owned = mat.as_view().to_owned();
1550
1551 check_mat(&owned, repr, ctx);
1552 check_mat_ref(owned.reborrow(), repr, ctx);
1553 check_rows(owned.rows(), repr, ctx);
1554
1555 if repr.num_elements() > 0 {
1556 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1557 }
1558 }
1559
1560 {
1562 let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1563 let owned = mat.as_view_mut().to_owned();
1564
1565 check_mat(&owned, repr, ctx);
1566 check_mat_ref(owned.reborrow(), repr, ctx);
1567 check_rows(owned.rows(), repr, ctx);
1568
1569 if repr.num_elements() > 0 {
1570 assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1571 }
1572 }
1573 }
1574 }
1575 }
1576
1577 #[test]
1578 fn test_mat_refmut() {
1579 for nrows in ROWS {
1580 for ncols in COLS {
1581 let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1582 let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1583
1584 {
1586 let ctx = &lazy_format!("{ctx} - by matmut");
1587 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1588 let ptr = b.as_ptr().cast::<u8>();
1589 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1590
1591 assert_eq!(
1592 ptr,
1593 matmut.as_raw_ptr(),
1594 "underlying memory should be preserved",
1595 );
1596
1597 fill_mat_mut(matmut.reborrow_mut(), repr);
1598
1599 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1600 check_mat_ref(matmut.reborrow(), repr, ctx);
1601 check_rows(matmut.rows(), repr, ctx);
1602 check_rows(matmut.reborrow().rows(), repr, ctx);
1603
1604 let matref = MatRef::new(repr, &b).unwrap();
1605 check_mat_ref(matref, repr, ctx);
1606 check_mat_ref(matref.reborrow(), repr, ctx);
1607 check_rows(matref.rows(), repr, ctx);
1608 }
1609
1610 {
1612 let ctx = &lazy_format!("{ctx} - by rows");
1613 let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1614 let ptr = b.as_ptr().cast::<u8>();
1615 let mut matmut = MatMut::new(repr, &mut b).unwrap();
1616
1617 assert_eq!(
1618 ptr,
1619 matmut.as_raw_ptr(),
1620 "underlying memory should be preserved",
1621 );
1622
1623 fill_rows_mut(matmut.rows_mut(), repr);
1624
1625 check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1626 check_mat_ref(matmut.reborrow(), repr, ctx);
1627 check_rows(matmut.rows(), repr, ctx);
1628 check_rows(matmut.reborrow().rows(), repr, ctx);
1629
1630 let matref = MatRef::new(repr, &b).unwrap();
1631 check_mat_ref(matref, repr, ctx);
1632 check_mat_ref(matref.reborrow(), repr, ctx);
1633 check_rows(matref.rows(), repr, ctx);
1634 }
1635 }
1636 }
1637 }
1638
1639 #[test]
1644 fn test_standard_new_owned() {
1645 let rows = [0, 1, 2, 3, 5, 10];
1646 let cols = [0, 1, 2, 3, 5, 10];
1647
1648 for nrows in rows {
1649 for ncols in cols {
1650 let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1651 let rows_iter = m.rows();
1652 let len = <_ as ExactSizeIterator>::len(&rows_iter);
1653 assert_eq!(len, nrows);
1654 for r in rows_iter {
1655 assert_eq!(r.len(), ncols);
1656 assert!(r.iter().all(|i| *i == 1usize));
1657 }
1658 }
1659 }
1660 }
1661
1662 #[test]
1663 fn matref_new_slice_length_error() {
1664 let repr = Standard::<u32>::new(3, 4).unwrap();
1665
1666 let data = vec![0u32; 12];
1668 assert!(MatRef::new(repr, &data).is_ok());
1669
1670 let short = vec![0u32; 11];
1672 assert!(matches!(
1673 MatRef::new(repr, &short),
1674 Err(SliceError::LengthMismatch {
1675 expected: 12,
1676 found: 11
1677 })
1678 ));
1679
1680 let long = vec![0u32; 13];
1682 assert!(matches!(
1683 MatRef::new(repr, &long),
1684 Err(SliceError::LengthMismatch {
1685 expected: 12,
1686 found: 13
1687 })
1688 ));
1689 }
1690
1691 #[test]
1692 fn matmut_new_slice_length_error() {
1693 let repr = Standard::<u32>::new(3, 4).unwrap();
1694
1695 let mut data = vec![0u32; 12];
1697 assert!(MatMut::new(repr, &mut data).is_ok());
1698
1699 let mut short = vec![0u32; 11];
1701 assert!(matches!(
1702 MatMut::new(repr, &mut short),
1703 Err(SliceError::LengthMismatch {
1704 expected: 12,
1705 found: 11
1706 })
1707 ));
1708
1709 let mut long = vec![0u32; 13];
1711 assert!(matches!(
1712 MatMut::new(repr, &mut long),
1713 Err(SliceError::LengthMismatch {
1714 expected: 12,
1715 found: 13
1716 })
1717 ));
1718 }
1719}