Skip to main content

diskann_quantization/multi_vector/
block_transposed.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Block-transposed matrix types with configurable packing.
7//!
8//! This module provides block-transposed matrix types — [`BlockTransposed`] (owned),
9//! [`BlockTransposedRef`] (shared view), and [`BlockTransposedMut`] (mutable view) —
10//! where groups of `GROUP` rows are stored in transposed form to enable efficient SIMD
11//! processing. An optional packing factor `PACK` interleaves adjacent columns within
12//! each group, which can be used to feed SIMD instructions that operate on packed pairs
13//! (e.g. `vpmaddwd` with `PACK = 2`).
14//!
15//! # Layout
16//!
17//! ## `PACK = 1` (standard block-transpose)
18//!
19//! Given a logical matrix with rows `a`, `b`, `c`, `d`, `e` (each with `K` columns)
20//! and `GROUP = 3`:
21//!
22//! ```text
23//!            Group Size (3)
24//!            <---------->
25//!
26//!            +----------+    ^
27//!            | a0 b0 c0 |    |
28//!            | a1 b1 c1 |    |
29//!            | a2 b2 c2 |    | Block Size (K)
30//!  Block 0   | ...      |    |
31//!  (Full)    | aK bK cK |    |
32//!            +----------+    v
33//!            +----------+
34//!            | d0 e0 XX |
35//!  Block 1   | d1 e1 XX |
36//!  (Partial) | ...      |
37//!            | dK eK XX |
38//!            +----------+
39//! ```
40//!
41//! ## `PACK = 2` (super-packed)
42//!
43//! With `GROUP = 4`, `PACK = 2`, and a logical matrix with rows `a`, `b`, `c`, `d`,
44//! `e`, `f` (each with **5** columns — odd, to show padding), adjacent column-pairs
45//! are interleaved per row within each group panel:
46//!
47//! ```text
48//!              GROUP × PACK (4 × 2 = 8)
49//!              <----------------------------->
50//!
51//!              +-----------------------------+    ^
52//!              | a0 a1  b0 b1  c0 c1  d0 d1  |    |  col-pair (0, 1)
53//!              | a2 a3  b2 b3  c2 c3  d2 d3  |    |  col-pair (2, 3)
54//!    Block 0   | a4 __  b4 __  c4 __  d4 __  |    |  col-pair (4, pad)
55//!    (Full)    +-----------------------------+    v
56//!              +-----------------------------+
57//!              | e0 e1  f0 f1  XX XX  XX XX  |       col-pair (0, 1)
58//!    Block 1   | e2 e3  f2 f3  XX XX  XX XX  |       col-pair (2, 3)
59//!    (Partial) | e4 __  f4 __  XX XX  XX XX  |       col-pair (4, pad)
60//!              +-----------------------------+
61//!
62//!    __ = zero (column padding)    XX = zero (row padding)
63//!    padded_ncols = 6  (5 rounded up to next multiple of PACK)
64//!    Block Size  = padded_ncols / PACK = 3 physical rows per block
65//! ```
66//!
67//! Each physical row of a block holds one column-pair across all `GROUP` rows.
68//! For example, the first physical row stores columns `(0, 1)` for rows
69//! `a, b, c, d` interleaved as `[a0, a1, b0, b1, c0, c1, d0, d1]`.
70//!
71//! Because `ncols = 5` is odd (not a multiple of `PACK = 2`), the last
72//! column-pair `(4, pad)` is zero-padded: `[a4, 0, b4, 0, c4, 0, d4, 0]`.
73//!
74//! # Constraints
75//!
76//! - `GROUP > 0`
77//! - `PACK > 0`
78//! - `GROUP % PACK == 0`
79
80use std::{alloc::Layout, marker::PhantomData, ptr::NonNull};
81
82use diskann_utils::{
83    Reborrow, ReborrowMut,
84    strided::StridedView,
85    views::{MatrixView, MutMatrixView},
86};
87
88use super::matrix::{
89    Defaulted, LayoutError, Mat, MatMut, MatRef, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut,
90    ReprOwned, SliceError,
91};
92use crate::bits::{AsMutPtr, AsPtr, MutSlicePtr, SlicePtr};
93use crate::utils;
94
95/// Round `ncols` up to the next multiple of `PACK`.
96#[inline]
97fn padded_ncols<const PACK: usize>(ncols: usize) -> usize {
98    ncols.next_multiple_of(PACK)
99}
100
101/// Compute the total number of `T` elements required to store a block-transposed matrix
102/// of `nrows x ncols` with group size `GROUP` and packing factor `PACK`.
103///
104/// This is the **unchecked** flavor — it assumes the caller has already validated that
105/// the dimensions do not overflow (e.g. after construction). For use in the constructor,
106/// prefer [`checked_compute_capacity`].
107///
108/// Compile-time constraints (`GROUP > 0`, `PACK > 0`, `GROUP % PACK == 0`) are enforced
109/// by [`BlockTransposedRepr::_ASSERTIONS`]; this function does **not** duplicate them.
110#[inline]
111fn compute_capacity<const GROUP: usize, const PACK: usize>(nrows: usize, ncols: usize) -> usize {
112    nrows.next_multiple_of(GROUP) * padded_ncols::<PACK>(ncols)
113}
114
115/// Checked variant of [`compute_capacity`] that returns `None` if any intermediate
116/// arithmetic overflows. Used by the constructor to reject impossibly large dimensions
117/// before committing to an allocation.
118#[inline]
119fn checked_compute_capacity<const GROUP: usize, const PACK: usize>(
120    nrows: usize,
121    ncols: usize,
122) -> Option<usize> {
123    nrows
124        .checked_next_multiple_of(GROUP)?
125        .checked_mul(ncols.checked_next_multiple_of(PACK)?)
126}
127
128/// Compute the linear index for the element at logical `(row, col)` in a block-transposed
129/// layout with group size `GROUP`, packing factor `PACK`, and `ncols` logical columns.
130#[inline]
131fn linear_index<const GROUP: usize, const PACK: usize>(
132    row: usize,
133    col: usize,
134    ncols: usize,
135) -> usize {
136    let pncols = padded_ncols::<PACK>(ncols);
137    let block = row / GROUP;
138    let row_in_block = row % GROUP;
139    block * GROUP * pncols + (col / PACK) * GROUP * PACK + row_in_block * PACK + (col % PACK)
140}
141
142/// Compute the offset from a row's base pointer (at col=0) to the element at `col`.
143///
144/// This is purely a function of the column index and the const layout parameters, not
145/// of any particular matrix's dimensions.
146#[inline]
147fn col_offset<const GROUP: usize, const PACK: usize>(col: usize) -> usize {
148    (col / PACK) * GROUP * PACK + (col % PACK)
149}
150
151/// Internal layout descriptor for block-transposed matrices.
152///
153/// This is not part of the public API — use [`BlockTransposed`], [`BlockTransposedRef`],
154/// or [`BlockTransposedMut`] instead.
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub(crate) struct BlockTransposedRepr<T, const GROUP: usize, const PACK: usize = 1> {
157    nrows: usize,
158    ncols: usize,
159    _elem: PhantomData<T>,
160}
161
162impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRepr<T, GROUP, PACK> {
163    // Compile-time assertions — evaluated whenever any method references this constant.
164    const _ASSERTIONS: () = {
165        assert!(GROUP > 0, "group size GROUP must be positive");
166        assert!(PACK > 0, "packing factor PACK must be positive");
167        assert!(
168            GROUP.is_multiple_of(PACK),
169            "GROUP must be divisible by PACK"
170        );
171    };
172
173    /// Create a new `BlockTransposedRepr` descriptor.
174    ///
175    /// Successful construction requires that the total memory for the backing allocation
176    /// does not exceed `isize::MAX`.
177    pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
178        let () = Self::_ASSERTIONS;
179        let capacity = checked_compute_capacity::<GROUP, PACK>(nrows, ncols)
180            .ok_or_else(|| Overflow::for_type::<T>(nrows, ncols))?;
181        Overflow::check_byte_budget::<T>(capacity, nrows, ncols)?;
182        Ok(Self {
183            nrows,
184            ncols,
185            _elem: PhantomData,
186        })
187    }
188
189    // ── Query helpers ────────────────────────────────────────────────
190
191    /// The total number of `T` elements in the backing allocation (including padding).
192    #[inline]
193    fn storage_len(&self) -> usize {
194        compute_capacity::<GROUP, PACK>(self.nrows, self.ncols)
195    }
196
197    /// Number of logical rows.
198    #[inline]
199    fn nrows(&self) -> usize {
200        self.nrows
201    }
202
203    /// Number of logical columns (dimensionality).
204    #[inline]
205    pub fn ncols(&self) -> usize {
206        self.ncols
207    }
208
209    /// Number of physical (padded) columns — logical columns rounded up to
210    /// the next multiple of `PACK`.
211    #[inline]
212    pub fn padded_ncols(&self) -> usize {
213        padded_ncols::<PACK>(self.ncols)
214    }
215
216    /// Number of completely full blocks.
217    #[inline]
218    pub fn full_blocks(&self) -> usize {
219        self.nrows / GROUP
220    }
221
222    /// Total number of blocks including a possible partially-filled tail.
223    #[inline]
224    pub fn num_blocks(&self) -> usize {
225        self.nrows.div_ceil(GROUP)
226    }
227
228    /// Number of valid elements in the last block, or 0 if all blocks are full.
229    #[inline]
230    pub fn remainder(&self) -> usize {
231        self.nrows % GROUP
232    }
233
234    /// Total number of logical rows rounded up to the next multiple of `GROUP`.
235    ///
236    /// This is the number of "available" row slots in the backing allocation,
237    /// including zero-padded rows in the last (possibly partial) block.
238    #[inline]
239    pub fn padded_nrows(&self) -> usize {
240        self.num_blocks() * GROUP
241    }
242
243    /// The stride (in elements) between the start of consecutive blocks.
244    #[inline]
245    fn block_stride(&self) -> usize {
246        GROUP * self.padded_ncols()
247    }
248
249    /// The linear offset of the start of `block`.
250    #[inline]
251    fn block_offset(&self, block: usize) -> usize {
252        block * self.block_stride()
253    }
254
255    /// Verify that `slice` has exactly `self.storage_len()` elements.
256    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
257        let cap = self.storage_len();
258        if slice.len() != cap {
259            Err(SliceError::LengthMismatch {
260                expected: cap,
261                found: slice.len(),
262            })
263        } else {
264            Ok(())
265        }
266    }
267
268    /// Helper: wrap a `Box<[T]>` into a [`Mat`] without any further checks.
269    ///
270    /// # Safety
271    ///
272    /// `b.len()` must equal `self.storage_len()`.
273    unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
274        debug_assert_eq!(b.len(), self.storage_len(), "safety contract violated");
275
276        let ptr = utils::box_into_nonnull(b).cast::<u8>();
277
278        // SAFETY: `ptr` is properly aligned and compatible with our layout.
279        unsafe { Mat::from_raw_parts(self, ptr) }
280    }
281}
282
283// ════════════════════════════════════════════════════════════════════
284// Row view types
285// ════════════════════════════════════════════════════════════════════
286
287/// An immutable view of a single logical row in a block-transposed matrix.
288///
289/// Because the elements of a logical row are strided (not contiguous), this struct
290/// provides indexed access and iteration over the row's elements.
291#[derive(Debug, Clone, Copy)]
292pub struct Row<'a, T, const GROUP: usize, const PACK: usize = 1> {
293    /// Pointer to the element at `(row, col=0)` in the backing allocation.
294    base: SlicePtr<'a, T>,
295    ncols: usize,
296}
297
298impl<T: Copy, const GROUP: usize, const PACK: usize> Row<'_, T, GROUP, PACK> {
299    /// Number of elements (columns) in this row.
300    #[inline]
301    pub fn len(&self) -> usize {
302        self.ncols
303    }
304
305    /// Whether the row is empty.
306    #[inline]
307    pub fn is_empty(&self) -> bool {
308        self.ncols == 0
309    }
310
311    /// Get a reference to the element at column `col`, or `None` if out of bounds.
312    #[inline]
313    pub fn get(&self, col: usize) -> Option<&T> {
314        if col < self.ncols {
315            // SAFETY: bounds checked, offset computed from validated layout.
316            Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
317        } else {
318            None
319        }
320    }
321
322    /// Return an iterator over the elements of this row.
323    #[inline]
324    pub fn iter(&self) -> RowIter<'_, T, GROUP, PACK> {
325        RowIter {
326            base: self.base,
327            col: 0,
328            ncols: self.ncols,
329        }
330    }
331}
332
333impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
334    for Row<'_, T, GROUP, PACK>
335{
336    type Output = T;
337
338    #[inline]
339    #[allow(clippy::panic)] // Index is expected to panic on OOB
340    fn index(&self, col: usize) -> &Self::Output {
341        self.get(col)
342            .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
343    }
344}
345
346/// Iterator over the elements of a [`Row`].
347#[derive(Debug, Clone)]
348pub struct RowIter<'a, T, const GROUP: usize, const PACK: usize = 1> {
349    base: SlicePtr<'a, T>,
350    col: usize,
351    ncols: usize,
352}
353
354impl<T: Copy, const GROUP: usize, const PACK: usize> Iterator for RowIter<'_, T, GROUP, PACK> {
355    type Item = T;
356
357    #[inline]
358    fn next(&mut self) -> Option<Self::Item> {
359        if self.col >= self.ncols {
360            return None;
361        }
362        // SAFETY: col < ncols means the offset is within the backing allocation.
363        let val = unsafe { *self.base.as_ptr().add(col_offset::<GROUP, PACK>(self.col)) };
364        self.col += 1;
365        Some(val)
366    }
367
368    #[inline]
369    fn size_hint(&self) -> (usize, Option<usize>) {
370        let remaining = self.ncols - self.col;
371        (remaining, Some(remaining))
372    }
373}
374
375impl<T: Copy, const GROUP: usize, const PACK: usize> ExactSizeIterator
376    for RowIter<'_, T, GROUP, PACK>
377{
378}
379impl<T: Copy, const GROUP: usize, const PACK: usize> std::iter::FusedIterator
380    for RowIter<'_, T, GROUP, PACK>
381{
382}
383
384/// A mutable view of a single logical row in a block-transposed matrix.
385#[derive(Debug)]
386pub struct RowMut<'a, T, const GROUP: usize, const PACK: usize = 1> {
387    base: MutSlicePtr<'a, T>,
388    ncols: usize,
389}
390
391impl<T: Copy, const GROUP: usize, const PACK: usize> RowMut<'_, T, GROUP, PACK> {
392    /// Number of elements (columns) in this row.
393    #[inline]
394    pub fn len(&self) -> usize {
395        self.ncols
396    }
397
398    /// Whether the row is empty.
399    #[inline]
400    pub fn is_empty(&self) -> bool {
401        self.ncols == 0
402    }
403
404    /// Get a reference to the element at column `col`, or `None` if out of bounds.
405    #[inline]
406    pub fn get(&self, col: usize) -> Option<&T> {
407        if col < self.ncols {
408            // SAFETY: bounds checked.
409            Some(unsafe { &*self.base.as_ptr().add(col_offset::<GROUP, PACK>(col)) })
410        } else {
411            None
412        }
413    }
414
415    /// Get a mutable reference to the element at column `col`, or `None` if out of bounds.
416    #[inline]
417    pub fn get_mut(&mut self, col: usize) -> Option<&mut T> {
418        if col < self.ncols {
419            // SAFETY: bounds checked.
420            Some(unsafe { &mut *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) })
421        } else {
422            None
423        }
424    }
425
426    /// Set the element at column `col`.
427    ///
428    /// # Panics
429    ///
430    /// Panics if `col >= self.len()`.
431    #[inline]
432    pub fn set(&mut self, col: usize, value: T) {
433        assert!(
434            col < self.ncols,
435            "column index {col} out of bounds (ncols = {})",
436            self.ncols
437        );
438        // SAFETY: bounds checked.
439        unsafe { *self.base.as_mut_ptr().add(col_offset::<GROUP, PACK>(col)) = value };
440    }
441}
442
443impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<usize>
444    for RowMut<'_, T, GROUP, PACK>
445{
446    type Output = T;
447
448    #[inline]
449    #[allow(clippy::panic)] // Index is expected to panic on OOB
450    fn index(&self, col: usize) -> &Self::Output {
451        self.get(col)
452            .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {})", self.ncols))
453    }
454}
455
456impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::IndexMut<usize>
457    for RowMut<'_, T, GROUP, PACK>
458{
459    #[inline]
460    #[allow(clippy::panic)] // IndexMut is expected to panic on OOB
461    fn index_mut(&mut self, col: usize) -> &mut Self::Output {
462        let ncols = self.ncols;
463        self.get_mut(col)
464            .unwrap_or_else(|| panic!("column index {col} out of bounds (ncols = {ncols})"))
465    }
466}
467
468// ════════════════════════════════════════════════════════════════════
469// Repr / ReprMut / ReprOwned
470// ════════════════════════════════════════════════════════════════════
471
472// SAFETY: `get_row` produces a valid `Row` for valid indices. The layout
473// reports the correct capacity for the block-transposed backing allocation.
474unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> Repr
475    for BlockTransposedRepr<T, GROUP, PACK>
476{
477    type Row<'a>
478        = Row<'a, T, GROUP, PACK>
479    where
480        Self: 'a;
481
482    fn nrows(&self) -> usize {
483        self.nrows
484    }
485
486    fn layout(&self) -> Result<Layout, LayoutError> {
487        Ok(Layout::array::<T>(self.storage_len())?)
488    }
489
490    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
491        debug_assert!(i < self.nrows);
492
493        // When ncols == 0 the backing allocation is zero-sized, so we must not
494        // compute any pointer offset.  Return a dangling base instead.
495        if self.ncols == 0 {
496            return Row {
497                // SAFETY: The row is empty (ncols == 0) so the pointer will never be
498                // dereferenced. A dangling `NonNull` satisfies the non-null invariant.
499                base: unsafe { SlicePtr::new_unchecked(NonNull::dangling()) },
500                ncols: 0,
501            };
502        }
503
504        let base_ptr = ptr.as_ptr().cast::<T>();
505        let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
506
507        // SAFETY: The caller asserts `i < self.nrows()`. The backing allocation has at
508        // least `self.storage_len()` elements, so the computed offset is in bounds.
509        let row_base = unsafe { base_ptr.add(offset) };
510
511        Row {
512            // SAFETY: `row_base` is derived from a `NonNull<u8>` with a valid offset,
513            // so it is non-null. The lifetime is tied to the caller's `'a`.
514            base: unsafe { SlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
515            ncols: self.ncols,
516        }
517    }
518}
519
520// SAFETY: `get_row_mut` produces a valid `RowMut`. Disjoint row indices
521// produce disjoint base pointers because each row within a block starts at a unique
522// offset modulo GROUP.
523unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprMut
524    for BlockTransposedRepr<T, GROUP, PACK>
525{
526    type RowMut<'a>
527        = RowMut<'a, T, GROUP, PACK>
528    where
529        Self: 'a;
530
531    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
532        debug_assert!(i < self.nrows);
533
534        // When ncols == 0 the backing allocation is zero-sized, so we must not
535        // compute any pointer offset.  Return a dangling base instead.
536        if self.ncols == 0 {
537            return RowMut {
538                // SAFETY: The row is empty (ncols == 0) so the pointer will never be
539                // dereferenced. A dangling `NonNull` satisfies the non-null invariant.
540                base: unsafe { MutSlicePtr::new_unchecked(NonNull::dangling()) },
541                ncols: 0,
542            };
543        }
544
545        let base_ptr = ptr.as_ptr().cast::<T>();
546        let offset = linear_index::<GROUP, PACK>(i, 0, self.ncols);
547
548        // SAFETY: `i < self.nrows` (debug-asserted) guarantees the offset is within
549        // the backing allocation. Same reasoning as `get_row`.
550        let row_base = unsafe { base_ptr.add(offset) };
551
552        RowMut {
553            // SAFETY: `row_base` is derived from a `NonNull<u8>` with a valid offset,
554            // so it is non-null. The lifetime is tied to the caller's `'a`.
555            base: unsafe { MutSlicePtr::new_unchecked(NonNull::new_unchecked(row_base)) },
556            ncols: self.ncols,
557        }
558    }
559}
560
561// SAFETY: Memory is deallocated by reconstructing the `Box<[T]>` that was created during
562// `NewOwned`.
563unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> ReprOwned
564    for BlockTransposedRepr<T, GROUP, PACK>
565{
566    unsafe fn drop(self, ptr: NonNull<u8>) {
567        // SAFETY: `ptr` was obtained from `Box::into_raw` with length `self.storage_len()`.
568        unsafe {
569            let slice_ptr =
570                std::ptr::slice_from_raw_parts_mut(ptr.cast::<T>().as_ptr(), self.storage_len());
571            let _ = Box::from_raw(slice_ptr);
572        }
573    }
574}
575
576// ════════════════════════════════════════════════════════════════════
577// Constructors
578// ════════════════════════════════════════════════════════════════════
579
580// SAFETY: The returned `Mat` contains a `Box` with exactly `self.storage_len()` elements.
581unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewOwned<T>
582    for BlockTransposedRepr<T, GROUP, PACK>
583{
584    type Error = crate::error::Infallible;
585
586    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
587        let b: Box<[T]> = vec![value; self.storage_len()].into_boxed_slice();
588
589        // SAFETY: By construction, `b.len() == self.storage_len()`.
590        Ok(unsafe { self.box_to_mat(b) })
591    }
592}
593
594// SAFETY: This safely re-uses `<Self as NewOwned<T>>`.
595unsafe impl<T: Copy + Default, const GROUP: usize, const PACK: usize> NewOwned<Defaulted>
596    for BlockTransposedRepr<T, GROUP, PACK>
597{
598    type Error = crate::error::Infallible;
599
600    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
601        self.new_owned(T::default())
602    }
603}
604
605// SAFETY: This checks slice length against storage_len.
606unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewRef<T>
607    for BlockTransposedRepr<T, GROUP, PACK>
608{
609    type Error = SliceError;
610
611    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
612        self.check_slice(data)?;
613
614        // SAFETY: `check_slice` verified the length.
615        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
616    }
617}
618
619// SAFETY: This checks slice length against storage_len.
620unsafe impl<T: Copy, const GROUP: usize, const PACK: usize> NewMut<T>
621    for BlockTransposedRepr<T, GROUP, PACK>
622{
623    type Error = SliceError;
624
625    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
626        self.check_slice(data)?;
627
628        // SAFETY: `check_slice` verified the length.
629        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
630    }
631}
632
633// ════════════════════════════════════════════════════════════════════
634// Delegation macro
635// ════════════════════════════════════════════════════════════════════
636
637/// Generates a forwarding method that delegates to `self.as_view().$name(...)`.
638///
639/// The generated doc-comment links back to the canonical implementation on
640/// [`BlockTransposedRef`], so documentation stays in sync automatically.
641macro_rules! delegate_to_ref {
642    // Safe function.
643    ($(#[$m:meta])* $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
644        #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
645        $(#[$m])*
646        #[inline]
647        $vis fn $name(&self $(, $a: $t)*) $(-> $r)? {
648            self.as_view().$name($($a),*)
649        }
650    };
651    // Unsafe function.
652    ($(#[$m:meta])* unsafe $vis:vis fn $name:ident(&self $(, $a:ident: $t:ty)*) $(-> $r:ty)?) => {
653        #[doc = concat!("See [`BlockTransposedRef::", stringify!($name), "`].")]
654        $(#[$m])*
655        #[inline]
656        $vis unsafe fn $name(&self $(, $a: $t)*) $(-> $r)? {
657            // SAFETY: Caller upholds the safety contract of the delegated method.
658            unsafe { self.as_view().$name($($a),*) }
659        }
660    };
661}
662
663// ════════════════════════════════════════════════════════════════════
664// Public wrapper types
665// ════════════════════════════════════════════════════════════════════
666
667/// An owning block-transposed matrix.
668///
669/// Wraps an owned allocation of `T` elements laid out in block-transposed order.
670/// See the [module-level documentation](self) for layout details.
671///
672/// For shared and mutable views, see [`BlockTransposedRef`] and [`BlockTransposedMut`].
673///
674/// # Row Types
675///
676/// Because rows are not contiguous in memory, the row types are view structs:
677///
678/// - [`Row`] — a `Copy` handle supporting `Index<usize>` and `.iter()`.
679/// - [`RowMut`] — a mutable handle supporting `IndexMut<usize>`.
680#[derive(Debug)]
681pub struct BlockTransposed<T: Copy, const GROUP: usize, const PACK: usize = 1> {
682    data: Mat<BlockTransposedRepr<T, GROUP, PACK>>,
683}
684
685/// A shared (immutable) view of a block-transposed matrix.
686///
687/// Created by [`BlockTransposed::as_view`].
688#[derive(Debug, Clone, Copy)]
689pub struct BlockTransposedRef<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
690    data: MatRef<'a, BlockTransposedRepr<T, GROUP, PACK>>,
691}
692
693/// A mutable view of a block-transposed matrix.
694///
695/// Created by [`BlockTransposed::as_view_mut`].
696pub struct BlockTransposedMut<'a, T: Copy, const GROUP: usize, const PACK: usize = 1> {
697    data: MatMut<'a, BlockTransposedRepr<T, GROUP, PACK>>,
698}
699
700// ── BlockTransposedRef (core read implementations) ───────────────
701
702impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, T, GROUP, PACK> {
703    /// Returns the number of logical rows.
704    #[inline]
705    pub fn nrows(&self) -> usize {
706        self.data.repr().nrows()
707    }
708
709    /// Returns the number of logical columns (dimensionality).
710    #[inline]
711    pub fn ncols(&self) -> usize {
712        self.data.repr().ncols()
713    }
714
715    /// Returns the number of physical (padded) columns.
716    #[inline]
717    pub fn padded_ncols(&self) -> usize {
718        self.data.repr().padded_ncols()
719    }
720
721    /// Group size (blocking factor `GROUP`).
722    pub const fn group_size(&self) -> usize {
723        GROUP
724    }
725
726    /// Group size (blocking factor `GROUP`) as a `const` function on the *type*.
727    pub const fn const_group_size() -> usize {
728        GROUP
729    }
730
731    /// Packing factor `PACK`.
732    pub const fn pack_size(&self) -> usize {
733        PACK
734    }
735
736    /// Number of completely full blocks.
737    #[inline]
738    pub fn full_blocks(&self) -> usize {
739        self.data.repr().full_blocks()
740    }
741
742    /// Total number of blocks including any partially-filled tail.
743    #[inline]
744    pub fn num_blocks(&self) -> usize {
745        self.data.repr().num_blocks()
746    }
747
748    /// Number of valid elements in the last partially-full block, or 0 if all
749    /// blocks are full.
750    #[inline]
751    pub fn remainder(&self) -> usize {
752        self.data.repr().remainder()
753    }
754
755    /// Total number of logical rows rounded up to the next multiple of `GROUP`.
756    ///
757    /// This is the number of "available" row slots in the backing allocation,
758    /// including zero-padded rows in the last (possibly partial) block.
759    #[inline]
760    pub fn padded_nrows(&self) -> usize {
761        self.data.repr().padded_nrows()
762    }
763
764    /// Return a raw typed pointer to the start of the backing data.
765    #[inline]
766    pub fn as_ptr(&self) -> *const T {
767        self.data.as_raw_ptr().cast::<T>()
768    }
769
770    /// Return the backing data as a shared slice.
771    ///
772    /// The returned slice has `storage_len()` elements — this includes all padding
773    /// for partial blocks and column-group alignment.
774    #[inline]
775    pub fn as_slice(&self) -> &'a [T] {
776        let len = self.data.repr().storage_len();
777        // SAFETY: The backing allocation has exactly `storage_len()` elements of type T.
778        unsafe { std::slice::from_raw_parts(self.as_ptr(), len) }
779    }
780
781    /// Return a pointer to the start of the given block.
782    ///
783    /// The caller may assume that for the returned pointer `ptr`,
784    /// `[ptr, ptr + GROUP * padded_ncols)` points to valid memory, even for the
785    /// remainder block.
786    ///
787    /// # Safety
788    ///
789    /// `block` must be less than `self.num_blocks()`. No bounds check is
790    /// performed in release builds; callers must verify the index themselves
791    /// (e.g. by iterating `0..self.num_blocks()`).
792    #[inline]
793    pub unsafe fn block_ptr_unchecked(&self, block: usize) -> *const T {
794        debug_assert!(block < self.num_blocks());
795        // SAFETY: Caller asserts `block < self.num_blocks()`.
796        unsafe { self.as_ptr().add(self.data.repr().block_offset(block)) }
797    }
798
799    /// Return a view over a full block as a [`MatrixView`].
800    ///
801    /// The returned view has `padded_ncols / PACK` rows and `GROUP * PACK`
802    /// columns. For `PACK == 1` this simplifies to `ncols` rows and `GROUP`
803    /// columns (the standard transposed interpretation).
804    ///
805    /// # Panics
806    ///
807    /// Panics if `block >= self.full_blocks()`.
808    #[allow(clippy::expect_used)]
809    pub fn block(&self, block: usize) -> MatrixView<'a, T> {
810        assert!(block < self.full_blocks());
811        let offset = self.data.repr().block_offset(block);
812        let stride = self.data.repr().block_stride();
813        // SAFETY: `block < full_blocks()` (asserted above) guarantees
814        // `offset + stride` is within the backing allocation.
815        let data: &[T] = unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
816        MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
817            .expect("base data should have been sized correctly")
818    }
819
820    /// Return a view over the remainder block, or `None` if there is no
821    /// remainder.
822    ///
823    /// The returned view has the same dimensions as [`block()`](Self::block):
824    /// `padded_ncols / PACK` rows and `GROUP * PACK` columns.
825    #[allow(clippy::expect_used)]
826    pub fn remainder_block(&self) -> Option<MatrixView<'a, T>> {
827        if self.remainder() == 0 {
828            None
829        } else {
830            let offset = self.data.repr().block_offset(self.full_blocks());
831            let stride = self.data.repr().block_stride();
832            // SAFETY: The remainder block exists (`remainder() != 0`),
833            // so `offset + stride` is within the backing allocation.
834            let data: &[T] =
835                unsafe { std::slice::from_raw_parts(self.as_ptr().add(offset), stride) };
836            Some(
837                MatrixView::try_from(data, self.padded_ncols() / PACK, GROUP * PACK)
838                    .expect("base data should have been sized correctly"),
839            )
840        }
841    }
842
843    /// Retrieve the value at the logical `(row, col)`.
844    ///
845    /// # Panics
846    ///
847    /// Panics if `row >= self.nrows()` or `col >= self.ncols()`.
848    #[inline]
849    pub fn get_element(&self, row: usize, col: usize) -> T {
850        assert!(
851            row < self.nrows(),
852            "row {row} out of bounds (nrows = {})",
853            self.nrows()
854        );
855        assert!(
856            col < self.ncols(),
857            "col {col} out of bounds (ncols = {})",
858            self.ncols()
859        );
860        let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
861        // SAFETY: bounds checked above.
862        unsafe { *self.as_ptr().add(idx) }
863    }
864
865    /// Get an immutable row view, or `None` if `i` is out of bounds.
866    #[inline]
867    pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
868        self.data.get_row(i)
869    }
870}
871
872// ── BlockTransposedMut ───────────────────────────────────────────
873
874impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, T, GROUP, PACK> {
875    /// Borrow as an immutable [`BlockTransposedRef`].
876    #[inline]
877    pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
878        BlockTransposedRef {
879            data: self.data.as_view(),
880        }
881    }
882
883    // ── Delegated read methods ───────────────────────────────────
884
885    delegate_to_ref!(pub fn nrows(&self) -> usize);
886    delegate_to_ref!(pub fn ncols(&self) -> usize);
887    delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
888    delegate_to_ref!(pub fn full_blocks(&self) -> usize);
889    delegate_to_ref!(pub fn num_blocks(&self) -> usize);
890    delegate_to_ref!(pub fn remainder(&self) -> usize);
891    delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
892    delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
893    delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
894    delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
895    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
896    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
897    delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
898
899    /// Group size (blocking factor `GROUP`).
900    pub const fn group_size(&self) -> usize {
901        GROUP
902    }
903
904    /// Group size as `const` function on the *type*.
905    pub const fn const_group_size() -> usize {
906        GROUP
907    }
908
909    /// Packing factor `PACK`.
910    pub const fn pack_size(&self) -> usize {
911        PACK
912    }
913
914    /// Get an immutable row view, or `None` if `i` is out of bounds.
915    #[inline]
916    pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
917        self.data.get_row(i)
918    }
919
920    // ── Mutable methods ──────────────────────────────────────────
921    //
922    // The `_inner` variants consume `self` by value so that the lifetime of
923    // the returned view is tied to `'a` (the underlying allocation), not to
924    // a temporary reborrow. Public `&mut self` methods reborrow into a
925    // short-lived `BlockTransposedMut` and then call the inner variant.
926
927    /// Return the backing data as a mutable slice.
928    ///
929    /// The returned slice has `storage_len()` elements (including all padding).
930    #[inline]
931    pub fn as_mut_slice(&mut self) -> &mut [T] {
932        self.reborrow_mut().mut_slice_inner()
933    }
934
935    fn mut_slice_inner(mut self) -> &'a mut [T] {
936        let len = self.data.repr().storage_len();
937        // SAFETY: We own exclusive access through `self`.
938        unsafe { std::slice::from_raw_parts_mut(self.data.as_raw_mut_ptr().cast::<T>(), len) }
939    }
940
941    /// Return a mutable view over a full block.
942    ///
943    /// # Panics
944    ///
945    /// Panics if `block >= self.full_blocks()`.
946    #[allow(clippy::expect_used)]
947    pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
948        self.reborrow_mut().block_mut_inner(block)
949    }
950
951    #[allow(clippy::expect_used)]
952    fn block_mut_inner(mut self, block: usize) -> MutMatrixView<'a, T> {
953        let repr = *self.data.repr();
954        assert!(block < repr.full_blocks());
955        let offset = repr.block_offset(block);
956        let stride = repr.block_stride();
957        let pncols = repr.padded_ncols();
958        // SAFETY: `block < full_blocks()`, so the range is within the allocation.
959        let data: &mut [T] = unsafe {
960            std::slice::from_raw_parts_mut(
961                self.data.as_raw_mut_ptr().cast::<T>().add(offset),
962                stride,
963            )
964        };
965        MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
966            .expect("base data should have been sized correctly")
967    }
968
969    /// Return a mutable view over the remainder block, or `None` if there is no
970    /// remainder.
971    #[allow(clippy::expect_used)]
972    pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
973        self.reborrow_mut().remainder_block_mut_inner()
974    }
975
976    #[allow(clippy::expect_used)]
977    fn remainder_block_mut_inner(mut self) -> Option<MutMatrixView<'a, T>> {
978        let repr = *self.data.repr();
979        if repr.remainder() == 0 {
980            None
981        } else {
982            let offset = repr.block_offset(repr.full_blocks());
983            let stride = repr.block_stride();
984            let pncols = repr.padded_ncols();
985            // SAFETY: Remainder block exists, so the range is within the allocation.
986            let data: &mut [T] = unsafe {
987                std::slice::from_raw_parts_mut(
988                    self.data.as_raw_mut_ptr().cast::<T>().add(offset),
989                    stride,
990                )
991            };
992            Some(
993                MutMatrixView::try_from(data, pncols / PACK, GROUP * PACK)
994                    .expect("base data should have been sized correctly"),
995            )
996        }
997    }
998
999    /// Get a mutable row view, or `None` if `i` is out of bounds.
1000    #[inline]
1001    pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
1002        self.data.get_row_mut(i)
1003    }
1004
1005    // ── Private helpers ──────────────────────────────────────────
1006
1007    fn reborrow_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
1008        BlockTransposedMut {
1009            data: self.data.reborrow_mut(),
1010        }
1011    }
1012}
1013
1014// ── BlockTransposed (owned) ──────────────────────────────────────
1015
1016impl<T: Copy, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
1017    /// Borrow as an immutable [`BlockTransposedRef`].
1018    pub fn as_view(&self) -> BlockTransposedRef<'_, T, GROUP, PACK> {
1019        BlockTransposedRef {
1020            data: self.data.as_view(),
1021        }
1022    }
1023
1024    /// Borrow as a mutable [`BlockTransposedMut`].
1025    pub fn as_view_mut(&mut self) -> BlockTransposedMut<'_, T, GROUP, PACK> {
1026        BlockTransposedMut {
1027            data: self.data.as_view_mut(),
1028        }
1029    }
1030
1031    // ── Delegated read methods ───────────────────────────────────
1032
1033    delegate_to_ref!(pub fn nrows(&self) -> usize);
1034    delegate_to_ref!(pub fn ncols(&self) -> usize);
1035    delegate_to_ref!(pub fn padded_ncols(&self) -> usize);
1036    delegate_to_ref!(pub fn full_blocks(&self) -> usize);
1037    delegate_to_ref!(pub fn num_blocks(&self) -> usize);
1038    delegate_to_ref!(pub fn remainder(&self) -> usize);
1039    delegate_to_ref!(pub fn padded_nrows(&self) -> usize);
1040    delegate_to_ref!(pub fn as_ptr(&self) -> *const T);
1041    delegate_to_ref!(pub fn as_slice(&self) -> &[T]);
1042    delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T);
1043    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn block(&self, block: usize) -> MatrixView<'_, T>);
1044    delegate_to_ref!(#[allow(clippy::expect_used)] pub fn remainder_block(&self) -> Option<MatrixView<'_, T>>);
1045    delegate_to_ref!(pub fn get_element(&self, row: usize, col: usize) -> T);
1046
1047    /// Group size (blocking factor `GROUP`).
1048    pub const fn group_size(&self) -> usize {
1049        GROUP
1050    }
1051
1052    /// Group size (blocking factor `GROUP`) as a `const` function on the *type*.
1053    pub const fn const_group_size() -> usize {
1054        GROUP
1055    }
1056
1057    /// Packing factor `PACK`.
1058    pub const fn pack_size(&self) -> usize {
1059        PACK
1060    }
1061
1062    /// Get an immutable row view, or `None` if `i` is out of bounds.
1063    #[inline]
1064    pub fn get_row(&self, i: usize) -> Option<Row<'_, T, GROUP, PACK>> {
1065        self.data.get_row(i)
1066    }
1067
1068    // ── Mutable methods (delegated to BlockTransposedMut) ────────
1069
1070    /// See [`BlockTransposedMut::as_mut_slice`].
1071    #[inline]
1072    pub fn as_mut_slice(&mut self) -> &mut [T] {
1073        self.as_view_mut().mut_slice_inner()
1074    }
1075
1076    /// See [`BlockTransposedMut::block_mut`].
1077    #[allow(clippy::expect_used)]
1078    pub fn block_mut(&mut self, block: usize) -> MutMatrixView<'_, T> {
1079        self.as_view_mut().block_mut_inner(block)
1080    }
1081
1082    /// See [`BlockTransposedMut::remainder_block_mut`].
1083    #[allow(clippy::expect_used)]
1084    pub fn remainder_block_mut(&mut self) -> Option<MutMatrixView<'_, T>> {
1085        self.as_view_mut().remainder_block_mut_inner()
1086    }
1087
1088    /// Get a mutable row view, or `None` if `i` is out of bounds.
1089    #[inline]
1090    pub fn get_row_mut(&mut self, i: usize) -> Option<RowMut<'_, T, GROUP, PACK>> {
1091        self.data.get_row_mut(i)
1092    }
1093}
1094
1095// ── Reborrow ─────────────────────────────────────────────────────
1096
1097impl<'this, T: Copy, const GROUP: usize, const PACK: usize> Reborrow<'this>
1098    for BlockTransposed<T, GROUP, PACK>
1099{
1100    type Target = BlockTransposedRef<'this, T, GROUP, PACK>;
1101
1102    #[inline]
1103    fn reborrow(&'this self) -> Self::Target {
1104        self.as_view()
1105    }
1106}
1107
1108// ── Factory methods ──────────────────────────────────────────────
1109
1110impl<T: Copy + Default, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
1111    /// Construct a default-initialized block-transposed matrix from dimensions.
1112    ///
1113    /// # Panics
1114    ///
1115    /// Panics if the dimensions overflow the allocation budget.
1116    #[allow(clippy::expect_used)]
1117    pub fn new(nrows: usize, ncols: usize) -> Self {
1118        let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)
1119            .expect("dimensions should not overflow");
1120        Self {
1121            data: Mat::new(repr, Defaulted).expect("infallible"),
1122        }
1123    }
1124
1125    /// Fallible variant of [`new`](Self::new).
1126    pub fn try_new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
1127        let repr = BlockTransposedRepr::<T, GROUP, PACK>::new(nrows, ncols)?;
1128        Ok(Self {
1129            data: Mat::new(repr, Defaulted).expect("infallible"),
1130        })
1131    }
1132
1133    /// Construct a block-transposed matrix by copying data from a [`StridedView`].
1134    ///
1135    /// Each source element at `(row, col)` is placed at the correct offset in the
1136    /// block-transposed layout. Padding positions (both partial-block rows and
1137    /// column-group padding when `ncols % PACK != 0`) are filled with
1138    /// `T::default()`.
1139    ///
1140    /// The loop iterates in physical (block-transposed) order — block, column-group,
1141    /// row-within-block, pack-lane — so that writes to the backing allocation are
1142    /// sequential. Source reads stride across rows of the [`StridedView`], which is
1143    /// acceptable because read-side prefetch is more effective than write-side.
1144    pub fn from_strided(v: StridedView<'_, T>) -> Self {
1145        let nrows = v.nrows();
1146        let ncols = v.ncols();
1147        let mut mat = Self::new(nrows, ncols);
1148
1149        let repr = *mat.data.repr();
1150        let num_blocks = repr.num_blocks();
1151        let pncols = repr.padded_ncols();
1152        let num_col_groups = pncols / PACK;
1153
1154        // Walk the backing allocation in physical order so that writes are
1155        // sequential. The allocation is default-initialized, so padding positions
1156        // already hold `T::default()` and can be skipped.
1157        let mut dst = mat.data.as_raw_mut_ptr().cast::<T>();
1158        for block in 0..num_blocks {
1159            let row_base = block * GROUP;
1160            for cg in 0..num_col_groups {
1161                let col_base = cg * PACK;
1162                for rib in 0..GROUP {
1163                    let row = row_base + rib;
1164                    if row < nrows {
1165                        // SAFETY: row < nrows is checked by the enclosing `if` condition.
1166                        let src_row = unsafe { v.get_row_unchecked(row) };
1167                        for p in 0..PACK {
1168                            let col = col_base + p;
1169                            if col < ncols {
1170                                // SAFETY: dst advances sequentially through the
1171                                // backing allocation which has exactly `storage_len`
1172                                // elements, and our loop visits each position once.
1173                                // `col < ncols` is checked above, and `src_row` has
1174                                // exactly `ncols` elements.
1175                                unsafe { *dst = *src_row.get_unchecked(col) };
1176                            }
1177                            // SAFETY: dst advances sequentially through the
1178                            // backing allocation which has exactly `storage_len`
1179                            // elements, and our loop visits each position once.
1180                            dst = unsafe { dst.add(1) };
1181                        }
1182                    } else {
1183                        // SAFETY: Entire row is padding — skip PACK positions.
1184                        // dst remains within the allocation.
1185                        dst = unsafe { dst.add(PACK) };
1186                    }
1187                }
1188            }
1189        }
1190
1191        mat
1192    }
1193
1194    /// Construct a block-transposed matrix by copying data from a [`MatrixView`].
1195    pub fn from_matrix_view(v: MatrixView<'_, T>) -> Self {
1196        Self::from_strided(v.into())
1197    }
1198}
1199
1200// ════════════════════════════════════════════════════════════════════
1201// Index<(usize, usize)> for BlockTransposed
1202// ════════════════════════════════════════════════════════════════════
1203
1204impl<T: Copy, const GROUP: usize, const PACK: usize> std::ops::Index<(usize, usize)>
1205    for BlockTransposed<T, GROUP, PACK>
1206{
1207    type Output = T;
1208
1209    #[inline]
1210    fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
1211        assert!(row < self.nrows());
1212        assert!(col < self.ncols());
1213        let idx = linear_index::<GROUP, PACK>(row, col, self.ncols());
1214        // SAFETY: bounds checked above and the backing allocation has `storage_len()` elements.
1215        unsafe { &*self.as_ptr().add(idx) }
1216    }
1217}
1218
1219// ════════════════════════════════════════════════════════════════════
1220// Tests
1221// ════════════════════════════════════════════════════════════════════
1222
1223#[cfg(test)]
1224mod tests {
1225    //! Test organisation:
1226    //!
1227    //!  1. **Helper functions** — `gen_*` element generators.
1228    //!  2. [`test_full_api`] — single parameterized function that exhaustively
1229    //!     exercises the full read + write API on all three wrapper types
1230    //!     (`BlockTransposed`, `BlockTransposedRef`, `BlockTransposedMut`).
1231    //!  3. **Test runners** — `#[test]` functions that call `test_full_api`
1232    //!     with various `(T, GROUP, PACK, nrows, ncols)` combinations.
1233    //!  4. [`test_block_layout_pack1`] — verifies that `PACK=1` blocks are
1234    //!     the standard row-to-column transposition.
1235    //!  5. **Focused tests** — edge cases that cannot be expressed as
1236    //!     parameters to `test_full_api` (`Send`/`Sync`, panic paths,
1237    //!     non-unit strides, concurrent mutation, etc.).
1238
1239    use diskann_utils::{lazy_format, views::Matrix};
1240
1241    use super::*;
1242    use crate::utils::div_round_up;
1243
1244    // ── Per-type element generators ──────────────────────────────────
1245    //
1246    // Each generator maps a flat index to a non-zero `T` value so that
1247    // `T::default()` (zero) can be used unambiguously to verify padding.
1248
1249    fn gen_f32(i: usize) -> f32 {
1250        (i + 1) as f32
1251    }
1252    fn gen_i32(i: usize) -> i32 {
1253        (i + 1) as i32
1254    }
1255    fn gen_u8(i: usize) -> u8 {
1256        ((i % 255) + 1) as u8
1257    }
1258
1259    // ── Unified parameterized test ──────────────────────────────────
1260
1261    /// Exhaustive test for the full `BlockTransposed` / `BlockTransposedRef` /
1262    /// `BlockTransposedMut` API surface, parameterized over element type `T`,
1263    /// group size `GROUP`, and packing factor `PACK`.
1264    ///
1265    /// Exercises: construction, query helpers, `Index` / `get_element`,
1266    /// immutable row views (`Row`), mutable row views (`RowMut`),
1267    /// `as_slice` / `as_mut_slice`, block views (immutable and mutable),
1268    /// `remainder_block` / `remainder_block_mut`, `block_ptr_unchecked`,
1269    /// `from_matrix_view`, OOB `get_row` returns, and both column and row
1270    /// padding verification.
1271    fn test_full_api<
1272        T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1273        const GROUP: usize,
1274        const PACK: usize,
1275    >(
1276        nrows: usize,
1277        ncols: usize,
1278        gen_element: fn(usize) -> T,
1279    ) {
1280        let context = lazy_format!(
1281            "T={}, GROUP={}, PACK={}, nrows={}, ncols={}",
1282            std::any::type_name::<T>(),
1283            GROUP,
1284            PACK,
1285            nrows,
1286            ncols,
1287        );
1288
1289        // ── Construction ─────────────────────────────────────────
1290
1291        let mut data = Matrix::new(T::default(), nrows, ncols);
1292        data.as_mut_slice()
1293            .iter_mut()
1294            .enumerate()
1295            .for_each(|(i, d)| *d = gen_element(i));
1296
1297        let mut transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1298
1299        let expected_padded = div_round_up(ncols, PACK) * PACK;
1300        let expected_remainder = nrows % GROUP;
1301        let storage_len = transpose.as_slice().len();
1302
1303        // ── Query methods on owned type ──────────────────────────
1304
1305        assert_eq!(transpose.nrows(), nrows, "{}", context);
1306        assert_eq!(transpose.ncols(), ncols, "{}", context);
1307        assert_eq!(transpose.group_size(), GROUP, "{}", context);
1308        assert_eq!(
1309            BlockTransposed::<T, GROUP, PACK>::const_group_size(),
1310            GROUP,
1311            "{}",
1312            context
1313        );
1314        assert_eq!(transpose.pack_size(), PACK, "{}", context);
1315        assert_eq!(transpose.full_blocks(), nrows / GROUP, "{}", context);
1316        assert_eq!(
1317            transpose.num_blocks(),
1318            div_round_up(nrows, GROUP),
1319            "{}",
1320            context,
1321        );
1322        assert_eq!(transpose.remainder(), expected_remainder, "{}", context);
1323        assert_eq!(transpose.padded_ncols(), expected_padded, "{}", context);
1324
1325        // ── Element access (owned) ───────────────────────────────
1326
1327        for row in 0..nrows {
1328            for col in 0..ncols {
1329                assert_eq!(
1330                    data[(row, col)],
1331                    transpose[(row, col)],
1332                    "Index at ({}, {}) -- {}",
1333                    row,
1334                    col,
1335                    context,
1336                );
1337                assert_eq!(
1338                    data[(row, col)],
1339                    transpose.get_element(row, col),
1340                    "get_element at ({}, {}) -- {}",
1341                    row,
1342                    col,
1343                    context,
1344                );
1345            }
1346        }
1347
1348        // ── Immutable row views (owned) ──────────────────────────
1349
1350        let view = transpose.as_view();
1351        for row in 0..nrows {
1352            let row_view = view.get_row(row).unwrap();
1353            assert_eq!(row_view.len(), ncols, "{}", context);
1354            assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1355            for col in 0..ncols {
1356                assert_eq!(
1357                    data[(row, col)],
1358                    row_view[col],
1359                    "row view at ({}, {}) -- {}",
1360                    row,
1361                    col,
1362                    context,
1363                );
1364            }
1365            // Row::get — in-bounds + OOB.
1366            if ncols > 0 {
1367                assert_eq!(row_view.get(0), Some(&data[(row, 0)]), "{}", context);
1368            }
1369            assert_eq!(row_view.get(ncols), None, "{}", context);
1370
1371            // Iterator + ExactSizeIterator.
1372            let iter = row_view.iter();
1373            assert_eq!(iter.len(), ncols, "{}", context);
1374            let (lo, hi) = iter.size_hint();
1375            assert_eq!(lo, ncols, "{}", context);
1376            assert_eq!(hi, Some(ncols), "{}", context);
1377
1378            let collected: Vec<T> = row_view.iter().collect();
1379            assert_eq!(collected.len(), ncols, "{}", context);
1380            for col in 0..ncols {
1381                assert_eq!(data[(row, col)], collected[col], "{}", context);
1382            }
1383        }
1384        // OOB row returns None.
1385        assert!(view.get_row(nrows).is_none(), "{}", context);
1386        let _ = view;
1387
1388        // ── BlockTransposedRef API ───────────────────────────────
1389
1390        {
1391            let view = transpose.as_view();
1392            assert_eq!(view.nrows(), nrows, "{}", context);
1393            assert_eq!(view.ncols(), ncols, "{}", context);
1394            assert_eq!(view.padded_ncols(), expected_padded, "{}", context);
1395            assert_eq!(view.group_size(), GROUP, "{}", context);
1396            assert_eq!(
1397                BlockTransposedRef::<T, GROUP, PACK>::const_group_size(),
1398                GROUP,
1399            );
1400            assert_eq!(view.pack_size(), PACK, "{}", context);
1401            assert_eq!(view.full_blocks(), nrows / GROUP, "{}", context);
1402            assert_eq!(view.num_blocks(), div_round_up(nrows, GROUP), "{}", context,);
1403            assert_eq!(view.remainder(), expected_remainder, "{}", context);
1404            assert_eq!(view.as_ptr(), transpose.as_ptr(), "{}", context);
1405            assert_eq!(view.as_slice(), transpose.as_slice(), "{}", context);
1406
1407            for row in 0..nrows {
1408                for col in 0..ncols {
1409                    assert_eq!(
1410                        data[(row, col)],
1411                        view.get_element(row, col),
1412                        "Ref get_element at ({}, {}) -- {}",
1413                        row,
1414                        col,
1415                        context,
1416                    );
1417                }
1418                let row_view = view.get_row(row).unwrap();
1419                for col in 0..ncols {
1420                    assert_eq!(data[(row, col)], row_view[col], "{}", context);
1421                }
1422            }
1423            assert!(view.get_row(nrows).is_none(), "{}", context);
1424        }
1425
1426        // ── BlockTransposedMut read API ──────────────────────────
1427
1428        let expected_ptr = transpose.as_ptr();
1429        {
1430            let mut_view = transpose.as_view_mut();
1431            assert_eq!(mut_view.nrows(), nrows, "{}", context);
1432            assert_eq!(mut_view.ncols(), ncols, "{}", context);
1433            assert_eq!(mut_view.padded_ncols(), expected_padded, "{}", context);
1434            assert_eq!(mut_view.group_size(), GROUP, "{}", context);
1435            assert_eq!(
1436                BlockTransposedMut::<T, GROUP, PACK>::const_group_size(),
1437                GROUP,
1438            );
1439            assert_eq!(mut_view.pack_size(), PACK, "{}", context);
1440            assert_eq!(mut_view.full_blocks(), nrows / GROUP, "{}", context);
1441            assert_eq!(
1442                mut_view.num_blocks(),
1443                div_round_up(nrows, GROUP),
1444                "{}",
1445                context,
1446            );
1447            assert_eq!(mut_view.remainder(), expected_remainder, "{}", context);
1448            assert_eq!(mut_view.as_ptr(), expected_ptr, "{}", context);
1449            assert_eq!(mut_view.as_slice().len(), storage_len, "{}", context);
1450
1451            for row in 0..nrows {
1452                for col in 0..ncols {
1453                    assert_eq!(
1454                        data[(row, col)],
1455                        mut_view.get_element(row, col),
1456                        "Mut get_element at ({}, {}) -- {}",
1457                        row,
1458                        col,
1459                        context,
1460                    );
1461                }
1462                let row_view = mut_view.get_row(row).unwrap();
1463                for col in 0..ncols {
1464                    assert_eq!(data[(row, col)], row_view[col], "{}", context);
1465                }
1466            }
1467            assert!(mut_view.get_row(nrows).is_none(), "{}", context);
1468        }
1469
1470        // ── BlockTransposedMut::as_view() ────────────────────────
1471
1472        {
1473            let mut_view = transpose.as_view_mut();
1474            let ref_from_mut = mut_view.as_view();
1475            assert_eq!(ref_from_mut.nrows(), nrows, "{}", context);
1476            for row in 0..nrows {
1477                for col in 0..ncols {
1478                    assert_eq!(
1479                        data[(row, col)],
1480                        ref_from_mut.get_element(row, col),
1481                        "{}",
1482                        context,
1483                    );
1484                }
1485            }
1486        }
1487
1488        // ── as_mut_slice ─────────────────────────────────────────
1489
1490        // Through BlockTransposedMut.
1491        {
1492            let mut mut_view = transpose.as_view_mut();
1493            assert_eq!(mut_view.as_mut_slice().len(), storage_len, "{}", context);
1494        }
1495        // Through BlockTransposed (owned).
1496        assert_eq!(transpose.as_mut_slice().len(), storage_len, "{}", context);
1497
1498        // ── Immutable block views on all three types ─────────────
1499
1500        let expected_block_nrows = expected_padded / PACK;
1501        let expected_block_ncols = GROUP * PACK;
1502
1503        for b in 0..transpose.full_blocks() {
1504            let block_data: Vec<T>;
1505            let ptr: *const T;
1506            {
1507                let block = transpose.block(b);
1508                assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1509                assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1510
1511                // SAFETY: b < full_blocks <= num_blocks.
1512                ptr = unsafe { transpose.block_ptr_unchecked(b) };
1513                assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1514
1515                block_data = block.as_slice().to_vec();
1516            }
1517
1518            // Same block via Ref.
1519            {
1520                let view = transpose.as_view();
1521                assert_eq!(view.block(b).as_slice(), &block_data[..], "{}", context);
1522                // SAFETY: `b` is in range `0..num_blocks` by the loop bound.
1523                assert_eq!(unsafe { view.block_ptr_unchecked(b) }, ptr, "{}", context);
1524            }
1525
1526            // Same block via Mut (read path).
1527            {
1528                let mut_view = transpose.as_view_mut();
1529                assert_eq!(mut_view.block(b).as_slice(), &block_data[..], "{}", context);
1530                assert_eq!(
1531                    // SAFETY: `b` is in range `0..num_blocks` by the loop bound.
1532                    unsafe { mut_view.block_ptr_unchecked(b) },
1533                    ptr,
1534                    "{}",
1535                    context,
1536                );
1537            }
1538        }
1539
1540        // Remainder block (immutable, all three types).
1541        if expected_remainder != 0 {
1542            let remainder_data: Vec<T>;
1543            let ptr: *const T;
1544            let fb = transpose.full_blocks();
1545            {
1546                let block = transpose.remainder_block().unwrap();
1547                assert_eq!(block.nrows(), expected_block_nrows, "{}", context);
1548                assert_eq!(block.ncols(), expected_block_ncols, "{}", context);
1549
1550                // SAFETY: fb < num_blocks (remainder exists).
1551                ptr = unsafe { transpose.block_ptr_unchecked(fb) };
1552                assert_eq!(ptr, block.as_slice().as_ptr(), "{}", context);
1553
1554                remainder_data = block.as_slice().to_vec();
1555            }
1556
1557            // Via Ref.
1558            {
1559                let view = transpose.as_view();
1560                let ref_block = view.remainder_block().unwrap();
1561                assert_eq!(ref_block.as_slice(), &remainder_data[..], "{}", context);
1562            }
1563            // Via Mut (read path).
1564            {
1565                let mut_view = transpose.as_view_mut();
1566                let mut_block = mut_view.remainder_block().unwrap();
1567                assert_eq!(mut_block.as_slice(), &remainder_data[..], "{}", context);
1568            }
1569        } else {
1570            assert!(transpose.remainder_block().is_none(), "{}", context);
1571            {
1572                let view = transpose.as_view();
1573                assert!(view.remainder_block().is_none(), "{}", context);
1574            }
1575            {
1576                let mut_view = transpose.as_view_mut();
1577                assert!(mut_view.remainder_block().is_none(), "{}", context);
1578            }
1579        }
1580
1581        // ── Mutable block views via BlockTransposedMut ───────────
1582
1583        {
1584            let mut mut_view = transpose.as_view_mut();
1585            for b in 0..mut_view.full_blocks() {
1586                let block_mut = mut_view.block_mut(b);
1587                assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1588                assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1589            }
1590            if expected_remainder != 0 {
1591                let rem = mut_view.remainder_block_mut().unwrap();
1592                assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1593                assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1594            } else {
1595                assert!(mut_view.remainder_block_mut().is_none(), "{}", context);
1596            }
1597        }
1598
1599        // Mutable block views via owned BlockTransposed.
1600        for b in 0..transpose.full_blocks() {
1601            let block_mut = transpose.block_mut(b);
1602            assert_eq!(block_mut.nrows(), expected_block_nrows, "{}", context);
1603            assert_eq!(block_mut.ncols(), expected_block_ncols, "{}", context);
1604        }
1605        if expected_remainder != 0 {
1606            let rem = transpose.remainder_block_mut().unwrap();
1607            assert_eq!(rem.nrows(), expected_block_nrows, "{}", context);
1608            assert_eq!(rem.ncols(), expected_block_ncols, "{}", context);
1609        } else {
1610            assert!(transpose.remainder_block_mut().is_none(), "{}", context);
1611        }
1612
1613        // ── Mutable row views via BlockTransposedMut ─────────────
1614
1615        {
1616            let mut mut_view = transpose.as_view_mut();
1617            for row in 0..nrows {
1618                let row_view = mut_view.get_row_mut(row).unwrap();
1619                assert_eq!(row_view.len(), ncols, "{}", context);
1620                assert_eq!(row_view.is_empty(), ncols == 0, "{}", context);
1621                for col in 0..ncols {
1622                    assert_eq!(data[(row, col)], row_view[col], "{}", context);
1623                }
1624            }
1625            assert!(mut_view.get_row_mut(nrows).is_none(), "{}", context);
1626        }
1627
1628        // ── Row::get, RowMut::get, RowMut::get_mut ──────────────
1629
1630        if nrows > 0 && ncols > 0 {
1631            // Row::get OOB.
1632            {
1633                let view = transpose.as_view();
1634                let row = view.get_row(0).unwrap();
1635                assert_eq!(row.get(ncols), None, "{}", context);
1636                assert_eq!(row.get(usize::MAX), None, "{}", context);
1637            }
1638
1639            // RowMut::get OOB.
1640            let row = transpose.get_row_mut(0).unwrap();
1641            assert_eq!(row.get(ncols), None, "{}", context);
1642
1643            // RowMut::get_mut — mutate and verify.
1644            let mut row = transpose.get_row_mut(0).unwrap();
1645            let sentinel = gen_element(usize::MAX / 2);
1646            let original = row[0];
1647            if let Some(v) = row.get_mut(0) {
1648                *v = sentinel;
1649            }
1650            assert_eq!(row.get_mut(ncols), None, "{}", context);
1651            // Explicit scope end so the mutable borrow is released before the next access.
1652            let _ = row;
1653            assert_eq!(transpose.get_element(0, 0), sentinel, "{}", context);
1654            // Restore original.
1655            transpose.get_row_mut(0).unwrap().set(0, original);
1656        }
1657
1658        // ── Zero out via block_mut / remainder_block_mut ─────────
1659
1660        for b in 0..transpose.full_blocks() {
1661            transpose.block_mut(b).as_mut_slice().fill(T::default());
1662        }
1663        if transpose.remainder() != 0 {
1664            transpose
1665                .remainder_block_mut()
1666                .unwrap()
1667                .as_mut_slice()
1668                .fill(T::default());
1669        }
1670        assert!(
1671            transpose.as_slice().iter().all(|v| *v == T::default()),
1672            "not fully zeroed -- {}",
1673            context,
1674        );
1675
1676        // ── Padding verification (fresh construction) ────────────
1677
1678        let transpose = BlockTransposed::<T, GROUP, PACK>::from_strided(data.as_view().into());
1679        let raw = transpose.as_slice();
1680
1681        // Column padding.
1682        for row in 0..nrows {
1683            for col in ncols..expected_padded {
1684                let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1685                assert_eq!(
1686                    raw[idx],
1687                    T::default(),
1688                    "col padding at ({}, {}) -- {}",
1689                    row,
1690                    col,
1691                    context,
1692                );
1693            }
1694        }
1695
1696        // Row padding (within partial blocks).
1697        let padded_nrows = nrows.next_multiple_of(GROUP);
1698        for row in nrows..padded_nrows {
1699            for col in 0..expected_padded {
1700                let idx = linear_index::<GROUP, PACK>(row, col, ncols);
1701                assert_eq!(
1702                    raw[idx],
1703                    T::default(),
1704                    "row padding at ({}, {}) -- {}",
1705                    row,
1706                    col,
1707                    context,
1708                );
1709            }
1710        }
1711
1712        // ── padded_nrows() returns padded row count ──────────────
1713
1714        assert_eq!(
1715            transpose.as_view().padded_nrows(),
1716            padded_nrows,
1717            "padded_nrows() mismatch -- {}",
1718            context,
1719        );
1720
1721        // ── from_matrix_view produces identical results ──────────
1722
1723        if nrows > 0 && ncols > 0 {
1724            let via_matrix = BlockTransposed::<T, GROUP, PACK>::from_matrix_view(data.as_view());
1725            assert_eq!(via_matrix.as_slice(), transpose.as_slice(), "{}", context);
1726        }
1727    }
1728
1729    // ════════════════════════════════════════════════════════════════
1730    // Test runners — each combination gets the full API surface.
1731    // ════════════════════════════════════════════════════════════════
1732
1733    #[test]
1734    fn test_api_pack1_group16() {
1735        // Miri: boundary rows around GROUP=16 block transitions;
1736        // full run: exhaustive sweep.
1737        let rows: Vec<usize> = if cfg!(miri) {
1738            vec![0, 1, 15, 16, 17, 33]
1739        } else {
1740            (0..128).collect()
1741        };
1742        let cols: Vec<usize> = if cfg!(miri) {
1743            vec![0, 1, 2]
1744        } else {
1745            (0..5).collect()
1746        };
1747        for &nrows in &rows {
1748            for &ncols in &cols {
1749                test_full_api::<f32, 16, 1>(nrows, ncols, gen_f32);
1750            }
1751        }
1752    }
1753
1754    #[test]
1755    fn test_api_pack1_group8() {
1756        // Miri: boundary rows around GROUP=8 block transitions;
1757        // full run: exhaustive sweep.
1758        let rows: Vec<usize> = if cfg!(miri) {
1759            vec![0, 1, 7, 8, 9, 17]
1760        } else {
1761            (0..128).collect()
1762        };
1763        let cols: Vec<usize> = if cfg!(miri) {
1764            vec![0, 1, 2]
1765        } else {
1766            (0..5).collect()
1767        };
1768        for &nrows in &rows {
1769            for &ncols in &cols {
1770                test_full_api::<f32, 8, 1>(nrows, ncols, gen_f32);
1771            }
1772        }
1773    }
1774
1775    #[test]
1776    fn test_api_pack2() {
1777        // Miri: boundary rows around GROUP=4/8/16 transitions;
1778        // cols hit PACK=2 boundary (even/odd). Full run: exhaustive.
1779        let rows: Vec<usize> = if cfg!(miri) {
1780            vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1781        } else {
1782            (0..48).collect()
1783        };
1784        let cols: Vec<usize> = if cfg!(miri) {
1785            vec![0, 1, 2, 3, 4, 5]
1786        } else {
1787            (0..9).collect()
1788        };
1789        for &nrows in &rows {
1790            for &ncols in &cols {
1791                test_full_api::<f32, 4, 2>(nrows, ncols, gen_f32);
1792                test_full_api::<f32, 8, 2>(nrows, ncols, gen_f32);
1793                test_full_api::<f32, 16, 2>(nrows, ncols, gen_f32);
1794            }
1795        }
1796    }
1797
1798    #[test]
1799    fn test_api_pack4() {
1800        // Miri: boundary rows around GROUP=4/8/16 transitions;
1801        // cols hit PACK=4 boundary (0,1,3,4,5,8). Full run: exhaustive.
1802        let rows: Vec<usize> = if cfg!(miri) {
1803            vec![0, 1, 3, 4, 5, 7, 8, 9, 15, 16, 17]
1804        } else {
1805            (0..48).collect()
1806        };
1807        let cols: Vec<usize> = if cfg!(miri) {
1808            vec![0, 1, 3, 4, 5, 8]
1809        } else {
1810            (0..9).collect()
1811        };
1812        for &nrows in &rows {
1813            for &ncols in &cols {
1814                test_full_api::<f32, 4, 4>(nrows, ncols, gen_f32);
1815                test_full_api::<f32, 8, 4>(nrows, ncols, gen_f32);
1816                test_full_api::<f32, 16, 4>(nrows, ncols, gen_f32);
1817            }
1818        }
1819    }
1820
1821    /// Exercise the unified test with non-`f32` element types.
1822    #[test]
1823    fn test_api_non_f32() {
1824        // i32:  PACK=1 and PACK=2
1825        test_full_api::<i32, 4, 1>(10, 7, gen_i32);
1826        test_full_api::<i32, 8, 2>(12, 5, gen_i32);
1827
1828        // u8:   PACK=1 and PACK=2
1829        test_full_api::<u8, 4, 2>(12, 5, gen_u8);
1830        test_full_api::<u8, 8, 1>(10, 7, gen_u8);
1831    }
1832
1833    // ════════════════════════════════════════════════════════════════
1834    // Block layout verification (PACK=1 only)
1835    // ════════════════════════════════════════════════════════════════
1836
1837    /// Verify that for PACK=1, each block is the standard row-to-column
1838    /// transposition of a GROUP-row slice of the source matrix.
1839    fn test_block_layout_pack1<
1840        T: Copy + Default + PartialEq + std::fmt::Debug + 'static,
1841        const GROUP: usize,
1842    >(
1843        nrows: usize,
1844        ncols: usize,
1845        gen_element: fn(usize) -> T,
1846    ) {
1847        let mut data = Matrix::new(T::default(), nrows, ncols);
1848        data.as_mut_slice()
1849            .iter_mut()
1850            .enumerate()
1851            .for_each(|(i, d)| *d = gen_element(i));
1852
1853        let transpose = BlockTransposed::<T, GROUP, 1>::from_strided(data.as_view().into());
1854
1855        // Full blocks.
1856        for b in 0..transpose.full_blocks() {
1857            let block = transpose.block(b);
1858            for i in 0..block.nrows() {
1859                for j in 0..block.ncols() {
1860                    assert_eq!(
1861                        block[(i, j)],
1862                        data[(GROUP * b + j, i)],
1863                        "block {} at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1864                        b,
1865                        i,
1866                        j,
1867                        GROUP,
1868                        nrows,
1869                        ncols,
1870                    );
1871                }
1872            }
1873        }
1874
1875        // Remainder block.
1876        if transpose.remainder() != 0 {
1877            let fb = transpose.full_blocks();
1878            let block = transpose.remainder_block().unwrap();
1879            for i in 0..block.nrows() {
1880                for j in 0..transpose.remainder() {
1881                    assert_eq!(
1882                        block[(i, j)],
1883                        data[(GROUP * fb + j, i)],
1884                        "remainder at ({}, {}) -- GROUP={}, nrows={}, ncols={}",
1885                        i,
1886                        j,
1887                        GROUP,
1888                        nrows,
1889                        ncols,
1890                    );
1891                }
1892            }
1893        }
1894    }
1895
1896    #[test]
1897    fn test_block_layout_pack1_group16() {
1898        let rows: Vec<usize> = if cfg!(miri) {
1899            vec![0, 1, 15, 16, 17, 33]
1900        } else {
1901            (0..128).collect()
1902        };
1903        let cols: Vec<usize> = if cfg!(miri) {
1904            vec![0, 1, 2]
1905        } else {
1906            (0..5).collect()
1907        };
1908        for &nrows in &rows {
1909            for &ncols in &cols {
1910                test_block_layout_pack1::<f32, 16>(nrows, ncols, gen_f32);
1911            }
1912        }
1913    }
1914
1915    #[test]
1916    fn test_block_layout_pack1_group8() {
1917        let rows: Vec<usize> = if cfg!(miri) {
1918            vec![0, 1, 7, 8, 9, 17]
1919        } else {
1920            (0..128).collect()
1921        };
1922        let cols: Vec<usize> = if cfg!(miri) {
1923            vec![0, 1, 2]
1924        } else {
1925            (0..5).collect()
1926        };
1927        for &nrows in &rows {
1928            for &ncols in &cols {
1929                test_block_layout_pack1::<f32, 8>(nrows, ncols, gen_f32);
1930            }
1931        }
1932    }
1933
1934    // ════════════════════════════════════════════════════════════════
1935    // Focused tests (not part of the unified parameterized test)
1936    // ════════════════════════════════════════════════════════════════
1937
1938    // ── Send / Sync static assertions ───────────────────────────────
1939
1940    #[test]
1941    fn test_row_view_send_sync() {
1942        fn assert_send<T: Send>() {}
1943        fn assert_sync<T: Sync>() {}
1944
1945        assert_send::<Row<'_, f32, 16>>();
1946        assert_sync::<Row<'_, f32, 16>>();
1947        assert_send::<Row<'_, u8, 8, 2>>();
1948        assert_sync::<Row<'_, u8, 8, 2>>();
1949
1950        assert_send::<RowMut<'_, f32, 16>>();
1951        assert_sync::<RowMut<'_, f32, 16>>();
1952        assert_send::<RowMut<'_, i32, 4, 4>>();
1953        assert_sync::<RowMut<'_, i32, 4, 4>>();
1954    }
1955
1956    // ── NewRef / NewMut from raw slices ─────────────────────────────
1957
1958    #[test]
1959    fn test_new_ref_and_new_mut() {
1960        let nrows = 5;
1961        let ncols = 3;
1962        let repr = BlockTransposedRepr::<f32, 4>::new(nrows, ncols).unwrap();
1963
1964        let mat = BlockTransposed::<f32, 4>::new(nrows, ncols);
1965        let raw: &[f32] = mat.as_slice();
1966
1967        let mat_ref = BlockTransposedRef {
1968            data: repr.new_ref(raw).unwrap(),
1969        };
1970        assert_eq!(mat_ref.nrows(), nrows);
1971        assert_eq!(mat_ref.ncols(), ncols);
1972        for row in 0..nrows {
1973            for col in 0..ncols {
1974                assert_eq!(mat_ref.get_element(row, col), mat.get_element(row, col));
1975            }
1976        }
1977
1978        let mut buf = raw.to_vec();
1979        let mat_mut = BlockTransposedMut {
1980            data: repr.new_mut(&mut buf).unwrap(),
1981        };
1982        assert_eq!(mat_mut.nrows(), nrows);
1983        assert_eq!(mat_mut.ncols(), ncols);
1984
1985        // Wrong-length slice should fail.
1986        let mut short = vec![0.0_f32; 2];
1987        assert!(repr.new_ref(&short).is_err());
1988        assert!(repr.new_mut(&mut short).is_err());
1989    }
1990
1991    // ── Row view edge cases ─────────────────────────────────────────
1992
1993    #[test]
1994    fn test_row_view_empty() {
1995        /// Verify that immutable and mutable empty-row views are sound for a
1996        /// given `GROUP`/`PACK` combination.
1997        fn check_empty<const GROUP: usize, const PACK: usize>() {
1998            let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(4, 0);
1999
2000            // Immutable views.
2001            let view = mat.as_view();
2002            for i in 0..4 {
2003                let row = view.get_row(i).unwrap();
2004                assert!(row.is_empty());
2005                assert_eq!(row.len(), 0);
2006                assert_eq!(row.iter().count(), 0);
2007            }
2008
2009            // Mutable views.
2010            for i in 0..4 {
2011                let row = mat.get_row_mut(i).unwrap();
2012                assert!(row.is_empty());
2013                assert_eq!(row.len(), 0);
2014            }
2015        }
2016
2017        check_empty::<16, 1>(); // default PACK
2018        check_empty::<4, 2>(); // PACK > 1
2019        check_empty::<4, 4>(); // PACK == GROUP
2020    }
2021
2022    // ── Bounds-checking panic tests ─────────────────────────────────
2023
2024    #[test]
2025    #[should_panic(expected = "column index 3 out of bounds")]
2026    fn test_row_view_index_oob() {
2027        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2028        let view = mat.as_view();
2029        let row = view.get_row(0).unwrap();
2030        let _ = row[3];
2031    }
2032
2033    #[test]
2034    #[should_panic(expected = "column index 3 out of bounds")]
2035    fn test_row_view_mut_index_oob() {
2036        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2037        let row = mat.get_row_mut(0).unwrap();
2038        let _ = row[3];
2039    }
2040
2041    #[test]
2042    #[should_panic(expected = "column index 3 out of bounds")]
2043    fn test_row_view_mut_index_mut_oob() {
2044        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2045        let mut row = mat.get_row_mut(0).unwrap();
2046        row[3] = 1.0;
2047    }
2048
2049    #[test]
2050    #[should_panic(expected = "column index 3 out of bounds")]
2051    fn test_row_view_set_oob() {
2052        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2053        let mut row = mat.get_row_mut(0).unwrap();
2054        row.set(3, 1.0);
2055    }
2056
2057    #[test]
2058    #[should_panic(expected = "row 4 out of bounds")]
2059    fn test_get_element_row_oob() {
2060        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2061        mat.get_element(4, 0);
2062    }
2063
2064    #[test]
2065    #[should_panic(expected = "col 3 out of bounds")]
2066    fn test_get_element_col_oob() {
2067        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2068        mat.get_element(0, 3);
2069    }
2070
2071    #[test]
2072    #[should_panic(expected = "assertion failed")]
2073    fn test_index_tuple_row_oob() {
2074        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2075        let _ = mat[(4, 0)];
2076    }
2077
2078    #[test]
2079    #[should_panic(expected = "assertion failed")]
2080    fn test_index_tuple_col_oob() {
2081        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2082        let _ = mat[(0, 3)];
2083    }
2084
2085    #[test]
2086    #[should_panic]
2087    fn test_block_oob() {
2088        let mat = BlockTransposed::<f32, 4>::new(4, 3);
2089        let _ = mat.block(1);
2090    }
2091
2092    #[test]
2093    #[should_panic]
2094    fn test_block_mut_oob() {
2095        let mut mat = BlockTransposed::<f32, 4>::new(4, 3);
2096        let _ = mat.block_mut(1);
2097    }
2098
2099    // ── from_strided with non-unit stride ───────────────────────────
2100
2101    #[test]
2102    fn test_from_strided_nonunit_stride() {
2103        use diskann_utils::strided::StridedView;
2104
2105        const GROUP: usize = 4;
2106        const PACK: usize = 2;
2107        let nrows = 5;
2108        let ncols = 3;
2109        let cstride = 8;
2110
2111        let required_len = (nrows - 1) * cstride + ncols;
2112        let mut flat = vec![0.0_f32; required_len];
2113        for row in 0..nrows {
2114            for col in 0..ncols {
2115                flat[row * cstride + col] = (row * 100 + col + 1) as f32;
2116            }
2117        }
2118
2119        let strided = StridedView::try_shrink_from(&flat, nrows, ncols, cstride)
2120            .expect("should construct strided view");
2121        let transpose = BlockTransposed::<f32, GROUP, PACK>::from_strided(strided);
2122
2123        assert_eq!(transpose.nrows(), nrows);
2124        assert_eq!(transpose.ncols(), ncols);
2125
2126        for row in 0..nrows {
2127            for col in 0..ncols {
2128                let expected = (row * 100 + col + 1) as f32;
2129                assert_eq!(
2130                    transpose[(row, col)],
2131                    expected,
2132                    "mismatch at ({}, {})",
2133                    row,
2134                    col,
2135                );
2136            }
2137        }
2138
2139        let padded_ncols = ncols.next_multiple_of(PACK);
2140        let raw: &[f32] = transpose.as_slice();
2141        for row in 0..nrows {
2142            for col in ncols..padded_ncols {
2143                let idx = linear_index::<GROUP, PACK>(row, col, ncols);
2144                assert_eq!(
2145                    raw[idx], 0.0,
2146                    "column-padding at ({}, {}) should be zero",
2147                    row, col,
2148                );
2149            }
2150        }
2151    }
2152
2153    // ── Concurrent multi-row mutation ───────────────────────────────
2154
2155    #[test]
2156    fn test_concurrent_row_mutation() {
2157        const GROUP: usize = 8;
2158        const PACK: usize = 2;
2159
2160        let (nrows, ncols, num_threads) = if cfg!(miri) { (8, 4, 2) } else { (64, 16, 4) };
2161
2162        let mut mat = BlockTransposed::<f32, GROUP, PACK>::new(nrows, ncols);
2163        let rows: Vec<RowMut<'_, f32, GROUP, PACK>> = mat.data.rows_mut().collect();
2164        let rows_per_thread = nrows / num_threads;
2165        let mut rows = rows.into_boxed_slice();
2166
2167        std::thread::scope(|s| {
2168            let mut remaining = &mut rows[..];
2169            for thread_id in 0..num_threads {
2170                let chunk_len = if thread_id == num_threads - 1 {
2171                    remaining.len()
2172                } else {
2173                    rows_per_thread
2174                };
2175                let (chunk, rest) = remaining.split_at_mut(chunk_len);
2176                remaining = rest;
2177                let start_row = thread_id * rows_per_thread;
2178
2179                s.spawn(move || {
2180                    for (offset, row_view) in chunk.iter_mut().enumerate() {
2181                        let row = start_row + offset;
2182                        for col in 0..ncols {
2183                            let value = (thread_id * 10000 + row * 100 + col) as f32;
2184                            row_view.set(col, value);
2185                        }
2186                    }
2187                });
2188            }
2189        });
2190
2191        for row in 0..nrows {
2192            let thread_id = (row / rows_per_thread).min(num_threads - 1);
2193            for col in 0..ncols {
2194                let expected = (thread_id * 10000 + row * 100 + col) as f32;
2195                assert_eq!(
2196                    mat.get_element(row, col),
2197                    expected,
2198                    "mismatch at ({}, {})",
2199                    row,
2200                    col,
2201                );
2202            }
2203        }
2204    }
2205}