use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};
use diskann_utils::{
Reborrow, ReborrowMut,
strided::StridedView,
views::{MatrixView, MutMatrixView},
};
use super::matrix::{
Defaulted, LayoutError, Mat, MatMut, MatRef, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut,
ReprOwned, SliceError,
};
use crate::bits::{AsMutPtr, AsPtr, MutSlicePtr, SlicePtr};
use crate::utils;
#[inline]
fn padded_ncols<const PACK: usize>(ncols: usize) -> usize {
ncols.next_multiple_of(PACK)
}
#[inline]
fn compute_capacity<const GROUP: usize, const PACK: usize>(nrows: usize, ncols: usize) -> usize {
nrows.next_multiple_of(GROUP) * padded_ncols::<PACK>(ncols)
}
#[inline]
fn checked_compute_capacity<const GROUP: usize, const PACK: usize>(
nrows: usize,
ncols: usize,
) -> Option<usize> {
nrows
.checked_next_multiple_of(GROUP)?
.checked_mul(ncols.checked_next_multiple_of(PACK)?)
}
#[inline]
fn linear_index<const GROUP: usize, const PACK: usize>(
row: usize,
col: usize,
ncols: usize,
) -> usize {
let pncols = padded_ncols::<PACK>(ncols);
let block = row / GROUP;
let row_in_block = row % GROUP;
block * GROUP * pncols + (col / PACK) * GROUP * PACK + row_in_block * PACK + (col % PACK)
}
#[inline]
fn col_offset<const GROUP: usize, const PACK: usize>(col: usize) -> usize {
(col / PACK) * GROUP * PACK + (col % PACK)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct BlockTransposedRepr<T, const GROUP: usize, const PACK: usize = 1> {
nrows: usize,
ncols: usize,
_elem: PhantomData<T>,
}
impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROUP, PACK> {
const _ASSERTIONS: () = {
assert!(GROUP > 0, "group size GROUP must be positive");
assert!(PACK > 0, "packing factor PACK must be positive");
assert!(
GROUP.is_multiple_of(PACK),
"GROUP must be divisible by PACK"
);
};
pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
let () = Self::_ASSERTIONS;
let capacity = checked_compute_capacity::<GROUP, PACK>(nrows, ncols)
.ok_or_else(|| Overflow::for_type::<T>(nrows, ncols))?;
Overflow::check_byte_budget::<T>(capacity, nrows, ncols)?;
Ok(Self {
nrows,
ncols,
_elem: PhantomData,
})
}
#[inline]
fn storage_len(&self) -> usize {
compute_capacity::<GROUP, PACK>(self.nrows, self.ncols)
}
#[inline]
fn nrows(&self) -> usize {
self.nrows
}
#[inline]
pub fn ncols(&self) -> usize {
self.ncols
}
#[inline]
pub fn padded_ncols(&self) -> usize {
padded_ncols::<PACK>(self.ncols)
}
#[inline]
pub fn full_blocks(&self) -> usize {
self.nrows / GROUP
}
#[inline]
pub fn num_blocks(&self) -> usize {
self.nrows.div_ceil(GROUP)
}
#[inline]
pub fn remainder(&self) -> usize {
self.nrows % GROUP
}
#[inline]
pub fn padded_nrows(&self) -> usize {
self.num_blocks() * GROUP
}
#[inline]
fn block_stride(&self) -> usize {
GROUP * self.padded_ncols()
}
#[inline]
fn block_offset(&self, block: usize) -> usize {
block * self.block_stride()
}
fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
let cap = self.storage_len();
if slice.len() != cap {
Err(SliceError::LengthMismatch {
expected: cap,
found: slice.len(),
})
} else {
Ok(())
}
}
unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
debug_assert_eq!(b.len(), self.storage_len(), "safety contract violated");
let ptr = utils::box_into_nonnull(b).cast::<u8>();
unsafe { Mat::from_raw_parts(self, ptr) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct Row<'a, T, const GROUP: usize, const PACK: usize = 1> {
base: SlicePtr<'a, T>,
ncols: usize,
}
impl<T: Copy, const GROUP: usize, const PACK: usize> Row<'_, T, GROUP, PACK> {
#[inline]
pub fn len(&self) -> usize {
self.ncols
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ncols == 0
}
#[inline]
pub fn get(&self, col: usize) -> Option<&T> {
if col < self.ncols {
Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
} else {
None
}
}
#[inline]
pub fn iter(&self) -> RowIter<'_, T, GROUP, PACK> {
RowIter {
base: self.base,
col: 0,
ncols: self.ncols,
}
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
for Row<'_, T, GROUP, PACK>
{
type Output = T;
#[inline]
#[allow(clippy::panic)] fn index(&self, col: usize) -> &Self::Output {
self.get(col)
.unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
}
}
#[derive(Debug, Clone)]
pub struct RowIter<'a, T, const GROUP: usize, const PACK: usize = 1> {
base: SlicePtr<'a, T>,
col: usize,
ncols: usize,
}
impl<T: Copy, const GROUP: usize, const PACK: usize> Iterator for RowIter<'_, T, GROUP, PACK> {
type Item = T;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.col >= self.ncols {
return None;
}
let val = unsafe { *self.base.as_ptr().add(col_offset::<GROUP, PACK>(self.col)) };
self.col += 1;
Some(val)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.ncols - self.col;
(remaining, Some(remaining))
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> ExactSizeIterator
for RowIter<'_, T, GROUP, PACK>
{
}
impl<T: Copy, const GROUP: usize, const PACK: usize> std::iter::FusedIterator
for RowIter<'_, T, GROUP, PACK>
{
}
#[derive(Debug)]
pub struct RowMut<'a, T, const GROUP: usize, const PACK: usize = 1> {
base: MutSlicePtr<'a, T>,
ncols: usize,
}
impl<T: Copy, const GROUP: usize, const PACK: usize> RowMut<'_, T, GROUP, PACK> {
#[inline]
pub fn len(&self) -> usize {
self.ncols
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ncols == 0
}
#[inline]
pub fn get(&self, col: usize) -> Option<&T> {
if col < self.ncols {
Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
} else {
None
}
}
#[inline]
pub fn get_mut(&mut self, col: usize) -> Option<&mut T> {
if col < self.ncols {
Some(unsafe { &mut *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) })
} else {
None
}
}
#[inline]
pub fn set(&mut self, col: usize, value: T) {
assert!(
col < self.ncols,
"column index {col} out of bounds (ncols = {})",
self.ncols
);
unsafe { *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) = value };
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
for RowMut<'_, T, GROUP, PACK>
{
type Output = T;
#[inline]
#[allow(clippy::panic)] fn index(&self, col: usize) -> &Self::Output {
self.get(col)
.unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::IndexMut<usize>
for RowMut<'_, T, GROUP, PACK>
{
#[inline]
#[allow(clippy::panic)] fn index_mut(&mut self, col: usize) -> &mut Self::Output {
let ncols = self.ncols;
self.get_mut(col)
.unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {ncols})"))
}
}
unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> Repr
for BlockTransposedRepr<T, GROUP, PACK>
{
type Row<'a>
= Row<'a, T, GROUP, PACK>
where
Self: 'a;
fn nrows(&self) -> usize {
self.nrows
}
fn layout(&self) -> Result<Layout, LayoutError> {
Ok(Layout::array::<T>(self.storage_len())?)
}
unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
debug_assert!(i < self.nrows);
if self.ncols == 0 {
return Row {
base: unsafe { SlicePtr::new_unchecked(NonNull::dangling()) },
ncols: 0,
};
}
let base_ptr = ptr.as_ptr().cast::<T>();
let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
let row_base = unsafe { base_ptr.add(offset) };
Row {
base: unsafe { SlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
ncols: self.ncols,
}
}
}
unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprMut
for BlockTransposedRepr<T, GROUP, PACK>
{
type RowMut<'a>
= RowMut<'a, T, GROUP, PACK>
where
Self: 'a;
unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
debug_assert!(i < self.nrows);
if self.ncols == 0 {
return RowMut {
base: unsafe { MutSlicePtr::new_unchecked(NonNull::dangling()) },
ncols: 0,
};
}
let base_ptr = ptr.as_ptr().cast::<T>();
let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
let row_base = unsafe { base_ptr.add(offset) };
RowMut {
base: unsafe { MutSlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
ncols: self.ncols,
}
}
}
unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprOwned
for BlockTransposedRepr<T, GROUP, PACK>
{
unsafe fn drop(self, ptr: NonNull<u8>) {
unsafe {
let slice_ptr =
std::ptr::slice_from_raw_parts_mut(ptr.cast::<T>().as_ptr(), self.storage_len());
let _ = Box::from_raw(slice_ptr);
}
}
}
unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewOwned<T>
for BlockTransposedRepr<T, GROUP, PACK>
{
type Error = crate::error::Infallible;
fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
let b: Box<[T]> = vec![value; self.storage_len()].into_boxed_slice();
Ok(unsafe { self.box_to_mat(b) })
}
}
unsafe impl<T: Copy + Default, const GROUP: usize, const PACK: usize> NewOwned<Defaulted>
for BlockTransposedRepr<T, GROUP, PACK>
{
type Error = crate::error::Infallible;
fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
self.new_owned(T::default())
}
}
unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewRef<T>
for BlockTransposedRepr<T, GROUP, PACK>
{
type Error = SliceError;
fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
self.check_slice(data)?;
Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
}
}
unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewMut<T>
for BlockTransposedRepr<T, GROUP, PACK>
{
type Error = SliceError;
fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
self.check_slice(data)?;
Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
}
}
macro_rules! delegate_to_ref {
($(#[$m:meta])* $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
#[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
$(#[$m])*
#[inline]
$vis fn $name(&self $(, $a: $t)*) $(-> $r)? {
self.as_view().$name($($a),*)
}
};
($(#[$m:meta])* unsafe $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
#[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
$(#[$m])*
#[inline]
$vis unsafe fn $name(&self $(, $a: $t)*) $(-> $r)? {
unsafe { self.as_view().$name($($a),*) }
}
};
}
#[derive(Debug)]
pub struct BlockTransposed<T: Copy, const GROUP: usize, const PACK: usize = 1> {
data: Mat<BlockTransposedRepr<T, GROUP, PACK>>,
}
#[derive(Debug, Clone, Copy)]
pub struct BlockTransposedRef<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
data: MatRef<'a, BlockTransposedRepr<T, GROUP, PACK>>,
}
pub struct BlockTransposedMut<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
data: MatMut<'a, BlockTransposedRepr<T, GROUP, PACK>>,
}
impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, T, GROUP, PACK> {
#[inline]
pub fn nrows(&self) -> usize {
self.data.repr().nrows()
}
#[inline]
pub fn ncols(&self) -> usize {
self.data.repr().ncols()
}
#[inline]
pub fn padded_ncols(&self) -> usize {
self.data.repr().padded_ncols()
}
pub const fn group_size(&self) -> usize {
GROUP
}
pub const fn const_group_size() -> usize {
GROUP
}
pub const fn pack_size(&self) -> usize {
PACK
}
#[inline]
pub fn full_blocks(&self) -> usize {
self.data.repr().full_blocks()
}
#[inline]
pub fn num_blocks(&self) -> usize {
self.data.repr().num_blocks()
}
#[inline]
pub fn remainder(&self) -> usize {
self.data.repr().remainder()
}
#[inline]
pub fn padded_nrows(&self) -> usize {
self.data.repr().padded_nrows()
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.data.as_raw_ptr().cast::<T>()
}
#[inline]
pub fn as_slice(&self) -> &'a [T] {
let len = self.data.repr().storage_len();
unsafe { std::slice::from_raw_parts(self.as_ptr(), len) }
}
#[inline]
pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const T {
debug_assert!(block < self.num_blocks());
unsafe { self.as_ptr().add(self.data.repr().block_offset(block)) }
}
#[allow(clippy::expect_used)]
pub fn block(&self, block: usize) -> MatrixView<'a, T> {
assert!(block < self.full_blocks());
let offset = self.data.repr().block_offset(block);
let stride = self.data.repr().block_stride();
let data: &[T] = unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
.expect("base data should have been sized correctly")
}
#[allow(clippy::expect_used)]
pub fn remainder_block(&self) -> Option<MatrixView<'a, T>> {
if self.remainder() == 0 {
None
} else {
let offset = self.data.repr().block_offset(self.full_blocks());
let stride = self.data.repr().block_stride();
let data: &[T] =
unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
Some(
MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
.expect("base data should have been sized correctly"),
)
}
}
#[inline]
pub fn get_element(&self, row: usize, col: usize) -> T {
assert!(
row < self.nrows(),
"row {row} out of bounds (nrows = {})",
self.nrows()
);
assert!(
col < self.ncols(),
"col {col} out of bounds (ncols = {})",
self.ncols()
);
let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
unsafe { *self.as_ptr().add(idx) }
}
#[inline]
pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
self.data.get_row(i)
}
}
impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, T, GROUP, PACK> {
#[inline]
pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
BlockTransposedRef {
data: self.data.as_view(),
}
}
delegate_to_ref!(pub fn nrows(&self) -> usize);
delegate_to_ref!(pub fn ncols(&self) -> usize);
delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
delegate_to_ref!(pub fn full_blocks(&self) -> usize);
delegate_to_ref!(pub fn num_blocks(&self) -> usize);
delegate_to_ref!(pub fn remainder(&self) -> usize);
delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
pub const fn group_size(&self) -> usize {
GROUP
}
pub const fn const_group_size() -> usize {
GROUP
}
pub const fn pack_size(&self) -> usize {
PACK
}
#[inline]
pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
self.data.get_row(i)
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
self.reborrow_mut().mut_slice_inner()
}
fn mut_slice_inner(mut self) -> &'a mut [T] {
let len = self.data.repr().storage_len();
unsafe { std::slice::from_raw_parts_mut(self.data.as_raw_mut_ptr().cast::<T>(), len) }
}
#[allow(clippy::expect_used)]
pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
self.reborrow_mut().block_mut_inner(block)
}
#[allow(clippy::expect_used)]
fn block_mut_inner(mut self, block: usize) -> MutMatrixView<'a, T> {
let repr = *self.data.repr();
assert!(block < repr.full_blocks());
let offset = repr.block_offset(block);
let stride = repr.block_stride();
let pncols = repr.padded_ncols();
let data: &mut [T] = unsafe {
std::slice::from_raw_parts_mut(
self.data.as_raw_mut_ptr().cast::<T>().add(offset),
stride,
)
};
MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
.expect("base data should have been sized correctly")
}
#[allow(clippy::expect_used)]
pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
self.reborrow_mut().remainder_block_mut_inner()
}
#[allow(clippy::expect_used)]
fn remainder_block_mut_inner(mut self) -> Option<MutMatrixView<'a, T>> {
let repr = *self.data.repr();
if repr.remainder() == 0 {
None
} else {
let offset = repr.block_offset(repr.full_blocks());
let stride = repr.block_stride();
let pncols = repr.padded_ncols();
let data: &mut [T] = unsafe {
std::slice::from_raw_parts_mut(
self.data.as_raw_mut_ptr().cast::<T>().add(offset),
stride,
)
};
Some(
MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
.expect("base data should have been sized correctly"),
)
}
}
#[inline]
pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
self.data.get_row_mut(i)
}
fn reborrow_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
BlockTransposedMut {
data: self.data.reborrow_mut(),
}
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
BlockTransposedRef {
data: self.data.as_view(),
}
}
pub fn as_view_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
BlockTransposedMut {
data: self.data.as_view_mut(),
}
}
delegate_to_ref!(pub fn nrows(&self) -> usize);
delegate_to_ref!(pub fn ncols(&self) -> usize);
delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
delegate_to_ref!(pub fn full_blocks(&self) -> usize);
delegate_to_ref!(pub fn num_blocks(&self) -> usize);
delegate_to_ref!(pub fn remainder(&self) -> usize);
delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
pub const fn group_size(&self) -> usize {
GROUP
}
pub const fn const_group_size() -> usize {
GROUP
}
pub const fn pack_size(&self) -> usize {
PACK
}
#[inline]
pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
self.data.get_row(i)
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
self.as_view_mut().mut_slice_inner()
}
#[allow(clippy::expect_used)]
pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
self.as_view_mut().block_mut_inner(block)
}
#[allow(clippy::expect_used)]
pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
self.as_view_mut().remainder_block_mut_inner()
}
#[inline]
pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
self.data.get_row_mut(i)
}
}
impl<'this, T: Copy, const GROUP: usize, const PACK: usize> Reborrow<'this>
for BlockTransposed<T, GROUP, PACK>
{
type Target = BlockTransposedRef<'this, T, GROUP, PACK>;
#[inline]
fn reborrow(&'this self) -> Self::Target {
self.as_view()
}
}
impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
#[allow(clippy::expect_used)]
pub fn new(nrows: usize, ncols: usize) -> Self {
let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)
.expect("dimensions should not overflow");
Self {
data: Mat::new(repr, Defaulted).expect("infallible"),
}
}
pub fn try_new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)?;
Ok(Self {
data: Mat::new(repr, Defaulted).expect("infallible"),
})
}
pub fn from_strided(v: StridedView<'_, T>) -> Self {
let nrows = v.nrows();
let ncols = v.ncols();
let mut mat = Self::new(nrows, ncols);
let repr = *mat.data.repr();
let num_blocks = repr.num_blocks();
let pncols = repr.padded_ncols();
let num_col_groups = pncols / PACK;
let mut dst = mat.data.as_raw_mut_ptr().cast::<T>();
for block in 0..num_blocks {
let row_base = block * GROUP;
for cg in 0..num_col_groups {
let col_base = cg * PACK;
for rib in 0..GROUP {
let row = row_base + rib;
if row < nrows {
let src_row = unsafe { v.get_row_unchecked(row) };
for p in 0..PACK {
let col = col_base + p;
if col < ncols {
unsafe { *dst = *src_row.get_unchecked(col) };
}
dst = unsafe { dst.add(1) };
}
} else {
dst = unsafe { dst.add(PACK) };
}
}
}
}
mat
}
pub fn from_matrix_view(v: MatrixView<'_, T>) -> Self {
Self::from_strided(v.into())
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<(usize, usize)>
for BlockTransposed<T, GROUP, PACK>
{
type Output = T;
#[inline]
fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
assert!(row < self.nrows());
assert!(col < self.ncols());
let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
unsafe { &*self.as_ptr().add(idx) }
}
}
#[cfg(test)]
mod tests {
use diskann_utils::{lazy_format, views::Matrix};
use super::*;
use crate::utils::div_round_up;
fn gen_f32(i: usize) -> f32 {
(i + 1) as f32
}
fn gen_i32(i: usize) -> i32 {
(i + 1) as i32
}
fn gen_u8(i: usize) -> u8 {
((i % 255) + 1) as u8
}
fn test_full_api<
T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
const GROUP: usize,
const PACK: usize,
>(
nrows: usize,
ncols: usize,
gen_element: fn(usize) -> T,
) {
let context = lazy_format!(
"T={}, GROUP={}, PACK={}, nrows={}, ncols={}",
std::any::type_name::<T>(),
GROUP,
PACK,
nrows,
ncols,
);
let mut data = Matrix::new(T::default(), nrows, ncols);
data.as_mut_slice()
.iter_mut()
.enumerate()
.for_each(|(i, d)| *d = gen_element(i));
let mut transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
let expected_padded = div_round_up(ncols, PACK) * PACK;
let expected_remainder = nrows % GROUP;
let storage_len = transpose.as_slice().len();
assert_eq!(transpose.nrows(), nrows, "{}", context);
assert_eq!(transpose.ncols(), ncols, "{}", context);
assert_eq!(transpose.group_size(), GROUP, "{}", context);
assert_eq!(
BlockTransposed::<T, GROUP, PACK>::const_group_size(),
GROUP,
"{}",
context
);
assert_eq!(transpose.pack_size(), PACK, "{}", context);
assert_eq!(transpose.full_blocks(), nrows / GROUP, "{}", context);
assert_eq!(
transpose.num_blocks(),
div_round_up(nrows, GROUP),
"{}",
context,
);
assert_eq!(transpose.remainder(), expected_remainder, "{}", context);
assert_eq!(transpose.padded_ncols(), expected_padded, "{}", context);
for row in 0..nrows {
for col in 0..ncols {
assert_eq!(
data[(row, col)],
transpose[(row, col)],
"Index at ({}, {}) -- {}",
row,
col,
context,
);
assert_eq!(
data[(row, col)],
transpose.get_element(row, col),
"get_element at ({}, {}) -- {}",
row,
col,
context,
);
}
}
let view = transpose.as_view();
for row in 0..nrows {
let row_view = view.get_row(row).unwrap();
assert_eq!(row_view.len(), ncols, "{}", context);
assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
for col in 0..ncols {
assert_eq!(
data[(row, col)],
row_view[col],
"row view at ({}, {}) -- {}",
row,
col,
context,
);
}
if ncols > 0 {
assert_eq!(row_view.get(0), Some(&data[(row, 0)]), "{}", context);
}
assert_eq!(row_view.get(ncols), None, "{}", context);
let iter = row_view.iter();
assert_eq!(iter.len(), ncols, "{}", context);
let (lo, hi) = iter.size_hint();
assert_eq!(lo, ncols, "{}", context);
assert_eq!(hi, Some(ncols), "{}", context);
let collected: Vec<T> = row_view.iter().collect();
assert_eq!(collected.len(), ncols, "{}", context);
for col in 0..ncols {
assert_eq!(data[(row, col)], collected[col], "{}", context);
}
}
assert!(view.get_row(nrows).is_none(), "{}", context);
let _ = view;
{
let view = transpose.as_view();
assert_eq!(view.nrows(), nrows, "{}", context);
assert_eq!(view.ncols(), ncols, "{}", context);
assert_eq!(view.padded_ncols(), expected_padded, "{}", context);
assert_eq!(view.group_size(), GROUP, "{}", context);
assert_eq!(
BlockTransposedRef::<T, GROUP, PACK>::const_group_size(),
GROUP,
);
assert_eq!(view.pack_size(), PACK, "{}", context);
assert_eq!(view.full_blocks(), nrows / GROUP, "{}", context);
assert_eq!(view.num_blocks(), div_round_up(nrows, GROUP), "{}", context,);
assert_eq!(view.remainder(), expected_remainder, "{}", context);
assert_eq!(view.as_ptr(), transpose.as_ptr(), "{}", context);
assert_eq!(view.as_slice(), transpose.as_slice(), "{}", context);
for row in 0..nrows {
for col in 0..ncols {
assert_eq!(
data[(row, col)],
view.get_element(row, col),
"Ref get_element at ({}, {}) -- {}",
row,
col,
context,
);
}
let row_view = view.get_row(row).unwrap();
for col in 0..ncols {
assert_eq!(data[(row, col)], row_view[col], "{}", context);
}
}
assert!(view.get_row(nrows).is_none(), "{}", context);
}
let expected_ptr = transpose.as_ptr();
{
let mut_view = transpose.as_view_mut();
assert_eq!(mut_view.nrows(), nrows, "{}", context);
assert_eq!(mut_view.ncols(), ncols, "{}", context);
assert_eq!(mut_view.padded_ncols(), expected_padded, "{}", context);
assert_eq!(mut_view.group_size(), GROUP, "{}", context);
assert_eq!(
BlockTransposedMut::<T, GROUP, PACK>::const_group_size(),
GROUP,
);
assert_eq!(mut_view.pack_size(), PACK, "{}", context);
assert_eq!(mut_view.full_blocks(), nrows / GROUP, "{}", context);
assert_eq!(
mut_view.num_blocks(),
div_round_up(nrows, GROUP),
"{}",
context,
);
assert_eq!(mut_view.remainder(), expected_remainder, "{}", context);
assert_eq!(mut_view.as_ptr(), expected_ptr, "{}", context);
assert_eq!(mut_view.as_slice().len(), storage_len, "{}", context);
for row in 0..nrows {
for col in 0..ncols {
assert_eq!(
data[(row, col)],
mut_view.get_element(row, col),
"Mut get_element at ({}, {}) -- {}",
row,
col,
context,
);
}
let row_view = mut_view.get_row(row).unwrap();
for col in 0..ncols {
assert_eq!(data[(row, col)], row_view[col], "{}", context);
}
}
assert!(mut_view.get_row(nrows).is_none(), "{}", context);
}
{
let mut_view = transpose.as_view_mut();
let ref_from_mut = mut_view.as_view();
assert_eq!(ref_from_mut.nrows(), nrows, "{}", context);
for row in 0..nrows {
for col in 0..ncols {
assert_eq!(
data[(row, col)],
ref_from_mut.get_element(row, col),
"{}",
context,
);
}
}
}
{
let mut mut_view = transpose.as_view_mut();
assert_eq!(mut_view.as_mut_slice().len(), storage_len, "{}", context);
}
assert_eq!(transpose.as_mut_slice().len(), storage_len, "{}", context);
let expected_block_nrows = expected_padded / PACK;
let expected_block_ncols = GROUP * PACK;
for b in 0..transpose.full_blocks() {
let block_data: Vec<T>;
let ptr: *const T;
{
let block = transpose.block(b);
assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
ptr = unsafe { transpose.block_ptr_unchecked(b) };
assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
block_data = block.as_slice().to_vec();
}
{
let view = transpose.as_view();
assert_eq!(view.block(b).as_slice(), &block_data[..], "{}", context);
assert_eq!(unsafe { view.block_ptr_unchecked(b) }, ptr, "{}", context);
}
{
let mut_view = transpose.as_view_mut();
assert_eq!(mut_view.block(b).as_slice(), &block_data[..], "{}", context);
assert_eq!(
unsafe { mut_view.block_ptr_unchecked(b) },
ptr,
"{}",
context,
);
}
}
if expected_remainder != 0 {
let remainder_data: Vec<T>;
let ptr: *const T;
let fb = transpose.full_blocks();
{
let block = transpose.remainder_block().unwrap();
assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
ptr = unsafe { transpose.block_ptr_unchecked(fb) };
assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
remainder_data = block.as_slice().to_vec();
}
{
let view = transpose.as_view();
let ref_block = view.remainder_block().unwrap();
assert_eq!(ref_block.as_slice(), &remainder_data[..], "{}", context);
}
{
let mut_view = transpose.as_view_mut();
let mut_block = mut_view.remainder_block().unwrap();
assert_eq!(mut_block.as_slice(), &remainder_data[..], "{}", context);
}
} else {
assert!(transpose.remainder_block().is_none(), "{}", context);
{
let view = transpose.as_view();
assert!(view.remainder_block().is_none(), "{}", context);
}
{
let mut_view = transpose.as_view_mut();
assert!(mut_view.remainder_block().is_none(), "{}", context);
}
}
{
let mut mut_view = transpose.as_view_mut();
for b in 0..mut_view.full_blocks() {
let block_mut = mut_view.block_mut(b);
assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
}
if expected_remainder != 0 {
let rem = mut_view.remainder_block_mut().unwrap();
assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
} else {
assert!(mut_view.remainder_block_mut().is_none(), "{}", context);
}
}
for b in 0..transpose.full_blocks() {
let block_mut = transpose.block_mut(b);
assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
}
if expected_remainder != 0 {
let rem = transpose.remainder_block_mut().unwrap();
assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
} else {
assert!(transpose.remainder_block_mut().is_none(), "{}", context);
}
{
let mut mut_view = transpose.as_view_mut();
for row in 0..nrows {
let row_view = mut_view.get_row_mut(row).unwrap();
assert_eq!(row_view.len(), ncols, "{}", context);
assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
for col in 0..ncols {
assert_eq!(data[(row, col)], row_view[col], "{}", context);
}
}
assert!(mut_view.get_row_mut(nrows).is_none(), "{}", context);
}
if nrows > 0 && ncols > 0 {
{
let view = transpose.as_view();
let row = view.get_row(0).unwrap();
assert_eq!(row.get(ncols), None, "{}", context);
assert_eq!(row.get(usize::MAX), None, "{}", context);
}
let row = transpose.get_row_mut(0).unwrap();
assert_eq!(row.get(ncols), None, "{}", context);
let mut row = transpose.get_row_mut(0).unwrap();
let sentinel = gen_element(usize::MAX / 2);
let original = row[0];
if let Some(v) = row.get_mut(0) {
*v = sentinel;
}
assert_eq!(row.get_mut(ncols), None, "{}", context);
let _ = row;
assert_eq!(transpose.get_element(0, 0), sentinel, "{}", context);
transpose.get_row_mut(0).unwrap().set(0, original);
}
for b in 0..transpose.full_blocks() {
transpose.block_mut(b).as_mut_slice().fill(T::default());
}
if transpose.remainder() != 0 {
transpose
.remainder_block_mut()
.unwrap()
.as_mut_slice()
.fill(T::default());
}
assert!(
transpose.as_slice().iter().all(|v| *v == T::default()),
"not fully zeroed -- {}",
context,
);
let transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
let raw = transpose.as_slice();
for row in 0..nrows {
for col in ncols..expected_padded {
let idx = linear_index::<GROUP, PACK>(row, col, ncols);
assert_eq!(
raw[idx],
T::default(),
"col padding at ({}, {}) -- {}",
row,
col,
context,
);
}
}
let padded_nrows = nrows.next_multiple_of(GROUP);
for row in nrows..padded_nrows {
for col in 0..expected_padded {
let idx = linear_index::<GROUP, PACK>(row, col, ncols);
assert_eq!(
raw[idx],
T::default(),
"row padding at ({}, {}) -- {}",
row,
col,
context,
);
}
}
assert_eq!(
transpose.as_view().padded_nrows(),
padded_nrows,
"padded_nrows() mismatch -- {}",
context,
);
if nrows > 0 && ncols > 0 {
let via_matrix = BlockTransposed::<T, GROUP, PACK>::from_matrix_view(data.as_view());
assert_eq!(via_matrix.as_slice(), transpose.as_slice(), "{}", context);
}
}
#[test]
fn test_api_pack1_group16() {
let rows: Vec<usize> = if cfg!(miri) {
vec![0, 1, 15, 16, 17, 33]
} else {
(0..128).collect()
};
let cols: Vec<usize> = if cfg!(miri) {
vec![0, 1, 2]
} else {
(0..5).collect()
};
for &nrows in &rows {
for &ncols in &cols {
test_full_api::<f32, 16, 1>(nrows, ncols, gen_f32);
}
}
}
#[test]
fn test_api_pack1_group8() {
let rows: Vec<usize> = if cfg!(miri) {
vec![0, 1, 7, 8, 9, 17]
} else {
(0..128).collect()
};
let cols: Vec<usize> = if cfg!(miri) {
vec![0, 1, 2]
} else {
(0..5).collect()
};
for &nrows in &rows {
for &ncols in &cols {
test_full_api::<f32, 8, 1>(nrows, ncols, gen_f32);
}
}
}
#[test]
fn test_api_pack2() {
let rows: Vec<usize> = if cfg!(miri) {
vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
} else {
(0..48).collect()
};
let cols: Vec<usize> = if cfg!(miri) {
vec![0, 1, 2, 3, 4, 5]
} else {
(0..9).collect()
};
for &nrows in &rows {
for &ncols in &cols {
test_full_api::<f32, 4, 2>(nrows, ncols, gen_f32);
test_full_api::<f32, 8, 2>(nrows, ncols, gen_f32);
test_full_api::<f32, 16, 2>(nrows, ncols, gen_f32);
}
}
}
#[test]
fn test_api_pack4() {
let rows: Vec<usize> = if cfg!(miri) {
vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
} else {
(0..48).collect()
};
let cols: Vec<usize> = if cfg!(miri) {
vec![0, 1, 3, 4, 5, 8]
} else {
(0..9).collect()
};
for &nrows in &rows {
for &ncols in &cols {
test_full_api::<f32, 4, 4>(nrows, ncols, gen_f32);
test_full_api::<f32, 8, 4>(nrows, ncols, gen_f32);
test_full_api::<f32, 16, 4>(nrows, ncols, gen_f32);
}
}
}
#[test]
fn test_api_non_f32() {
test_full_api::<i32, 4, 1>(10, 7, gen_i32);
test_full_api::<i32, 8, 2>(12, 5, gen_i32);
test_full_api::<u8, 4, 2>(12, 5, gen_u8);
test_full_api::<u8, 8, 1>(10, 7, gen_u8);
}
fn test_block_layout_pack1<
T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
const GROUP: usize,
>(
nrows: usize,
ncols: usize,
gen_element: fn(usize) -> T,
) {
let mut data = Matrix::new(T::default(), nrows, ncols);
data.as_mut_slice()
.iter_mut()
.enumerate()
.for_each(|(i, d)| *d = gen_element(i));
let transpose = BlockTransposed::<T, GROUP, 1>::from_strided(data.as_view().into());
for b in 0..transpose.full_blocks() {
let block = transpose.block(b);
for i in 0..block.nrows() {
for j in 0..block.ncols() {
assert_eq!(
block[(i, j)],
data[(GROUP * b + j, i)],
"block {} at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
b,
i,
j,
GROUP,
nrows,
ncols,
);
}
}
}
if transpose.remainder() != 0 {
let fb = transpose.full_blocks();
let block = transpose.remainder_block().unwrap();
for i in 0..block.nrows() {
for j in 0..transpose.remainder() {
assert_eq!(
block[(i, j)],
data[(GROUP * fb + j, i)],
"remainder at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
i,
j,
GROUP,
nrows,
ncols,
);
}
}
}
}
#[test]
fn test_block_layout_pack1_group16() {
let rows: Vec<usize> = if cfg!(miri) {
vec![0, 1, 15, 16, 17, 33]
} else {
(0..128).collect()
};
let cols: Vec<usize> = if cfg!(miri) {
vec![0, 1, 2]
} else {
(0..5).collect()
};
for &nrows in &rows {
for &ncols in &cols {
test_block_layout_pack1::<f32, 16>(nrows, ncols, gen_f32);
}
}
}
#[test]
fn test_block_layout_pack1_group8() {
let rows: Vec<usize> = if cfg!(miri) {
vec![0, 1, 7, 8, 9, 17]
} else {
(0..128).collect()
};
let cols: Vec<usize> = if cfg!(miri) {
vec![0, 1, 2]
} else {
(0..5).collect()
};
for &nrows in &rows {
for &ncols in &cols {
test_block_layout_pack1::<f32, 8>(nrows, ncols, gen_f32);
}
}
}
#[test]
fn test_row_view_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<Row<'_, f32, 16>>();
assert_sync::<Row<'_, f32, 16>>();
assert_send::<Row<'_, u8, 8, 2>>();
assert_sync::<Row<'_, u8, 8, 2>>();
assert_send::<RowMut<'_, f32, 16>>();
assert_sync::<RowMut<'_, f32, 16>>();
assert_send::<RowMut<'_, i32, 4, 4>>();
assert_sync::<RowMut<'_, i32, 4, 4>>();
}
#[test]
fn test_new_ref_and_new_mut() {
let nrows = 5;
let ncols = 3;
let repr = BlockTransposedRepr::<f32, 4>::new(nrows, ncols).unwrap();
let mat = BlockTransposed::<f32, 4>::new(nrows, ncols);
let raw: &[f32] = mat.as_slice();
let mat_ref = BlockTransposedRef {
data: repr.new_ref(raw).unwrap(),
};
assert_eq!(mat_ref.nrows(), nrows);
assert_eq!(mat_ref.ncols(), ncols);
for row in 0..nrows {
for col in 0..ncols {
assert_eq!(mat_ref.get_element(row, col), mat.get_element(row, col));
}
}
let mut buf = raw.to_vec();
let mat_mut = BlockTransposedMut {
data: repr.new_mut(&mut buf).unwrap(),
};
assert_eq!(mat_mut.nrows(), nrows);
assert_eq!(mat_mut.ncols(), ncols);
let mut short = vec![0.0_f32; 2];
assert!(repr.new_ref(&short).is_err());
assert!(repr.new_mut(&mut short).is_err());
}
#[test]
fn test_row_view_empty() {
fn check_empty<const GROUP: usize, const PACK: usize>() {
let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(4, 0);
let view = mat.as_view();
for i in 0..4 {
let row = view.get_row(i).unwrap();
assert!(row.is_empty());
assert_eq!(row.len(), 0);
assert_eq!(row.iter().count(), 0);
}
for i in 0..4 {
let row = mat.get_row_mut(i).unwrap();
assert!(row.is_empty());
assert_eq!(row.len(), 0);
}
}
check_empty::<16, 1>(); check_empty::<4, 2>(); check_empty::<4, 4>(); }
#[test]
#[should_panic(expected = "column index 3 out of bounds")]
fn test_row_view_index_oob() {
let mat = BlockTransposed::<f32, 4>::new(4, 3);
let view = mat.as_view();
let row = view.get_row(0).unwrap();
let _ = row[3];
}
#[test]
#[should_panic(expected = "column index 3 out of bounds")]
fn test_row_view_mut_index_oob() {
let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
let row = mat.get_row_mut(0).unwrap();
let _ = row[3];
}
#[test]
#[should_panic(expected = "column index 3 out of bounds")]
fn test_row_view_mut_index_mut_oob() {
let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
let mut row = mat.get_row_mut(0).unwrap();
row[3] = 1.0;
}
#[test]
#[should_panic(expected = "column index 3 out of bounds")]
fn test_row_view_set_oob() {
let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
let mut row = mat.get_row_mut(0).unwrap();
row.set(3, 1.0);
}
#[test]
#[should_panic(expected = "row 4 out of bounds")]
fn test_get_element_row_oob() {
let mat = BlockTransposed::<f32, 4>::new(4, 3);
mat.get_element(4, 0);
}
#[test]
#[should_panic(expected = "col 3 out of bounds")]
fn test_get_element_col_oob() {
let mat = BlockTransposed::<f32, 4>::new(4, 3);
mat.get_element(0, 3);
}
#[test]
#[should_panic(expected = "assertion failed")]
fn test_index_tuple_row_oob() {
let mat = BlockTransposed::<f32, 4>::new(4, 3);
let _ = mat[(4, 0)];
}
#[test]
#[should_panic(expected = "assertion failed")]
fn test_index_tuple_col_oob() {
let mat = BlockTransposed::<f32, 4>::new(4, 3);
let _ = mat[(0, 3)];
}
#[test]
#[should_panic]
fn test_block_oob() {
let mat = BlockTransposed::<f32, 4>::new(4, 3);
let _ = mat.block(1);
}
#[test]
#[should_panic]
fn test_block_mut_oob() {
let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
let _ = mat.block_mut(1);
}
#[test]
fn test_from_strided_nonunit_stride() {
use diskann_utils::strided::StridedView;
const GROUP: usize = 4;
const PACK: usize = 2;
let nrows = 5;
let ncols = 3;
let cstride = 8;
let required_len = (nrows - 1) * cstride + ncols;
let mut flat = vec![0.0_f32; required_len];
for row in 0..nrows {
for col in 0..ncols {
flat[row * cstride + col] = (row * 100 + col + 1) as f32;
}
}
let strided = StridedView::try_shrink_from(&flat, nrows, ncols, cstride)
.expect("should construct strided view");
let transpose = BlockTransposed::<f32, GROUP, PACK>::from_strided(strided);
assert_eq!(transpose.nrows(), nrows);
assert_eq!(transpose.ncols(), ncols);
for row in 0..nrows {
for col in 0..ncols {
let expected = (row * 100 + col + 1) as f32;
assert_eq!(
transpose[(row, col)],
expected,
"mismatch at ({}, {})",
row,
col,
);
}
}
let padded_ncols = ncols.next_multiple_of(PACK);
let raw: &[f32] = transpose.as_slice();
for row in 0..nrows {
for col in ncols..padded_ncols {
let idx = linear_index::<GROUP, PACK>(row, col, ncols);
assert_eq!(
raw[idx], 0.0,
"column-padding at ({}, {}) should be zero",
row, col,
);
}
}
}
#[test]
fn test_concurrent_row_mutation() {
const GROUP: usize = 8;
const PACK: usize = 2;
let (nrows, ncols, num_threads) = if cfg!(miri) { (8, 4, 2) } else { (64, 16, 4) };
let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(nrows, ncols);
let rows: Vec<RowMut<'_, f32, GROUP, PACK>> = mat.data.rows_mut().collect();
let rows_per_thread = nrows / num_threads;
let mut rows = rows.into_boxed_slice();
std::thread::scope(|s| {
let mut remaining = &mut rows[..];
for thread_id in 0..num_threads {
let chunk_len = if thread_id == num_threads - 1 {
remaining.len()
} else {
rows_per_thread
};
let (chunk, rest) = remaining.split_at_mut(chunk_len);
remaining = rest;
let start_row = thread_id * rows_per_thread;
s.spawn(move || {
for (offset, row_view) in chunk.iter_mut().enumerate() {
let row = start_row + offset;
for col in 0..ncols {
let value = (thread_id * 10000 + row * 100 + col) as f32;
row_view.set(col, value);
}
}
});
}
});
for row in 0..nrows {
let thread_id = (row / rows_per_thread).min(num_threads - 1);
for col in 0..ncols {
let expected = (thread_id * 10000 + row * 100 + col) as f32;
assert_eq!(
mat.get_element(row, col),
expected,
"mismatch at ({}, {})",
row,
col,
);
}
}
}
}