1use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};
81
82use diskann_utils::{
83 Reborrow, ReborrowMut,
84 strided::StridedView,
85 views::{MatrixView, MutMatrixView},
86};
87
88use super::matrix::{
89 Defaulted, LayoutError, Mat, MatMut, MatRef, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut,
90 ReprOwned, SliceError,
91};
92use crate::bits::{AsMutPtr, AsPtr, MutSlicePtr, SlicePtr};
93use crate::utils;
94
95#[inline]
97fn padded_ncols<const PACK: usize>(ncols: usize) -> usize {
98 ncols.next_multiple_of(PACK)
99}
100
101#[inline]
111fn compute_capacity<const GROUP: usize, const PACK: usize>(nrows: usize, ncols: usize) -> usize {
112 nrows.next_multiple_of(GROUP) * padded_ncols::<PACK>(ncols)
113}
114
115#[inline]
119fn checked_compute_capacity<const GROUP: usize, const PACK: usize>(
120 nrows: usize,
121 ncols: usize,
122) -> Option<usize> {
123 nrows
124 .checked_next_multiple_of(GROUP)?
125 .checked_mul(ncols.checked_next_multiple_of(PACK)?)
126}
127
128#[inline]
131fn linear_index<const GROUP: usize, const PACK: usize>(
132 row: usize,
133 col: usize,
134 ncols: usize,
135) -> usize {
136 let pncols = padded_ncols::<PACK>(ncols);
137 let block = row / GROUP;
138 let row_in_block = row % GROUP;
139 block * GROUP * pncols + (col / PACK) * GROUP * PACK + row_in_block * PACK + (col % PACK)
140}
141
142#[inline]
147fn col_offset<const GROUP: usize, const PACK: usize>(col: usize) -> usize {
148 (col / PACK) * GROUP * PACK + (col % PACK)
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub(crate) struct BlockTransposedRepr<T, const GROUP: usize, const PACK: usize = 1> {
157 nrows: usize,
158 ncols: usize,
159 _elem: PhantomData<T>,
160}
161
162impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROUP, PACK> {
163 const _ASSERTIONS: () = {
165 assert!(GROUP > 0, "group size GROUP must be positive");
166 assert!(PACK > 0, "packing factor PACK must be positive");
167 assert!(
168 GROUP.is_multiple_of(PACK),
169 "GROUP must be divisible by PACK"
170 );
171 };
172
173 pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
178 let () = Self::_ASSERTIONS;
179 let capacity = checked_compute_capacity::<GROUP, PACK>(nrows, ncols)
180 .ok_or_else(|| Overflow::for_type::<T>(nrows, ncols))?;
181 Overflow::check_byte_budget::<T>(capacity, nrows, ncols)?;
182 Ok(Self {
183 nrows,
184 ncols,
185 _elem: PhantomData,
186 })
187 }
188
189 #[inline]
193 fn storage_len(&self) -> usize {
194 compute_capacity::<GROUP, PACK>(self.nrows, self.ncols)
195 }
196
197 #[inline]
199 fn nrows(&self) -> usize {
200 self.nrows
201 }
202
203 #[inline]
205 pub fn ncols(&self) -> usize {
206 self.ncols
207 }
208
209 #[inline]
212 pub fn padded_ncols(&self) -> usize {
213 padded_ncols::<PACK>(self.ncols)
214 }
215
216 #[inline]
218 pub fn full_blocks(&self) -> usize {
219 self.nrows / GROUP
220 }
221
222 #[inline]
224 pub fn num_blocks(&self) -> usize {
225 self.nrows.div_ceil(GROUP)
226 }
227
228 #[inline]
230 pub fn remainder(&self) -> usize {
231 self.nrows % GROUP
232 }
233
234 #[inline]
239 pub fn padded_nrows(&self) -> usize {
240 self.num_blocks() * GROUP
241 }
242
243 #[inline]
245 fn block_stride(&self) -> usize {
246 GROUP * self.padded_ncols()
247 }
248
249 #[inline]
251 fn block_offset(&self, block: usize) -> usize {
252 block * self.block_stride()
253 }
254
255 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
257 let cap = self.storage_len();
258 if slice.len() != cap {
259 Err(SliceError::LengthMismatch {
260 expected: cap,
261 found: slice.len(),
262 })
263 } else {
264 Ok(())
265 }
266 }
267
268 unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
274 debug_assert_eq!(b.len(), self.storage_len(), "safety contract violated");
275
276 let ptr = utils::box_into_nonnull(b).cast::<u8>();
277
278 unsafe { Mat::from_raw_parts(self, ptr) }
280 }
281}
282
283#[derive(Debug, Clone, Copy)]
292pub struct Row<'a, T, const GROUP: usize, const PACK: usize = 1> {
293 base: SlicePtr<'a, T>,
295 ncols: usize,
296}
297
298impl<T: Copy, const GROUP: usize, const PACK: usize> Row<'_, T, GROUP, PACK> {
299 #[inline]
301 pub fn len(&self) -> usize {
302 self.ncols
303 }
304
305 #[inline]
307 pub fn is_empty(&self) -> bool {
308 self.ncols == 0
309 }
310
311 #[inline]
313 pub fn get(&self, col: usize) -> Option<&T> {
314 if col < self.ncols {
315 Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
317 } else {
318 None
319 }
320 }
321
322 #[inline]
324 pub fn iter(&self) -> RowIter<'_, T, GROUP, PACK> {
325 RowIter {
326 base: self.base,
327 col: 0,
328 ncols: self.ncols,
329 }
330 }
331}
332
333impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
334 for Row<'_, T, GROUP, PACK>
335{
336 type Output = T;
337
338 #[inline]
339 #[allow(clippy::panic)] fn index(&self, col: usize) -> &Self::Output {
341 self.get(col)
342 .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
343 }
344}
345
346#[derive(Debug, Clone)]
348pub struct RowIter<'a, T, const GROUP: usize, const PACK: usize = 1> {
349 base: SlicePtr<'a, T>,
350 col: usize,
351 ncols: usize,
352}
353
354impl<T: Copy, const GROUP: usize, const PACK: usize> Iterator for RowIter<'_, T, GROUP, PACK> {
355 type Item = T;
356
357 #[inline]
358 fn next(&mut self) -> Option<Self::Item> {
359 if self.col >= self.ncols {
360 return None;
361 }
362 let val = unsafe { *self.base.as_ptr().add(col_offset::<GROUP, PACK>(self.col)) };
364 self.col += 1;
365 Some(val)
366 }
367
368 #[inline]
369 fn size_hint(&self) -> (usize, Option<usize>) {
370 let remaining = self.ncols - self.col;
371 (remaining, Some(remaining))
372 }
373}
374
375impl<T: Copy, const GROUP: usize, const PACK: usize> ExactSizeIterator
376 for RowIter<'_, T, GROUP, PACK>
377{
378}
379impl<T: Copy, const GROUP: usize, const PACK: usize> std::iter::FusedIterator
380 for RowIter<'_, T, GROUP, PACK>
381{
382}
383
384#[derive(Debug)]
386pub struct RowMut<'a, T, const GROUP: usize, const PACK: usize = 1> {
387 base: MutSlicePtr<'a, T>,
388 ncols: usize,
389}
390
391impl<T: Copy, const GROUP: usize, const PACK: usize> RowMut<'_, T, GROUP, PACK> {
392 #[inline]
394 pub fn len(&self) -> usize {
395 self.ncols
396 }
397
398 #[inline]
400 pub fn is_empty(&self) -> bool {
401 self.ncols == 0
402 }
403
404 #[inline]
406 pub fn get(&self, col: usize) -> Option<&T> {
407 if col < self.ncols {
408 Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
410 } else {
411 None
412 }
413 }
414
415 #[inline]
417 pub fn get_mut(&mut self, col: usize) -> Option<&mut T> {
418 if col < self.ncols {
419 Some(unsafe { &mut *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) })
421 } else {
422 None
423 }
424 }
425
426 #[inline]
432 pub fn set(&mut self, col: usize, value: T) {
433 assert!(
434 col < self.ncols,
435 "column index {col} out of bounds (ncols = {})",
436 self.ncols
437 );
438 unsafe { *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) = value };
440 }
441}
442
443impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
444 for RowMut<'_, T, GROUP, PACK>
445{
446 type Output = T;
447
448 #[inline]
449 #[allow(clippy::panic)] fn index(&self, col: usize) -> &Self::Output {
451 self.get(col)
452 .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
453 }
454}
455
456impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::IndexMut<usize>
457 for RowMut<'_, T, GROUP, PACK>
458{
459 #[inline]
460 #[allow(clippy::panic)] fn index_mut(&mut self, col: usize) -> &mut Self::Output {
462 let ncols = self.ncols;
463 self.get_mut(col)
464 .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {ncols})"))
465 }
466}
467
468unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> Repr
475 for BlockTransposedRepr<T, GROUP, PACK>
476{
477 type Row<'a>
478 = Row<'a, T, GROUP, PACK>
479 where
480 Self: 'a;
481
482 fn nrows(&self) -> usize {
483 self.nrows
484 }
485
486 fn layout(&self) -> Result<Layout, LayoutError> {
487 Ok(Layout::array::<T>(self.storage_len())?)
488 }
489
490 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
491 debug_assert!(i < self.nrows);
492
493 if self.ncols == 0 {
496 return Row {
497 base: unsafe { SlicePtr::new_unchecked(NonNull::dangling()) },
500 ncols: 0,
501 };
502 }
503
504 let base_ptr = ptr.as_ptr().cast::<T>();
505 let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
506
507 let row_base = unsafe { base_ptr.add(offset) };
510
511 Row {
512 base: unsafe { SlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
515 ncols: self.ncols,
516 }
517 }
518}
519
520unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprMut
524 for BlockTransposedRepr<T, GROUP, PACK>
525{
526 type RowMut<'a>
527 = RowMut<'a, T, GROUP, PACK>
528 where
529 Self: 'a;
530
531 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
532 debug_assert!(i < self.nrows);
533
534 if self.ncols == 0 {
537 return RowMut {
538 base: unsafe { MutSlicePtr::new_unchecked(NonNull::dangling()) },
541 ncols: 0,
542 };
543 }
544
545 let base_ptr = ptr.as_ptr().cast::<T>();
546 let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
547
548 let row_base = unsafe { base_ptr.add(offset) };
551
552 RowMut {
553 base: unsafe { MutSlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
556 ncols: self.ncols,
557 }
558 }
559}
560
561unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprOwned
564 for BlockTransposedRepr<T, GROUP, PACK>
565{
566 unsafe fn drop(self, ptr: NonNull<u8>) {
567 unsafe {
569 let slice_ptr =
570 std::ptr::slice_from_raw_parts_mut(ptr.cast::<T>().as_ptr(), self.storage_len());
571 let _ = Box::from_raw(slice_ptr);
572 }
573 }
574}
575
576unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewOwned<T>
582 for BlockTransposedRepr<T, GROUP, PACK>
583{
584 type Error = crate::error::Infallible;
585
586 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
587 let b: Box<[T]> = vec![value; self.storage_len()].into_boxed_slice();
588
589 Ok(unsafe { self.box_to_mat(b) })
591 }
592}
593
594unsafe impl<T: Copy + Default, const GROUP: usize, const PACK: usize> NewOwned<Defaulted>
596 for BlockTransposedRepr<T, GROUP, PACK>
597{
598 type Error = crate::error::Infallible;
599
600 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
601 self.new_owned(T::default())
602 }
603}
604
605unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewRef<T>
607 for BlockTransposedRepr<T, GROUP, PACK>
608{
609 type Error = SliceError;
610
611 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
612 self.check_slice(data)?;
613
614 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
616 }
617}
618
619unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewMut<T>
621 for BlockTransposedRepr<T, GROUP, PACK>
622{
623 type Error = SliceError;
624
625 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
626 self.check_slice(data)?;
627
628 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
630 }
631}
632
633macro_rules! delegate_to_ref {
642 ($(#[$m:meta])* $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
644 #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
645 $(#[$m])*
646 #[inline]
647 $vis fn $name(&self $(, $a: $t)*) $(-> $r)? {
648 self.as_view().$name($($a),*)
649 }
650 };
651 ($(#[$m:meta])* unsafe $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
653 #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
654 $(#[$m])*
655 #[inline]
656 $vis unsafe fn $name(&self $(, $a: $t)*) $(-> $r)? {
657 unsafe { self.as_view().$name($($a),*) }
659 }
660 };
661}
662
663#[derive(Debug)]
681pub struct BlockTransposed<T: Copy, const GROUP: usize, const PACK: usize = 1> {
682 data: Mat<BlockTransposedRepr<T, GROUP, PACK>>,
683}
684
685#[derive(Debug, Clone, Copy)]
689pub struct BlockTransposedRef<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
690 data: MatRef<'a, BlockTransposedRepr<T, GROUP, PACK>>,
691}
692
693pub struct BlockTransposedMut<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
697 data: MatMut<'a, BlockTransposedRepr<T, GROUP, PACK>>,
698}
699
700impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, T, GROUP, PACK> {
703 #[inline]
705 pub fn nrows(&self) -> usize {
706 self.data.repr().nrows()
707 }
708
709 #[inline]
711 pub fn ncols(&self) -> usize {
712 self.data.repr().ncols()
713 }
714
715 #[inline]
717 pub fn padded_ncols(&self) -> usize {
718 self.data.repr().padded_ncols()
719 }
720
721 pub const fn group_size(&self) -> usize {
723 GROUP
724 }
725
726 pub const fn const_group_size() -> usize {
728 GROUP
729 }
730
731 pub const fn pack_size(&self) -> usize {
733 PACK
734 }
735
736 #[inline]
738 pub fn full_blocks(&self) -> usize {
739 self.data.repr().full_blocks()
740 }
741
742 #[inline]
744 pub fn num_blocks(&self) -> usize {
745 self.data.repr().num_blocks()
746 }
747
748 #[inline]
751 pub fn remainder(&self) -> usize {
752 self.data.repr().remainder()
753 }
754
755 #[inline]
760 pub fn padded_nrows(&self) -> usize {
761 self.data.repr().padded_nrows()
762 }
763
764 #[inline]
766 pub fn as_ptr(&self) -> *const T {
767 self.data.as_raw_ptr().cast::<T>()
768 }
769
770 #[inline]
775 pub fn as_slice(&self) -> &'a [T] {
776 let len = self.data.repr().storage_len();
777 unsafe { std::slice::from_raw_parts(self.as_ptr(), len) }
779 }
780
781 #[inline]
793 pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const T {
794 debug_assert!(block < self.num_blocks());
795 unsafe { self.as_ptr().add(self.data.repr().block_offset(block)) }
797 }
798
799 #[allow(clippy::expect_used)]
809 pub fn block(&self, block: usize) -> MatrixView<'a, T> {
810 assert!(block < self.full_blocks());
811 let offset = self.data.repr().block_offset(block);
812 let stride = self.data.repr().block_stride();
813 let data: &[T] = unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
816 MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
817 .expect("base data should have been sized correctly")
818 }
819
820 #[allow(clippy::expect_used)]
826 pub fn remainder_block(&self) -> Option<MatrixView<'a, T>> {
827 if self.remainder() == 0 {
828 None
829 } else {
830 let offset = self.data.repr().block_offset(self.full_blocks());
831 let stride = self.data.repr().block_stride();
832 let data: &[T] =
835 unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
836 Some(
837 MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
838 .expect("base data should have been sized correctly"),
839 )
840 }
841 }
842
843 #[inline]
849 pub fn get_element(&self, row: usize, col: usize) -> T {
850 assert!(
851 row < self.nrows(),
852 "row {row} out of bounds (nrows = {})",
853 self.nrows()
854 );
855 assert!(
856 col < self.ncols(),
857 "col {col} out of bounds (ncols = {})",
858 self.ncols()
859 );
860 let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
861 unsafe { *self.as_ptr().add(idx) }
863 }
864
865 #[inline]
867 pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
868 self.data.get_row(i)
869 }
870}
871
872impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, T, GROUP, PACK> {
875 #[inline]
877 pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
878 BlockTransposedRef {
879 data: self.data.as_view(),
880 }
881 }
882
883 delegate_to_ref!(pub fn nrows(&self) -> usize);
886 delegate_to_ref!(pub fn ncols(&self) -> usize);
887 delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
888 delegate_to_ref!(pub fn full_blocks(&self) -> usize);
889 delegate_to_ref!(pub fn num_blocks(&self) -> usize);
890 delegate_to_ref!(pub fn remainder(&self) -> usize);
891 delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
892 delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
893 delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
894 delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
895 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
896 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
897 delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
898
899 pub const fn group_size(&self) -> usize {
901 GROUP
902 }
903
904 pub const fn const_group_size() -> usize {
906 GROUP
907 }
908
909 pub const fn pack_size(&self) -> usize {
911 PACK
912 }
913
914 #[inline]
916 pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
917 self.data.get_row(i)
918 }
919
920 #[inline]
931 pub fn as_mut_slice(&mut self) -> &mut [T] {
932 self.reborrow_mut().mut_slice_inner()
933 }
934
935 fn mut_slice_inner(mut self) -> &'a mut [T] {
936 let len = self.data.repr().storage_len();
937 unsafe { std::slice::from_raw_parts_mut(self.data.as_raw_mut_ptr().cast::<T>(), len) }
939 }
940
941 #[allow(clippy::expect_used)]
947 pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
948 self.reborrow_mut().block_mut_inner(block)
949 }
950
951 #[allow(clippy::expect_used)]
952 fn block_mut_inner(mut self, block: usize) -> MutMatrixView<'a, T> {
953 let repr = *self.data.repr();
954 assert!(block < repr.full_blocks());
955 let offset = repr.block_offset(block);
956 let stride = repr.block_stride();
957 let pncols = repr.padded_ncols();
958 let data: &mut [T] = unsafe {
960 std::slice::from_raw_parts_mut(
961 self.data.as_raw_mut_ptr().cast::<T>().add(offset),
962 stride,
963 )
964 };
965 MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
966 .expect("base data should have been sized correctly")
967 }
968
969 #[allow(clippy::expect_used)]
972 pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
973 self.reborrow_mut().remainder_block_mut_inner()
974 }
975
976 #[allow(clippy::expect_used)]
977 fn remainder_block_mut_inner(mut self) -> Option<MutMatrixView<'a, T>> {
978 let repr = *self.data.repr();
979 if repr.remainder() == 0 {
980 None
981 } else {
982 let offset = repr.block_offset(repr.full_blocks());
983 let stride = repr.block_stride();
984 let pncols = repr.padded_ncols();
985 let data: &mut [T] = unsafe {
987 std::slice::from_raw_parts_mut(
988 self.data.as_raw_mut_ptr().cast::<T>().add(offset),
989 stride,
990 )
991 };
992 Some(
993 MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
994 .expect("base data should have been sized correctly"),
995 )
996 }
997 }
998
999 #[inline]
1001 pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
1002 self.data.get_row_mut(i)
1003 }
1004
1005 fn reborrow_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
1008 BlockTransposedMut {
1009 data: self.data.reborrow_mut(),
1010 }
1011 }
1012}
1013
1014impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
1017 pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
1019 BlockTransposedRef {
1020 data: self.data.as_view(),
1021 }
1022 }
1023
1024 pub fn as_view_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
1026 BlockTransposedMut {
1027 data: self.data.as_view_mut(),
1028 }
1029 }
1030
1031 delegate_to_ref!(pub fn nrows(&self) -> usize);
1034 delegate_to_ref!(pub fn ncols(&self) -> usize);
1035 delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
1036 delegate_to_ref!(pub fn full_blocks(&self) -> usize);
1037 delegate_to_ref!(pub fn num_blocks(&self) -> usize);
1038 delegate_to_ref!(pub fn remainder(&self) -> usize);
1039 delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
1040 delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
1041 delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
1042 delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
1043 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
1044 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
1045 delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
1046
1047 pub const fn group_size(&self) -> usize {
1049 GROUP
1050 }
1051
1052 pub const fn const_group_size() -> usize {
1054 GROUP
1055 }
1056
1057 pub const fn pack_size(&self) -> usize {
1059 PACK
1060 }
1061
1062 #[inline]
1064 pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
1065 self.data.get_row(i)
1066 }
1067
1068 #[inline]
1072 pub fn as_mut_slice(&mut self) -> &mut [T] {
1073 self.as_view_mut().mut_slice_inner()
1074 }
1075
1076 #[allow(clippy::expect_used)]
1078 pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
1079 self.as_view_mut().block_mut_inner(block)
1080 }
1081
1082 #[allow(clippy::expect_used)]
1084 pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
1085 self.as_view_mut().remainder_block_mut_inner()
1086 }
1087
1088 #[inline]
1090 pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
1091 self.data.get_row_mut(i)
1092 }
1093}
1094
1095impl<'this, T: Copy, const GROUP: usize, const PACK: usize> Reborrow<'this>
1098 for BlockTransposed<T, GROUP, PACK>
1099{
1100 type Target = BlockTransposedRef<'this, T, GROUP, PACK>;
1101
1102 #[inline]
1103 fn reborrow(&'this self) -> Self::Target {
1104 self.as_view()
1105 }
1106}
1107
1108impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
1111 #[allow(clippy::expect_used)]
1117 pub fn new(nrows: usize, ncols: usize) -> Self {
1118 let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)
1119 .expect("dimensions should not overflow");
1120 Self {
1121 data: Mat::new(repr, Defaulted).expect("infallible"),
1122 }
1123 }
1124
1125 pub fn try_new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
1127 let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)?;
1128 Ok(Self {
1129 data: Mat::new(repr, Defaulted).expect("infallible"),
1130 })
1131 }
1132
1133 pub fn from_strided(v: StridedView<'_, T>) -> Self {
1145 let nrows = v.nrows();
1146 let ncols = v.ncols();
1147 let mut mat = Self::new(nrows, ncols);
1148
1149 let repr = *mat.data.repr();
1150 let num_blocks = repr.num_blocks();
1151 let pncols = repr.padded_ncols();
1152 let num_col_groups = pncols / PACK;
1153
1154 let mut dst = mat.data.as_raw_mut_ptr().cast::<T>();
1158 for block in 0..num_blocks {
1159 let row_base = block * GROUP;
1160 for cg in 0..num_col_groups {
1161 let col_base = cg * PACK;
1162 for rib in 0..GROUP {
1163 let row = row_base + rib;
1164 if row < nrows {
1165 let src_row = unsafe { v.get_row_unchecked(row) };
1167 for p in 0..PACK {
1168 let col = col_base + p;
1169 if col < ncols {
1170 unsafe { *dst = *src_row.get_unchecked(col) };
1176 }
1177 dst = unsafe { dst.add(1) };
1181 }
1182 } else {
1183 dst = unsafe { dst.add(PACK) };
1186 }
1187 }
1188 }
1189 }
1190
1191 mat
1192 }
1193
1194 pub fn from_matrix_view(v: MatrixView<'_, T>) -> Self {
1196 Self::from_strided(v.into())
1197 }
1198}
1199
1200impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<(usize, usize)>
1205 for BlockTransposed<T, GROUP, PACK>
1206{
1207 type Output = T;
1208
1209 #[inline]
1210 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
1211 assert!(row < self.nrows());
1212 assert!(col < self.ncols());
1213 let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
1214 unsafe { &*self.as_ptr().add(idx) }
1216 }
1217}
1218
1219#[cfg(test)]
1224mod tests {
1225 use diskann_utils::{lazy_format, views::Matrix};
1240
1241 use super::*;
1242 use crate::utils::div_round_up;
1243
1244 fn gen_f32(i: usize) -> f32 {
1250 (i + 1) as f32
1251 }
1252 fn gen_i32(i: usize) -> i32 {
1253 (i + 1) as i32
1254 }
1255 fn gen_u8(i: usize) -> u8 {
1256 ((i % 255) + 1) as u8
1257 }
1258
1259 fn test_full_api<
1272 T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1273 const GROUP: usize,
1274 const PACK: usize,
1275 >(
1276 nrows: usize,
1277 ncols: usize,
1278 gen_element: fn(usize) -> T,
1279 ) {
1280 let context = lazy_format!(
1281 "T={}, GROUP={}, PACK={}, nrows={}, ncols={}",
1282 std::any::type_name::<T>(),
1283 GROUP,
1284 PACK,
1285 nrows,
1286 ncols,
1287 );
1288
1289 let mut data = Matrix::new(T::default(), nrows, ncols);
1292 data.as_mut_slice()
1293 .iter_mut()
1294 .enumerate()
1295 .for_each(|(i, d)| *d = gen_element(i));
1296
1297 let mut transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1298
1299 let expected_padded = div_round_up(ncols, PACK) * PACK;
1300 let expected_remainder = nrows % GROUP;
1301 let storage_len = transpose.as_slice().len();
1302
1303 assert_eq!(transpose.nrows(), nrows, "{}", context);
1306 assert_eq!(transpose.ncols(), ncols, "{}", context);
1307 assert_eq!(transpose.group_size(), GROUP, "{}", context);
1308 assert_eq!(
1309 BlockTransposed::<T, GROUP, PACK>::const_group_size(),
1310 GROUP,
1311 "{}",
1312 context
1313 );
1314 assert_eq!(transpose.pack_size(), PACK, "{}", context);
1315 assert_eq!(transpose.full_blocks(), nrows / GROUP, "{}", context);
1316 assert_eq!(
1317 transpose.num_blocks(),
1318 div_round_up(nrows, GROUP),
1319 "{}",
1320 context,
1321 );
1322 assert_eq!(transpose.remainder(), expected_remainder, "{}", context);
1323 assert_eq!(transpose.padded_ncols(), expected_padded, "{}", context);
1324
1325 for row in 0..nrows {
1328 for col in 0..ncols {
1329 assert_eq!(
1330 data[(row, col)],
1331 transpose[(row, col)],
1332 "Index at ({}, {}) -- {}",
1333 row,
1334 col,
1335 context,
1336 );
1337 assert_eq!(
1338 data[(row, col)],
1339 transpose.get_element(row, col),
1340 "get_element at ({}, {}) -- {}",
1341 row,
1342 col,
1343 context,
1344 );
1345 }
1346 }
1347
1348 let view = transpose.as_view();
1351 for row in 0..nrows {
1352 let row_view = view.get_row(row).unwrap();
1353 assert_eq!(row_view.len(), ncols, "{}", context);
1354 assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1355 for col in 0..ncols {
1356 assert_eq!(
1357 data[(row, col)],
1358 row_view[col],
1359 "row view at ({}, {}) -- {}",
1360 row,
1361 col,
1362 context,
1363 );
1364 }
1365 if ncols > 0 {
1367 assert_eq!(row_view.get(0), Some(&data[(row, 0)]), "{}", context);
1368 }
1369 assert_eq!(row_view.get(ncols), None, "{}", context);
1370
1371 let iter = row_view.iter();
1373 assert_eq!(iter.len(), ncols, "{}", context);
1374 let (lo, hi) = iter.size_hint();
1375 assert_eq!(lo, ncols, "{}", context);
1376 assert_eq!(hi, Some(ncols), "{}", context);
1377
1378 let collected: Vec<T> = row_view.iter().collect();
1379 assert_eq!(collected.len(), ncols, "{}", context);
1380 for col in 0..ncols {
1381 assert_eq!(data[(row, col)], collected[col], "{}", context);
1382 }
1383 }
1384 assert!(view.get_row(nrows).is_none(), "{}", context);
1386 let _ = view;
1387
1388 {
1391 let view = transpose.as_view();
1392 assert_eq!(view.nrows(), nrows, "{}", context);
1393 assert_eq!(view.ncols(), ncols, "{}", context);
1394 assert_eq!(view.padded_ncols(), expected_padded, "{}", context);
1395 assert_eq!(view.group_size(), GROUP, "{}", context);
1396 assert_eq!(
1397 BlockTransposedRef::<T, GROUP, PACK>::const_group_size(),
1398 GROUP,
1399 );
1400 assert_eq!(view.pack_size(), PACK, "{}", context);
1401 assert_eq!(view.full_blocks(), nrows / GROUP, "{}", context);
1402 assert_eq!(view.num_blocks(), div_round_up(nrows, GROUP), "{}", context,);
1403 assert_eq!(view.remainder(), expected_remainder, "{}", context);
1404 assert_eq!(view.as_ptr(), transpose.as_ptr(), "{}", context);
1405 assert_eq!(view.as_slice(), transpose.as_slice(), "{}", context);
1406
1407 for row in 0..nrows {
1408 for col in 0..ncols {
1409 assert_eq!(
1410 data[(row, col)],
1411 view.get_element(row, col),
1412 "Ref get_element at ({}, {}) -- {}",
1413 row,
1414 col,
1415 context,
1416 );
1417 }
1418 let row_view = view.get_row(row).unwrap();
1419 for col in 0..ncols {
1420 assert_eq!(data[(row, col)], row_view[col], "{}", context);
1421 }
1422 }
1423 assert!(view.get_row(nrows).is_none(), "{}", context);
1424 }
1425
1426 let expected_ptr = transpose.as_ptr();
1429 {
1430 let mut_view = transpose.as_view_mut();
1431 assert_eq!(mut_view.nrows(), nrows, "{}", context);
1432 assert_eq!(mut_view.ncols(), ncols, "{}", context);
1433 assert_eq!(mut_view.padded_ncols(), expected_padded, "{}", context);
1434 assert_eq!(mut_view.group_size(), GROUP, "{}", context);
1435 assert_eq!(
1436 BlockTransposedMut::<T, GROUP, PACK>::const_group_size(),
1437 GROUP,
1438 );
1439 assert_eq!(mut_view.pack_size(), PACK, "{}", context);
1440 assert_eq!(mut_view.full_blocks(), nrows / GROUP, "{}", context);
1441 assert_eq!(
1442 mut_view.num_blocks(),
1443 div_round_up(nrows, GROUP),
1444 "{}",
1445 context,
1446 );
1447 assert_eq!(mut_view.remainder(), expected_remainder, "{}", context);
1448 assert_eq!(mut_view.as_ptr(), expected_ptr, "{}", context);
1449 assert_eq!(mut_view.as_slice().len(), storage_len, "{}", context);
1450
1451 for row in 0..nrows {
1452 for col in 0..ncols {
1453 assert_eq!(
1454 data[(row, col)],
1455 mut_view.get_element(row, col),
1456 "Mut get_element at ({}, {}) -- {}",
1457 row,
1458 col,
1459 context,
1460 );
1461 }
1462 let row_view = mut_view.get_row(row).unwrap();
1463 for col in 0..ncols {
1464 assert_eq!(data[(row, col)], row_view[col], "{}", context);
1465 }
1466 }
1467 assert!(mut_view.get_row(nrows).is_none(), "{}", context);
1468 }
1469
1470 {
1473 let mut_view = transpose.as_view_mut();
1474 let ref_from_mut = mut_view.as_view();
1475 assert_eq!(ref_from_mut.nrows(), nrows, "{}", context);
1476 for row in 0..nrows {
1477 for col in 0..ncols {
1478 assert_eq!(
1479 data[(row, col)],
1480 ref_from_mut.get_element(row, col),
1481 "{}",
1482 context,
1483 );
1484 }
1485 }
1486 }
1487
1488 {
1492 let mut mut_view = transpose.as_view_mut();
1493 assert_eq!(mut_view.as_mut_slice().len(), storage_len, "{}", context);
1494 }
1495 assert_eq!(transpose.as_mut_slice().len(), storage_len, "{}", context);
1497
1498 let expected_block_nrows = expected_padded / PACK;
1501 let expected_block_ncols = GROUP * PACK;
1502
1503 for b in 0..transpose.full_blocks() {
1504 let block_data: Vec<T>;
1505 let ptr: *const T;
1506 {
1507 let block = transpose.block(b);
1508 assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1509 assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1510
1511 ptr = unsafe { transpose.block_ptr_unchecked(b) };
1513 assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1514
1515 block_data = block.as_slice().to_vec();
1516 }
1517
1518 {
1520 let view = transpose.as_view();
1521 assert_eq!(view.block(b).as_slice(), &block_data[..], "{}", context);
1522 assert_eq!(unsafe { view.block_ptr_unchecked(b) }, ptr, "{}", context);
1524 }
1525
1526 {
1528 let mut_view = transpose.as_view_mut();
1529 assert_eq!(mut_view.block(b).as_slice(), &block_data[..], "{}", context);
1530 assert_eq!(
1531 unsafe { mut_view.block_ptr_unchecked(b) },
1533 ptr,
1534 "{}",
1535 context,
1536 );
1537 }
1538 }
1539
1540 if expected_remainder != 0 {
1542 let remainder_data: Vec<T>;
1543 let ptr: *const T;
1544 let fb = transpose.full_blocks();
1545 {
1546 let block = transpose.remainder_block().unwrap();
1547 assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1548 assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1549
1550 ptr = unsafe { transpose.block_ptr_unchecked(fb) };
1552 assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1553
1554 remainder_data = block.as_slice().to_vec();
1555 }
1556
1557 {
1559 let view = transpose.as_view();
1560 let ref_block = view.remainder_block().unwrap();
1561 assert_eq!(ref_block.as_slice(), &remainder_data[..], "{}", context);
1562 }
1563 {
1565 let mut_view = transpose.as_view_mut();
1566 let mut_block = mut_view.remainder_block().unwrap();
1567 assert_eq!(mut_block.as_slice(), &remainder_data[..], "{}", context);
1568 }
1569 } else {
1570 assert!(transpose.remainder_block().is_none(), "{}", context);
1571 {
1572 let view = transpose.as_view();
1573 assert!(view.remainder_block().is_none(), "{}", context);
1574 }
1575 {
1576 let mut_view = transpose.as_view_mut();
1577 assert!(mut_view.remainder_block().is_none(), "{}", context);
1578 }
1579 }
1580
1581 {
1584 let mut mut_view = transpose.as_view_mut();
1585 for b in 0..mut_view.full_blocks() {
1586 let block_mut = mut_view.block_mut(b);
1587 assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1588 assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1589 }
1590 if expected_remainder != 0 {
1591 let rem = mut_view.remainder_block_mut().unwrap();
1592 assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1593 assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1594 } else {
1595 assert!(mut_view.remainder_block_mut().is_none(), "{}", context);
1596 }
1597 }
1598
1599 for b in 0..transpose.full_blocks() {
1601 let block_mut = transpose.block_mut(b);
1602 assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1603 assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1604 }
1605 if expected_remainder != 0 {
1606 let rem = transpose.remainder_block_mut().unwrap();
1607 assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1608 assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1609 } else {
1610 assert!(transpose.remainder_block_mut().is_none(), "{}", context);
1611 }
1612
1613 {
1616 let mut mut_view = transpose.as_view_mut();
1617 for row in 0..nrows {
1618 let row_view = mut_view.get_row_mut(row).unwrap();
1619 assert_eq!(row_view.len(), ncols, "{}", context);
1620 assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1621 for col in 0..ncols {
1622 assert_eq!(data[(row, col)], row_view[col], "{}", context);
1623 }
1624 }
1625 assert!(mut_view.get_row_mut(nrows).is_none(), "{}", context);
1626 }
1627
1628 if nrows > 0 && ncols > 0 {
1631 {
1633 let view = transpose.as_view();
1634 let row = view.get_row(0).unwrap();
1635 assert_eq!(row.get(ncols), None, "{}", context);
1636 assert_eq!(row.get(usize::MAX), None, "{}", context);
1637 }
1638
1639 let row = transpose.get_row_mut(0).unwrap();
1641 assert_eq!(row.get(ncols), None, "{}", context);
1642
1643 let mut row = transpose.get_row_mut(0).unwrap();
1645 let sentinel = gen_element(usize::MAX / 2);
1646 let original = row[0];
1647 if let Some(v) = row.get_mut(0) {
1648 *v = sentinel;
1649 }
1650 assert_eq!(row.get_mut(ncols), None, "{}", context);
1651 let _ = row;
1653 assert_eq!(transpose.get_element(0, 0), sentinel, "{}", context);
1654 transpose.get_row_mut(0).unwrap().set(0, original);
1656 }
1657
1658 for b in 0..transpose.full_blocks() {
1661 transpose.block_mut(b).as_mut_slice().fill(T::default());
1662 }
1663 if transpose.remainder() != 0 {
1664 transpose
1665 .remainder_block_mut()
1666 .unwrap()
1667 .as_mut_slice()
1668 .fill(T::default());
1669 }
1670 assert!(
1671 transpose.as_slice().iter().all(|v| *v == T::default()),
1672 "not fully zeroed -- {}",
1673 context,
1674 );
1675
1676 let transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1679 let raw = transpose.as_slice();
1680
1681 for row in 0..nrows {
1683 for col in ncols..expected_padded {
1684 let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1685 assert_eq!(
1686 raw[idx],
1687 T::default(),
1688 "col padding at ({}, {}) -- {}",
1689 row,
1690 col,
1691 context,
1692 );
1693 }
1694 }
1695
1696 let padded_nrows = nrows.next_multiple_of(GROUP);
1698 for row in nrows..padded_nrows {
1699 for col in 0..expected_padded {
1700 let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1701 assert_eq!(
1702 raw[idx],
1703 T::default(),
1704 "row padding at ({}, {}) -- {}",
1705 row,
1706 col,
1707 context,
1708 );
1709 }
1710 }
1711
1712 assert_eq!(
1715 transpose.as_view().padded_nrows(),
1716 padded_nrows,
1717 "padded_nrows() mismatch -- {}",
1718 context,
1719 );
1720
1721 if nrows > 0 && ncols > 0 {
1724 let via_matrix = BlockTransposed::<T, GROUP, PACK>::from_matrix_view(data.as_view());
1725 assert_eq!(via_matrix.as_slice(), transpose.as_slice(), "{}", context);
1726 }
1727 }
1728
1729 #[test]
1734 fn test_api_pack1_group16() {
1735 let rows: Vec<usize> = if cfg!(miri) {
1738 vec![0, 1, 15, 16, 17, 33]
1739 } else {
1740 (0..128).collect()
1741 };
1742 let cols: Vec<usize> = if cfg!(miri) {
1743 vec![0, 1, 2]
1744 } else {
1745 (0..5).collect()
1746 };
1747 for &nrows in &rows {
1748 for &ncols in &cols {
1749 test_full_api::<f32, 16, 1>(nrows, ncols, gen_f32);
1750 }
1751 }
1752 }
1753
1754 #[test]
1755 fn test_api_pack1_group8() {
1756 let rows: Vec<usize> = if cfg!(miri) {
1759 vec![0, 1, 7, 8, 9, 17]
1760 } else {
1761 (0..128).collect()
1762 };
1763 let cols: Vec<usize> = if cfg!(miri) {
1764 vec![0, 1, 2]
1765 } else {
1766 (0..5).collect()
1767 };
1768 for &nrows in &rows {
1769 for &ncols in &cols {
1770 test_full_api::<f32, 8, 1>(nrows, ncols, gen_f32);
1771 }
1772 }
1773 }
1774
1775 #[test]
1776 fn test_api_pack2() {
1777 let rows: Vec<usize> = if cfg!(miri) {
1780 vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1781 } else {
1782 (0..48).collect()
1783 };
1784 let cols: Vec<usize> = if cfg!(miri) {
1785 vec![0, 1, 2, 3, 4, 5]
1786 } else {
1787 (0..9).collect()
1788 };
1789 for &nrows in &rows {
1790 for &ncols in &cols {
1791 test_full_api::<f32, 4, 2>(nrows, ncols, gen_f32);
1792 test_full_api::<f32, 8, 2>(nrows, ncols, gen_f32);
1793 test_full_api::<f32, 16, 2>(nrows, ncols, gen_f32);
1794 }
1795 }
1796 }
1797
1798 #[test]
1799 fn test_api_pack4() {
1800 let rows: Vec<usize> = if cfg!(miri) {
1803 vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1804 } else {
1805 (0..48).collect()
1806 };
1807 let cols: Vec<usize> = if cfg!(miri) {
1808 vec![0, 1, 3, 4, 5, 8]
1809 } else {
1810 (0..9).collect()
1811 };
1812 for &nrows in &rows {
1813 for &ncols in &cols {
1814 test_full_api::<f32, 4, 4>(nrows, ncols, gen_f32);
1815 test_full_api::<f32, 8, 4>(nrows, ncols, gen_f32);
1816 test_full_api::<f32, 16, 4>(nrows, ncols, gen_f32);
1817 }
1818 }
1819 }
1820
1821 #[test]
1823 fn test_api_non_f32() {
1824 test_full_api::<i32, 4, 1>(10, 7, gen_i32);
1826 test_full_api::<i32, 8, 2>(12, 5, gen_i32);
1827
1828 test_full_api::<u8, 4, 2>(12, 5, gen_u8);
1830 test_full_api::<u8, 8, 1>(10, 7, gen_u8);
1831 }
1832
1833 fn test_block_layout_pack1<
1840 T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1841 const GROUP: usize,
1842 >(
1843 nrows: usize,
1844 ncols: usize,
1845 gen_element: fn(usize) -> T,
1846 ) {
1847 let mut data = Matrix::new(T::default(), nrows, ncols);
1848 data.as_mut_slice()
1849 .iter_mut()
1850 .enumerate()
1851 .for_each(|(i, d)| *d = gen_element(i));
1852
1853 let transpose = BlockTransposed::<T, GROUP, 1>::from_strided(data.as_view().into());
1854
1855 for b in 0..transpose.full_blocks() {
1857 let block = transpose.block(b);
1858 for i in 0..block.nrows() {
1859 for j in 0..block.ncols() {
1860 assert_eq!(
1861 block[(i, j)],
1862 data[(GROUP * b + j, i)],
1863 "block {} at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1864 b,
1865 i,
1866 j,
1867 GROUP,
1868 nrows,
1869 ncols,
1870 );
1871 }
1872 }
1873 }
1874
1875 if transpose.remainder() != 0 {
1877 let fb = transpose.full_blocks();
1878 let block = transpose.remainder_block().unwrap();
1879 for i in 0..block.nrows() {
1880 for j in 0..transpose.remainder() {
1881 assert_eq!(
1882 block[(i, j)],
1883 data[(GROUP * fb + j, i)],
1884 "remainder at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1885 i,
1886 j,
1887 GROUP,
1888 nrows,
1889 ncols,
1890 );
1891 }
1892 }
1893 }
1894 }
1895
1896 #[test]
1897 fn test_block_layout_pack1_group16() {
1898 let rows: Vec<usize> = if cfg!(miri) {
1899 vec![0, 1, 15, 16, 17, 33]
1900 } else {
1901 (0..128).collect()
1902 };
1903 let cols: Vec<usize> = if cfg!(miri) {
1904 vec![0, 1, 2]
1905 } else {
1906 (0..5).collect()
1907 };
1908 for &nrows in &rows {
1909 for &ncols in &cols {
1910 test_block_layout_pack1::<f32, 16>(nrows, ncols, gen_f32);
1911 }
1912 }
1913 }
1914
1915 #[test]
1916 fn test_block_layout_pack1_group8() {
1917 let rows: Vec<usize> = if cfg!(miri) {
1918 vec![0, 1, 7, 8, 9, 17]
1919 } else {
1920 (0..128).collect()
1921 };
1922 let cols: Vec<usize> = if cfg!(miri) {
1923 vec![0, 1, 2]
1924 } else {
1925 (0..5).collect()
1926 };
1927 for &nrows in &rows {
1928 for &ncols in &cols {
1929 test_block_layout_pack1::<f32, 8>(nrows, ncols, gen_f32);
1930 }
1931 }
1932 }
1933
1934 #[test]
1941 fn test_row_view_send_sync() {
1942 fn assert_send<T: Send>() {}
1943 fn assert_sync<T: Sync>() {}
1944
1945 assert_send::<Row<'_, f32, 16>>();
1946 assert_sync::<Row<'_, f32, 16>>();
1947 assert_send::<Row<'_, u8, 8, 2>>();
1948 assert_sync::<Row<'_, u8, 8, 2>>();
1949
1950 assert_send::<RowMut<'_, f32, 16>>();
1951 assert_sync::<RowMut<'_, f32, 16>>();
1952 assert_send::<RowMut<'_, i32, 4, 4>>();
1953 assert_sync::<RowMut<'_, i32, 4, 4>>();
1954 }
1955
1956 #[test]
1959 fn test_new_ref_and_new_mut() {
1960 let nrows = 5;
1961 let ncols = 3;
1962 let repr = BlockTransposedRepr::<f32, 4>::new(nrows, ncols).unwrap();
1963
1964 let mat = BlockTransposed::<f32, 4>::new(nrows, ncols);
1965 let raw: &[f32] = mat.as_slice();
1966
1967 let mat_ref = BlockTransposedRef {
1968 data: repr.new_ref(raw).unwrap(),
1969 };
1970 assert_eq!(mat_ref.nrows(), nrows);
1971 assert_eq!(mat_ref.ncols(), ncols);
1972 for row in 0..nrows {
1973 for col in 0..ncols {
1974 assert_eq!(mat_ref.get_element(row, col), mat.get_element(row, col));
1975 }
1976 }
1977
1978 let mut buf = raw.to_vec();
1979 let mat_mut = BlockTransposedMut {
1980 data: repr.new_mut(&mut buf).unwrap(),
1981 };
1982 assert_eq!(mat_mut.nrows(), nrows);
1983 assert_eq!(mat_mut.ncols(), ncols);
1984
1985 let mut short = vec![0.0_f32; 2];
1987 assert!(repr.new_ref(&short).is_err());
1988 assert!(repr.new_mut(&mut short).is_err());
1989 }
1990
1991 #[test]
1994 fn test_row_view_empty() {
1995 fn check_empty<const GROUP: usize, const PACK: usize>() {
1998 let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(4, 0);
1999
2000 let view = mat.as_view();
2002 for i in 0..4 {
2003 let row = view.get_row(i).unwrap();
2004 assert!(row.is_empty());
2005 assert_eq!(row.len(), 0);
2006 assert_eq!(row.iter().count(), 0);
2007 }
2008
2009 for i in 0..4 {
2011 let row = mat.get_row_mut(i).unwrap();
2012 assert!(row.is_empty());
2013 assert_eq!(row.len(), 0);
2014 }
2015 }
2016
2017 check_empty::<16, 1>(); check_empty::<4, 2>(); check_empty::<4, 4>(); }
2021
2022 #[test]
2025 #[should_panic(expected = "column index 3 out of bounds")]
2026 fn test_row_view_index_oob() {
2027 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2028 let view = mat.as_view();
2029 let row = view.get_row(0).unwrap();
2030 let _ = row[3];
2031 }
2032
2033 #[test]
2034 #[should_panic(expected = "column index 3 out of bounds")]
2035 fn test_row_view_mut_index_oob() {
2036 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2037 let row = mat.get_row_mut(0).unwrap();
2038 let _ = row[3];
2039 }
2040
2041 #[test]
2042 #[should_panic(expected = "column index 3 out of bounds")]
2043 fn test_row_view_mut_index_mut_oob() {
2044 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2045 let mut row = mat.get_row_mut(0).unwrap();
2046 row[3] = 1.0;
2047 }
2048
2049 #[test]
2050 #[should_panic(expected = "column index 3 out of bounds")]
2051 fn test_row_view_set_oob() {
2052 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2053 let mut row = mat.get_row_mut(0).unwrap();
2054 row.set(3, 1.0);
2055 }
2056
2057 #[test]
2058 #[should_panic(expected = "row 4 out of bounds")]
2059 fn test_get_element_row_oob() {
2060 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2061 mat.get_element(4, 0);
2062 }
2063
2064 #[test]
2065 #[should_panic(expected = "col 3 out of bounds")]
2066 fn test_get_element_col_oob() {
2067 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2068 mat.get_element(0, 3);
2069 }
2070
2071 #[test]
2072 #[should_panic(expected = "assertion failed")]
2073 fn test_index_tuple_row_oob() {
2074 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2075 let _ = mat[(4, 0)];
2076 }
2077
2078 #[test]
2079 #[should_panic(expected = "assertion failed")]
2080 fn test_index_tuple_col_oob() {
2081 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2082 let _ = mat[(0, 3)];
2083 }
2084
2085 #[test]
2086 #[should_panic]
2087 fn test_block_oob() {
2088 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2089 let _ = mat.block(1);
2090 }
2091
2092 #[test]
2093 #[should_panic]
2094 fn test_block_mut_oob() {
2095 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2096 let _ = mat.block_mut(1);
2097 }
2098
2099 #[test]
2102 fn test_from_strided_nonunit_stride() {
2103 use diskann_utils::strided::StridedView;
2104
2105 const GROUP: usize = 4;
2106 const PACK: usize = 2;
2107 let nrows = 5;
2108 let ncols = 3;
2109 let cstride = 8;
2110
2111 let required_len = (nrows - 1) * cstride + ncols;
2112 let mut flat = vec![0.0_f32; required_len];
2113 for row in 0..nrows {
2114 for col in 0..ncols {
2115 flat[row * cstride + col] = (row * 100 + col + 1) as f32;
2116 }
2117 }
2118
2119 let strided = StridedView::try_shrink_from(&flat, nrows, ncols, cstride)
2120 .expect("should construct strided view");
2121 let transpose = BlockTransposed::<f32, GROUP, PACK>::from_strided(strided);
2122
2123 assert_eq!(transpose.nrows(), nrows);
2124 assert_eq!(transpose.ncols(), ncols);
2125
2126 for row in 0..nrows {
2127 for col in 0..ncols {
2128 let expected = (row * 100 + col + 1) as f32;
2129 assert_eq!(
2130 transpose[(row, col)],
2131 expected,
2132 "mismatch at ({}, {})",
2133 row,
2134 col,
2135 );
2136 }
2137 }
2138
2139 let padded_ncols = ncols.next_multiple_of(PACK);
2140 let raw: &[f32] = transpose.as_slice();
2141 for row in 0..nrows {
2142 for col in ncols..padded_ncols {
2143 let idx = linear_index::<GROUP, PACK>(row, col, ncols);
2144 assert_eq!(
2145 raw[idx], 0.0,
2146 "column-padding at ({}, {}) should be zero",
2147 row, col,
2148 );
2149 }
2150 }
2151 }
2152
2153 #[test]
2156 fn test_concurrent_row_mutation() {
2157 const GROUP: usize = 8;
2158 const PACK: usize = 2;
2159
2160 let (nrows, ncols, num_threads) = if cfg!(miri) { (8, 4, 2) } else { (64, 16, 4) };
2161
2162 let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(nrows, ncols);
2163 let rows: Vec<RowMut<'_, f32, GROUP, PACK>> = mat.data.rows_mut().collect();
2164 let rows_per_thread = nrows / num_threads;
2165 let mut rows = rows.into_boxed_slice();
2166
2167 std::thread::scope(|s| {
2168 let mut remaining = &mut rows[..];
2169 for thread_id in 0..num_threads {
2170 let chunk_len = if thread_id == num_threads - 1 {
2171 remaining.len()
2172 } else {
2173 rows_per_thread
2174 };
2175 let (chunk, rest) = remaining.split_at_mut(chunk_len);
2176 remaining = rest;
2177 let start_row = thread_id * rows_per_thread;
2178
2179 s.spawn(move || {
2180 for (offset, row_view) in chunk.iter_mut().enumerate() {
2181 let row = start_row + offset;
2182 for col in 0..ncols {
2183 let value = (thread_id * 10000 + row * 100 + col) as f32;
2184 row_view.set(col, value);
2185 }
2186 }
2187 });
2188 }
2189 });
2190
2191 for row in 0..nrows {
2192 let thread_id = (row / rows_per_thread).min(num_threads - 1);
2193 for col in 0..ncols {
2194 let expected = (thread_id * 10000 + row * 100 + col) as f32;
2195 assert_eq!(
2196 mat.get_element(row, col),
2197 expected,
2198 "mismatch at ({}, {})",
2199 row,
2200 col,
2201 );
2202 }
2203 }
2204 }
2205}