Skip to main content

diskann_quantization/multi_vector/
block_transposed.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Block-transposed matrix types with configurable packing.
7//!
8//! This module provides block-transposed matrix types — [`BlockTransposed`] (owned),
9//! [`BlockTransposedRef`] (shared view), and [`BlockTransposedMut`] (mutable view) —
10//! where groups of `GROUP` rows are stored in transposed form to enable efficient SIMD
11//! processing. An optional packing factor `PACK` interleaves adjacent columns within
12//! each group, which can be used to feed SIMD instructions that operate on packed pairs
13//! (e.g. `vpmaddwd` with `PACK = 2`).
14//!
15//! # Layout
16//!
17//! ## `PACK = 1` (standard block-transpose)
18//!
19//! Given a logical matrix with rows `a`, `b`, `c`, `d`, `e` (each with `K` columns)
20//! and `GROUP = 3`:
21//!
22//! ```text
23//!            Group Size (3)
24//!            <---------->
25//!
26//!            +----------+    ^
27//!            | a0 b0 c0 |    |
28//!            | a1 b1 c1 |    |
29//!            | a2 b2 c2 |    | Block Size (K)
30//!  Block 0   | ...      |    |
31//!  (Full)    | aK bK cK |    |
32//!            +----------+    v
33//!            +----------+
34//!            | d0 e0 XX |
35//!  Block 1   | d1 e1 XX |
36//!  (Partial) | ...      |
37//!            | dK eK XX |
38//!            +----------+
39//! ```
40//!
41//! ## `PACK = 2` (super-packed)
42//!
43//! With `GROUP = 4`, `PACK = 2`, and a logical matrix with rows `a`, `b`, `c`, `d`,
44//! `e`, `f` (each with **5** columns — odd, to show padding), adjacent column-pairs
45//! are interleaved per row within each group panel:
46//!
47//! ```text
48//!              GROUP × PACK (4 × 2 = 8)
49//!              <----------------------------->
50//!
51//!              +-----------------------------+    ^
52//!              | a0 a1  b0 b1  c0 c1  d0 d1  |    |  col-pair (0, 1)
53//!              | a2 a3  b2 b3  c2 c3  d2 d3  |    |  col-pair (2, 3)
54//!    Block 0   | a4 __  b4 __  c4 __  d4 __  |    |  col-pair (4, pad)
55//!    (Full)    +-----------------------------+    v
56//!              +-----------------------------+
57//!              | e0 e1  f0 f1  XX XX  XX XX  |       col-pair (0, 1)
58//!    Block 1   | e2 e3  f2 f3  XX XX  XX XX  |       col-pair (2, 3)
59//!    (Partial) | e4 __  f4 __  XX XX  XX XX  |       col-pair (4, pad)
60//!              +-----------------------------+
61//!
62//!    __ = zero (column padding)    XX = zero (row padding)
63//!    padded_ncols = 6  (5 rounded up to next multiple of PACK)
64//!    Block Size  = padded_ncols / PACK = 3 physical rows per block
65//! ```
66//!
67//! Each physical row of a block holds one column-pair across all `GROUP` rows.
68//! For example, the first physical row stores columns `(0, 1)` for rows
69//! `a, b, c, d` interleaved as `[a0, a1, b0, b1, c0, c1, d0, d1]`.
70//!
71//! Because `ncols = 5` is odd (not a multiple of `PACK = 2`), the last
72//! column-pair `(4, pad)` is zero-padded: `[a4, 0, b4, 0, c4, 0, d4, 0]`.
73//!
74//! # Constraints
75//!
76//! - `GROUP > 0`
77//! - `PACK > 0`
78//! - `GROUP % PACK == 0`
79
80use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};
81
82use diskann_utils::{
83    ReborrowMut,
84    strided::StridedView,
85    views::{MatrixView, MutMatrixView},
86};
87
88use super::matrix::{
89    Defaulted, LayoutError, Mat, MatMut, MatRef, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut,
90    ReprOwned, SliceError,
91};
92use crate::bits::{AsMutPtr, AsPtr, MutSlicePtr, SlicePtr};
93use crate::utils;
94
95/// Round `ncols` up to the next multiple of `PACK`.
96#[inline]
97fn padded_ncols<const PACK: usize>(ncols: usize) -> usize {
98    ncols.next_multiple_of(PACK)
99}
100
101/// Compute the total number of `T` elements required to store a block-transposed matrix
102/// of `nrows x ncols` with group size `GROUP` and packing factor `PACK`.
103///
104/// This is the **unchecked** flavor — it assumes the caller has already validated that
105/// the dimensions do not overflow (e.g. after construction). For use in the constructor,
106/// prefer [`checked_compute_capacity`].
107///
108/// Compile-time constraints (`GROUP > 0`, `PACK > 0`, `GROUP % PACK == 0`) are enforced
109/// by [`BlockTransposedRepr::_ASSERTIONS`]; this function does **not** duplicate them.
110#[inline]
111fn compute_capacity<const GROUP: usize, const PACK: usize>(nrows: usize, ncols: usize) -> usize {
112    nrows.next_multiple_of(GROUP) * padded_ncols::<PACK>(ncols)
113}
114
115/// Checked variant of [`compute_capacity`] that returns `None` if any intermediate
116/// arithmetic overflows. Used by the constructor to reject impossibly large dimensions
117/// before committing to an allocation.
118#[inline]
119fn checked_compute_capacity<const GROUP: usize, const PACK: usize>(
120    nrows: usize,
121    ncols: usize,
122) -> Option<usize> {
123    nrows
124        .checked_next_multiple_of(GROUP)?
125        .checked_mul(ncols.checked_next_multiple_of(PACK)?)
126}
127
128/// Compute the linear index for the element at logical `(row, col)` in a block-transposed
129/// layout with group size `GROUP`, packing factor `PACK`, and `ncols` logical columns.
130#[inline]
131fn linear_index<const GROUP: usize, const PACK: usize>(
132    row: usize,
133    col: usize,
134    ncols: usize,
135) -> usize {
136    let pncols = padded_ncols::<PACK>(ncols);
137    let block = row / GROUP;
138    let row_in_block = row % GROUP;
139    block * GROUP * pncols + (col / PACK) * GROUP * PACK + row_in_block * PACK + (col % PACK)
140}
141
142/// Compute the offset from a row's base pointer (at col=0) to the element at `col`.
143///
144/// This is purely a function of the column index and the const layout parameters, not
145/// of any particular matrix's dimensions.
146#[inline]
147fn col_offset<const GROUP: usize, const PACK: usize>(col: usize) -> usize {
148    (col / PACK) * GROUP * PACK + (col % PACK)
149}
150
151/// Internal layout descriptor for block-transposed matrices.
152///
153/// This is not part of the public API — use [`BlockTransposed`], [`BlockTransposedRef`],
154/// or [`BlockTransposedMut`] instead.
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub(crate) struct BlockTransposedRepr<T, const GROUP: usize, const PACK: usize = 1> {
157    nrows: usize,
158    ncols: usize,
159    _elem: PhantomData<T>,
160}
161
162impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROUP, PACK> {
163    // Compile-time assertions — evaluated whenever any method references this constant.
164    const _ASSERTIONS: () = {
165        assert!(GROUP > 0, "group size GROUP must be positive");
166        assert!(PACK > 0, "packing factor PACK must be positive");
167        assert!(
168            GROUP.is_multiple_of(PACK),
169            "GROUP must be divisible by PACK"
170        );
171    };
172
173    /// Create a new `BlockTransposedRepr` descriptor.
174    ///
175    /// Successful construction requires that the total memory for the backing allocation
176    /// does not exceed `isize::MAX`.
177    pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
178        let () = Self::_ASSERTIONS;
179        let capacity = checked_compute_capacity::<GROUP, PACK>(nrows, ncols)
180            .ok_or_else(|| Overflow::for_type::<T>(nrows, ncols))?;
181        Overflow::check_byte_budget::<T>(capacity, nrows, ncols)?;
182        Ok(Self {
183            nrows,
184            ncols,
185            _elem: PhantomData,
186        })
187    }
188
189    // ── Query helpers ────────────────────────────────────────────────
190
191    /// The total number of `T` elements in the backing allocation (including padding).
192    #[inline]
193    fn storage_len(&self) -> usize {
194        compute_capacity::<GROUP, PACK>(self.nrows, self.ncols)
195    }
196
197    /// Number of logical rows.
198    #[inline]
199    fn nrows(&self) -> usize {
200        self.nrows
201    }
202
203    /// Number of logical columns (dimensionality).
204    #[inline]
205    pub fn ncols(&self) -> usize {
206        self.ncols
207    }
208
209    /// Number of physical (padded) columns — logical columns rounded up to
210    /// the next multiple of `PACK`.
211    #[inline]
212    pub fn padded_ncols(&self) -> usize {
213        padded_ncols::<PACK>(self.ncols)
214    }
215
216    /// Number of completely full blocks.
217    #[inline]
218    pub fn full_blocks(&self) -> usize {
219        self.nrows / GROUP
220    }
221
222    /// Total number of blocks including a possible partially-filled tail.
223    #[inline]
224    pub fn num_blocks(&self) -> usize {
225        self.nrows.div_ceil(GROUP)
226    }
227
228    /// Number of valid elements in the last block, or 0 if all blocks are full.
229    #[inline]
230    pub fn remainder(&self) -> usize {
231        self.nrows % GROUP
232    }
233
234    /// The stride (in elements) between the start of consecutive blocks.
235    #[inline]
236    fn block_stride(&self) -> usize {
237        GROUP * self.padded_ncols()
238    }
239
240    /// The linear offset of the start of `block`.
241    #[inline]
242    fn block_offset(&self, block: usize) -> usize {
243        block * self.block_stride()
244    }
245
246    /// Verify that `slice` has exactly `self.storage_len()` elements.
247    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
248        let cap = self.storage_len();
249        if slice.len() != cap {
250            Err(SliceError::LengthMismatch {
251                expected: cap,
252                found: slice.len(),
253            })
254        } else {
255            Ok(())
256        }
257    }
258
259    /// Helper: wrap a `Box<[T]>` into a [`Mat`] without any further checks.
260    ///
261    /// # Safety
262    ///
263    /// `b.len()` must equal `self.storage_len()`.
264    unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
265        debug_assert_eq!(b.len(), self.storage_len(), "safety contract violated");
266
267        let ptr = utils::box_into_nonnull(b).cast::<u8>();
268
269        // SAFETY: `ptr` is properly aligned and compatible with our layout.
270        unsafe { Mat::from_raw_parts(self, ptr) }
271    }
272}
273
274// ════════════════════════════════════════════════════════════════════
275// Row view types
276// ════════════════════════════════════════════════════════════════════
277
278/// An immutable view of a single logical row in a block-transposed matrix.
279///
280/// Because the elements of a logical row are strided (not contiguous), this struct
281/// provides indexed access and iteration over the row's elements.
282#[derive(Debug, Clone, Copy)]
283pub struct Row<'a, T, const GROUP: usize, const PACK: usize = 1> {
284    /// Pointer to the element at `(row, col=0)` in the backing allocation.
285    base: SlicePtr<'a, T>,
286    ncols: usize,
287}
288
289impl<T: Copy, const GROUP: usize, const PACK: usize> Row<'_, T, GROUP, PACK> {
290    /// Number of elements (columns) in this row.
291    #[inline]
292    pub fn len(&self) -> usize {
293        self.ncols
294    }
295
296    /// Whether the row is empty.
297    #[inline]
298    pub fn is_empty(&self) -> bool {
299        self.ncols == 0
300    }
301
302    /// Get a reference to the element at column `col`, or `None` if out of bounds.
303    #[inline]
304    pub fn get(&self, col: usize) -> Option<&T> {
305        if col < self.ncols {
306            // SAFETY: bounds checked, offset computed from validated layout.
307            Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
308        } else {
309            None
310        }
311    }
312
313    /// Return an iterator over the elements of this row.
314    #[inline]
315    pub fn iter(&self) -> RowIter<'_, T, GROUP, PACK> {
316        RowIter {
317            base: self.base,
318            col: 0,
319            ncols: self.ncols,
320        }
321    }
322}
323
324impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
325    for Row<'_, T, GROUP, PACK>
326{
327    type Output = T;
328
329    #[inline]
330    #[allow(clippy::panic)] // Index is expected to panic on OOB
331    fn index(&self, col: usize) -> &Self::Output {
332        self.get(col)
333            .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
334    }
335}
336
337/// Iterator over the elements of a [`Row`].
338#[derive(Debug, Clone)]
339pub struct RowIter<'a, T, const GROUP: usize, const PACK: usize = 1> {
340    base: SlicePtr<'a, T>,
341    col: usize,
342    ncols: usize,
343}
344
345impl<T: Copy, const GROUP: usize, const PACK: usize> Iterator for RowIter<'_, T, GROUP, PACK> {
346    type Item = T;
347
348    #[inline]
349    fn next(&mut self) -> Option<Self::Item> {
350        if self.col >= self.ncols {
351            return None;
352        }
353        // SAFETY: col < ncols means the offset is within the backing allocation.
354        let val = unsafe { *self.base.as_ptr().add(col_offset::<GROUP, PACK>(self.col)) };
355        self.col += 1;
356        Some(val)
357    }
358
359    #[inline]
360    fn size_hint(&self) -> (usize, Option<usize>) {
361        let remaining = self.ncols - self.col;
362        (remaining, Some(remaining))
363    }
364}
365
366impl<T: Copy, const GROUP: usize, const PACK: usize> ExactSizeIterator
367    for RowIter<'_, T, GROUP, PACK>
368{
369}
370impl<T: Copy, const GROUP: usize, const PACK: usize> std::iter::FusedIterator
371    for RowIter<'_, T, GROUP, PACK>
372{
373}
374
375/// A mutable view of a single logical row in a block-transposed matrix.
376#[derive(Debug)]
377pub struct RowMut<'a, T, const GROUP: usize, const PACK: usize = 1> {
378    base: MutSlicePtr<'a, T>,
379    ncols: usize,
380}
381
382impl<T: Copy, const GROUP: usize, const PACK: usize> RowMut<'_, T, GROUP, PACK> {
383    /// Number of elements (columns) in this row.
384    #[inline]
385    pub fn len(&self) -> usize {
386        self.ncols
387    }
388
389    /// Whether the row is empty.
390    #[inline]
391    pub fn is_empty(&self) -> bool {
392        self.ncols == 0
393    }
394
395    /// Get a reference to the element at column `col`, or `None` if out of bounds.
396    #[inline]
397    pub fn get(&self, col: usize) -> Option<&T> {
398        if col < self.ncols {
399            // SAFETY: bounds checked.
400            Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
401        } else {
402            None
403        }
404    }
405
406    /// Get a mutable reference to the element at column `col`, or `None` if out of bounds.
407    #[inline]
408    pub fn get_mut(&mut self, col: usize) -> Option<&mut T> {
409        if col < self.ncols {
410            // SAFETY: bounds checked.
411            Some(unsafe { &mut *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) })
412        } else {
413            None
414        }
415    }
416
417    /// Set the element at column `col`.
418    ///
419    /// # Panics
420    ///
421    /// Panics if `col >= self.len()`.
422    #[inline]
423    pub fn set(&mut self, col: usize, value: T) {
424        assert!(
425            col < self.ncols,
426            "column index {col} out of bounds (ncols = {})",
427            self.ncols
428        );
429        // SAFETY: bounds checked.
430        unsafe { *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) = value };
431    }
432}
433
434impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
435    for RowMut<'_, T, GROUP, PACK>
436{
437    type Output = T;
438
439    #[inline]
440    #[allow(clippy::panic)] // Index is expected to panic on OOB
441    fn index(&self, col: usize) -> &Self::Output {
442        self.get(col)
443            .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
444    }
445}
446
447impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::IndexMut<usize>
448    for RowMut<'_, T, GROUP, PACK>
449{
450    #[inline]
451    #[allow(clippy::panic)] // IndexMut is expected to panic on OOB
452    fn index_mut(&mut self, col: usize) -> &mut Self::Output {
453        let ncols = self.ncols;
454        self.get_mut(col)
455            .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {ncols})"))
456    }
457}
458
459// ════════════════════════════════════════════════════════════════════
460// Repr / ReprMut / ReprOwned
461// ════════════════════════════════════════════════════════════════════
462
463// SAFETY: `get_row` produces a valid `Row` for valid indices. The layout
464// reports the correct capacity for the block-transposed backing allocation.
465unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> Repr
466    for BlockTransposedRepr<T, GROUP, PACK>
467{
468    type Row<'a>
469        = Row<'a, T, GROUP, PACK>
470    where
471        Self: 'a;
472
473    fn nrows(&self) -> usize {
474        self.nrows
475    }
476
477    fn layout(&self) -> Result<Layout, LayoutError> {
478        Ok(Layout::array::<T>(self.storage_len())?)
479    }
480
481    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
482        debug_assert!(i < self.nrows);
483
484        // When ncols == 0 the backing allocation is zero-sized, so we must not
485        // compute any pointer offset.  Return a dangling base instead.
486        if self.ncols == 0 {
487            return Row {
488                // SAFETY: The row is empty (ncols == 0) so the pointer will never be
489                // dereferenced. A dangling `NonNull` satisfies the non-null invariant.
490                base: unsafe { SlicePtr::new_unchecked(NonNull::dangling()) },
491                ncols: 0,
492            };
493        }
494
495        let base_ptr = ptr.as_ptr().cast::<T>();
496        let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
497
498        // SAFETY: The caller asserts `i < self.nrows()`. The backing allocation has at
499        // least `self.storage_len()` elements, so the computed offset is in bounds.
500        let row_base = unsafe { base_ptr.add(offset) };
501
502        Row {
503            // SAFETY: `row_base` is derived from a `NonNull<u8>` with a valid offset,
504            // so it is non-null. The lifetime is tied to the caller's `'a`.
505            base: unsafe { SlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
506            ncols: self.ncols,
507        }
508    }
509}
510
511// SAFETY: `get_row_mut` produces a valid `RowMut`. Disjoint row indices
512// produce disjoint base pointers because each row within a block starts at a unique
513// offset modulo GROUP.
514unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprMut
515    for BlockTransposedRepr<T, GROUP, PACK>
516{
517    type RowMut<'a>
518        = RowMut<'a, T, GROUP, PACK>
519    where
520        Self: 'a;
521
522    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
523        debug_assert!(i < self.nrows);
524
525        // When ncols == 0 the backing allocation is zero-sized, so we must not
526        // compute any pointer offset.  Return a dangling base instead.
527        if self.ncols == 0 {
528            return RowMut {
529                // SAFETY: The row is empty (ncols == 0) so the pointer will never be
530                // dereferenced. A dangling `NonNull` satisfies the non-null invariant.
531                base: unsafe { MutSlicePtr::new_unchecked(NonNull::dangling()) },
532                ncols: 0,
533            };
534        }
535
536        let base_ptr = ptr.as_ptr().cast::<T>();
537        let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
538
539        // SAFETY: `i < self.nrows` (debug-asserted) guarantees the offset is within
540        // the backing allocation. Same reasoning as `get_row`.
541        let row_base = unsafe { base_ptr.add(offset) };
542
543        RowMut {
544            // SAFETY: `row_base` is derived from a `NonNull<u8>` with a valid offset,
545            // so it is non-null. The lifetime is tied to the caller's `'a`.
546            base: unsafe { MutSlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
547            ncols: self.ncols,
548        }
549    }
550}
551
552// SAFETY: Memory is deallocated by reconstructing the `Box<[T]>` that was created during
553// `NewOwned`.
554unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprOwned
555    for BlockTransposedRepr<T, GROUP, PACK>
556{
557    unsafe fn drop(self, ptr: NonNull<u8>) {
558        // SAFETY: `ptr` was obtained from `Box::into_raw` with length `self.storage_len()`.
559        unsafe {
560            let slice_ptr =
561                std::ptr::slice_from_raw_parts_mut(ptr.cast::<T>().as_ptr(), self.storage_len());
562            let _ = Box::from_raw(slice_ptr);
563        }
564    }
565}
566
567// ════════════════════════════════════════════════════════════════════
568// Constructors
569// ════════════════════════════════════════════════════════════════════
570
571// SAFETY: The returned `Mat` contains a `Box` with exactly `self.storage_len()` elements.
572unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewOwned<T>
573    for BlockTransposedRepr<T, GROUP, PACK>
574{
575    type Error = crate::error::Infallible;
576
577    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
578        let b: Box<[T]> = vec![value; self.storage_len()].into_boxed_slice();
579
580        // SAFETY: By construction, `b.len() == self.storage_len()`.
581        Ok(unsafe { self.box_to_mat(b) })
582    }
583}
584
585// SAFETY: This safely re-uses `<Self as NewOwned<T>>`.
586unsafe impl<T: Copy + Default, const GROUP: usize, const PACK: usize> NewOwned<Defaulted>
587    for BlockTransposedRepr<T, GROUP, PACK>
588{
589    type Error = crate::error::Infallible;
590
591    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
592        self.new_owned(T::default())
593    }
594}
595
596// SAFETY: This checks slice length against storage_len.
597unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewRef<T>
598    for BlockTransposedRepr<T, GROUP, PACK>
599{
600    type Error = SliceError;
601
602    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
603        self.check_slice(data)?;
604
605        // SAFETY: `check_slice` verified the length.
606        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
607    }
608}
609
610// SAFETY: This checks slice length against storage_len.
611unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewMut<T>
612    for BlockTransposedRepr<T, GROUP, PACK>
613{
614    type Error = SliceError;
615
616    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
617        self.check_slice(data)?;
618
619        // SAFETY: `check_slice` verified the length.
620        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
621    }
622}
623
624// ════════════════════════════════════════════════════════════════════
625// Delegation macro
626// ════════════════════════════════════════════════════════════════════
627
628/// Generates a forwarding method that delegates to `self.as_view().$name(...)`.
629///
630/// The generated doc-comment links back to the canonical implementation on
631/// [`BlockTransposedRef`], so documentation stays in sync automatically.
632macro_rules! delegate_to_ref {
633    // Safe function.
634    ($(#[$m:meta])* $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
635        #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
636        $(#[$m])*
637        #[inline]
638        $vis fn $name(&self $(, $a: $t)*) $(-> $r)? {
639            self.as_view().$name($($a),*)
640        }
641    };
642    // Unsafe function.
643    ($(#[$m:meta])* unsafe $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
644        #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
645        $(#[$m])*
646        #[inline]
647        $vis unsafe fn $name(&self $(, $a: $t)*) $(-> $r)? {
648            // SAFETY: Caller upholds the safety contract of the delegated method.
649            unsafe { self.as_view().$name($($a),*) }
650        }
651    };
652}
653
654// ════════════════════════════════════════════════════════════════════
655// Public wrapper types
656// ════════════════════════════════════════════════════════════════════
657
658/// An owning block-transposed matrix.
659///
660/// Wraps an owned allocation of `T` elements laid out in block-transposed order.
661/// See the [module-level documentation](self) for layout details.
662///
663/// For shared and mutable views, see [`BlockTransposedRef`] and [`BlockTransposedMut`].
664///
665/// # Row Types
666///
667/// Because rows are not contiguous in memory, the row types are view structs:
668///
669/// - [`Row`] — a `Copy` handle supporting `Index<usize>` and `.iter()`.
670/// - [`RowMut`] — a mutable handle supporting `IndexMut<usize>`.
671#[derive(Debug)]
672pub struct BlockTransposed<T: Copy, const GROUP: usize, const PACK: usize = 1> {
673    data: Mat<BlockTransposedRepr<T, GROUP, PACK>>,
674}
675
676/// A shared (immutable) view of a block-transposed matrix.
677///
678/// Created by [`BlockTransposed::as_view`].
679#[derive(Debug, Clone, Copy)]
680pub struct BlockTransposedRef<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
681    data: MatRef<'a, BlockTransposedRepr<T, GROUP, PACK>>,
682}
683
684/// A mutable view of a block-transposed matrix.
685///
686/// Created by [`BlockTransposed::as_view_mut`].
687pub struct BlockTransposedMut<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
688    data: MatMut<'a, BlockTransposedRepr<T, GROUP, PACK>>,
689}
690
691// ── BlockTransposedRef (core read implementations) ───────────────
692
693impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, T, GROUP, PACK> {
694    /// Returns the number of logical rows.
695    #[inline]
696    pub fn nrows(&self) -> usize {
697        self.data.repr().nrows()
698    }
699
700    /// Returns the number of logical columns (dimensionality).
701    #[inline]
702    pub fn ncols(&self) -> usize {
703        self.data.repr().ncols()
704    }
705
706    /// Returns the number of physical (padded) columns.
707    #[inline]
708    pub fn padded_ncols(&self) -> usize {
709        self.data.repr().padded_ncols()
710    }
711
712    /// Group size (blocking factor `GROUP`).
713    pub const fn group_size(&self) -> usize {
714        GROUP
715    }
716
717    /// Group size (blocking factor `GROUP`) as a `const` function on the *type*.
718    pub const fn const_group_size() -> usize {
719        GROUP
720    }
721
722    /// Packing factor `PACK`.
723    pub const fn pack_size(&self) -> usize {
724        PACK
725    }
726
727    /// Number of completely full blocks.
728    #[inline]
729    pub fn full_blocks(&self) -> usize {
730        self.data.repr().full_blocks()
731    }
732
733    /// Total number of blocks including any partially-filled tail.
734    #[inline]
735    pub fn num_blocks(&self) -> usize {
736        self.data.repr().num_blocks()
737    }
738
739    /// Number of valid elements in the last partially-full block, or 0 if all
740    /// blocks are full.
741    #[inline]
742    pub fn remainder(&self) -> usize {
743        self.data.repr().remainder()
744    }
745
746    /// Return a raw typed pointer to the start of the backing data.
747    #[inline]
748    pub fn as_ptr(&self) -> *const T {
749        self.data.as_raw_ptr().cast::<T>()
750    }
751
752    /// Return the backing data as a shared slice.
753    ///
754    /// The returned slice has `storage_len()` elements — this includes all padding
755    /// for partial blocks and column-group alignment.
756    #[inline]
757    pub fn as_slice(&self) -> &'a [T] {
758        let len = self.data.repr().storage_len();
759        // SAFETY: The backing allocation has exactly `storage_len()` elements of type T.
760        unsafe { std::slice::from_raw_parts(self.as_ptr(), len) }
761    }
762
763    /// Return a pointer to the start of the given block.
764    ///
765    /// The caller may assume that for the returned pointer `ptr`,
766    /// `[ptr, ptr + GROUP * padded_ncols)` points to valid memory, even for the
767    /// remainder block.
768    ///
769    /// # Safety
770    ///
771    /// `block` must be less than `self.num_blocks()`. No bounds check is
772    /// performed in release builds; callers must verify the index themselves
773    /// (e.g. by iterating `0..self.num_blocks()`).
774    #[inline]
775    pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const T {
776        debug_assert!(block < self.num_blocks());
777        // SAFETY: Caller asserts `block < self.num_blocks()`.
778        unsafe { self.as_ptr().add(self.data.repr().block_offset(block)) }
779    }
780
781    /// Return a view over a full block as a [`MatrixView`].
782    ///
783    /// The returned view has `padded_ncols / PACK` rows and `GROUP * PACK`
784    /// columns. For `PACK == 1` this simplifies to `ncols` rows and `GROUP`
785    /// columns (the standard transposed interpretation).
786    ///
787    /// # Panics
788    ///
789    /// Panics if `block >= self.full_blocks()`.
790    #[allow(clippy::expect_used)]
791    pub fn block(&self, block: usize) -> MatrixView<'a, T> {
792        assert!(block < self.full_blocks());
793        let offset = self.data.repr().block_offset(block);
794        let stride = self.data.repr().block_stride();
795        // SAFETY: `block < full_blocks()` (asserted above) guarantees
796        // `offset + stride` is within the backing allocation.
797        let data: &[T] = unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
798        MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
799            .expect("base data should have been sized correctly")
800    }
801
802    /// Return a view over the remainder block, or `None` if there is no
803    /// remainder.
804    ///
805    /// The returned view has the same dimensions as [`block()`](Self::block):
806    /// `padded_ncols / PACK` rows and `GROUP * PACK` columns.
807    #[allow(clippy::expect_used)]
808    pub fn remainder_block(&self) -> Option<MatrixView<'a, T>> {
809        if self.remainder() == 0 {
810            None
811        } else {
812            let offset = self.data.repr().block_offset(self.full_blocks());
813            let stride = self.data.repr().block_stride();
814            // SAFETY: The remainder block exists (`remainder() != 0`),
815            // so `offset + stride` is within the backing allocation.
816            let data: &[T] =
817                unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
818            Some(
819                MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
820                    .expect("base data should have been sized correctly"),
821            )
822        }
823    }
824
825    /// Retrieve the value at the logical `(row, col)`.
826    ///
827    /// # Panics
828    ///
829    /// Panics if `row >= self.nrows()` or `col >= self.ncols()`.
830    #[inline]
831    pub fn get_element(&self, row: usize, col: usize) -> T {
832        assert!(
833            row < self.nrows(),
834            "row {row} out of bounds (nrows = {})",
835            self.nrows()
836        );
837        assert!(
838            col < self.ncols(),
839            "col {col} out of bounds (ncols = {})",
840            self.ncols()
841        );
842        let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
843        // SAFETY: bounds checked above.
844        unsafe { *self.as_ptr().add(idx) }
845    }
846
847    /// Get an immutable row view, or `None` if `i` is out of bounds.
848    #[inline]
849    pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
850        self.data.get_row(i)
851    }
852}
853
854// ── BlockTransposedMut ───────────────────────────────────────────
855
856impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, T, GROUP, PACK> {
857    /// Borrow as an immutable [`BlockTransposedRef`].
858    #[inline]
859    pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
860        BlockTransposedRef {
861            data: self.data.as_view(),
862        }
863    }
864
865    // ── Delegated read methods ───────────────────────────────────
866
867    delegate_to_ref!(pub fn nrows(&self) -> usize);
868    delegate_to_ref!(pub fn ncols(&self) -> usize);
869    delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
870    delegate_to_ref!(pub fn full_blocks(&self) -> usize);
871    delegate_to_ref!(pub fn num_blocks(&self) -> usize);
872    delegate_to_ref!(pub fn remainder(&self) -> usize);
873    delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
874    delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
875    delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
876    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
877    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
878    delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
879
880    /// Group size (blocking factor `GROUP`).
881    pub const fn group_size(&self) -> usize {
882        GROUP
883    }
884
885    /// Group size as `const` function on the *type*.
886    pub const fn const_group_size() -> usize {
887        GROUP
888    }
889
890    /// Packing factor `PACK`.
891    pub const fn pack_size(&self) -> usize {
892        PACK
893    }
894
895    /// Get an immutable row view, or `None` if `i` is out of bounds.
896    #[inline]
897    pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
898        self.data.get_row(i)
899    }
900
901    // ── Mutable methods ──────────────────────────────────────────
902    //
903    // The `_inner` variants consume `self` by value so that the lifetime of
904    // the returned view is tied to `'a` (the underlying allocation), not to
905    // a temporary reborrow. Public `&mut self` methods reborrow into a
906    // short-lived `BlockTransposedMut` and then call the inner variant.
907
908    /// Return the backing data as a mutable slice.
909    ///
910    /// The returned slice has `storage_len()` elements (including all padding).
911    #[inline]
912    pub fn as_mut_slice(&mut self) -> &mut [T] {
913        self.reborrow_mut().mut_slice_inner()
914    }
915
916    fn mut_slice_inner(mut self) -> &'a mut [T] {
917        let len = self.data.repr().storage_len();
918        // SAFETY: We own exclusive access through `self`.
919        unsafe { std::slice::from_raw_parts_mut(self.data.as_raw_mut_ptr().cast::<T>(), len) }
920    }
921
922    /// Return a mutable view over a full block.
923    ///
924    /// # Panics
925    ///
926    /// Panics if `block >= self.full_blocks()`.
927    #[allow(clippy::expect_used)]
928    pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
929        self.reborrow_mut().block_mut_inner(block)
930    }
931
932    #[allow(clippy::expect_used)]
933    fn block_mut_inner(mut self, block: usize) -> MutMatrixView<'a, T> {
934        let repr = *self.data.repr();
935        assert!(block < repr.full_blocks());
936        let offset = repr.block_offset(block);
937        let stride = repr.block_stride();
938        let pncols = repr.padded_ncols();
939        // SAFETY: `block < full_blocks()`, so the range is within the allocation.
940        let data: &mut [T] = unsafe {
941            std::slice::from_raw_parts_mut(
942                self.data.as_raw_mut_ptr().cast::<T>().add(offset),
943                stride,
944            )
945        };
946        MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
947            .expect("base data should have been sized correctly")
948    }
949
950    /// Return a mutable view over the remainder block, or `None` if there is no
951    /// remainder.
952    #[allow(clippy::expect_used)]
953    pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
954        self.reborrow_mut().remainder_block_mut_inner()
955    }
956
957    #[allow(clippy::expect_used)]
958    fn remainder_block_mut_inner(mut self) -> Option<MutMatrixView<'a, T>> {
959        let repr = *self.data.repr();
960        if repr.remainder() == 0 {
961            None
962        } else {
963            let offset = repr.block_offset(repr.full_blocks());
964            let stride = repr.block_stride();
965            let pncols = repr.padded_ncols();
966            // SAFETY: Remainder block exists, so the range is within the allocation.
967            let data: &mut [T] = unsafe {
968                std::slice::from_raw_parts_mut(
969                    self.data.as_raw_mut_ptr().cast::<T>().add(offset),
970                    stride,
971                )
972            };
973            Some(
974                MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
975                    .expect("base data should have been sized correctly"),
976            )
977        }
978    }
979
980    /// Get a mutable row view, or `None` if `i` is out of bounds.
981    #[inline]
982    pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
983        self.data.get_row_mut(i)
984    }
985
986    // ── Private helpers ──────────────────────────────────────────
987
988    fn reborrow_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
989        BlockTransposedMut {
990            data: self.data.reborrow_mut(),
991        }
992    }
993}
994
995// ── BlockTransposed (owned) ──────────────────────────────────────
996
997impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
998    /// Borrow as an immutable [`BlockTransposedRef`].
999    pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
1000        BlockTransposedRef {
1001            data: self.data.as_view(),
1002        }
1003    }
1004
1005    /// Borrow as a mutable [`BlockTransposedMut`].
1006    pub fn as_view_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
1007        BlockTransposedMut {
1008            data: self.data.as_view_mut(),
1009        }
1010    }
1011
1012    // ── Delegated read methods ───────────────────────────────────
1013
1014    delegate_to_ref!(pub fn nrows(&self) -> usize);
1015    delegate_to_ref!(pub fn ncols(&self) -> usize);
1016    delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
1017    delegate_to_ref!(pub fn full_blocks(&self) -> usize);
1018    delegate_to_ref!(pub fn num_blocks(&self) -> usize);
1019    delegate_to_ref!(pub fn remainder(&self) -> usize);
1020    delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
1021    delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
1022    delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
1023    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
1024    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
1025    delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
1026
1027    /// Group size (blocking factor `GROUP`).
1028    pub const fn group_size(&self) -> usize {
1029        GROUP
1030    }
1031
1032    /// Group size (blocking factor `GROUP`) as a `const` function on the *type*.
1033    pub const fn const_group_size() -> usize {
1034        GROUP
1035    }
1036
1037    /// Packing factor `PACK`.
1038    pub const fn pack_size(&self) -> usize {
1039        PACK
1040    }
1041
1042    /// Get an immutable row view, or `None` if `i` is out of bounds.
1043    #[inline]
1044    pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
1045        self.data.get_row(i)
1046    }
1047
1048    // ── Mutable methods (delegated to BlockTransposedMut) ────────
1049
1050    /// See [`BlockTransposedMut::as_mut_slice`].
1051    #[inline]
1052    pub fn as_mut_slice(&mut self) -> &mut [T] {
1053        self.as_view_mut().mut_slice_inner()
1054    }
1055
1056    /// See [`BlockTransposedMut::block_mut`].
1057    #[allow(clippy::expect_used)]
1058    pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
1059        self.as_view_mut().block_mut_inner(block)
1060    }
1061
1062    /// See [`BlockTransposedMut::remainder_block_mut`].
1063    #[allow(clippy::expect_used)]
1064    pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
1065        self.as_view_mut().remainder_block_mut_inner()
1066    }
1067
1068    /// Get a mutable row view, or `None` if `i` is out of bounds.
1069    #[inline]
1070    pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
1071        self.data.get_row_mut(i)
1072    }
1073}
1074
1075// ── Factory methods ──────────────────────────────────────────────
1076
1077impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
1078    /// Construct a default-initialized block-transposed matrix from dimensions.
1079    ///
1080    /// # Panics
1081    ///
1082    /// Panics if the dimensions overflow the allocation budget.
1083    #[allow(clippy::expect_used)]
1084    pub fn new(nrows: usize, ncols: usize) -> Self {
1085        let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)
1086            .expect("dimensions should not overflow");
1087        Self {
1088            data: Mat::new(repr, Defaulted).expect("infallible"),
1089        }
1090    }
1091
1092    /// Fallible variant of [`new`](Self::new).
1093    pub fn try_new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
1094        let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)?;
1095        Ok(Self {
1096            data: Mat::new(repr, Defaulted).expect("infallible"),
1097        })
1098    }
1099
1100    /// Construct a block-transposed matrix by copying data from a [`StridedView`].
1101    ///
1102    /// Each source element at `(row, col)` is placed at the correct offset in the
1103    /// block-transposed layout. Padding positions (both partial-block rows and
1104    /// column-group padding when `ncols % PACK != 0`) are filled with
1105    /// `T::default()`.
1106    ///
1107    /// The loop iterates in physical (block-transposed) order — block, column-group,
1108    /// row-within-block, pack-lane — so that writes to the backing allocation are
1109    /// sequential. Source reads stride across rows of the [`StridedView`], which is
1110    /// acceptable because read-side prefetch is more effective than write-side.
1111    pub fn from_strided(v: StridedView<'_, T>) -> Self {
1112        let nrows = v.nrows();
1113        let ncols = v.ncols();
1114        let mut mat = Self::new(nrows, ncols);
1115
1116        let repr = *mat.data.repr();
1117        let num_blocks = repr.num_blocks();
1118        let pncols = repr.padded_ncols();
1119        let num_col_groups = pncols / PACK;
1120
1121        // Walk the backing allocation in physical order so that writes are
1122        // sequential. The allocation is default-initialized, so padding positions
1123        // already hold `T::default()` and can be skipped.
1124        let mut dst = mat.data.as_raw_mut_ptr().cast::<T>();
1125        for block in 0..num_blocks {
1126            let row_base = block * GROUP;
1127            for cg in 0..num_col_groups {
1128                let col_base = cg * PACK;
1129                for rib in 0..GROUP {
1130                    let row = row_base + rib;
1131                    if row < nrows {
1132                        // SAFETY: row < nrows is checked by the enclosing `if` condition.
1133                        let src_row = unsafe { v.get_row_unchecked(row) };
1134                        for p in 0..PACK {
1135                            let col = col_base + p;
1136                            if col < ncols {
1137                                // SAFETY: dst advances sequentially through the
1138                                // backing allocation which has exactly `storage_len`
1139                                // elements, and our loop visits each position once.
1140                                // `col < ncols` is checked above, and `src_row` has
1141                                // exactly `ncols` elements.
1142                                unsafe { *dst = *src_row.get_unchecked(col) };
1143                            }
1144                            // SAFETY: dst advances sequentially through the
1145                            // backing allocation which has exactly `storage_len`
1146                            // elements, and our loop visits each position once.
1147                            dst = unsafe { dst.add(1) };
1148                        }
1149                    } else {
1150                        // SAFETY: Entire row is padding — skip PACK positions.
1151                        // dst remains within the allocation.
1152                        dst = unsafe { dst.add(PACK) };
1153                    }
1154                }
1155            }
1156        }
1157
1158        mat
1159    }
1160
1161    /// Construct a block-transposed matrix by copying data from a [`MatrixView`].
1162    pub fn from_matrix_view(v: MatrixView<'_, T>) -> Self {
1163        Self::from_strided(v.into())
1164    }
1165}
1166
1167// ════════════════════════════════════════════════════════════════════
1168// Index<(usize, usize)> for BlockTransposed
1169// ════════════════════════════════════════════════════════════════════
1170
1171impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<(usize, usize)>
1172    for BlockTransposed<T, GROUP, PACK>
1173{
1174    type Output = T;
1175
1176    #[inline]
1177    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
1178        assert!(row < self.nrows());
1179        assert!(col < self.ncols());
1180        let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
1181        // SAFETY: bounds checked above and the backing allocation has `storage_len()` elements.
1182        unsafe { &*self.as_ptr().add(idx) }
1183    }
1184}
1185
1186// ════════════════════════════════════════════════════════════════════
1187// Tests
1188// ════════════════════════════════════════════════════════════════════
1189
1190#[cfg(test)]
1191mod tests {
1192    //! Test organisation:
1193    //!
1194    //!  1. **Helper functions** — `gen_*` element generators.
1195    //!  2. [`test_full_api`] — single parameterized function that exhaustively
1196    //!     exercises the full read + write API on all three wrapper types
1197    //!     (`BlockTransposed`, `BlockTransposedRef`, `BlockTransposedMut`).
1198    //!  3. **Test runners** — `#[test]` functions that call `test_full_api`
1199    //!     with various `(T, GROUP, PACK, nrows, ncols)` combinations.
1200    //!  4. [`test_block_layout_pack1`] — verifies that `PACK=1` blocks are
1201    //!     the standard row-to-column transposition.
1202    //!  5. **Focused tests** — edge cases that cannot be expressed as
1203    //!     parameters to `test_full_api` (`Send`/`Sync`, panic paths,
1204    //!     non-unit strides, concurrent mutation, etc.).
1205
1206    use diskann_utils::{lazy_format, views::Matrix};
1207
1208    use super::*;
1209    use crate::utils::div_round_up;
1210
1211    // ── Per-type element generators ──────────────────────────────────
1212    //
1213    // Each generator maps a flat index to a non-zero `T` value so that
1214    // `T::default()` (zero) can be used unambiguously to verify padding.
1215
1216    fn gen_f32(i: usize) -> f32 {
1217        (i + 1) as f32
1218    }
1219    fn gen_i32(i: usize) -> i32 {
1220        (i + 1) as i32
1221    }
1222    fn gen_u8(i: usize) -> u8 {
1223        ((i % 255) + 1) as u8
1224    }
1225
1226    // ── Unified parameterized test ──────────────────────────────────
1227
1228    /// Exhaustive test for the full `BlockTransposed` / `BlockTransposedRef` /
1229    /// `BlockTransposedMut` API surface, parameterized over element type `T`,
1230    /// group size `GROUP`, and packing factor `PACK`.
1231    ///
1232    /// Exercises: construction, query helpers, `Index` / `get_element`,
1233    /// immutable row views (`Row`), mutable row views (`RowMut`),
1234    /// `as_slice` / `as_mut_slice`, block views (immutable and mutable),
1235    /// `remainder_block` / `remainder_block_mut`, `block_ptr_unchecked`,
1236    /// `from_matrix_view`, OOB `get_row` returns, and both column and row
1237    /// padding verification.
1238    fn test_full_api<
1239        T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1240        const GROUP: usize,
1241        const PACK: usize,
1242    >(
1243        nrows: usize,
1244        ncols: usize,
1245        gen_element: fn(usize) -> T,
1246    ) {
1247        let context = lazy_format!(
1248            "T={}, GROUP={}, PACK={}, nrows={}, ncols={}",
1249            std::any::type_name::<T>(),
1250            GROUP,
1251            PACK,
1252            nrows,
1253            ncols,
1254        );
1255
1256        // ── Construction ─────────────────────────────────────────
1257
1258        let mut data = Matrix::new(T::default(), nrows, ncols);
1259        data.as_mut_slice()
1260            .iter_mut()
1261            .enumerate()
1262            .for_each(|(i, d)| *d = gen_element(i));
1263
1264        let mut transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1265
1266        let expected_padded = div_round_up(ncols, PACK) * PACK;
1267        let expected_remainder = nrows % GROUP;
1268        let storage_len = transpose.as_slice().len();
1269
1270        // ── Query methods on owned type ──────────────────────────
1271
1272        assert_eq!(transpose.nrows(), nrows, "{}", context);
1273        assert_eq!(transpose.ncols(), ncols, "{}", context);
1274        assert_eq!(transpose.group_size(), GROUP, "{}", context);
1275        assert_eq!(
1276            BlockTransposed::<T, GROUP, PACK>::const_group_size(),
1277            GROUP,
1278            "{}",
1279            context
1280        );
1281        assert_eq!(transpose.pack_size(), PACK, "{}", context);
1282        assert_eq!(transpose.full_blocks(), nrows / GROUP, "{}", context);
1283        assert_eq!(
1284            transpose.num_blocks(),
1285            div_round_up(nrows, GROUP),
1286            "{}",
1287            context,
1288        );
1289        assert_eq!(transpose.remainder(), expected_remainder, "{}", context);
1290        assert_eq!(transpose.padded_ncols(), expected_padded, "{}", context);
1291
1292        // ── Element access (owned) ───────────────────────────────
1293
1294        for row in 0..nrows {
1295            for col in 0..ncols {
1296                assert_eq!(
1297                    data[(row, col)],
1298                    transpose[(row, col)],
1299                    "Index at ({}, {}) -- {}",
1300                    row,
1301                    col,
1302                    context,
1303                );
1304                assert_eq!(
1305                    data[(row, col)],
1306                    transpose.get_element(row, col),
1307                    "get_element at ({}, {}) -- {}",
1308                    row,
1309                    col,
1310                    context,
1311                );
1312            }
1313        }
1314
1315        // ── Immutable row views (owned) ──────────────────────────
1316
1317        let view = transpose.as_view();
1318        for row in 0..nrows {
1319            let row_view = view.get_row(row).unwrap();
1320            assert_eq!(row_view.len(), ncols, "{}", context);
1321            assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1322            for col in 0..ncols {
1323                assert_eq!(
1324                    data[(row, col)],
1325                    row_view[col],
1326                    "row view at ({}, {}) -- {}",
1327                    row,
1328                    col,
1329                    context,
1330                );
1331            }
1332            // Row::get — in-bounds + OOB.
1333            if ncols > 0 {
1334                assert_eq!(row_view.get(0), Some(&data[(row, 0)]), "{}", context);
1335            }
1336            assert_eq!(row_view.get(ncols), None, "{}", context);
1337
1338            // Iterator + ExactSizeIterator.
1339            let iter = row_view.iter();
1340            assert_eq!(iter.len(), ncols, "{}", context);
1341            let (lo, hi) = iter.size_hint();
1342            assert_eq!(lo, ncols, "{}", context);
1343            assert_eq!(hi, Some(ncols), "{}", context);
1344
1345            let collected: Vec<T> = row_view.iter().collect();
1346            assert_eq!(collected.len(), ncols, "{}", context);
1347            for col in 0..ncols {
1348                assert_eq!(data[(row, col)], collected[col], "{}", context);
1349            }
1350        }
1351        // OOB row returns None.
1352        assert!(view.get_row(nrows).is_none(), "{}", context);
1353        let _ = view;
1354
1355        // ── BlockTransposedRef API ───────────────────────────────
1356
1357        {
1358            let view = transpose.as_view();
1359            assert_eq!(view.nrows(), nrows, "{}", context);
1360            assert_eq!(view.ncols(), ncols, "{}", context);
1361            assert_eq!(view.padded_ncols(), expected_padded, "{}", context);
1362            assert_eq!(view.group_size(), GROUP, "{}", context);
1363            assert_eq!(
1364                BlockTransposedRef::<T, GROUP, PACK>::const_group_size(),
1365                GROUP,
1366            );
1367            assert_eq!(view.pack_size(), PACK, "{}", context);
1368            assert_eq!(view.full_blocks(), nrows / GROUP, "{}", context);
1369            assert_eq!(view.num_blocks(), div_round_up(nrows, GROUP), "{}", context,);
1370            assert_eq!(view.remainder(), expected_remainder, "{}", context);
1371            assert_eq!(view.as_ptr(), transpose.as_ptr(), "{}", context);
1372            assert_eq!(view.as_slice(), transpose.as_slice(), "{}", context);
1373
1374            for row in 0..nrows {
1375                for col in 0..ncols {
1376                    assert_eq!(
1377                        data[(row, col)],
1378                        view.get_element(row, col),
1379                        "Ref get_element at ({}, {}) -- {}",
1380                        row,
1381                        col,
1382                        context,
1383                    );
1384                }
1385                let row_view = view.get_row(row).unwrap();
1386                for col in 0..ncols {
1387                    assert_eq!(data[(row, col)], row_view[col], "{}", context);
1388                }
1389            }
1390            assert!(view.get_row(nrows).is_none(), "{}", context);
1391        }
1392
1393        // ── BlockTransposedMut read API ──────────────────────────
1394
1395        let expected_ptr = transpose.as_ptr();
1396        {
1397            let mut_view = transpose.as_view_mut();
1398            assert_eq!(mut_view.nrows(), nrows, "{}", context);
1399            assert_eq!(mut_view.ncols(), ncols, "{}", context);
1400            assert_eq!(mut_view.padded_ncols(), expected_padded, "{}", context);
1401            assert_eq!(mut_view.group_size(), GROUP, "{}", context);
1402            assert_eq!(
1403                BlockTransposedMut::<T, GROUP, PACK>::const_group_size(),
1404                GROUP,
1405            );
1406            assert_eq!(mut_view.pack_size(), PACK, "{}", context);
1407            assert_eq!(mut_view.full_blocks(), nrows / GROUP, "{}", context);
1408            assert_eq!(
1409                mut_view.num_blocks(),
1410                div_round_up(nrows, GROUP),
1411                "{}",
1412                context,
1413            );
1414            assert_eq!(mut_view.remainder(), expected_remainder, "{}", context);
1415            assert_eq!(mut_view.as_ptr(), expected_ptr, "{}", context);
1416            assert_eq!(mut_view.as_slice().len(), storage_len, "{}", context);
1417
1418            for row in 0..nrows {
1419                for col in 0..ncols {
1420                    assert_eq!(
1421                        data[(row, col)],
1422                        mut_view.get_element(row, col),
1423                        "Mut get_element at ({}, {}) -- {}",
1424                        row,
1425                        col,
1426                        context,
1427                    );
1428                }
1429                let row_view = mut_view.get_row(row).unwrap();
1430                for col in 0..ncols {
1431                    assert_eq!(data[(row, col)], row_view[col], "{}", context);
1432                }
1433            }
1434            assert!(mut_view.get_row(nrows).is_none(), "{}", context);
1435        }
1436
1437        // ── BlockTransposedMut::as_view() ────────────────────────
1438
1439        {
1440            let mut_view = transpose.as_view_mut();
1441            let ref_from_mut = mut_view.as_view();
1442            assert_eq!(ref_from_mut.nrows(), nrows, "{}", context);
1443            for row in 0..nrows {
1444                for col in 0..ncols {
1445                    assert_eq!(
1446                        data[(row, col)],
1447                        ref_from_mut.get_element(row, col),
1448                        "{}",
1449                        context,
1450                    );
1451                }
1452            }
1453        }
1454
1455        // ── as_mut_slice ─────────────────────────────────────────
1456
1457        // Through BlockTransposedMut.
1458        {
1459            let mut mut_view = transpose.as_view_mut();
1460            assert_eq!(mut_view.as_mut_slice().len(), storage_len, "{}", context);
1461        }
1462        // Through BlockTransposed (owned).
1463        assert_eq!(transpose.as_mut_slice().len(), storage_len, "{}", context);
1464
1465        // ── Immutable block views on all three types ─────────────
1466
1467        let expected_block_nrows = expected_padded / PACK;
1468        let expected_block_ncols = GROUP * PACK;
1469
1470        for b in 0..transpose.full_blocks() {
1471            let block_data: Vec<T>;
1472            let ptr: *const T;
1473            {
1474                let block = transpose.block(b);
1475                assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1476                assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1477
1478                // SAFETY: b < full_blocks <= num_blocks.
1479                ptr = unsafe { transpose.block_ptr_unchecked(b) };
1480                assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1481
1482                block_data = block.as_slice().to_vec();
1483            }
1484
1485            // Same block via Ref.
1486            {
1487                let view = transpose.as_view();
1488                assert_eq!(view.block(b).as_slice(), &block_data[..], "{}", context);
1489                // SAFETY: `b` is in range `0..num_blocks` by the loop bound.
1490                assert_eq!(unsafe { view.block_ptr_unchecked(b) }, ptr, "{}", context);
1491            }
1492
1493            // Same block via Mut (read path).
1494            {
1495                let mut_view = transpose.as_view_mut();
1496                assert_eq!(mut_view.block(b).as_slice(), &block_data[..], "{}", context);
1497                assert_eq!(
1498                    // SAFETY: `b` is in range `0..num_blocks` by the loop bound.
1499                    unsafe { mut_view.block_ptr_unchecked(b) },
1500                    ptr,
1501                    "{}",
1502                    context,
1503                );
1504            }
1505        }
1506
1507        // Remainder block (immutable, all three types).
1508        if expected_remainder != 0 {
1509            let remainder_data: Vec<T>;
1510            let ptr: *const T;
1511            let fb = transpose.full_blocks();
1512            {
1513                let block = transpose.remainder_block().unwrap();
1514                assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1515                assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1516
1517                // SAFETY: fb < num_blocks (remainder exists).
1518                ptr = unsafe { transpose.block_ptr_unchecked(fb) };
1519                assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1520
1521                remainder_data = block.as_slice().to_vec();
1522            }
1523
1524            // Via Ref.
1525            {
1526                let view = transpose.as_view();
1527                let ref_block = view.remainder_block().unwrap();
1528                assert_eq!(ref_block.as_slice(), &remainder_data[..], "{}", context);
1529            }
1530            // Via Mut (read path).
1531            {
1532                let mut_view = transpose.as_view_mut();
1533                let mut_block = mut_view.remainder_block().unwrap();
1534                assert_eq!(mut_block.as_slice(), &remainder_data[..], "{}", context);
1535            }
1536        } else {
1537            assert!(transpose.remainder_block().is_none(), "{}", context);
1538            {
1539                let view = transpose.as_view();
1540                assert!(view.remainder_block().is_none(), "{}", context);
1541            }
1542            {
1543                let mut_view = transpose.as_view_mut();
1544                assert!(mut_view.remainder_block().is_none(), "{}", context);
1545            }
1546        }
1547
1548        // ── Mutable block views via BlockTransposedMut ───────────
1549
1550        {
1551            let mut mut_view = transpose.as_view_mut();
1552            for b in 0..mut_view.full_blocks() {
1553                let block_mut = mut_view.block_mut(b);
1554                assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1555                assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1556            }
1557            if expected_remainder != 0 {
1558                let rem = mut_view.remainder_block_mut().unwrap();
1559                assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1560                assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1561            } else {
1562                assert!(mut_view.remainder_block_mut().is_none(), "{}", context);
1563            }
1564        }
1565
1566        // Mutable block views via owned BlockTransposed.
1567        for b in 0..transpose.full_blocks() {
1568            let block_mut = transpose.block_mut(b);
1569            assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1570            assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1571        }
1572        if expected_remainder != 0 {
1573            let rem = transpose.remainder_block_mut().unwrap();
1574            assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1575            assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1576        } else {
1577            assert!(transpose.remainder_block_mut().is_none(), "{}", context);
1578        }
1579
1580        // ── Mutable row views via BlockTransposedMut ─────────────
1581
1582        {
1583            let mut mut_view = transpose.as_view_mut();
1584            for row in 0..nrows {
1585                let row_view = mut_view.get_row_mut(row).unwrap();
1586                assert_eq!(row_view.len(), ncols, "{}", context);
1587                assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1588                for col in 0..ncols {
1589                    assert_eq!(data[(row, col)], row_view[col], "{}", context);
1590                }
1591            }
1592            assert!(mut_view.get_row_mut(nrows).is_none(), "{}", context);
1593        }
1594
1595        // ── Row::get, RowMut::get, RowMut::get_mut ──────────────
1596
1597        if nrows > 0 && ncols > 0 {
1598            // Row::get OOB.
1599            {
1600                let view = transpose.as_view();
1601                let row = view.get_row(0).unwrap();
1602                assert_eq!(row.get(ncols), None, "{}", context);
1603                assert_eq!(row.get(usize::MAX), None, "{}", context);
1604            }
1605
1606            // RowMut::get OOB.
1607            let row = transpose.get_row_mut(0).unwrap();
1608            assert_eq!(row.get(ncols), None, "{}", context);
1609
1610            // RowMut::get_mut — mutate and verify.
1611            let mut row = transpose.get_row_mut(0).unwrap();
1612            let sentinel = gen_element(usize::MAX / 2);
1613            let original = row[0];
1614            if let Some(v) = row.get_mut(0) {
1615                *v = sentinel;
1616            }
1617            assert_eq!(row.get_mut(ncols), None, "{}", context);
1618            // Explicit scope end so the mutable borrow is released before the next access.
1619            let _ = row;
1620            assert_eq!(transpose.get_element(0, 0), sentinel, "{}", context);
1621            // Restore original.
1622            transpose.get_row_mut(0).unwrap().set(0, original);
1623        }
1624
1625        // ── Zero out via block_mut / remainder_block_mut ─────────
1626
1627        for b in 0..transpose.full_blocks() {
1628            transpose.block_mut(b).as_mut_slice().fill(T::default());
1629        }
1630        if transpose.remainder() != 0 {
1631            transpose
1632                .remainder_block_mut()
1633                .unwrap()
1634                .as_mut_slice()
1635                .fill(T::default());
1636        }
1637        assert!(
1638            transpose.as_slice().iter().all(|v| *v == T::default()),
1639            "not fully zeroed -- {}",
1640            context,
1641        );
1642
1643        // ── Padding verification (fresh construction) ────────────
1644
1645        let transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1646        let raw = transpose.as_slice();
1647
1648        // Column padding.
1649        for row in 0..nrows {
1650            for col in ncols..expected_padded {
1651                let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1652                assert_eq!(
1653                    raw[idx],
1654                    T::default(),
1655                    "col padding at ({}, {}) -- {}",
1656                    row,
1657                    col,
1658                    context,
1659                );
1660            }
1661        }
1662
1663        // Row padding (within partial blocks).
1664        let padded_nrows = nrows.next_multiple_of(GROUP);
1665        for row in nrows..padded_nrows {
1666            for col in 0..expected_padded {
1667                let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1668                assert_eq!(
1669                    raw[idx],
1670                    T::default(),
1671                    "row padding at ({}, {}) -- {}",
1672                    row,
1673                    col,
1674                    context,
1675                );
1676            }
1677        }
1678
1679        // ── from_matrix_view produces identical results ──────────
1680
1681        if nrows > 0 && ncols > 0 {
1682            let via_matrix = BlockTransposed::<T, GROUP, PACK>::from_matrix_view(data.as_view());
1683            assert_eq!(via_matrix.as_slice(), transpose.as_slice(), "{}", context);
1684        }
1685    }
1686
1687    // ════════════════════════════════════════════════════════════════
1688    // Test runners — each combination gets the full API surface.
1689    // ════════════════════════════════════════════════════════════════
1690
1691    #[test]
1692    fn test_api_pack1_group16() {
1693        // Miri: boundary rows around GROUP=16 block transitions;
1694        // full run: exhaustive sweep.
1695        let rows: Vec<usize> = if cfg!(miri) {
1696            vec![0, 1, 15, 16, 17, 33]
1697        } else {
1698            (0..128).collect()
1699        };
1700        let cols: Vec<usize> = if cfg!(miri) {
1701            vec![0, 1, 2]
1702        } else {
1703            (0..5).collect()
1704        };
1705        for &nrows in &rows {
1706            for &ncols in &cols {
1707                test_full_api::<f32, 16, 1>(nrows, ncols, gen_f32);
1708            }
1709        }
1710    }
1711
1712    #[test]
1713    fn test_api_pack1_group8() {
1714        // Miri: boundary rows around GROUP=8 block transitions;
1715        // full run: exhaustive sweep.
1716        let rows: Vec<usize> = if cfg!(miri) {
1717            vec![0, 1, 7, 8, 9, 17]
1718        } else {
1719            (0..128).collect()
1720        };
1721        let cols: Vec<usize> = if cfg!(miri) {
1722            vec![0, 1, 2]
1723        } else {
1724            (0..5).collect()
1725        };
1726        for &nrows in &rows {
1727            for &ncols in &cols {
1728                test_full_api::<f32, 8, 1>(nrows, ncols, gen_f32);
1729            }
1730        }
1731    }
1732
1733    #[test]
1734    fn test_api_pack2() {
1735        // Miri: boundary rows around GROUP=4/8/16 transitions;
1736        // cols hit PACK=2 boundary (even/odd). Full run: exhaustive.
1737        let rows: Vec<usize> = if cfg!(miri) {
1738            vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1739        } else {
1740            (0..48).collect()
1741        };
1742        let cols: Vec<usize> = if cfg!(miri) {
1743            vec![0, 1, 2, 3, 4, 5]
1744        } else {
1745            (0..9).collect()
1746        };
1747        for &nrows in &rows {
1748            for &ncols in &cols {
1749                test_full_api::<f32, 4, 2>(nrows, ncols, gen_f32);
1750                test_full_api::<f32, 8, 2>(nrows, ncols, gen_f32);
1751                test_full_api::<f32, 16, 2>(nrows, ncols, gen_f32);
1752            }
1753        }
1754    }
1755
1756    #[test]
1757    fn test_api_pack4() {
1758        // Miri: boundary rows around GROUP=4/8/16 transitions;
1759        // cols hit PACK=4 boundary (0,1,3,4,5,8). Full run: exhaustive.
1760        let rows: Vec<usize> = if cfg!(miri) {
1761            vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1762        } else {
1763            (0..48).collect()
1764        };
1765        let cols: Vec<usize> = if cfg!(miri) {
1766            vec![0, 1, 3, 4, 5, 8]
1767        } else {
1768            (0..9).collect()
1769        };
1770        for &nrows in &rows {
1771            for &ncols in &cols {
1772                test_full_api::<f32, 4, 4>(nrows, ncols, gen_f32);
1773                test_full_api::<f32, 8, 4>(nrows, ncols, gen_f32);
1774                test_full_api::<f32, 16, 4>(nrows, ncols, gen_f32);
1775            }
1776        }
1777    }
1778
1779    /// Exercise the unified test with non-`f32` element types.
1780    #[test]
1781    fn test_api_non_f32() {
1782        // i32:  PACK=1 and PACK=2
1783        test_full_api::<i32, 4, 1>(10, 7, gen_i32);
1784        test_full_api::<i32, 8, 2>(12, 5, gen_i32);
1785
1786        // u8:   PACK=1 and PACK=2
1787        test_full_api::<u8, 4, 2>(12, 5, gen_u8);
1788        test_full_api::<u8, 8, 1>(10, 7, gen_u8);
1789    }
1790
1791    // ════════════════════════════════════════════════════════════════
1792    // Block layout verification (PACK=1 only)
1793    // ════════════════════════════════════════════════════════════════
1794
1795    /// Verify that for PACK=1, each block is the standard row-to-column
1796    /// transposition of a GROUP-row slice of the source matrix.
1797    fn test_block_layout_pack1<
1798        T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1799        const GROUP: usize,
1800    >(
1801        nrows: usize,
1802        ncols: usize,
1803        gen_element: fn(usize) -> T,
1804    ) {
1805        let mut data = Matrix::new(T::default(), nrows, ncols);
1806        data.as_mut_slice()
1807            .iter_mut()
1808            .enumerate()
1809            .for_each(|(i, d)| *d = gen_element(i));
1810
1811        let transpose = BlockTransposed::<T, GROUP, 1>::from_strided(data.as_view().into());
1812
1813        // Full blocks.
1814        for b in 0..transpose.full_blocks() {
1815            let block = transpose.block(b);
1816            for i in 0..block.nrows() {
1817                for j in 0..block.ncols() {
1818                    assert_eq!(
1819                        block[(i, j)],
1820                        data[(GROUP * b + j, i)],
1821                        "block {} at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1822                        b,
1823                        i,
1824                        j,
1825                        GROUP,
1826                        nrows,
1827                        ncols,
1828                    );
1829                }
1830            }
1831        }
1832
1833        // Remainder block.
1834        if transpose.remainder() != 0 {
1835            let fb = transpose.full_blocks();
1836            let block = transpose.remainder_block().unwrap();
1837            for i in 0..block.nrows() {
1838                for j in 0..transpose.remainder() {
1839                    assert_eq!(
1840                        block[(i, j)],
1841                        data[(GROUP * fb + j, i)],
1842                        "remainder at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1843                        i,
1844                        j,
1845                        GROUP,
1846                        nrows,
1847                        ncols,
1848                    );
1849                }
1850            }
1851        }
1852    }
1853
1854    #[test]
1855    fn test_block_layout_pack1_group16() {
1856        let rows: Vec<usize> = if cfg!(miri) {
1857            vec![0, 1, 15, 16, 17, 33]
1858        } else {
1859            (0..128).collect()
1860        };
1861        let cols: Vec<usize> = if cfg!(miri) {
1862            vec![0, 1, 2]
1863        } else {
1864            (0..5).collect()
1865        };
1866        for &nrows in &rows {
1867            for &ncols in &cols {
1868                test_block_layout_pack1::<f32, 16>(nrows, ncols, gen_f32);
1869            }
1870        }
1871    }
1872
1873    #[test]
1874    fn test_block_layout_pack1_group8() {
1875        let rows: Vec<usize> = if cfg!(miri) {
1876            vec![0, 1, 7, 8, 9, 17]
1877        } else {
1878            (0..128).collect()
1879        };
1880        let cols: Vec<usize> = if cfg!(miri) {
1881            vec![0, 1, 2]
1882        } else {
1883            (0..5).collect()
1884        };
1885        for &nrows in &rows {
1886            for &ncols in &cols {
1887                test_block_layout_pack1::<f32, 8>(nrows, ncols, gen_f32);
1888            }
1889        }
1890    }
1891
1892    // ════════════════════════════════════════════════════════════════
1893    // Focused tests (not part of the unified parameterized test)
1894    // ════════════════════════════════════════════════════════════════
1895
1896    // ── Send / Sync static assertions ───────────────────────────────
1897
1898    #[test]
1899    fn test_row_view_send_sync() {
1900        fn assert_send<T: Send>() {}
1901        fn assert_sync<T: Sync>() {}
1902
1903        assert_send::<Row<'_, f32, 16>>();
1904        assert_sync::<Row<'_, f32, 16>>();
1905        assert_send::<Row<'_, u8, 8, 2>>();
1906        assert_sync::<Row<'_, u8, 8, 2>>();
1907
1908        assert_send::<RowMut<'_, f32, 16>>();
1909        assert_sync::<RowMut<'_, f32, 16>>();
1910        assert_send::<RowMut<'_, i32, 4, 4>>();
1911        assert_sync::<RowMut<'_, i32, 4, 4>>();
1912    }
1913
1914    // ── NewRef / NewMut from raw slices ─────────────────────────────
1915
1916    #[test]
1917    fn test_new_ref_and_new_mut() {
1918        let nrows = 5;
1919        let ncols = 3;
1920        let repr = BlockTransposedRepr::<f32, 4>::new(nrows, ncols).unwrap();
1921
1922        let mat = BlockTransposed::<f32, 4>::new(nrows, ncols);
1923        let raw: &[f32] = mat.as_slice();
1924
1925        let mat_ref = BlockTransposedRef {
1926            data: repr.new_ref(raw).unwrap(),
1927        };
1928        assert_eq!(mat_ref.nrows(), nrows);
1929        assert_eq!(mat_ref.ncols(), ncols);
1930        for row in 0..nrows {
1931            for col in 0..ncols {
1932                assert_eq!(mat_ref.get_element(row, col), mat.get_element(row, col));
1933            }
1934        }
1935
1936        let mut buf = raw.to_vec();
1937        let mat_mut = BlockTransposedMut {
1938            data: repr.new_mut(&mut buf).unwrap(),
1939        };
1940        assert_eq!(mat_mut.nrows(), nrows);
1941        assert_eq!(mat_mut.ncols(), ncols);
1942
1943        // Wrong-length slice should fail.
1944        let mut short = vec![0.0_f32; 2];
1945        assert!(repr.new_ref(&short).is_err());
1946        assert!(repr.new_mut(&mut short).is_err());
1947    }
1948
1949    // ── Row view edge cases ─────────────────────────────────────────
1950
1951    #[test]
1952    fn test_row_view_empty() {
1953        /// Verify that immutable and mutable empty-row views are sound for a
1954        /// given `GROUP`/`PACK` combination.
1955        fn check_empty<const GROUP: usize, const PACK: usize>() {
1956            let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(4, 0);
1957
1958            // Immutable views.
1959            let view = mat.as_view();
1960            for i in 0..4 {
1961                let row = view.get_row(i).unwrap();
1962                assert!(row.is_empty());
1963                assert_eq!(row.len(), 0);
1964                assert_eq!(row.iter().count(), 0);
1965            }
1966
1967            // Mutable views.
1968            for i in 0..4 {
1969                let row = mat.get_row_mut(i).unwrap();
1970                assert!(row.is_empty());
1971                assert_eq!(row.len(), 0);
1972            }
1973        }
1974
1975        check_empty::<16, 1>(); // default PACK
1976        check_empty::<4, 2>(); // PACK > 1
1977        check_empty::<4, 4>(); // PACK == GROUP
1978    }
1979
1980    // ── Bounds-checking panic tests ─────────────────────────────────
1981
1982    #[test]
1983    #[should_panic(expected = "column index 3 out of bounds")]
1984    fn test_row_view_index_oob() {
1985        let mat = BlockTransposed::<f32, 4>::new(4, 3);
1986        let view = mat.as_view();
1987        let row = view.get_row(0).unwrap();
1988        let _ = row[3];
1989    }
1990
1991    #[test]
1992    #[should_panic(expected = "column index 3 out of bounds")]
1993    fn test_row_view_mut_index_oob() {
1994        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
1995        let row = mat.get_row_mut(0).unwrap();
1996        let _ = row[3];
1997    }
1998
1999    #[test]
2000    #[should_panic(expected = "column index 3 out of bounds")]
2001    fn test_row_view_mut_index_mut_oob() {
2002        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2003        let mut row = mat.get_row_mut(0).unwrap();
2004        row[3] = 1.0;
2005    }
2006
2007    #[test]
2008    #[should_panic(expected = "column index 3 out of bounds")]
2009    fn test_row_view_set_oob() {
2010        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2011        let mut row = mat.get_row_mut(0).unwrap();
2012        row.set(3, 1.0);
2013    }
2014
2015    #[test]
2016    #[should_panic(expected = "row 4 out of bounds")]
2017    fn test_get_element_row_oob() {
2018        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2019        mat.get_element(4, 0);
2020    }
2021
2022    #[test]
2023    #[should_panic(expected = "col 3 out of bounds")]
2024    fn test_get_element_col_oob() {
2025        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2026        mat.get_element(0, 3);
2027    }
2028
2029    #[test]
2030    #[should_panic(expected = "assertion failed")]
2031    fn test_index_tuple_row_oob() {
2032        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2033        let _ = mat[(4, 0)];
2034    }
2035
2036    #[test]
2037    #[should_panic(expected = "assertion failed")]
2038    fn test_index_tuple_col_oob() {
2039        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2040        let _ = mat[(0, 3)];
2041    }
2042
2043    #[test]
2044    #[should_panic]
2045    fn test_block_oob() {
2046        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2047        let _ = mat.block(1);
2048    }
2049
2050    #[test]
2051    #[should_panic]
2052    fn test_block_mut_oob() {
2053        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2054        let _ = mat.block_mut(1);
2055    }
2056
2057    // ── from_strided with non-unit stride ───────────────────────────
2058
2059    #[test]
2060    fn test_from_strided_nonunit_stride() {
2061        use diskann_utils::strided::StridedView;
2062
2063        const GROUP: usize = 4;
2064        const PACK: usize = 2;
2065        let nrows = 5;
2066        let ncols = 3;
2067        let cstride = 8;
2068
2069        let required_len = (nrows - 1) * cstride + ncols;
2070        let mut flat = vec![0.0_f32; required_len];
2071        for row in 0..nrows {
2072            for col in 0..ncols {
2073                flat[row * cstride + col] = (row * 100 + col + 1) as f32;
2074            }
2075        }
2076
2077        let strided = StridedView::try_shrink_from(&flat, nrows, ncols, cstride)
2078            .expect("should construct strided view");
2079        let transpose = BlockTransposed::<f32, GROUP, PACK>::from_strided(strided);
2080
2081        assert_eq!(transpose.nrows(), nrows);
2082        assert_eq!(transpose.ncols(), ncols);
2083
2084        for row in 0..nrows {
2085            for col in 0..ncols {
2086                let expected = (row * 100 + col + 1) as f32;
2087                assert_eq!(
2088                    transpose[(row, col)],
2089                    expected,
2090                    "mismatch at ({}, {})",
2091                    row,
2092                    col,
2093                );
2094            }
2095        }
2096
2097        let padded_ncols = ncols.next_multiple_of(PACK);
2098        let raw: &[f32] = transpose.as_slice();
2099        for row in 0..nrows {
2100            for col in ncols..padded_ncols {
2101                let idx = linear_index::<GROUP, PACK>(row, col, ncols);
2102                assert_eq!(
2103                    raw[idx], 0.0,
2104                    "column-padding at ({}, {}) should be zero",
2105                    row, col,
2106                );
2107            }
2108        }
2109    }
2110
2111    // ── Concurrent multi-row mutation ───────────────────────────────
2112
2113    #[test]
2114    fn test_concurrent_row_mutation() {
2115        const GROUP: usize = 8;
2116        const PACK: usize = 2;
2117
2118        let (nrows, ncols, num_threads) = if cfg!(miri) { (8, 4, 2) } else { (64, 16, 4) };
2119
2120        let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(nrows, ncols);
2121        let rows: Vec<RowMut<'_, f32, GROUP, PACK>> = mat.data.rows_mut().collect();
2122        let rows_per_thread = nrows / num_threads;
2123        let mut rows = rows.into_boxed_slice();
2124
2125        std::thread::scope(|s| {
2126            let mut remaining = &mut rows[..];
2127            for thread_id in 0..num_threads {
2128                let chunk_len = if thread_id == num_threads - 1 {
2129                    remaining.len()
2130                } else {
2131                    rows_per_thread
2132                };
2133                let (chunk, rest) = remaining.split_at_mut(chunk_len);
2134                remaining = rest;
2135                let start_row = thread_id * rows_per_thread;
2136
2137                s.spawn(move || {
2138                    for (offset, row_view) in chunk.iter_mut().enumerate() {
2139                        let row = start_row + offset;
2140                        for col in 0..ncols {
2141                            let value = (thread_id * 10000 + row * 100 + col) as f32;
2142                            row_view.set(col, value);
2143                        }
2144                    }
2145                });
2146            }
2147        });
2148
2149        for row in 0..nrows {
2150            let thread_id = (row / rows_per_thread).min(num_threads - 1);
2151            for col in 0..ncols {
2152                let expected = (thread_id * 10000 + row * 100 + col) as f32;
2153                assert_eq!(
2154                    mat.get_element(row, col),
2155                    expected,
2156                    "mismatch at ({}, {})",
2157                    row,
2158                    col,
2159                );
2160            }
2161        }
2162    }
2163}