Skip to main content

diskann_quantization/multi_vector/
matrix.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Row-major matrix types for multi-vector representations.
7//!
8//! This module provides flexible matrix abstractions that support different underlying
9//! storage formats through the [`Repr`] trait. The primary types are:
10//!
11//! - [`Mat`]: An owning matrix that manages its own memory.
12//! - [`MatRef`]: An immutable borrowed view of matrix data.
13//! - [`MatMut`]: A mutable borrowed view of matrix data.
14//!
15//! # Representations
16//!
17//! Representation types interact with the [`Mat`] family of types using the following traits:
18//!
19//! - [`Repr`]: Read-only matrix representation.
20//! - [`ReprMut`]: Mutable matrix representation.
21//! - [`ReprOwned`]: Owning matrix representation.
22//!
23//! Each trait refinement has a corresponding constructor:
24//!
25//! - [`NewRef`]: Construct a read-only [`MatRef`] view over a slice.
26//! - [`NewMut`]: Construct a mutable [`MatMut`] matrix view over a slice.
27//! - [`NewOwned`]: Construct a new owning [`Mat`].
28//!
29
30use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
31
32use diskann_utils::{Reborrow, ReborrowMut};
33use thiserror::Error;
34
35use crate::utils;
36
37/// Representation trait describing the layout and access patterns for a matrix.
38///
39/// Implementations define how raw bytes are interpreted as typed rows. This enables
40/// matrices over different storage formats (dense, quantized, etc.) using a single
41/// generic [`Mat`] type.
42///
43/// # Associated Types
44///
45/// - `Row<'a>`: The immutable row type (e.g., `&[f32]`, `&[f16]`).
46///
47/// # Safety
48///
49/// Implementations must ensure:
50///
51/// - [`get_row`](Self::get_row) returns valid references for the given row index.
52///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
53///   the contract for the raw pointer.
54///
55/// - The objects implicitly managed by this representation inherit the `Send` and `Sync`
56///   attributes of `Repr`. That is, `Repr: Send` implies that the objects in backing memory
57///   are [`Send`], and likewise with `Sync`. This is necessary to apply [`Send`] and [`Sync`]
58///   bounds to [`Mat`], [`MatRef`], and [`MatMut`].
59pub unsafe trait Repr: Copy {
60    /// Immutable row reference type.
61    type Row<'a>
62    where
63        Self: 'a;
64
65    /// Returns the number of rows in the matrix.
66    ///
67    /// # Safety Contract
68    ///
69    /// This function must be loosely pure in the sense that for any given instance of
70    /// `self`, `self.nrows()` must return the same value.
71    fn nrows(&self) -> usize;
72
73    /// Returns the memory layout for a memory allocation containing [`Repr::nrows`] vectors
74    /// each with vector dimension [`Repr::ncols`].
75    ///
76    /// # Safety Contract
77    ///
78    /// The [`Layout`] returned from this method must be consistent with the contract of
79    /// [`Repr::get_row`].
80    fn layout(&self) -> Result<Layout, LayoutError>;
81
82    /// Returns an immutable reference to the `i`-th row.
83    ///
84    /// # Safety
85    ///
86    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
87    /// - The entire range for this slice must be within a single allocation.
88    /// - `i` must be less than [`Repr::nrows`].
89    /// - The memory referenced by the returned [`Repr::Row`] must not be mutated for the
90    ///   duration of lifetime `'a`.
91    /// - The lifetime for the returned [`Repr::Row`] is inferred from its usage. Correct
92    ///   usage must properly tie the lifetime to a source.
93    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
94}
95
96/// Extension of [`Repr`] that supports mutable row access.
97///
98/// # Associated Types
99///
100/// - `RowMut<'a>`: The mutable row type (e.g., `&mut [f32]`).
101///
102/// # Safety
103///
104/// Implementors must ensure:
105///
106/// - [`get_row_mut`](Self::get_row_mut) returns valid references for the given row index.
107///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
108///   the contract for the raw pointer.
109///
110///   Additionally, since the implementation of the [`RowsMut`] iterator can give out rows
111///   for all `i` in `0..self.nrows()`, the implementation of [`Self::get_row_mut`] must be
112///   such that the result for disjoint `i` must not interfere with one another.
113pub unsafe trait ReprMut: Repr {
114    /// Mutable row reference type.
115    type RowMut<'a>
116    where
117        Self: 'a;
118
119    /// Returns a mutable reference to the i-th row.
120    ///
121    /// # Safety
122    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
123    /// - The entire range for this slice must be within a single allocation.
124    /// - `i` must be less than `self.nrows()`.
125    /// - The memory referenced by the returned [`ReprMut::RowMut`] must not be accessed
126    ///   through any other reference for the duration of lifetime `'a`.
127    /// - The lifetime for the returned [`ReprMut::RowMut`] is inferred from its usage.
128    ///   Correct usage must properly tie the lifetime to a source.
129    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
130}
131
132/// Extension trait for [`Repr`] that supports deallocation of owned matrices. This is used
133/// in conjunction with [`NewOwned`] to create matrices.
134///
135/// Requires [`ReprMut`] since owned matrices should support mutation.
136///
137/// # Safety
138///
139/// Implementors must ensure that `drop` properly deallocates the memory in a way compatible
140/// with all [`NewOwned`] implementations.
141pub unsafe trait ReprOwned: ReprMut {
142    /// Deallocates memory at `ptr` and drops `self`.
143    ///
144    /// # Safety
145    ///
146    /// - `ptr` must have been obtained via [`NewOwned`] with the same value of `self`.
147    /// - This method may only be called once for such a pointer.
148    /// - After calling this method, the memory behind `ptr` may not be dereferenced at all.
149    unsafe fn drop(self, ptr: NonNull<u8>);
150}
151
152/// A new-type version of `std::alloc::LayoutError` for cleaner error handling.
153///
154/// This is basically the same as [`std::alloc::LayoutError`], but constructible in
155/// use code to allow implementors of [`Repr::layout`] to return it for reasons other than
156/// those derived from `std::alloc::Layout`'s methods.
157#[derive(Debug, Clone, Copy)]
158#[non_exhaustive]
159pub struct LayoutError;
160
161impl LayoutError {
162    /// Construct a new opaque [`LayoutError`].
163    pub fn new() -> Self {
164        Self
165    }
166}
167
168impl Default for LayoutError {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl std::fmt::Display for LayoutError {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        write!(f, "LayoutError")
177    }
178}
179
180impl std::error::Error for LayoutError {}
181
182impl From<std::alloc::LayoutError> for LayoutError {
183    fn from(_: std::alloc::LayoutError) -> Self {
184        LayoutError
185    }
186}
187
188//////////////////
189// Constructors //
190//////////////////
191
192/// Create a new [`MatRef`] over a slice.
193///
194/// # Safety
195///
196/// Implementations must validate the length (and any other requirements) of the provided
197/// slice to ensure it is compatible with the implementation of [`Repr`].
198pub unsafe trait NewRef<T>: Repr {
199    /// Errors that can occur when initializing.
200    type Error;
201
202    /// Create a new [`MatRef`] over `slice`.
203    fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
204}
205
206/// Create a new [`MatMut`] over a slice.
207///
208/// # Safety
209///
210/// Implementations must validate the length (and any other requirements) of the provided
211/// slice to ensure it is compatible with the implementation of [`ReprMut`].
212pub unsafe trait NewMut<T>: ReprMut {
213    /// Errors that can occur when initializing.
214    type Error;
215
216    /// Create a new [`MatMut`] over `slice`.
217    fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
218}
219
220/// Create a new [`Mat`] from an initializer.
221///
222/// # Safety
223///
224/// Implementations must ensure that the returned [`Mat`] is compatible with
225/// `Self`'s implementation of [`ReprOwned`].
226pub unsafe trait NewOwned<T>: ReprOwned {
227    /// Errors that can occur when initializing.
228    type Error;
229
230    /// Create a new [`Mat`] initialized with `init`.
231    fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
232}
233
234/// An initializer argument to [`NewOwned`] that uses a type's [`Default`] implementation
235/// to initialize a matrix.
236///
237/// ```rust
238/// use diskann_quantization::multi_vector::{Mat, Standard, Defaulted};
239/// let mat = Mat::new(Standard::<f32>::new(4, 3).unwrap(), Defaulted).unwrap();
240/// for i in 0..4 {
241///     assert!(mat.get_row(i).unwrap().iter().all(|&x| x == 0.0f32));
242/// }
243/// ```
244#[derive(Debug, Clone, Copy)]
245pub struct Defaulted;
246
247/// Create a new [`Mat`] cloned from a view.
248pub trait NewCloned: ReprOwned {
249    /// Clone the contents behind `v`, returning a new owning [`Mat`].
250    ///
251    /// Implementations should ensure the returned [`Mat`] is "semantically the same" as `v`.
252    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
253}
254
255//////////////
256// Standard //
257//////////////
258
259/// Metadata for dense row-major matrices of `Copy` types.
260///
261/// Rows are stored contiguously as `&[T]` slices. This is the default representation
262/// type for standard floating-point multi-vectors.
263///
264/// # Row Types
265///
266/// - `Row<'a>`: `&'a [T]`
267/// - `RowMut<'a>`: `&'a mut [T]`
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct Standard<T> {
270    nrows: usize,
271    ncols: usize,
272    _elem: PhantomData<T>,
273}
274
275impl<T: Copy> Standard<T> {
276    /// Create a new `Standard` for data of type `T`.
277    ///
278    /// Successful construction requires:
279    ///
280    /// * The total number of elements determined by `nrows * ncols` does not exceed
281    ///   `usize::MAX`.
282    /// * The total memory footprint defined by `ncols * nrows * size_of::<T>()` does not
283    ///   exceed `isize::MAX`.
284    pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
285        Overflow::check::<T>(nrows, ncols)?;
286        Ok(Self {
287            nrows,
288            ncols,
289            _elem: PhantomData,
290        })
291    }
292
293    /// Returns the number of total elements (`rows x cols`) in this matrix.
294    pub fn num_elements(&self) -> usize {
295        // Since we've constructed `self` - we know we cannot overflow.
296        self.nrows() * self.ncols()
297    }
298
299    /// Returns `rows`, the number of rows in this matrix.
300    fn nrows(&self) -> usize {
301        self.nrows
302    }
303
304    /// Returns `ncols`, the number of elements in a row of this matrix.
305    fn ncols(&self) -> usize {
306        self.ncols
307    }
308
309    /// Checks the following:
310    ///
311    /// 1. Computation of the number of elements in `self` does not overflow.
312    /// 2. Argument `slice` has the expected number of elements.
313    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
314        let len = self.num_elements();
315
316        if slice.len() != len {
317            Err(SliceError::LengthMismatch {
318                expected: len,
319                found: slice.len(),
320            })
321        } else {
322            Ok(())
323        }
324    }
325
326    /// Create a new [`Mat`] around the contents of `b` **without** any checks.
327    ///
328    /// # Safety
329    ///
330    /// The length of `b` must be exactly [`Standard::num_elements`].
331    unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
332        debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
333
334        let ptr = utils::box_into_nonnull(b).cast::<u8>();
335
336        // SAFETY: `ptr` is properly aligned and points to a slice of the required length.
337        // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
338        // it from `Box::into_raw`.
339        unsafe { Mat::from_raw_parts(self, ptr) }
340    }
341}
342
343/// Error for [`Standard::new`].
344#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346    nrows: usize,
347    ncols: usize,
348    elsize: usize,
349}
350
351impl Overflow {
352    /// Construct an `Overflow` error for the given dimensions and element type.
353    pub(crate) fn for_type<T>(nrows: usize, ncols: usize) -> Self {
354        Self {
355            nrows,
356            ncols,
357            elsize: std::mem::size_of::<T>(),
358        }
359    }
360
361    /// Verify that `capacity` elements of type `T` fit within the `isize::MAX` byte
362    /// budget required by Rust's allocation APIs.
363    ///
364    /// On failure the error reports the original `(nrows, ncols)` dimensions rather
365    /// than the padded capacity.
366    pub(crate) fn check_byte_budget<T>(
367        capacity: usize,
368        nrows: usize,
369        ncols: usize,
370    ) -> Result<(), Self> {
371        let bytes = std::mem::size_of::<T>().saturating_mul(capacity);
372        if bytes <= isize::MAX as usize {
373            Ok(())
374        } else {
375            Err(Self::for_type::<T>(nrows, ncols))
376        }
377    }
378
379    pub(crate) fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
380        // Guard the element count itself so that `num_elements()` can never overflow.
381        let capacity = nrows
382            .checked_mul(ncols)
383            .ok_or_else(|| Self::for_type::<T>(nrows, ncols))?;
384
385        Self::check_byte_budget::<T>(capacity, nrows, ncols)
386    }
387}
388
389impl std::fmt::Display for Overflow {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        if self.elsize == 0 {
392            write!(
393                f,
394                "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
395                self.nrows, self.ncols,
396            )
397        } else {
398            write!(
399                f,
400                "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
401                self.nrows, self.ncols, self.elsize,
402            )
403        }
404    }
405}
406
407impl std::error::Error for Overflow {}
408
409/// Error types for [`Standard`].
410#[derive(Debug, Clone, Copy, Error)]
411#[non_exhaustive]
412pub enum SliceError {
413    #[error("Length mismatch: expected {expected}, found {found}")]
414    LengthMismatch { expected: usize, found: usize },
415}
416
417// SAFETY: The implementation correctly computes row offsets as `i * ncols` and
418// constructs valid slices of the appropriate length. The `layout` method correctly
419// reports the memory layout requirements.
420unsafe impl<T: Copy> Repr for Standard<T> {
421    type Row<'a>
422        = &'a [T]
423    where
424        T: 'a;
425
426    fn nrows(&self) -> usize {
427        self.nrows
428    }
429
430    fn layout(&self) -> Result<Layout, LayoutError> {
431        Ok(Layout::array::<T>(self.num_elements())?)
432    }
433
434    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
435        debug_assert!(ptr.cast::<T>().is_aligned());
436        debug_assert!(i < self.nrows);
437
438        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
439        // audits the constructors for `Mat` and friends, we know that there is room for at
440        // least `self.num_elements()` elements from the base pointer, so this access is safe.
441        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
442
443        // SAFETY: The logic is the same as the previous `unsafe` block.
444        unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
445    }
446}
447
448// SAFETY: The implementation correctly computes row offsets and constructs valid mutable
449// slices.
450unsafe impl<T: Copy> ReprMut for Standard<T> {
451    type RowMut<'a>
452        = &'a mut [T]
453    where
454        T: 'a;
455
456    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
457        debug_assert!(ptr.cast::<T>().is_aligned());
458        debug_assert!(i < self.nrows);
459
460        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
461        // audits the constructors for `Mat` and friends, we know that there is room for at
462        // least `self.num_elements()` elements from the base pointer, so this access is safe.
463        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
464
465        // SAFETY: The logic is the same as the previous `unsafe` block. Further, the caller
466        // attests that creating a mutable reference is safe.
467        unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
468    }
469}
470
471// SAFETY: The drop implementation correctly reconstructs a Box from the raw pointer
472// using the same length (nrows * ncols) that was used for allocation, allowing Box
473// to properly deallocate the memory.
474unsafe impl<T: Copy> ReprOwned for Standard<T> {
475    unsafe fn drop(self, ptr: NonNull<u8>) {
476        // SAFETY: The caller guarantees that `ptr` was obtained from an implementation of
477        // `NewOwned` for an equivalent instance of `self`.
478        //
479        // We ensure that `NewOwned` goes through boxes, so here we reconstruct a Box to
480        // let it handle deallocation.
481        unsafe {
482            let slice_ptr = std::ptr::slice_from_raw_parts_mut(
483                ptr.cast::<T>().as_ptr(),
484                self.nrows * self.ncols,
485            );
486            let _ = Box::from_raw(slice_ptr);
487        }
488    }
489}
490
491// SAFETY: The implementation uses guarantees from `Box` to ensure that the pointer
492// initialized by it is non-null and properly aligned to the underlying type.
493unsafe impl<T> NewOwned<T> for Standard<T>
494where
495    T: Copy,
496{
497    type Error = crate::error::Infallible;
498    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
499        let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
500
501        // SAFETY: By construction, `b` has length `self.num_elements()`.
502        Ok(unsafe { self.box_to_mat(b) })
503    }
504}
505
506// SAFETY: This safely reuses `<Self as NewOwned<T>>`.
507unsafe impl<T> NewOwned<Defaulted> for Standard<T>
508where
509    T: Copy + Default,
510{
511    type Error = crate::error::Infallible;
512    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
513        self.new_owned(T::default())
514    }
515}
516
517// SAFETY: This checks that the slice has the correct length, which is all that is
518// required for [`Repr`].
519unsafe impl<T> NewRef<T> for Standard<T>
520where
521    T: Copy,
522{
523    type Error = SliceError;
524    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
525        self.check_slice(data)?;
526
527        // SAFETY: The function `check_slice` verifies that `data` is compatible with
528        // the layout requirement of `Standard`.
529        //
530        // We've properly checked that the underlying pointer is okay.
531        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
532    }
533}
534
535// SAFETY: This checks that the slice has the correct length, which is all that is
536// required for [`ReprMut`].
537unsafe impl<T> NewMut<T> for Standard<T>
538where
539    T: Copy,
540{
541    type Error = SliceError;
542    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
543        self.check_slice(data)?;
544
545        // SAFETY: The function `check_slice` verifies that `data` is compatible with
546        // the layout requirement of `Standard`.
547        //
548        // We've properly checked that the underlying pointer is okay.
549        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
550    }
551}
552
553impl<T> NewCloned for Standard<T>
554where
555    T: Copy,
556{
557    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
558        let b: Box<[T]> = v.rows().flatten().copied().collect();
559
560        // SAFETY: By construction, `b` has length `v.repr().num_elements()`.
561        unsafe { v.repr().box_to_mat(b) }
562    }
563}
564
565/////////
566// Mat //
567/////////
568
569/// An owning matrix that manages its own memory.
570///
571/// The matrix stores raw bytes interpreted according to representation type `T`.
572/// Memory is automatically deallocated when the matrix is dropped.
573#[derive(Debug)]
574pub struct Mat<T: ReprOwned> {
575    ptr: NonNull<u8>,
576    repr: T,
577    _invariant: PhantomData<fn(T) -> T>,
578}
579
580// SAFETY: [`Repr`] is required to propagate its `Send` bound.
581unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
582
583// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
584unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
585
586impl<T: ReprOwned> Mat<T> {
587    /// Create a new matrix using `init` as the initializer.
588    pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
589    where
590        T: NewOwned<U>,
591    {
592        repr.new_owned(init)
593    }
594
595    /// Returns the number of rows (vectors) in the matrix.
596    #[inline]
597    pub fn num_vectors(&self) -> usize {
598        self.repr.nrows()
599    }
600
601    /// Returns a reference to the underlying representation.
602    pub fn repr(&self) -> &T {
603        &self.repr
604    }
605
606    /// Returns the `i`th row if `i < self.num_vectors()`.
607    #[must_use]
608    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
609        if i < self.num_vectors() {
610            // SAFETY: Bounds check passed, and the Mat was constructed
611            // with valid representation and pointer.
612            let row = unsafe { self.get_row_unchecked(i) };
613            Some(row)
614        } else {
615            None
616        }
617    }
618
619    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
620        // SAFETY: Caller must ensure i < self.num_vectors(). The constructors for this type
621        // ensure that `ptr` is compatible with `T`.
622        unsafe { self.repr.get_row(self.ptr, i) }
623    }
624
625    /// Returns the `i`th mutable row if `i < self.num_vectors()`.
626    #[must_use]
627    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
628        if i < self.num_vectors() {
629            // SAFETY: Bounds check passed, and we have exclusive access via &mut self.
630            Some(unsafe { self.get_row_mut_unchecked(i) })
631        } else {
632            None
633        }
634    }
635
636    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
637        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
638        // type ensure that `ptr` is compatible with `T`.
639        unsafe { self.repr.get_row_mut(self.ptr, i) }
640    }
641
642    /// Returns an immutable view of the matrix.
643    #[inline]
644    pub fn as_view(&self) -> MatRef<'_, T> {
645        MatRef {
646            ptr: self.ptr,
647            repr: self.repr,
648            _lifetime: PhantomData,
649        }
650    }
651
652    /// Returns a mutable view of the matrix.
653    #[inline]
654    pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
655        MatMut {
656            ptr: self.ptr,
657            repr: self.repr,
658            _lifetime: PhantomData,
659        }
660    }
661
662    /// Returns an iterator over immutable row references.
663    pub fn rows(&self) -> Rows<'_, T> {
664        Rows::new(self.reborrow())
665    }
666
667    /// Returns an iterator over mutable row references.
668    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
669        RowsMut::new(self.reborrow_mut())
670    }
671
672    /// Construct a new [`Mat`] over the raw pointer and representation without performing
673    /// any validity checks.
674    ///
675    /// # Safety
676    ///
677    /// Argument `ptr` must be:
678    ///
679    /// 1. Point to memory compatible with [`Repr::layout`].
680    /// 2. Be compatible with the drop logic in [`ReprOwned`].
681    pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
682        Self {
683            ptr,
684            repr,
685            _invariant: PhantomData,
686        }
687    }
688
689    /// Return the base pointer for the [`Mat`].
690    pub fn as_raw_ptr(&self) -> *const u8 {
691        self.ptr.as_ptr()
692    }
693
694    /// Return a mutable base pointer for the [`Mat`].
695    pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
696        self.ptr.as_ptr()
697    }
698}
699
700impl<T: ReprOwned> Drop for Mat<T> {
701    fn drop(&mut self) {
702        // SAFETY: `ptr` was correctly initialized according to `layout`
703        // and we are guaranteed exclusive access to the data due to Rust borrow rules.
704        unsafe { self.repr.drop(self.ptr) };
705    }
706}
707
708impl<T: NewCloned> Clone for Mat<T> {
709    fn clone(&self) -> Self {
710        T::new_cloned(self.as_view())
711    }
712}
713
714impl<T: Copy> Mat<Standard<T>> {
715    /// Returns the raw dimension (columns) of the vectors in the matrix.
716    #[inline]
717    pub fn vector_dim(&self) -> usize {
718        self.repr.ncols()
719    }
720}
721
722////////////
723// MatRef //
724////////////
725
726/// An immutable borrowed view of a matrix.
727///
728/// Provides read-only access to matrix data without ownership. Implements [`Copy`]
729/// and can be freely cloned.
730///
731/// # Type Parameter
732/// - `T`: A [`Repr`] implementation defining the row layout.
733///
734/// # Access
735/// - [`get_row`](Self::get_row): Get an immutable row by index.
736/// - [`rows`](Self::rows): Iterate over all rows.
737#[derive(Debug, Clone, Copy)]
738pub struct MatRef<'a, T: Repr> {
739    ptr: NonNull<u8>,
740    repr: T,
741    /// Marker to tie the lifetime to the borrowed data.
742    _lifetime: PhantomData<&'a T>,
743}
744
745// SAFETY: [`Repr`] is required to propagate its `Send` bound.
746unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
747
748// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
749unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
750
751impl<'a, T: Repr> MatRef<'a, T> {
752    /// Construct a new [`MatRef`] over `data`.
753    pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
754    where
755        T: NewRef<U>,
756    {
757        repr.new_ref(data)
758    }
759
760    /// Returns the number of rows (vectors) in the matrix.
761    #[inline]
762    pub fn num_vectors(&self) -> usize {
763        self.repr.nrows()
764    }
765
766    /// Returns a reference to the underlying representation.
767    pub fn repr(&self) -> &T {
768        &self.repr
769    }
770
771    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
772    #[must_use]
773    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
774        if i < self.num_vectors() {
775            // SAFETY: Bounds check passed, and the MatRef was constructed
776            // with valid representation and pointer.
777            let row = unsafe { self.get_row_unchecked(i) };
778            Some(row)
779        } else {
780            None
781        }
782    }
783
784    /// Returns the i-th row without bounds checking.
785    ///
786    /// # Safety
787    ///
788    /// `i` must be less than `self.num_vectors()`.
789    #[inline]
790    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
791        // SAFETY: Caller must ensure i < self.num_vectors().
792        unsafe { self.repr.get_row(self.ptr, i) }
793    }
794
795    /// Returns an iterator over immutable row references.
796    pub fn rows(&self) -> Rows<'_, T> {
797        Rows::new(*self)
798    }
799
800    /// Return a [`Mat`] with the same contents as `self`.
801    pub fn to_owned(&self) -> Mat<T>
802    where
803        T: NewCloned,
804    {
805        T::new_cloned(*self)
806    }
807
808    /// Construct a new [`MatRef`] over the raw pointer and representation without performing
809    /// any validity checks.
810    ///
811    /// # Safety
812    ///
813    /// Argument `ptr` must point to memory compatible with [`Repr::layout`] and pass any
814    /// validity checks required by `T`.
815    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
816        Self {
817            ptr,
818            repr,
819            _lifetime: PhantomData,
820        }
821    }
822
823    /// Return the base pointer for the [`MatRef`].
824    pub fn as_raw_ptr(&self) -> *const u8 {
825        self.ptr.as_ptr()
826    }
827}
828
829impl<'a, T: Copy> MatRef<'a, Standard<T>> {
830    /// Returns the raw dimension (columns) of the vectors in the matrix.
831    #[inline]
832    pub fn vector_dim(&self) -> usize {
833        self.repr.ncols()
834    }
835}
836
837// Reborrow: Mat -> MatRef
838impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
839    type Target = MatRef<'this, T>;
840
841    fn reborrow(&'this self) -> Self::Target {
842        self.as_view()
843    }
844}
845
846// ReborrowMut: Mat -> MatMut
847impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
848    type Target = MatMut<'this, T>;
849
850    fn reborrow_mut(&'this mut self) -> Self::Target {
851        self.as_view_mut()
852    }
853}
854
855// Reborrow: MatRef -> MatRef (with shorter lifetime)
856impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
857    type Target = MatRef<'this, T>;
858
859    fn reborrow(&'this self) -> Self::Target {
860        MatRef {
861            ptr: self.ptr,
862            repr: self.repr,
863            _lifetime: PhantomData,
864        }
865    }
866}
867
868////////////
869// MatMut //
870////////////
871
872/// A mutable borrowed view of a matrix.
873///
874/// Provides read-write access to matrix data without ownership.
875///
876/// # Type Parameter
877/// - `T`: A [`ReprMut`] implementation defining the row layout.
878///
879/// # Access
880/// - [`get_row`](Self::get_row): Get an immutable row by index.
881/// - [`get_row_mut`](Self::get_row_mut): Get a mutable row by index.
882/// - [`as_view`](Self::as_view): Reborrow as immutable [`MatRef`].
883/// - [`rows`](Self::rows), [`rows_mut`](Self::rows_mut): Iterate over rows.
884#[derive(Debug)]
885pub struct MatMut<'a, T: ReprMut> {
886    ptr: NonNull<u8>,
887    repr: T,
888    /// Marker to tie the lifetime to the mutably borrowed data.
889    _lifetime: PhantomData<&'a mut T>,
890}
891
892// SAFETY: [`ReprMut`] is required to propagate its `Send` bound.
893unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
894
895// SAFETY: [`ReprMut`] is required to propagate its `Sync` bound.
896unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
897
898impl<'a, T: ReprMut> MatMut<'a, T> {
899    /// Construct a new [`MatMut`] over `data`.
900    pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
901    where
902        T: NewMut<U>,
903    {
904        repr.new_mut(data)
905    }
906
907    /// Returns the number of rows (vectors) in the matrix.
908    #[inline]
909    pub fn num_vectors(&self) -> usize {
910        self.repr.nrows()
911    }
912
913    /// Returns a reference to the underlying representation.
914    pub fn repr(&self) -> &T {
915        &self.repr
916    }
917
918    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
919    #[inline]
920    #[must_use]
921    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
922        if i < self.num_vectors() {
923            // SAFETY: Bounds check passed.
924            Some(unsafe { self.get_row_unchecked(i) })
925        } else {
926            None
927        }
928    }
929
930    /// Returns the i-th row without bounds checking.
931    ///
932    /// # Safety
933    ///
934    /// `i` must be less than `self.num_vectors()`.
935    #[inline]
936    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
937        // SAFETY: Caller must ensure i < self.num_vectors().
938        unsafe { self.repr.get_row(self.ptr, i) }
939    }
940
941    /// Returns a mutable reference to the `i`-th row, or `None` if out of bounds.
942    #[inline]
943    #[must_use]
944    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
945        if i < self.num_vectors() {
946            // SAFETY: Bounds check passed.
947            Some(unsafe { self.get_row_mut_unchecked(i) })
948        } else {
949            None
950        }
951    }
952
953    /// Returns a mutable reference to the i-th row without bounds checking.
954    ///
955    /// # Safety
956    ///
957    /// `i` must be less than [`num_vectors()`](Self::num_vectors).
958    #[inline]
959    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
960        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
961        // type ensure that `ptr` is compatible with `T`.
962        unsafe { self.repr.get_row_mut(self.ptr, i) }
963    }
964
965    /// Reborrows as an immutable [`MatRef`].
966    pub fn as_view(&self) -> MatRef<'_, T> {
967        MatRef {
968            ptr: self.ptr,
969            repr: self.repr,
970            _lifetime: PhantomData,
971        }
972    }
973
974    /// Returns an iterator over immutable row references.
975    pub fn rows(&self) -> Rows<'_, T> {
976        Rows::new(self.reborrow())
977    }
978
979    /// Returns an iterator over mutable row references.
980    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
981        RowsMut::new(self.reborrow_mut())
982    }
983
984    /// Return a [`Mat`] with the same contents as `self`.
985    pub fn to_owned(&self) -> Mat<T>
986    where
987        T: NewCloned,
988    {
989        T::new_cloned(self.as_view())
990    }
991
992    /// Construct a new [`MatMut`] over the raw pointer and representation without performing
993    /// any validity checks.
994    ///
995    /// # Safety
996    ///
997    /// Argument `ptr` must point to memory compatible with [`Repr::layout`].
998    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
999        Self {
1000            ptr,
1001            repr,
1002            _lifetime: PhantomData,
1003        }
1004    }
1005
1006    /// Return the base pointer for the [`MatMut`].
1007    pub fn as_raw_ptr(&self) -> *const u8 {
1008        self.ptr.as_ptr()
1009    }
1010
1011    /// Return a mutable base pointer for the [`MatMut`].
1012    pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
1013        self.ptr.as_ptr()
1014    }
1015}
1016
1017// Reborrow: MatMut -> MatRef
1018impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
1019    type Target = MatRef<'this, T>;
1020
1021    fn reborrow(&'this self) -> Self::Target {
1022        self.as_view()
1023    }
1024}
1025
1026// ReborrowMut: MatMut -> MatMut (with shorter lifetime)
1027impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
1028    type Target = MatMut<'this, T>;
1029
1030    fn reborrow_mut(&'this mut self) -> Self::Target {
1031        MatMut {
1032            ptr: self.ptr,
1033            repr: self.repr,
1034            _lifetime: PhantomData,
1035        }
1036    }
1037}
1038
1039impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1040    /// Returns the raw dimension (columns) of the vectors in the matrix.
1041    #[inline]
1042    pub fn vector_dim(&self) -> usize {
1043        self.repr.ncols()
1044    }
1045}
1046
1047//////////
1048// Rows //
1049//////////
1050
1051/// Iterator over immutable row references of a matrix.
1052///
1053/// Created by [`Mat::rows`], [`MatRef::rows`], or [`MatMut::rows`].
1054#[derive(Debug)]
1055pub struct Rows<'a, T: Repr> {
1056    matrix: MatRef<'a, T>,
1057    current: usize,
1058}
1059
1060impl<'a, T> Rows<'a, T>
1061where
1062    T: Repr,
1063{
1064    fn new(matrix: MatRef<'a, T>) -> Self {
1065        Self { matrix, current: 0 }
1066    }
1067}
1068
1069impl<'a, T> Iterator for Rows<'a, T>
1070where
1071    T: Repr + 'a,
1072{
1073    type Item = T::Row<'a>;
1074
1075    fn next(&mut self) -> Option<Self::Item> {
1076        let current = self.current;
1077        if current >= self.matrix.num_vectors() {
1078            None
1079        } else {
1080            self.current += 1;
1081            // SAFETY: We make sure through the above check that
1082            // the access is within bounds.
1083            //
1084            // Extending the lifetime to `'a` is safe because the underlying
1085            // MatRef has lifetime `'a`.
1086            Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1087        }
1088    }
1089
1090    fn size_hint(&self) -> (usize, Option<usize>) {
1091        let remaining = self.matrix.num_vectors() - self.current;
1092        (remaining, Some(remaining))
1093    }
1094}
1095
1096impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1097impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1098
1099/////////////
1100// RowsMut //
1101/////////////
1102
1103/// Iterator over mutable row references of a matrix.
1104///
1105/// Created by [`Mat::rows_mut`] or [`MatMut::rows_mut`].
1106#[derive(Debug)]
1107pub struct RowsMut<'a, T: ReprMut> {
1108    matrix: MatMut<'a, T>,
1109    current: usize,
1110}
1111
1112impl<'a, T> RowsMut<'a, T>
1113where
1114    T: ReprMut,
1115{
1116    fn new(matrix: MatMut<'a, T>) -> Self {
1117        Self { matrix, current: 0 }
1118    }
1119}
1120
1121impl<'a, T> Iterator for RowsMut<'a, T>
1122where
1123    T: ReprMut + 'a,
1124{
1125    type Item = T::RowMut<'a>;
1126
1127    fn next(&mut self) -> Option<Self::Item> {
1128        let current = self.current;
1129        if current >= self.matrix.num_vectors() {
1130            None
1131        } else {
1132            self.current += 1;
1133            // SAFETY: We make sure through the above check that
1134            // the access is within bounds.
1135            //
1136            // Extending the lifetime to `'a` is safe because:
1137            // 1. The underlying MatMut has lifetime `'a`.
1138            // 2. The iterator ensures that the mutable row indices are disjoint, so
1139            //    there is no aliasing as long as the implementation of `ReprMut` ensures
1140            //    there is not mutable sharing of the `RowMut` types.
1141            Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1142        }
1143    }
1144
1145    fn size_hint(&self) -> (usize, Option<usize>) {
1146        let remaining = self.matrix.num_vectors() - self.current;
1147        (remaining, Some(remaining))
1148    }
1149}
1150
1151impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1152impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1153
1154///////////
1155// Tests //
1156///////////
1157
1158#[cfg(test)]
1159mod tests {
1160    use super::*;
1161
1162    use std::fmt::Display;
1163
1164    use diskann_utils::lazy_format;
1165
1166    /// Helper to assert a type is Copy.
1167    fn assert_copy<T: Copy>(_: &T) {}
1168
1169    // ── Variance assertions ──────────────────────────────────────
1170    //
1171    // These functions are never called. The test is that they compile:
1172    // covariant positions must accept subtype coercions.
1173    //
1174    // The negative (invariance) counterparts live in
1175    // `tests/compile-fail/multi/{mat,matmut}_invariant.rs`.
1176
1177    /// `MatRef` is covariant in `'a`: a longer borrow can shorten.
1178    fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
1179        v: MatRef<'long, T>,
1180    ) -> MatRef<'short, T> {
1181        v
1182    }
1183
1184    /// `MatRef` is covariant in `T`: `Standard<&'long u8>` → `Standard<&'short u8>`.
1185    fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
1186        v: MatRef<'a, Standard<&'long u8>>,
1187    ) -> MatRef<'a, Standard<&'short u8>> {
1188        v
1189    }
1190
1191    /// `MatMut` is covariant in `'a`: a longer borrow can shorten.
1192    fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
1193        v: MatMut<'long, T>,
1194    ) -> MatMut<'short, T> {
1195        v
1196    }
1197
1198    fn edge_cases(nrows: usize) -> Vec<usize> {
1199        let max = usize::MAX;
1200
1201        vec![
1202            nrows,
1203            nrows + 1,
1204            nrows + 11,
1205            nrows + 20,
1206            max / 2,
1207            max.div_ceil(2),
1208            max - 1,
1209            max,
1210        ]
1211    }
1212
1213    fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1214        assert_eq!(x.repr(), &repr);
1215        assert_eq!(x.num_vectors(), repr.nrows());
1216        assert_eq!(x.vector_dim(), repr.ncols());
1217
1218        for i in 0..x.num_vectors() {
1219            let row = x.get_row_mut(i).unwrap();
1220            assert_eq!(row.len(), repr.ncols());
1221            row.iter_mut()
1222                .enumerate()
1223                .for_each(|(j, r)| *r = 10 * i + j);
1224        }
1225
1226        for i in edge_cases(repr.nrows()).into_iter() {
1227            assert!(x.get_row_mut(i).is_none());
1228        }
1229    }
1230
1231    fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1232        assert_eq!(x.repr(), &repr);
1233        assert_eq!(x.num_vectors(), repr.nrows());
1234        assert_eq!(x.vector_dim(), repr.ncols());
1235
1236        for i in 0..x.num_vectors() {
1237            let row = x.get_row_mut(i).unwrap();
1238            assert_eq!(row.len(), repr.ncols());
1239
1240            row.iter_mut()
1241                .enumerate()
1242                .for_each(|(j, r)| *r = 10 * i + j);
1243        }
1244
1245        for i in edge_cases(repr.nrows()).into_iter() {
1246            assert!(x.get_row_mut(i).is_none());
1247        }
1248    }
1249
1250    fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1251        assert_eq!(x.len(), repr.nrows());
1252        // Materialize all rows at once.
1253        let mut all_rows: Vec<_> = x.collect();
1254        assert_eq!(all_rows.len(), repr.nrows());
1255        for (i, row) in all_rows.iter_mut().enumerate() {
1256            assert_eq!(row.len(), repr.ncols());
1257            row.iter_mut()
1258                .enumerate()
1259                .for_each(|(j, r)| *r = 10 * i + j);
1260        }
1261    }
1262
1263    fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1264        assert_eq!(x.repr(), &repr);
1265        assert_eq!(x.num_vectors(), repr.nrows());
1266        assert_eq!(x.vector_dim(), repr.ncols());
1267
1268        for i in 0..x.num_vectors() {
1269            let row = x.get_row(i).unwrap();
1270
1271            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1272            row.iter().enumerate().for_each(|(j, r)| {
1273                assert_eq!(
1274                    *r,
1275                    10 * i + j,
1276                    "mismatched entry at row {}, col {} -- ctx: {}",
1277                    i,
1278                    j,
1279                    ctx
1280                )
1281            });
1282        }
1283
1284        for i in edge_cases(repr.nrows()).into_iter() {
1285            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1286        }
1287    }
1288
1289    fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1290        assert_eq!(x.repr(), &repr);
1291        assert_eq!(x.num_vectors(), repr.nrows());
1292        assert_eq!(x.vector_dim(), repr.ncols());
1293
1294        assert_copy(&x);
1295        for i in 0..x.num_vectors() {
1296            let row = x.get_row(i).unwrap();
1297            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1298
1299            row.iter().enumerate().for_each(|(j, r)| {
1300                assert_eq!(
1301                    *r,
1302                    10 * i + j,
1303                    "mismatched entry at row {}, col {} -- ctx: {}",
1304                    i,
1305                    j,
1306                    ctx
1307                )
1308            });
1309        }
1310
1311        for i in edge_cases(repr.nrows()).into_iter() {
1312            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1313        }
1314    }
1315
1316    fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1317        assert_eq!(x.repr(), &repr);
1318        assert_eq!(x.num_vectors(), repr.nrows());
1319        assert_eq!(x.vector_dim(), repr.ncols());
1320
1321        for i in 0..x.num_vectors() {
1322            let row = x.get_row(i).unwrap();
1323            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1324
1325            row.iter().enumerate().for_each(|(j, r)| {
1326                assert_eq!(
1327                    *r,
1328                    10 * i + j,
1329                    "mismatched entry at row {}, col {} -- ctx: {}",
1330                    i,
1331                    j,
1332                    ctx
1333                )
1334            });
1335        }
1336
1337        for i in edge_cases(repr.nrows()).into_iter() {
1338            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1339        }
1340    }
1341
1342    fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1343        assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1344        let all_rows: Vec<_> = x.collect();
1345        assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1346        for (i, row) in all_rows.iter().enumerate() {
1347            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1348            row.iter().enumerate().for_each(|(j, r)| {
1349                assert_eq!(
1350                    *r,
1351                    10 * i + j,
1352                    "mismatched entry at row {}, col {} -- ctx: {}",
1353                    i,
1354                    j,
1355                    ctx
1356                )
1357            });
1358        }
1359    }
1360
1361    //////////////
1362    // Standard //
1363    //////////////
1364
1365    #[test]
1366    fn standard_representation() {
1367        let repr = Standard::<f32>::new(4, 3).unwrap();
1368        assert_eq!(repr.nrows(), 4);
1369        assert_eq!(repr.ncols(), 3);
1370
1371        let layout = repr.layout().unwrap();
1372        assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1373        assert_eq!(layout.align(), std::mem::align_of::<f32>());
1374    }
1375
1376    #[test]
1377    fn standard_zero_dimensions() {
1378        for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1379            let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1380            assert_eq!(repr.nrows(), nrows);
1381            assert_eq!(repr.ncols(), ncols);
1382            let layout = repr.layout().unwrap();
1383            assert_eq!(layout.size(), 0);
1384        }
1385    }
1386
1387    #[test]
1388    fn standard_check_slice() {
1389        let repr = Standard::<u32>::new(3, 4).unwrap();
1390
1391        // Correct length succeeds
1392        let data = vec![0u32; 12];
1393        assert!(repr.check_slice(&data).is_ok());
1394
1395        // Too short fails
1396        let short = vec![0u32; 11];
1397        assert!(matches!(
1398            repr.check_slice(&short),
1399            Err(SliceError::LengthMismatch {
1400                expected: 12,
1401                found: 11
1402            })
1403        ));
1404
1405        // Too long fails
1406        let long = vec![0u32; 13];
1407        assert!(matches!(
1408            repr.check_slice(&long),
1409            Err(SliceError::LengthMismatch {
1410                expected: 12,
1411                found: 13
1412            })
1413        ));
1414
1415        // Overflow case
1416        let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1417        assert!(matches!(overflow_repr, Overflow { .. }));
1418    }
1419
1420    #[test]
1421    fn standard_new_rejects_element_count_overflow() {
1422        // nrows * ncols overflows usize even though per-element size is small.
1423        assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1424        assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1425        assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1426    }
1427
1428    #[test]
1429    fn standard_new_rejects_byte_count_exceeding_isize_max() {
1430        // Element count fits in usize, but total bytes exceed isize::MAX.
1431        let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1432        assert!(Standard::<u64>::new(half, 1).is_err());
1433        assert!(Standard::<u64>::new(1, half).is_err());
1434    }
1435
1436    #[test]
1437    fn standard_new_accepts_boundary_below_isize_max() {
1438        // Largest allocation that still fits in isize::MAX bytes.
1439        let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1440        let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1441        assert_eq!(repr.num_elements(), max_elems);
1442    }
1443
1444    #[test]
1445    fn standard_new_zst_rejects_element_count_overflow() {
1446        // For ZSTs the byte count is always 0, but element-count overflow
1447        // must still be caught so that `num_elements()` never wraps.
1448        assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1449        assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1450    }
1451
1452    #[test]
1453    fn standard_new_zst_accepts_large_non_overflowing() {
1454        // Large-but-valid ZST matrix: element count fits in usize.
1455        let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1456        assert_eq!(repr.num_elements(), usize::MAX);
1457        assert_eq!(repr.layout().unwrap().size(), 0);
1458    }
1459
1460    #[test]
1461    fn standard_new_overflow_error_display() {
1462        let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1463        let msg = err.to_string();
1464        assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1465
1466        let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1467        let zst_msg = zst_err.to_string();
1468        assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1469        assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1470    }
1471
1472    /////////
1473    // Mat //
1474    /////////
1475
1476    #[test]
1477    fn mat_new_and_basic_accessors() {
1478        let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1479        let base: *const u8 = mat.as_raw_ptr();
1480
1481        assert_eq!(mat.num_vectors(), 3);
1482        assert_eq!(mat.vector_dim(), 4);
1483
1484        let repr = mat.repr();
1485        assert_eq!(repr.nrows(), 3);
1486        assert_eq!(repr.ncols(), 4);
1487
1488        for (i, r) in mat.rows().enumerate() {
1489            assert_eq!(r, &[42, 42, 42, 42]);
1490            let ptr = r.as_ptr().cast::<u8>();
1491            assert_eq!(
1492                ptr,
1493                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1494            );
1495        }
1496    }
1497
1498    #[test]
1499    fn mat_new_with_default() {
1500        let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1501        let base: *const u8 = mat.as_raw_ptr();
1502
1503        assert_eq!(mat.num_vectors(), 2);
1504        for (i, row) in mat.rows().enumerate() {
1505            assert!(row.iter().all(|&v| v == 0));
1506
1507            let ptr = row.as_ptr().cast::<u8>();
1508            assert_eq!(
1509                ptr,
1510                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1511            );
1512        }
1513    }
1514
1515    const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1516    const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1517
1518    #[test]
1519    fn test_mat() {
1520        for nrows in ROWS {
1521            for ncols in COLS {
1522                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1523                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1524
1525                // Populate the matrix using `&mut Mat`
1526                {
1527                    let ctx = &lazy_format!("{ctx} - direct");
1528                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1529
1530                    assert_eq!(mat.num_vectors(), *nrows);
1531                    assert_eq!(mat.vector_dim(), *ncols);
1532
1533                    fill_mat(&mut mat, repr);
1534
1535                    check_mat(&mat, repr, ctx);
1536                    check_mat_ref(mat.reborrow(), repr, ctx);
1537                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1538                    check_rows(mat.rows(), repr, ctx);
1539
1540                    // Check reborrow preserves pointers.
1541                    assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1542                    assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1543                }
1544
1545                // Populate the matrix using `MatMut`
1546                {
1547                    let ctx = &lazy_format!("{ctx} - matmut");
1548                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1549                    let matmut = mat.reborrow_mut();
1550
1551                    assert_eq!(matmut.num_vectors(), *nrows);
1552                    assert_eq!(matmut.vector_dim(), *ncols);
1553
1554                    fill_mat_mut(matmut, repr);
1555
1556                    check_mat(&mat, repr, ctx);
1557                    check_mat_ref(mat.reborrow(), repr, ctx);
1558                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1559                    check_rows(mat.rows(), repr, ctx);
1560                }
1561
1562                // Populate the matrix using `RowsMut`
1563                {
1564                    let ctx = &lazy_format!("{ctx} - rows_mut");
1565                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1566                    fill_rows_mut(mat.rows_mut(), repr);
1567
1568                    check_mat(&mat, repr, ctx);
1569                    check_mat_ref(mat.reborrow(), repr, ctx);
1570                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1571                    check_rows(mat.rows(), repr, ctx);
1572                }
1573            }
1574        }
1575    }
1576
1577    #[test]
1578    fn test_mat_clone() {
1579        for nrows in ROWS {
1580            for ncols in COLS {
1581                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1582                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1583
1584                let mut mat = Mat::new(repr, Defaulted).unwrap();
1585                fill_mat(&mut mat, repr);
1586
1587                // Clone via Mat::clone
1588                {
1589                    let ctx = &lazy_format!("{ctx} - Mat::clone");
1590                    let cloned = mat.clone();
1591
1592                    assert_eq!(cloned.num_vectors(), *nrows);
1593                    assert_eq!(cloned.vector_dim(), *ncols);
1594
1595                    check_mat(&cloned, repr, ctx);
1596                    check_mat_ref(cloned.reborrow(), repr, ctx);
1597                    check_rows(cloned.rows(), repr, ctx);
1598
1599                    // Cloned allocation is independent.
1600                    if repr.num_elements() > 0 {
1601                        assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1602                    }
1603                }
1604
1605                // Clone via MatRef::to_owned
1606                {
1607                    let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1608                    let owned = mat.as_view().to_owned();
1609
1610                    check_mat(&owned, repr, ctx);
1611                    check_mat_ref(owned.reborrow(), repr, ctx);
1612                    check_rows(owned.rows(), repr, ctx);
1613
1614                    if repr.num_elements() > 0 {
1615                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1616                    }
1617                }
1618
1619                // Clone via MatMut::to_owned
1620                {
1621                    let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1622                    let owned = mat.as_view_mut().to_owned();
1623
1624                    check_mat(&owned, repr, ctx);
1625                    check_mat_ref(owned.reborrow(), repr, ctx);
1626                    check_rows(owned.rows(), repr, ctx);
1627
1628                    if repr.num_elements() > 0 {
1629                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1630                    }
1631                }
1632            }
1633        }
1634    }
1635
1636    #[test]
1637    fn test_mat_refmut() {
1638        for nrows in ROWS {
1639            for ncols in COLS {
1640                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1641                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1642
1643                // Populate the matrix using `&mut Mat`
1644                {
1645                    let ctx = &lazy_format!("{ctx} - by matmut");
1646                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1647                    let ptr = b.as_ptr().cast::<u8>();
1648                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1649
1650                    assert_eq!(
1651                        ptr,
1652                        matmut.as_raw_ptr(),
1653                        "underlying memory should be preserved",
1654                    );
1655
1656                    fill_mat_mut(matmut.reborrow_mut(), repr);
1657
1658                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1659                    check_mat_ref(matmut.reborrow(), repr, ctx);
1660                    check_rows(matmut.rows(), repr, ctx);
1661                    check_rows(matmut.reborrow().rows(), repr, ctx);
1662
1663                    let matref = MatRef::new(repr, &b).unwrap();
1664                    check_mat_ref(matref, repr, ctx);
1665                    check_mat_ref(matref.reborrow(), repr, ctx);
1666                    check_rows(matref.rows(), repr, ctx);
1667                }
1668
1669                // Populate the matrix using `RowsMut`
1670                {
1671                    let ctx = &lazy_format!("{ctx} - by rows");
1672                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1673                    let ptr = b.as_ptr().cast::<u8>();
1674                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1675
1676                    assert_eq!(
1677                        ptr,
1678                        matmut.as_raw_ptr(),
1679                        "underlying memory should be preserved",
1680                    );
1681
1682                    fill_rows_mut(matmut.rows_mut(), repr);
1683
1684                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1685                    check_mat_ref(matmut.reborrow(), repr, ctx);
1686                    check_rows(matmut.rows(), repr, ctx);
1687                    check_rows(matmut.reborrow().rows(), repr, ctx);
1688
1689                    let matref = MatRef::new(repr, &b).unwrap();
1690                    check_mat_ref(matref, repr, ctx);
1691                    check_mat_ref(matref.reborrow(), repr, ctx);
1692                    check_rows(matref.rows(), repr, ctx);
1693                }
1694            }
1695        }
1696    }
1697
1698    //////////////////
1699    // Constructors //
1700    //////////////////
1701
1702    #[test]
1703    fn test_standard_new_owned() {
1704        let rows = [0, 1, 2, 3, 5, 10];
1705        let cols = [0, 1, 2, 3, 5, 10];
1706
1707        for nrows in rows {
1708            for ncols in cols {
1709                let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1710                let rows_iter = m.rows();
1711                let len = <_ as ExactSizeIterator>::len(&rows_iter);
1712                assert_eq!(len, nrows);
1713                for r in rows_iter {
1714                    assert_eq!(r.len(), ncols);
1715                    assert!(r.iter().all(|i| *i == 1usize));
1716                }
1717            }
1718        }
1719    }
1720
1721    #[test]
1722    fn matref_new_slice_length_error() {
1723        let repr = Standard::<u32>::new(3, 4).unwrap();
1724
1725        // Correct length succeeds
1726        let data = vec![0u32; 12];
1727        assert!(MatRef::new(repr, &data).is_ok());
1728
1729        // Too short fails
1730        let short = vec![0u32; 11];
1731        assert!(matches!(
1732            MatRef::new(repr, &short),
1733            Err(SliceError::LengthMismatch {
1734                expected: 12,
1735                found: 11
1736            })
1737        ));
1738
1739        // Too long fails
1740        let long = vec![0u32; 13];
1741        assert!(matches!(
1742            MatRef::new(repr, &long),
1743            Err(SliceError::LengthMismatch {
1744                expected: 12,
1745                found: 13
1746            })
1747        ));
1748    }
1749
1750    #[test]
1751    fn matmut_new_slice_length_error() {
1752        let repr = Standard::<u32>::new(3, 4).unwrap();
1753
1754        // Correct length succeeds
1755        let mut data = vec![0u32; 12];
1756        assert!(MatMut::new(repr, &mut data).is_ok());
1757
1758        // Too short fails
1759        let mut short = vec![0u32; 11];
1760        assert!(matches!(
1761            MatMut::new(repr, &mut short),
1762            Err(SliceError::LengthMismatch {
1763                expected: 12,
1764                found: 11
1765            })
1766        ));
1767
1768        // Too long fails
1769        let mut long = vec![0u32; 13];
1770        assert!(matches!(
1771            MatMut::new(repr, &mut long),
1772            Err(SliceError::LengthMismatch {
1773                expected: 12,
1774                found: 13
1775            })
1776        ));
1777    }
1778}