1use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};
81
82use diskann_utils::{
83 ReborrowMut,
84 strided::StridedView,
85 views::{MatrixView, MutMatrixView},
86};
87
88use super::matrix::{
89 Defaulted, LayoutError, Mat, MatMut, MatRef, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut,
90 ReprOwned, SliceError,
91};
92use crate::bits::{AsMutPtr, AsPtr, MutSlicePtr, SlicePtr};
93use crate::utils;
94
95#[inline]
97fn padded_ncols<const PACK: usize>(ncols: usize) -> usize {
98 ncols.next_multiple_of(PACK)
99}
100
101#[inline]
111fn compute_capacity<const GROUP: usize, const PACK: usize>(nrows: usize, ncols: usize) -> usize {
112 nrows.next_multiple_of(GROUP) * padded_ncols::<PACK>(ncols)
113}
114
115#[inline]
119fn checked_compute_capacity<const GROUP: usize, const PACK: usize>(
120 nrows: usize,
121 ncols: usize,
122) -> Option<usize> {
123 nrows
124 .checked_next_multiple_of(GROUP)?
125 .checked_mul(ncols.checked_next_multiple_of(PACK)?)
126}
127
128#[inline]
131fn linear_index<const GROUP: usize, const PACK: usize>(
132 row: usize,
133 col: usize,
134 ncols: usize,
135) -> usize {
136 let pncols = padded_ncols::<PACK>(ncols);
137 let block = row / GROUP;
138 let row_in_block = row % GROUP;
139 block * GROUP * pncols + (col / PACK) * GROUP * PACK + row_in_block * PACK + (col % PACK)
140}
141
142#[inline]
147fn col_offset<const GROUP: usize, const PACK: usize>(col: usize) -> usize {
148 (col / PACK) * GROUP * PACK + (col % PACK)
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub(crate) struct BlockTransposedRepr<T, const GROUP: usize, const PACK: usize = 1> {
157 nrows: usize,
158 ncols: usize,
159 _elem: PhantomData<T>,
160}
161
162impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROUP, PACK> {
163 const _ASSERTIONS: () = {
165 assert!(GROUP > 0, "group size GROUP must be positive");
166 assert!(PACK > 0, "packing factor PACK must be positive");
167 assert!(
168 GROUP.is_multiple_of(PACK),
169 "GROUP must be divisible by PACK"
170 );
171 };
172
173 pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
178 let () = Self::_ASSERTIONS;
179 let capacity = checked_compute_capacity::<GROUP, PACK>(nrows, ncols)
180 .ok_or_else(|| Overflow::for_type::<T>(nrows, ncols))?;
181 Overflow::check_byte_budget::<T>(capacity, nrows, ncols)?;
182 Ok(Self {
183 nrows,
184 ncols,
185 _elem: PhantomData,
186 })
187 }
188
189 #[inline]
193 fn storage_len(&self) -> usize {
194 compute_capacity::<GROUP, PACK>(self.nrows, self.ncols)
195 }
196
197 #[inline]
199 fn nrows(&self) -> usize {
200 self.nrows
201 }
202
203 #[inline]
205 pub fn ncols(&self) -> usize {
206 self.ncols
207 }
208
209 #[inline]
212 pub fn padded_ncols(&self) -> usize {
213 padded_ncols::<PACK>(self.ncols)
214 }
215
216 #[inline]
218 pub fn full_blocks(&self) -> usize {
219 self.nrows / GROUP
220 }
221
222 #[inline]
224 pub fn num_blocks(&self) -> usize {
225 self.nrows.div_ceil(GROUP)
226 }
227
228 #[inline]
230 pub fn remainder(&self) -> usize {
231 self.nrows % GROUP
232 }
233
234 #[inline]
236 fn block_stride(&self) -> usize {
237 GROUP * self.padded_ncols()
238 }
239
240 #[inline]
242 fn block_offset(&self, block: usize) -> usize {
243 block * self.block_stride()
244 }
245
246 fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
248 let cap = self.storage_len();
249 if slice.len() != cap {
250 Err(SliceError::LengthMismatch {
251 expected: cap,
252 found: slice.len(),
253 })
254 } else {
255 Ok(())
256 }
257 }
258
259 unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
265 debug_assert_eq!(b.len(), self.storage_len(), "safety contract violated");
266
267 let ptr = utils::box_into_nonnull(b).cast::<u8>();
268
269 unsafe { Mat::from_raw_parts(self, ptr) }
271 }
272}
273
274#[derive(Debug, Clone, Copy)]
283pub struct Row<'a, T, const GROUP: usize, const PACK: usize = 1> {
284 base: SlicePtr<'a, T>,
286 ncols: usize,
287}
288
289impl<T: Copy, const GROUP: usize, const PACK: usize> Row<'_, T, GROUP, PACK> {
290 #[inline]
292 pub fn len(&self) -> usize {
293 self.ncols
294 }
295
296 #[inline]
298 pub fn is_empty(&self) -> bool {
299 self.ncols == 0
300 }
301
302 #[inline]
304 pub fn get(&self, col: usize) -> Option<&T> {
305 if col < self.ncols {
306 Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
308 } else {
309 None
310 }
311 }
312
313 #[inline]
315 pub fn iter(&self) -> RowIter<'_, T, GROUP, PACK> {
316 RowIter {
317 base: self.base,
318 col: 0,
319 ncols: self.ncols,
320 }
321 }
322}
323
324impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
325 for Row<'_, T, GROUP, PACK>
326{
327 type Output = T;
328
329 #[inline]
330 #[allow(clippy::panic)] fn index(&self, col: usize) -> &Self::Output {
332 self.get(col)
333 .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
334 }
335}
336
337#[derive(Debug, Clone)]
339pub struct RowIter<'a, T, const GROUP: usize, const PACK: usize = 1> {
340 base: SlicePtr<'a, T>,
341 col: usize,
342 ncols: usize,
343}
344
345impl<T: Copy, const GROUP: usize, const PACK: usize> Iterator for RowIter<'_, T, GROUP, PACK> {
346 type Item = T;
347
348 #[inline]
349 fn next(&mut self) -> Option<Self::Item> {
350 if self.col >= self.ncols {
351 return None;
352 }
353 let val = unsafe { *self.base.as_ptr().add(col_offset::<GROUP, PACK>(self.col)) };
355 self.col += 1;
356 Some(val)
357 }
358
359 #[inline]
360 fn size_hint(&self) -> (usize, Option<usize>) {
361 let remaining = self.ncols - self.col;
362 (remaining, Some(remaining))
363 }
364}
365
366impl<T: Copy, const GROUP: usize, const PACK: usize> ExactSizeIterator
367 for RowIter<'_, T, GROUP, PACK>
368{
369}
370impl<T: Copy, const GROUP: usize, const PACK: usize> std::iter::FusedIterator
371 for RowIter<'_, T, GROUP, PACK>
372{
373}
374
375#[derive(Debug)]
377pub struct RowMut<'a, T, const GROUP: usize, const PACK: usize = 1> {
378 base: MutSlicePtr<'a, T>,
379 ncols: usize,
380}
381
382impl<T: Copy, const GROUP: usize, const PACK: usize> RowMut<'_, T, GROUP, PACK> {
383 #[inline]
385 pub fn len(&self) -> usize {
386 self.ncols
387 }
388
389 #[inline]
391 pub fn is_empty(&self) -> bool {
392 self.ncols == 0
393 }
394
395 #[inline]
397 pub fn get(&self, col: usize) -> Option<&T> {
398 if col < self.ncols {
399 Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
401 } else {
402 None
403 }
404 }
405
406 #[inline]
408 pub fn get_mut(&mut self, col: usize) -> Option<&mut T> {
409 if col < self.ncols {
410 Some(unsafe { &mut *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) })
412 } else {
413 None
414 }
415 }
416
417 #[inline]
423 pub fn set(&mut self, col: usize, value: T) {
424 assert!(
425 col < self.ncols,
426 "column index {col} out of bounds (ncols = {})",
427 self.ncols
428 );
429 unsafe { *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) = value };
431 }
432}
433
434impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
435 for RowMut<'_, T, GROUP, PACK>
436{
437 type Output = T;
438
439 #[inline]
440 #[allow(clippy::panic)] fn index(&self, col: usize) -> &Self::Output {
442 self.get(col)
443 .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
444 }
445}
446
447impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::IndexMut<usize>
448 for RowMut<'_, T, GROUP, PACK>
449{
450 #[inline]
451 #[allow(clippy::panic)] fn index_mut(&mut self, col: usize) -> &mut Self::Output {
453 let ncols = self.ncols;
454 self.get_mut(col)
455 .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {ncols})"))
456 }
457}
458
459unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> Repr
466 for BlockTransposedRepr<T, GROUP, PACK>
467{
468 type Row<'a>
469 = Row<'a, T, GROUP, PACK>
470 where
471 Self: 'a;
472
473 fn nrows(&self) -> usize {
474 self.nrows
475 }
476
477 fn layout(&self) -> Result<Layout, LayoutError> {
478 Ok(Layout::array::<T>(self.storage_len())?)
479 }
480
481 unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
482 debug_assert!(i < self.nrows);
483
484 if self.ncols == 0 {
487 return Row {
488 base: unsafe { SlicePtr::new_unchecked(NonNull::dangling()) },
491 ncols: 0,
492 };
493 }
494
495 let base_ptr = ptr.as_ptr().cast::<T>();
496 let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
497
498 let row_base = unsafe { base_ptr.add(offset) };
501
502 Row {
503 base: unsafe { SlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
506 ncols: self.ncols,
507 }
508 }
509}
510
511unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprMut
515 for BlockTransposedRepr<T, GROUP, PACK>
516{
517 type RowMut<'a>
518 = RowMut<'a, T, GROUP, PACK>
519 where
520 Self: 'a;
521
522 unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
523 debug_assert!(i < self.nrows);
524
525 if self.ncols == 0 {
528 return RowMut {
529 base: unsafe { MutSlicePtr::new_unchecked(NonNull::dangling()) },
532 ncols: 0,
533 };
534 }
535
536 let base_ptr = ptr.as_ptr().cast::<T>();
537 let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
538
539 let row_base = unsafe { base_ptr.add(offset) };
542
543 RowMut {
544 base: unsafe { MutSlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
547 ncols: self.ncols,
548 }
549 }
550}
551
552unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprOwned
555 for BlockTransposedRepr<T, GROUP, PACK>
556{
557 unsafe fn drop(self, ptr: NonNull<u8>) {
558 unsafe {
560 let slice_ptr =
561 std::ptr::slice_from_raw_parts_mut(ptr.cast::<T>().as_ptr(), self.storage_len());
562 let _ = Box::from_raw(slice_ptr);
563 }
564 }
565}
566
567unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewOwned<T>
573 for BlockTransposedRepr<T, GROUP, PACK>
574{
575 type Error = crate::error::Infallible;
576
577 fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
578 let b: Box<[T]> = vec![value; self.storage_len()].into_boxed_slice();
579
580 Ok(unsafe { self.box_to_mat(b) })
582 }
583}
584
585unsafe impl<T: Copy + Default, const GROUP: usize, const PACK: usize> NewOwned<Defaulted>
587 for BlockTransposedRepr<T, GROUP, PACK>
588{
589 type Error = crate::error::Infallible;
590
591 fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
592 self.new_owned(T::default())
593 }
594}
595
596unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewRef<T>
598 for BlockTransposedRepr<T, GROUP, PACK>
599{
600 type Error = SliceError;
601
602 fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
603 self.check_slice(data)?;
604
605 Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
607 }
608}
609
610unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewMut<T>
612 for BlockTransposedRepr<T, GROUP, PACK>
613{
614 type Error = SliceError;
615
616 fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
617 self.check_slice(data)?;
618
619 Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
621 }
622}
623
624macro_rules! delegate_to_ref {
633 ($(#[$m:meta])* $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
635 #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
636 $(#[$m])*
637 #[inline]
638 $vis fn $name(&self $(, $a: $t)*) $(-> $r)? {
639 self.as_view().$name($($a),*)
640 }
641 };
642 ($(#[$m:meta])* unsafe $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
644 #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
645 $(#[$m])*
646 #[inline]
647 $vis unsafe fn $name(&self $(, $a: $t)*) $(-> $r)? {
648 unsafe { self.as_view().$name($($a),*) }
650 }
651 };
652}
653
654#[derive(Debug)]
672pub struct BlockTransposed<T: Copy, const GROUP: usize, const PACK: usize = 1> {
673 data: Mat<BlockTransposedRepr<T, GROUP, PACK>>,
674}
675
676#[derive(Debug, Clone, Copy)]
680pub struct BlockTransposedRef<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
681 data: MatRef<'a, BlockTransposedRepr<T, GROUP, PACK>>,
682}
683
684pub struct BlockTransposedMut<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
688 data: MatMut<'a, BlockTransposedRepr<T, GROUP, PACK>>,
689}
690
691impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, T, GROUP, PACK> {
694 #[inline]
696 pub fn nrows(&self) -> usize {
697 self.data.repr().nrows()
698 }
699
700 #[inline]
702 pub fn ncols(&self) -> usize {
703 self.data.repr().ncols()
704 }
705
706 #[inline]
708 pub fn padded_ncols(&self) -> usize {
709 self.data.repr().padded_ncols()
710 }
711
712 pub const fn group_size(&self) -> usize {
714 GROUP
715 }
716
717 pub const fn const_group_size() -> usize {
719 GROUP
720 }
721
722 pub const fn pack_size(&self) -> usize {
724 PACK
725 }
726
727 #[inline]
729 pub fn full_blocks(&self) -> usize {
730 self.data.repr().full_blocks()
731 }
732
733 #[inline]
735 pub fn num_blocks(&self) -> usize {
736 self.data.repr().num_blocks()
737 }
738
739 #[inline]
742 pub fn remainder(&self) -> usize {
743 self.data.repr().remainder()
744 }
745
746 #[inline]
748 pub fn as_ptr(&self) -> *const T {
749 self.data.as_raw_ptr().cast::<T>()
750 }
751
752 #[inline]
757 pub fn as_slice(&self) -> &'a [T] {
758 let len = self.data.repr().storage_len();
759 unsafe { std::slice::from_raw_parts(self.as_ptr(), len) }
761 }
762
763 #[inline]
775 pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const T {
776 debug_assert!(block < self.num_blocks());
777 unsafe { self.as_ptr().add(self.data.repr().block_offset(block)) }
779 }
780
781 #[allow(clippy::expect_used)]
791 pub fn block(&self, block: usize) -> MatrixView<'a, T> {
792 assert!(block < self.full_blocks());
793 let offset = self.data.repr().block_offset(block);
794 let stride = self.data.repr().block_stride();
795 let data: &[T] = unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
798 MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
799 .expect("base data should have been sized correctly")
800 }
801
802 #[allow(clippy::expect_used)]
808 pub fn remainder_block(&self) -> Option<MatrixView<'a, T>> {
809 if self.remainder() == 0 {
810 None
811 } else {
812 let offset = self.data.repr().block_offset(self.full_blocks());
813 let stride = self.data.repr().block_stride();
814 let data: &[T] =
817 unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
818 Some(
819 MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
820 .expect("base data should have been sized correctly"),
821 )
822 }
823 }
824
825 #[inline]
831 pub fn get_element(&self, row: usize, col: usize) -> T {
832 assert!(
833 row < self.nrows(),
834 "row {row} out of bounds (nrows = {})",
835 self.nrows()
836 );
837 assert!(
838 col < self.ncols(),
839 "col {col} out of bounds (ncols = {})",
840 self.ncols()
841 );
842 let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
843 unsafe { *self.as_ptr().add(idx) }
845 }
846
847 #[inline]
849 pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
850 self.data.get_row(i)
851 }
852}
853
854impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, T, GROUP, PACK> {
857 #[inline]
859 pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
860 BlockTransposedRef {
861 data: self.data.as_view(),
862 }
863 }
864
865 delegate_to_ref!(pub fn nrows(&self) -> usize);
868 delegate_to_ref!(pub fn ncols(&self) -> usize);
869 delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
870 delegate_to_ref!(pub fn full_blocks(&self) -> usize);
871 delegate_to_ref!(pub fn num_blocks(&self) -> usize);
872 delegate_to_ref!(pub fn remainder(&self) -> usize);
873 delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
874 delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
875 delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
876 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
877 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
878 delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
879
880 pub const fn group_size(&self) -> usize {
882 GROUP
883 }
884
885 pub const fn const_group_size() -> usize {
887 GROUP
888 }
889
890 pub const fn pack_size(&self) -> usize {
892 PACK
893 }
894
895 #[inline]
897 pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
898 self.data.get_row(i)
899 }
900
901 #[inline]
912 pub fn as_mut_slice(&mut self) -> &mut [T] {
913 self.reborrow_mut().mut_slice_inner()
914 }
915
916 fn mut_slice_inner(mut self) -> &'a mut [T] {
917 let len = self.data.repr().storage_len();
918 unsafe { std::slice::from_raw_parts_mut(self.data.as_raw_mut_ptr().cast::<T>(), len) }
920 }
921
922 #[allow(clippy::expect_used)]
928 pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
929 self.reborrow_mut().block_mut_inner(block)
930 }
931
932 #[allow(clippy::expect_used)]
933 fn block_mut_inner(mut self, block: usize) -> MutMatrixView<'a, T> {
934 let repr = *self.data.repr();
935 assert!(block < repr.full_blocks());
936 let offset = repr.block_offset(block);
937 let stride = repr.block_stride();
938 let pncols = repr.padded_ncols();
939 let data: &mut [T] = unsafe {
941 std::slice::from_raw_parts_mut(
942 self.data.as_raw_mut_ptr().cast::<T>().add(offset),
943 stride,
944 )
945 };
946 MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
947 .expect("base data should have been sized correctly")
948 }
949
950 #[allow(clippy::expect_used)]
953 pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
954 self.reborrow_mut().remainder_block_mut_inner()
955 }
956
957 #[allow(clippy::expect_used)]
958 fn remainder_block_mut_inner(mut self) -> Option<MutMatrixView<'a, T>> {
959 let repr = *self.data.repr();
960 if repr.remainder() == 0 {
961 None
962 } else {
963 let offset = repr.block_offset(repr.full_blocks());
964 let stride = repr.block_stride();
965 let pncols = repr.padded_ncols();
966 let data: &mut [T] = unsafe {
968 std::slice::from_raw_parts_mut(
969 self.data.as_raw_mut_ptr().cast::<T>().add(offset),
970 stride,
971 )
972 };
973 Some(
974 MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
975 .expect("base data should have been sized correctly"),
976 )
977 }
978 }
979
980 #[inline]
982 pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
983 self.data.get_row_mut(i)
984 }
985
986 fn reborrow_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
989 BlockTransposedMut {
990 data: self.data.reborrow_mut(),
991 }
992 }
993}
994
995impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
998 pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
1000 BlockTransposedRef {
1001 data: self.data.as_view(),
1002 }
1003 }
1004
1005 pub fn as_view_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
1007 BlockTransposedMut {
1008 data: self.data.as_view_mut(),
1009 }
1010 }
1011
1012 delegate_to_ref!(pub fn nrows(&self) -> usize);
1015 delegate_to_ref!(pub fn ncols(&self) -> usize);
1016 delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
1017 delegate_to_ref!(pub fn full_blocks(&self) -> usize);
1018 delegate_to_ref!(pub fn num_blocks(&self) -> usize);
1019 delegate_to_ref!(pub fn remainder(&self) -> usize);
1020 delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
1021 delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
1022 delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
1023 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
1024 delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
1025 delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
1026
1027 pub const fn group_size(&self) -> usize {
1029 GROUP
1030 }
1031
1032 pub const fn const_group_size() -> usize {
1034 GROUP
1035 }
1036
1037 pub const fn pack_size(&self) -> usize {
1039 PACK
1040 }
1041
1042 #[inline]
1044 pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
1045 self.data.get_row(i)
1046 }
1047
1048 #[inline]
1052 pub fn as_mut_slice(&mut self) -> &mut [T] {
1053 self.as_view_mut().mut_slice_inner()
1054 }
1055
1056 #[allow(clippy::expect_used)]
1058 pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
1059 self.as_view_mut().block_mut_inner(block)
1060 }
1061
1062 #[allow(clippy::expect_used)]
1064 pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
1065 self.as_view_mut().remainder_block_mut_inner()
1066 }
1067
1068 #[inline]
1070 pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
1071 self.data.get_row_mut(i)
1072 }
1073}
1074
1075impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
1078 #[allow(clippy::expect_used)]
1084 pub fn new(nrows: usize, ncols: usize) -> Self {
1085 let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)
1086 .expect("dimensions should not overflow");
1087 Self {
1088 data: Mat::new(repr, Defaulted).expect("infallible"),
1089 }
1090 }
1091
1092 pub fn try_new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
1094 let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)?;
1095 Ok(Self {
1096 data: Mat::new(repr, Defaulted).expect("infallible"),
1097 })
1098 }
1099
1100 pub fn from_strided(v: StridedView<'_, T>) -> Self {
1112 let nrows = v.nrows();
1113 let ncols = v.ncols();
1114 let mut mat = Self::new(nrows, ncols);
1115
1116 let repr = *mat.data.repr();
1117 let num_blocks = repr.num_blocks();
1118 let pncols = repr.padded_ncols();
1119 let num_col_groups = pncols / PACK;
1120
1121 let mut dst = mat.data.as_raw_mut_ptr().cast::<T>();
1125 for block in 0..num_blocks {
1126 let row_base = block * GROUP;
1127 for cg in 0..num_col_groups {
1128 let col_base = cg * PACK;
1129 for rib in 0..GROUP {
1130 let row = row_base + rib;
1131 if row < nrows {
1132 let src_row = unsafe { v.get_row_unchecked(row) };
1134 for p in 0..PACK {
1135 let col = col_base + p;
1136 if col < ncols {
1137 unsafe { *dst = *src_row.get_unchecked(col) };
1143 }
1144 dst = unsafe { dst.add(1) };
1148 }
1149 } else {
1150 dst = unsafe { dst.add(PACK) };
1153 }
1154 }
1155 }
1156 }
1157
1158 mat
1159 }
1160
1161 pub fn from_matrix_view(v: MatrixView<'_, T>) -> Self {
1163 Self::from_strided(v.into())
1164 }
1165}
1166
1167impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<(usize, usize)>
1172 for BlockTransposed<T, GROUP, PACK>
1173{
1174 type Output = T;
1175
1176 #[inline]
1177 fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
1178 assert!(row < self.nrows());
1179 assert!(col < self.ncols());
1180 let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
1181 unsafe { &*self.as_ptr().add(idx) }
1183 }
1184}
1185
1186#[cfg(test)]
1191mod tests {
1192 use diskann_utils::{lazy_format, views::Matrix};
1207
1208 use super::*;
1209 use crate::utils::div_round_up;
1210
1211 fn gen_f32(i: usize) -> f32 {
1217 (i + 1) as f32
1218 }
1219 fn gen_i32(i: usize) -> i32 {
1220 (i + 1) as i32
1221 }
1222 fn gen_u8(i: usize) -> u8 {
1223 ((i % 255) + 1) as u8
1224 }
1225
1226 fn test_full_api<
1239 T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1240 const GROUP: usize,
1241 const PACK: usize,
1242 >(
1243 nrows: usize,
1244 ncols: usize,
1245 gen_element: fn(usize) -> T,
1246 ) {
1247 let context = lazy_format!(
1248 "T={}, GROUP={}, PACK={}, nrows={}, ncols={}",
1249 std::any::type_name::<T>(),
1250 GROUP,
1251 PACK,
1252 nrows,
1253 ncols,
1254 );
1255
1256 let mut data = Matrix::new(T::default(), nrows, ncols);
1259 data.as_mut_slice()
1260 .iter_mut()
1261 .enumerate()
1262 .for_each(|(i, d)| *d = gen_element(i));
1263
1264 let mut transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1265
1266 let expected_padded = div_round_up(ncols, PACK) * PACK;
1267 let expected_remainder = nrows % GROUP;
1268 let storage_len = transpose.as_slice().len();
1269
1270 assert_eq!(transpose.nrows(), nrows, "{}", context);
1273 assert_eq!(transpose.ncols(), ncols, "{}", context);
1274 assert_eq!(transpose.group_size(), GROUP, "{}", context);
1275 assert_eq!(
1276 BlockTransposed::<T, GROUP, PACK>::const_group_size(),
1277 GROUP,
1278 "{}",
1279 context
1280 );
1281 assert_eq!(transpose.pack_size(), PACK, "{}", context);
1282 assert_eq!(transpose.full_blocks(), nrows / GROUP, "{}", context);
1283 assert_eq!(
1284 transpose.num_blocks(),
1285 div_round_up(nrows, GROUP),
1286 "{}",
1287 context,
1288 );
1289 assert_eq!(transpose.remainder(), expected_remainder, "{}", context);
1290 assert_eq!(transpose.padded_ncols(), expected_padded, "{}", context);
1291
1292 for row in 0..nrows {
1295 for col in 0..ncols {
1296 assert_eq!(
1297 data[(row, col)],
1298 transpose[(row, col)],
1299 "Index at ({}, {}) -- {}",
1300 row,
1301 col,
1302 context,
1303 );
1304 assert_eq!(
1305 data[(row, col)],
1306 transpose.get_element(row, col),
1307 "get_element at ({}, {}) -- {}",
1308 row,
1309 col,
1310 context,
1311 );
1312 }
1313 }
1314
1315 let view = transpose.as_view();
1318 for row in 0..nrows {
1319 let row_view = view.get_row(row).unwrap();
1320 assert_eq!(row_view.len(), ncols, "{}", context);
1321 assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1322 for col in 0..ncols {
1323 assert_eq!(
1324 data[(row, col)],
1325 row_view[col],
1326 "row view at ({}, {}) -- {}",
1327 row,
1328 col,
1329 context,
1330 );
1331 }
1332 if ncols > 0 {
1334 assert_eq!(row_view.get(0), Some(&data[(row, 0)]), "{}", context);
1335 }
1336 assert_eq!(row_view.get(ncols), None, "{}", context);
1337
1338 let iter = row_view.iter();
1340 assert_eq!(iter.len(), ncols, "{}", context);
1341 let (lo, hi) = iter.size_hint();
1342 assert_eq!(lo, ncols, "{}", context);
1343 assert_eq!(hi, Some(ncols), "{}", context);
1344
1345 let collected: Vec<T> = row_view.iter().collect();
1346 assert_eq!(collected.len(), ncols, "{}", context);
1347 for col in 0..ncols {
1348 assert_eq!(data[(row, col)], collected[col], "{}", context);
1349 }
1350 }
1351 assert!(view.get_row(nrows).is_none(), "{}", context);
1353 let _ = view;
1354
1355 {
1358 let view = transpose.as_view();
1359 assert_eq!(view.nrows(), nrows, "{}", context);
1360 assert_eq!(view.ncols(), ncols, "{}", context);
1361 assert_eq!(view.padded_ncols(), expected_padded, "{}", context);
1362 assert_eq!(view.group_size(), GROUP, "{}", context);
1363 assert_eq!(
1364 BlockTransposedRef::<T, GROUP, PACK>::const_group_size(),
1365 GROUP,
1366 );
1367 assert_eq!(view.pack_size(), PACK, "{}", context);
1368 assert_eq!(view.full_blocks(), nrows / GROUP, "{}", context);
1369 assert_eq!(view.num_blocks(), div_round_up(nrows, GROUP), "{}", context,);
1370 assert_eq!(view.remainder(), expected_remainder, "{}", context);
1371 assert_eq!(view.as_ptr(), transpose.as_ptr(), "{}", context);
1372 assert_eq!(view.as_slice(), transpose.as_slice(), "{}", context);
1373
1374 for row in 0..nrows {
1375 for col in 0..ncols {
1376 assert_eq!(
1377 data[(row, col)],
1378 view.get_element(row, col),
1379 "Ref get_element at ({}, {}) -- {}",
1380 row,
1381 col,
1382 context,
1383 );
1384 }
1385 let row_view = view.get_row(row).unwrap();
1386 for col in 0..ncols {
1387 assert_eq!(data[(row, col)], row_view[col], "{}", context);
1388 }
1389 }
1390 assert!(view.get_row(nrows).is_none(), "{}", context);
1391 }
1392
1393 let expected_ptr = transpose.as_ptr();
1396 {
1397 let mut_view = transpose.as_view_mut();
1398 assert_eq!(mut_view.nrows(), nrows, "{}", context);
1399 assert_eq!(mut_view.ncols(), ncols, "{}", context);
1400 assert_eq!(mut_view.padded_ncols(), expected_padded, "{}", context);
1401 assert_eq!(mut_view.group_size(), GROUP, "{}", context);
1402 assert_eq!(
1403 BlockTransposedMut::<T, GROUP, PACK>::const_group_size(),
1404 GROUP,
1405 );
1406 assert_eq!(mut_view.pack_size(), PACK, "{}", context);
1407 assert_eq!(mut_view.full_blocks(), nrows / GROUP, "{}", context);
1408 assert_eq!(
1409 mut_view.num_blocks(),
1410 div_round_up(nrows, GROUP),
1411 "{}",
1412 context,
1413 );
1414 assert_eq!(mut_view.remainder(), expected_remainder, "{}", context);
1415 assert_eq!(mut_view.as_ptr(), expected_ptr, "{}", context);
1416 assert_eq!(mut_view.as_slice().len(), storage_len, "{}", context);
1417
1418 for row in 0..nrows {
1419 for col in 0..ncols {
1420 assert_eq!(
1421 data[(row, col)],
1422 mut_view.get_element(row, col),
1423 "Mut get_element at ({}, {}) -- {}",
1424 row,
1425 col,
1426 context,
1427 );
1428 }
1429 let row_view = mut_view.get_row(row).unwrap();
1430 for col in 0..ncols {
1431 assert_eq!(data[(row, col)], row_view[col], "{}", context);
1432 }
1433 }
1434 assert!(mut_view.get_row(nrows).is_none(), "{}", context);
1435 }
1436
1437 {
1440 let mut_view = transpose.as_view_mut();
1441 let ref_from_mut = mut_view.as_view();
1442 assert_eq!(ref_from_mut.nrows(), nrows, "{}", context);
1443 for row in 0..nrows {
1444 for col in 0..ncols {
1445 assert_eq!(
1446 data[(row, col)],
1447 ref_from_mut.get_element(row, col),
1448 "{}",
1449 context,
1450 );
1451 }
1452 }
1453 }
1454
1455 {
1459 let mut mut_view = transpose.as_view_mut();
1460 assert_eq!(mut_view.as_mut_slice().len(), storage_len, "{}", context);
1461 }
1462 assert_eq!(transpose.as_mut_slice().len(), storage_len, "{}", context);
1464
1465 let expected_block_nrows = expected_padded / PACK;
1468 let expected_block_ncols = GROUP * PACK;
1469
1470 for b in 0..transpose.full_blocks() {
1471 let block_data: Vec<T>;
1472 let ptr: *const T;
1473 {
1474 let block = transpose.block(b);
1475 assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1476 assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1477
1478 ptr = unsafe { transpose.block_ptr_unchecked(b) };
1480 assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1481
1482 block_data = block.as_slice().to_vec();
1483 }
1484
1485 {
1487 let view = transpose.as_view();
1488 assert_eq!(view.block(b).as_slice(), &block_data[..], "{}", context);
1489 assert_eq!(unsafe { view.block_ptr_unchecked(b) }, ptr, "{}", context);
1491 }
1492
1493 {
1495 let mut_view = transpose.as_view_mut();
1496 assert_eq!(mut_view.block(b).as_slice(), &block_data[..], "{}", context);
1497 assert_eq!(
1498 unsafe { mut_view.block_ptr_unchecked(b) },
1500 ptr,
1501 "{}",
1502 context,
1503 );
1504 }
1505 }
1506
1507 if expected_remainder != 0 {
1509 let remainder_data: Vec<T>;
1510 let ptr: *const T;
1511 let fb = transpose.full_blocks();
1512 {
1513 let block = transpose.remainder_block().unwrap();
1514 assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1515 assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1516
1517 ptr = unsafe { transpose.block_ptr_unchecked(fb) };
1519 assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1520
1521 remainder_data = block.as_slice().to_vec();
1522 }
1523
1524 {
1526 let view = transpose.as_view();
1527 let ref_block = view.remainder_block().unwrap();
1528 assert_eq!(ref_block.as_slice(), &remainder_data[..], "{}", context);
1529 }
1530 {
1532 let mut_view = transpose.as_view_mut();
1533 let mut_block = mut_view.remainder_block().unwrap();
1534 assert_eq!(mut_block.as_slice(), &remainder_data[..], "{}", context);
1535 }
1536 } else {
1537 assert!(transpose.remainder_block().is_none(), "{}", context);
1538 {
1539 let view = transpose.as_view();
1540 assert!(view.remainder_block().is_none(), "{}", context);
1541 }
1542 {
1543 let mut_view = transpose.as_view_mut();
1544 assert!(mut_view.remainder_block().is_none(), "{}", context);
1545 }
1546 }
1547
1548 {
1551 let mut mut_view = transpose.as_view_mut();
1552 for b in 0..mut_view.full_blocks() {
1553 let block_mut = mut_view.block_mut(b);
1554 assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1555 assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1556 }
1557 if expected_remainder != 0 {
1558 let rem = mut_view.remainder_block_mut().unwrap();
1559 assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1560 assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1561 } else {
1562 assert!(mut_view.remainder_block_mut().is_none(), "{}", context);
1563 }
1564 }
1565
1566 for b in 0..transpose.full_blocks() {
1568 let block_mut = transpose.block_mut(b);
1569 assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1570 assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1571 }
1572 if expected_remainder != 0 {
1573 let rem = transpose.remainder_block_mut().unwrap();
1574 assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1575 assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1576 } else {
1577 assert!(transpose.remainder_block_mut().is_none(), "{}", context);
1578 }
1579
1580 {
1583 let mut mut_view = transpose.as_view_mut();
1584 for row in 0..nrows {
1585 let row_view = mut_view.get_row_mut(row).unwrap();
1586 assert_eq!(row_view.len(), ncols, "{}", context);
1587 assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1588 for col in 0..ncols {
1589 assert_eq!(data[(row, col)], row_view[col], "{}", context);
1590 }
1591 }
1592 assert!(mut_view.get_row_mut(nrows).is_none(), "{}", context);
1593 }
1594
1595 if nrows > 0 && ncols > 0 {
1598 {
1600 let view = transpose.as_view();
1601 let row = view.get_row(0).unwrap();
1602 assert_eq!(row.get(ncols), None, "{}", context);
1603 assert_eq!(row.get(usize::MAX), None, "{}", context);
1604 }
1605
1606 let row = transpose.get_row_mut(0).unwrap();
1608 assert_eq!(row.get(ncols), None, "{}", context);
1609
1610 let mut row = transpose.get_row_mut(0).unwrap();
1612 let sentinel = gen_element(usize::MAX / 2);
1613 let original = row[0];
1614 if let Some(v) = row.get_mut(0) {
1615 *v = sentinel;
1616 }
1617 assert_eq!(row.get_mut(ncols), None, "{}", context);
1618 let _ = row;
1620 assert_eq!(transpose.get_element(0, 0), sentinel, "{}", context);
1621 transpose.get_row_mut(0).unwrap().set(0, original);
1623 }
1624
1625 for b in 0..transpose.full_blocks() {
1628 transpose.block_mut(b).as_mut_slice().fill(T::default());
1629 }
1630 if transpose.remainder() != 0 {
1631 transpose
1632 .remainder_block_mut()
1633 .unwrap()
1634 .as_mut_slice()
1635 .fill(T::default());
1636 }
1637 assert!(
1638 transpose.as_slice().iter().all(|v| *v == T::default()),
1639 "not fully zeroed -- {}",
1640 context,
1641 );
1642
1643 let transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1646 let raw = transpose.as_slice();
1647
1648 for row in 0..nrows {
1650 for col in ncols..expected_padded {
1651 let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1652 assert_eq!(
1653 raw[idx],
1654 T::default(),
1655 "col padding at ({}, {}) -- {}",
1656 row,
1657 col,
1658 context,
1659 );
1660 }
1661 }
1662
1663 let padded_nrows = nrows.next_multiple_of(GROUP);
1665 for row in nrows..padded_nrows {
1666 for col in 0..expected_padded {
1667 let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1668 assert_eq!(
1669 raw[idx],
1670 T::default(),
1671 "row padding at ({}, {}) -- {}",
1672 row,
1673 col,
1674 context,
1675 );
1676 }
1677 }
1678
1679 if nrows > 0 && ncols > 0 {
1682 let via_matrix = BlockTransposed::<T, GROUP, PACK>::from_matrix_view(data.as_view());
1683 assert_eq!(via_matrix.as_slice(), transpose.as_slice(), "{}", context);
1684 }
1685 }
1686
1687 #[test]
1692 fn test_api_pack1_group16() {
1693 let rows: Vec<usize> = if cfg!(miri) {
1696 vec![0, 1, 15, 16, 17, 33]
1697 } else {
1698 (0..128).collect()
1699 };
1700 let cols: Vec<usize> = if cfg!(miri) {
1701 vec![0, 1, 2]
1702 } else {
1703 (0..5).collect()
1704 };
1705 for &nrows in &rows {
1706 for &ncols in &cols {
1707 test_full_api::<f32, 16, 1>(nrows, ncols, gen_f32);
1708 }
1709 }
1710 }
1711
1712 #[test]
1713 fn test_api_pack1_group8() {
1714 let rows: Vec<usize> = if cfg!(miri) {
1717 vec![0, 1, 7, 8, 9, 17]
1718 } else {
1719 (0..128).collect()
1720 };
1721 let cols: Vec<usize> = if cfg!(miri) {
1722 vec![0, 1, 2]
1723 } else {
1724 (0..5).collect()
1725 };
1726 for &nrows in &rows {
1727 for &ncols in &cols {
1728 test_full_api::<f32, 8, 1>(nrows, ncols, gen_f32);
1729 }
1730 }
1731 }
1732
1733 #[test]
1734 fn test_api_pack2() {
1735 let rows: Vec<usize> = if cfg!(miri) {
1738 vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1739 } else {
1740 (0..48).collect()
1741 };
1742 let cols: Vec<usize> = if cfg!(miri) {
1743 vec![0, 1, 2, 3, 4, 5]
1744 } else {
1745 (0..9).collect()
1746 };
1747 for &nrows in &rows {
1748 for &ncols in &cols {
1749 test_full_api::<f32, 4, 2>(nrows, ncols, gen_f32);
1750 test_full_api::<f32, 8, 2>(nrows, ncols, gen_f32);
1751 test_full_api::<f32, 16, 2>(nrows, ncols, gen_f32);
1752 }
1753 }
1754 }
1755
1756 #[test]
1757 fn test_api_pack4() {
1758 let rows: Vec<usize> = if cfg!(miri) {
1761 vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1762 } else {
1763 (0..48).collect()
1764 };
1765 let cols: Vec<usize> = if cfg!(miri) {
1766 vec![0, 1, 3, 4, 5, 8]
1767 } else {
1768 (0..9).collect()
1769 };
1770 for &nrows in &rows {
1771 for &ncols in &cols {
1772 test_full_api::<f32, 4, 4>(nrows, ncols, gen_f32);
1773 test_full_api::<f32, 8, 4>(nrows, ncols, gen_f32);
1774 test_full_api::<f32, 16, 4>(nrows, ncols, gen_f32);
1775 }
1776 }
1777 }
1778
1779 #[test]
1781 fn test_api_non_f32() {
1782 test_full_api::<i32, 4, 1>(10, 7, gen_i32);
1784 test_full_api::<i32, 8, 2>(12, 5, gen_i32);
1785
1786 test_full_api::<u8, 4, 2>(12, 5, gen_u8);
1788 test_full_api::<u8, 8, 1>(10, 7, gen_u8);
1789 }
1790
1791 fn test_block_layout_pack1<
1798 T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1799 const GROUP: usize,
1800 >(
1801 nrows: usize,
1802 ncols: usize,
1803 gen_element: fn(usize) -> T,
1804 ) {
1805 let mut data = Matrix::new(T::default(), nrows, ncols);
1806 data.as_mut_slice()
1807 .iter_mut()
1808 .enumerate()
1809 .for_each(|(i, d)| *d = gen_element(i));
1810
1811 let transpose = BlockTransposed::<T, GROUP, 1>::from_strided(data.as_view().into());
1812
1813 for b in 0..transpose.full_blocks() {
1815 let block = transpose.block(b);
1816 for i in 0..block.nrows() {
1817 for j in 0..block.ncols() {
1818 assert_eq!(
1819 block[(i, j)],
1820 data[(GROUP * b + j, i)],
1821 "block {} at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1822 b,
1823 i,
1824 j,
1825 GROUP,
1826 nrows,
1827 ncols,
1828 );
1829 }
1830 }
1831 }
1832
1833 if transpose.remainder() != 0 {
1835 let fb = transpose.full_blocks();
1836 let block = transpose.remainder_block().unwrap();
1837 for i in 0..block.nrows() {
1838 for j in 0..transpose.remainder() {
1839 assert_eq!(
1840 block[(i, j)],
1841 data[(GROUP * fb + j, i)],
1842 "remainder at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1843 i,
1844 j,
1845 GROUP,
1846 nrows,
1847 ncols,
1848 );
1849 }
1850 }
1851 }
1852 }
1853
1854 #[test]
1855 fn test_block_layout_pack1_group16() {
1856 let rows: Vec<usize> = if cfg!(miri) {
1857 vec![0, 1, 15, 16, 17, 33]
1858 } else {
1859 (0..128).collect()
1860 };
1861 let cols: Vec<usize> = if cfg!(miri) {
1862 vec![0, 1, 2]
1863 } else {
1864 (0..5).collect()
1865 };
1866 for &nrows in &rows {
1867 for &ncols in &cols {
1868 test_block_layout_pack1::<f32, 16>(nrows, ncols, gen_f32);
1869 }
1870 }
1871 }
1872
1873 #[test]
1874 fn test_block_layout_pack1_group8() {
1875 let rows: Vec<usize> = if cfg!(miri) {
1876 vec![0, 1, 7, 8, 9, 17]
1877 } else {
1878 (0..128).collect()
1879 };
1880 let cols: Vec<usize> = if cfg!(miri) {
1881 vec![0, 1, 2]
1882 } else {
1883 (0..5).collect()
1884 };
1885 for &nrows in &rows {
1886 for &ncols in &cols {
1887 test_block_layout_pack1::<f32, 8>(nrows, ncols, gen_f32);
1888 }
1889 }
1890 }
1891
1892 #[test]
1899 fn test_row_view_send_sync() {
1900 fn assert_send<T: Send>() {}
1901 fn assert_sync<T: Sync>() {}
1902
1903 assert_send::<Row<'_, f32, 16>>();
1904 assert_sync::<Row<'_, f32, 16>>();
1905 assert_send::<Row<'_, u8, 8, 2>>();
1906 assert_sync::<Row<'_, u8, 8, 2>>();
1907
1908 assert_send::<RowMut<'_, f32, 16>>();
1909 assert_sync::<RowMut<'_, f32, 16>>();
1910 assert_send::<RowMut<'_, i32, 4, 4>>();
1911 assert_sync::<RowMut<'_, i32, 4, 4>>();
1912 }
1913
1914 #[test]
1917 fn test_new_ref_and_new_mut() {
1918 let nrows = 5;
1919 let ncols = 3;
1920 let repr = BlockTransposedRepr::<f32, 4>::new(nrows, ncols).unwrap();
1921
1922 let mat = BlockTransposed::<f32, 4>::new(nrows, ncols);
1923 let raw: &[f32] = mat.as_slice();
1924
1925 let mat_ref = BlockTransposedRef {
1926 data: repr.new_ref(raw).unwrap(),
1927 };
1928 assert_eq!(mat_ref.nrows(), nrows);
1929 assert_eq!(mat_ref.ncols(), ncols);
1930 for row in 0..nrows {
1931 for col in 0..ncols {
1932 assert_eq!(mat_ref.get_element(row, col), mat.get_element(row, col));
1933 }
1934 }
1935
1936 let mut buf = raw.to_vec();
1937 let mat_mut = BlockTransposedMut {
1938 data: repr.new_mut(&mut buf).unwrap(),
1939 };
1940 assert_eq!(mat_mut.nrows(), nrows);
1941 assert_eq!(mat_mut.ncols(), ncols);
1942
1943 let mut short = vec![0.0_f32; 2];
1945 assert!(repr.new_ref(&short).is_err());
1946 assert!(repr.new_mut(&mut short).is_err());
1947 }
1948
1949 #[test]
1952 fn test_row_view_empty() {
1953 fn check_empty<const GROUP: usize, const PACK: usize>() {
1956 let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(4, 0);
1957
1958 let view = mat.as_view();
1960 for i in 0..4 {
1961 let row = view.get_row(i).unwrap();
1962 assert!(row.is_empty());
1963 assert_eq!(row.len(), 0);
1964 assert_eq!(row.iter().count(), 0);
1965 }
1966
1967 for i in 0..4 {
1969 let row = mat.get_row_mut(i).unwrap();
1970 assert!(row.is_empty());
1971 assert_eq!(row.len(), 0);
1972 }
1973 }
1974
1975 check_empty::<16, 1>(); check_empty::<4, 2>(); check_empty::<4, 4>(); }
1979
1980 #[test]
1983 #[should_panic(expected = "column index 3 out of bounds")]
1984 fn test_row_view_index_oob() {
1985 let mat = BlockTransposed::<f32, 4>::new(4, 3);
1986 let view = mat.as_view();
1987 let row = view.get_row(0).unwrap();
1988 let _ = row[3];
1989 }
1990
1991 #[test]
1992 #[should_panic(expected = "column index 3 out of bounds")]
1993 fn test_row_view_mut_index_oob() {
1994 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
1995 let row = mat.get_row_mut(0).unwrap();
1996 let _ = row[3];
1997 }
1998
1999 #[test]
2000 #[should_panic(expected = "column index 3 out of bounds")]
2001 fn test_row_view_mut_index_mut_oob() {
2002 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2003 let mut row = mat.get_row_mut(0).unwrap();
2004 row[3] = 1.0;
2005 }
2006
2007 #[test]
2008 #[should_panic(expected = "column index 3 out of bounds")]
2009 fn test_row_view_set_oob() {
2010 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2011 let mut row = mat.get_row_mut(0).unwrap();
2012 row.set(3, 1.0);
2013 }
2014
2015 #[test]
2016 #[should_panic(expected = "row 4 out of bounds")]
2017 fn test_get_element_row_oob() {
2018 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2019 mat.get_element(4, 0);
2020 }
2021
2022 #[test]
2023 #[should_panic(expected = "col 3 out of bounds")]
2024 fn test_get_element_col_oob() {
2025 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2026 mat.get_element(0, 3);
2027 }
2028
2029 #[test]
2030 #[should_panic(expected = "assertion failed")]
2031 fn test_index_tuple_row_oob() {
2032 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2033 let _ = mat[(4, 0)];
2034 }
2035
2036 #[test]
2037 #[should_panic(expected = "assertion failed")]
2038 fn test_index_tuple_col_oob() {
2039 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2040 let _ = mat[(0, 3)];
2041 }
2042
2043 #[test]
2044 #[should_panic]
2045 fn test_block_oob() {
2046 let mat = BlockTransposed::<f32, 4>::new(4, 3);
2047 let _ = mat.block(1);
2048 }
2049
2050 #[test]
2051 #[should_panic]
2052 fn test_block_mut_oob() {
2053 let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2054 let _ = mat.block_mut(1);
2055 }
2056
2057 #[test]
2060 fn test_from_strided_nonunit_stride() {
2061 use diskann_utils::strided::StridedView;
2062
2063 const GROUP: usize = 4;
2064 const PACK: usize = 2;
2065 let nrows = 5;
2066 let ncols = 3;
2067 let cstride = 8;
2068
2069 let required_len = (nrows - 1) * cstride + ncols;
2070 let mut flat = vec![0.0_f32; required_len];
2071 for row in 0..nrows {
2072 for col in 0..ncols {
2073 flat[row * cstride + col] = (row * 100 + col + 1) as f32;
2074 }
2075 }
2076
2077 let strided = StridedView::try_shrink_from(&flat, nrows, ncols, cstride)
2078 .expect("should construct strided view");
2079 let transpose = BlockTransposed::<f32, GROUP, PACK>::from_strided(strided);
2080
2081 assert_eq!(transpose.nrows(), nrows);
2082 assert_eq!(transpose.ncols(), ncols);
2083
2084 for row in 0..nrows {
2085 for col in 0..ncols {
2086 let expected = (row * 100 + col + 1) as f32;
2087 assert_eq!(
2088 transpose[(row, col)],
2089 expected,
2090 "mismatch at ({}, {})",
2091 row,
2092 col,
2093 );
2094 }
2095 }
2096
2097 let padded_ncols = ncols.next_multiple_of(PACK);
2098 let raw: &[f32] = transpose.as_slice();
2099 for row in 0..nrows {
2100 for col in ncols..padded_ncols {
2101 let idx = linear_index::<GROUP, PACK>(row, col, ncols);
2102 assert_eq!(
2103 raw[idx], 0.0,
2104 "column-padding at ({}, {}) should be zero",
2105 row, col,
2106 );
2107 }
2108 }
2109 }
2110
2111 #[test]
2114 fn test_concurrent_row_mutation() {
2115 const GROUP: usize = 8;
2116 const PACK: usize = 2;
2117
2118 let (nrows, ncols, num_threads) = if cfg!(miri) { (8, 4, 2) } else { (64, 16, 4) };
2119
2120 let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(nrows, ncols);
2121 let rows: Vec<RowMut<'_, f32, GROUP, PACK>> = mat.data.rows_mut().collect();
2122 let rows_per_thread = nrows / num_threads;
2123 let mut rows = rows.into_boxed_slice();
2124
2125 std::thread::scope(|s| {
2126 let mut remaining = &mut rows[..];
2127 for thread_id in 0..num_threads {
2128 let chunk_len = if thread_id == num_threads - 1 {
2129 remaining.len()
2130 } else {
2131 rows_per_thread
2132 };
2133 let (chunk, rest) = remaining.split_at_mut(chunk_len);
2134 remaining = rest;
2135 let start_row = thread_id * rows_per_thread;
2136
2137 s.spawn(move || {
2138 for (offset, row_view) in chunk.iter_mut().enumerate() {
2139 let row = start_row + offset;
2140 for col in 0..ncols {
2141 let value = (thread_id * 10000 + row * 100 + col) as f32;
2142 row_view.set(col, value);
2143 }
2144 }
2145 });
2146 }
2147 });
2148
2149 for row in 0..nrows {
2150 let thread_id = (row / rows_per_thread).min(num_threads - 1);
2151 for col in 0..ncols {
2152 let expected = (thread_id * 10000 + row * 100 + col) as f32;
2153 assert_eq!(
2154 mat.get_element(row, col),
2155 expected,
2156 "mismatch at ({}, {})",
2157 row,
2158 col,
2159 );
2160 }
2161 }
2162 }
2163}