Skip to main content

diskann_quantization/multi_vector/
matrix.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Row-major matrix types for multi-vector representations.
7//!
8//! This module provides flexible matrix abstractions that support different underlying
9//! storage formats through the [`Repr`] trait. The primary types are:
10//!
11//! - [`Mat`]: An owning matrix that manages its own memory.
12//! - [`MatRef`]: An immutable borrowed view of matrix data.
13//! - [`MatMut`]: A mutable borrowed view of matrix data.
14//!
15//! # Representations
16//!
17//! Representation types interact with the [`Mat`] family of types using the following traits:
18//!
19//! - [`Repr`]: Read-only matrix representation.
20//! - [`ReprMut`]: Mutable matrix representation.
21//! - [`ReprOwned`]: Owning matrix representation.
22//!
23//! Each trait refinement has a corresponding constructor:
24//!
25//! - [`NewRef`]: Construct a read-only [`MatRef`] view over a slice.
26//! - [`NewMut`]: Construct a mutable [`MatMut`] matrix view over a slice.
27//! - [`NewOwned`]: Construct a new owning [`Mat`].
28//!
29
30use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
31
32use diskann_utils::{Reborrow, ReborrowMut, views::MatrixView};
33use thiserror::Error;
34
35use crate::utils;
36
37/// Representation trait describing the layout and access patterns for a matrix.
38///
39/// Implementations define how raw bytes are interpreted as typed rows. This enables
40/// matrices over different storage formats (dense, quantized, etc.) using a single
41/// generic [`Mat`] type.
42///
43/// # Associated Types
44///
45/// - `Row<'a>`: The immutable row type (e.g., `&[f32]`, `&[f16]`).
46///
47/// # Safety
48///
49/// Implementations must ensure:
50///
51/// - [`get_row`](Self::get_row) returns valid references for the given row index.
52///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
53///   the contract for the raw pointer.
54///
55/// - The objects implicitly managed by this representation inherit the `Send` and `Sync`
56///   attributes of `Repr`. That is, `Repr: Send` implies that the objects in backing memory
57///   are [`Send`], and likewise with `Sync`. This is necessary to apply [`Send`] and [`Sync`]
58///   bounds to [`Mat`], [`MatRef`], and [`MatMut`].
59pub unsafe trait Repr: Copy {
60    /// Immutable row reference type.
61    type Row<'a>
62    where
63        Self: 'a;
64
65    /// Returns the number of rows in the matrix.
66    ///
67    /// # Safety Contract
68    ///
69    /// This function must be loosely pure in the sense that for any given instance of
70    /// `self`, `self.nrows()` must return the same value.
71    fn nrows(&self) -> usize;
72
73    /// Returns the memory layout for a memory allocation containing [`Repr::nrows`] vectors
74    /// each with vector dimension [`Repr::ncols`].
75    ///
76    /// # Safety Contract
77    ///
78    /// The [`Layout`] returned from this method must be consistent with the contract of
79    /// [`Repr::get_row`].
80    fn layout(&self) -> Result<Layout, LayoutError>;
81
82    /// Returns an immutable reference to the `i`-th row.
83    ///
84    /// # Safety
85    ///
86    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
87    /// - The entire range for this slice must be within a single allocation.
88    /// - `i` must be less than [`Repr::nrows`].
89    /// - The memory referenced by the returned [`Repr::Row`] must not be mutated for the
90    ///   duration of lifetime `'a`.
91    /// - The lifetime for the returned [`Repr::Row`] is inferred from its usage. Correct
92    ///   usage must properly tie the lifetime to a source.
93    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
94}
95
96/// Extension of [`Repr`] that supports mutable row access.
97///
98/// # Associated Types
99///
100/// - `RowMut<'a>`: The mutable row type (e.g., `&mut [f32]`).
101///
102/// # Safety
103///
104/// Implementors must ensure:
105///
106/// - [`get_row_mut`](Self::get_row_mut) returns valid references for the given row index.
107///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
108///   the contract for the raw pointer.
109///
110///   Additionally, since the implementation of the [`RowsMut`] iterator can give out rows
111///   for all `i` in `0..self.nrows()`, the implementation of [`Self::get_row_mut`] must be
112///   such that the result for disjoint `i` must not interfere with one another.
113pub unsafe trait ReprMut: Repr {
114    /// Mutable row reference type.
115    type RowMut<'a>
116    where
117        Self: 'a;
118
119    /// Returns a mutable reference to the i-th row.
120    ///
121    /// # Safety
122    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
123    /// - The entire range for this slice must be within a single allocation.
124    /// - `i` must be less than `self.nrows()`.
125    /// - The memory referenced by the returned [`ReprMut::RowMut`] must not be accessed
126    ///   through any other reference for the duration of lifetime `'a`.
127    /// - The lifetime for the returned [`ReprMut::RowMut`] is inferred from its usage.
128    ///   Correct usage must properly tie the lifetime to a source.
129    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
130}
131
132/// Extension trait for [`Repr`] that supports deallocation of owned matrices. This is used
133/// in conjunction with [`NewOwned`] to create matrices.
134///
135/// Requires [`ReprMut`] since owned matrices should support mutation.
136///
137/// # Safety
138///
139/// Implementors must ensure that `drop` properly deallocates the memory in a way compatible
140/// with all [`NewOwned`] implementations.
141pub unsafe trait ReprOwned: ReprMut {
142    /// Deallocates memory at `ptr` and drops `self`.
143    ///
144    /// # Safety
145    ///
146    /// - `ptr` must have been obtained via [`NewOwned`] with the same value of `self`.
147    /// - This method may only be called once for such a pointer.
148    /// - After calling this method, the memory behind `ptr` may not be dereferenced at all.
149    unsafe fn drop(self, ptr: NonNull<u8>);
150}
151
152/// A new-type version of `std::alloc::LayoutError` for cleaner error handling.
153///
154/// This is basically the same as [`std::alloc::LayoutError`], but constructible in
155/// use code to allow implementors of [`Repr::layout`] to return it for reasons other than
156/// those derived from `std::alloc::Layout`'s methods.
157#[derive(Debug, Clone, Copy)]
158#[non_exhaustive]
159pub struct LayoutError;
160
161impl LayoutError {
162    /// Construct a new opaque [`LayoutError`].
163    pub fn new() -> Self {
164        Self
165    }
166}
167
168impl Default for LayoutError {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl std::fmt::Display for LayoutError {
175    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
176        write!(f, "LayoutError")
177    }
178}
179
180impl std::error::Error for LayoutError {}
181
182impl From<std::alloc::LayoutError> for LayoutError {
183    fn from(_: std::alloc::LayoutError) -> Self {
184        LayoutError
185    }
186}
187
188//////////////////
189// Constructors //
190//////////////////
191
192/// Create a new [`MatRef`] over a slice.
193///
194/// # Safety
195///
196/// Implementations must validate the length (and any other requirements) of the provided
197/// slice to ensure it is compatible with the implementation of [`Repr`].
198pub unsafe trait NewRef<T>: Repr {
199    /// Errors that can occur when initializing.
200    type Error;
201
202    /// Create a new [`MatRef`] over `slice`.
203    fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
204}
205
206/// Create a new [`MatMut`] over a slice.
207///
208/// # Safety
209///
210/// Implementations must validate the length (and any other requirements) of the provided
211/// slice to ensure it is compatible with the implementation of [`ReprMut`].
212pub unsafe trait NewMut<T>: ReprMut {
213    /// Errors that can occur when initializing.
214    type Error;
215
216    /// Create a new [`MatMut`] over `slice`.
217    fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
218}
219
220/// Create a new [`Mat`] from an initializer.
221///
222/// # Safety
223///
224/// Implementations must ensure that the returned [`Mat`] is compatible with
225/// `Self`'s implementation of [`ReprOwned`].
226pub unsafe trait NewOwned<T>: ReprOwned {
227    /// Errors that can occur when initializing.
228    type Error;
229
230    /// Create a new [`Mat`] initialized with `init`.
231    fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
232}
233
234/// An initializer argument to [`NewOwned`] that uses a type's [`Default`] implementation
235/// to initialize a matrix.
236///
237/// ```rust
238/// use diskann_quantization::multi_vector::{Mat, Standard, Defaulted};
239/// let mat = Mat::new(Standard::<f32>::new(4, 3).unwrap(), Defaulted).unwrap();
240/// for i in 0..4 {
241///     assert!(mat.get_row(i).unwrap().iter().all(|&x| x == 0.0f32));
242/// }
243/// ```
244#[derive(Debug, Clone, Copy)]
245pub struct Defaulted;
246
247/// Create a new [`Mat`] cloned from a view.
248pub trait NewCloned: ReprOwned {
249    /// Clone the contents behind `v`, returning a new owning [`Mat`].
250    ///
251    /// Implementations should ensure the returned [`Mat`] is "semantically the same" as `v`.
252    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
253}
254
255//////////////
256// Standard //
257//////////////
258
259/// Metadata for dense row-major matrices of `Copy` types.
260///
261/// Rows are stored contiguously as `&[T]` slices. This is the default representation
262/// type for standard floating-point multi-vectors.
263///
264/// # Row Types
265///
266/// - `Row<'a>`: `&'a [T]`
267/// - `RowMut<'a>`: `&'a mut [T]`
268#[derive(Debug, Clone, Copy, PartialEq, Eq)]
269pub struct Standard<T> {
270    nrows: usize,
271    ncols: usize,
272    _elem: PhantomData<T>,
273}
274
275impl<T: Copy> Standard<T> {
276    /// Create a new `Standard` for data of type `T`.
277    ///
278    /// Successful construction requires:
279    ///
280    /// * The total number of elements determined by `nrows * ncols` does not exceed
281    ///   `usize::MAX`.
282    /// * The total memory footprint defined by `ncols * nrows * size_of::<T>()` does not
283    ///   exceed `isize::MAX`.
284    pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
285        Overflow::check::<T>(nrows, ncols)?;
286        Ok(Self {
287            nrows,
288            ncols,
289            _elem: PhantomData,
290        })
291    }
292
293    /// Returns the number of total elements (`rows x cols`) in this matrix.
294    pub fn num_elements(&self) -> usize {
295        // Since we've constructed `self` - we know we cannot overflow.
296        self.nrows() * self.ncols()
297    }
298
299    /// Returns `rows`, the number of rows in this matrix.
300    fn nrows(&self) -> usize {
301        self.nrows
302    }
303
304    /// Returns `ncols`, the number of elements in a row of this matrix.
305    fn ncols(&self) -> usize {
306        self.ncols
307    }
308
309    /// Checks the following:
310    ///
311    /// 1. Computation of the number of elements in `self` does not overflow.
312    /// 2. Argument `slice` has the expected number of elements.
313    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
314        let len = self.num_elements();
315
316        if slice.len() != len {
317            Err(SliceError::LengthMismatch {
318                expected: len,
319                found: slice.len(),
320            })
321        } else {
322            Ok(())
323        }
324    }
325
326    /// Create a new [`Mat`] around the contents of `b` **without** any checks.
327    ///
328    /// # Safety
329    ///
330    /// The length of `b` must be exactly [`Standard::num_elements`].
331    unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
332        debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
333
334        let ptr = utils::box_into_nonnull(b).cast::<u8>();
335
336        // SAFETY: `ptr` is properly aligned and points to a slice of the required length.
337        // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
338        // it from `Box::into_raw`.
339        unsafe { Mat::from_raw_parts(self, ptr) }
340    }
341}
342
343/// Error for [`Standard::new`].
344#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346    nrows: usize,
347    ncols: usize,
348    elsize: usize,
349}
350
351impl Overflow {
352    /// Construct an `Overflow` error for the given dimensions and element type.
353    pub(crate) fn for_type<T>(nrows: usize, ncols: usize) -> Self {
354        Self {
355            nrows,
356            ncols,
357            elsize: std::mem::size_of::<T>(),
358        }
359    }
360
361    /// Verify that `capacity` elements of type `T` fit within the `isize::MAX` byte
362    /// budget required by Rust's allocation APIs.
363    ///
364    /// On failure the error reports the original `(nrows, ncols)` dimensions rather
365    /// than the padded capacity.
366    pub(crate) fn check_byte_budget<T>(
367        capacity: usize,
368        nrows: usize,
369        ncols: usize,
370    ) -> Result<(), Self> {
371        let bytes = std::mem::size_of::<T>().saturating_mul(capacity);
372        if bytes <= isize::MAX as usize {
373            Ok(())
374        } else {
375            Err(Self::for_type::<T>(nrows, ncols))
376        }
377    }
378
379    pub(crate) fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
380        // Guard the element count itself so that `num_elements()` can never overflow.
381        let capacity = nrows
382            .checked_mul(ncols)
383            .ok_or_else(|| Self::for_type::<T>(nrows, ncols))?;
384
385        Self::check_byte_budget::<T>(capacity, nrows, ncols)
386    }
387}
388
389impl std::fmt::Display for Overflow {
390    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
391        if self.elsize == 0 {
392            write!(
393                f,
394                "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
395                self.nrows, self.ncols,
396            )
397        } else {
398            write!(
399                f,
400                "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
401                self.nrows, self.ncols, self.elsize,
402            )
403        }
404    }
405}
406
407impl std::error::Error for Overflow {}
408
409/// Error types for [`Standard`].
410#[derive(Debug, Clone, Copy, Error)]
411#[non_exhaustive]
412pub enum SliceError {
413    #[error("Length mismatch: expected {expected}, found {found}")]
414    LengthMismatch { expected: usize, found: usize },
415}
416
417// SAFETY: The implementation correctly computes row offsets as `i * ncols` and
418// constructs valid slices of the appropriate length. The `layout` method correctly
419// reports the memory layout requirements.
420unsafe impl<T: Copy> Repr for Standard<T> {
421    type Row<'a>
422        = &'a [T]
423    where
424        T: 'a;
425
426    fn nrows(&self) -> usize {
427        self.nrows
428    }
429
430    fn layout(&self) -> Result<Layout, LayoutError> {
431        Ok(Layout::array::<T>(self.num_elements())?)
432    }
433
434    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
435        debug_assert!(ptr.cast::<T>().is_aligned());
436        debug_assert!(i < self.nrows);
437
438        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
439        // audits the constructors for `Mat` and friends, we know that there is room for at
440        // least `self.num_elements()` elements from the base pointer, so this access is safe.
441        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
442
443        // SAFETY: The logic is the same as the previous `unsafe` block.
444        unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
445    }
446}
447
448// SAFETY: The implementation correctly computes row offsets and constructs valid mutable
449// slices.
450unsafe impl<T: Copy> ReprMut for Standard<T> {
451    type RowMut<'a>
452        = &'a mut [T]
453    where
454        T: 'a;
455
456    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
457        debug_assert!(ptr.cast::<T>().is_aligned());
458        debug_assert!(i < self.nrows);
459
460        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
461        // audits the constructors for `Mat` and friends, we know that there is room for at
462        // least `self.num_elements()` elements from the base pointer, so this access is safe.
463        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
464
465        // SAFETY: The logic is the same as the previous `unsafe` block. Further, the caller
466        // attests that creating a mutable reference is safe.
467        unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
468    }
469}
470
471// SAFETY: The drop implementation correctly reconstructs a Box from the raw pointer
472// using the same length (nrows * ncols) that was used for allocation, allowing Box
473// to properly deallocate the memory.
474unsafe impl<T: Copy> ReprOwned for Standard<T> {
475    unsafe fn drop(self, ptr: NonNull<u8>) {
476        // SAFETY: The caller guarantees that `ptr` was obtained from an implementation of
477        // `NewOwned` for an equivalent instance of `self`.
478        //
479        // We ensure that `NewOwned` goes through boxes, so here we reconstruct a Box to
480        // let it handle deallocation.
481        unsafe {
482            let slice_ptr = std::ptr::slice_from_raw_parts_mut(
483                ptr.cast::<T>().as_ptr(),
484                self.nrows * self.ncols,
485            );
486            let _ = Box::from_raw(slice_ptr);
487        }
488    }
489}
490
491// SAFETY: The implementation uses guarantees from `Box` to ensure that the pointer
492// initialized by it is non-null and properly aligned to the underlying type.
493unsafe impl<T> NewOwned<T> for Standard<T>
494where
495    T: Copy,
496{
497    type Error = crate::error::Infallible;
498    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
499        let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
500
501        // SAFETY: By construction, `b` has length `self.num_elements()`.
502        Ok(unsafe { self.box_to_mat(b) })
503    }
504}
505
506// SAFETY: This safely reuses `<Self as NewOwned<T>>`.
507unsafe impl<T> NewOwned<Defaulted> for Standard<T>
508where
509    T: Copy + Default,
510{
511    type Error = crate::error::Infallible;
512    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
513        self.new_owned(T::default())
514    }
515}
516
517// SAFETY: This checks that the slice has the correct length, which is all that is
518// required for [`Repr`].
519unsafe impl<T> NewRef<T> for Standard<T>
520where
521    T: Copy,
522{
523    type Error = SliceError;
524    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
525        self.check_slice(data)?;
526
527        // SAFETY: The function `check_slice` verifies that `data` is compatible with
528        // the layout requirement of `Standard`.
529        //
530        // We've properly checked that the underlying pointer is okay.
531        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
532    }
533}
534
535// SAFETY: This checks that the slice has the correct length, which is all that is
536// required for [`ReprMut`].
537unsafe impl<T> NewMut<T> for Standard<T>
538where
539    T: Copy,
540{
541    type Error = SliceError;
542    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
543        self.check_slice(data)?;
544
545        // SAFETY: The function `check_slice` verifies that `data` is compatible with
546        // the layout requirement of `Standard`.
547        //
548        // We've properly checked that the underlying pointer is okay.
549        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
550    }
551}
552
553impl<T> NewCloned for Standard<T>
554where
555    T: Copy,
556{
557    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
558        let b: Box<[T]> = v.rows().flatten().copied().collect();
559
560        // SAFETY: By construction, `b` has length `v.repr().num_elements()`.
561        unsafe { v.repr().box_to_mat(b) }
562    }
563}
564
565/////////
566// Mat //
567/////////
568
569/// An owning matrix that manages its own memory.
570///
571/// The matrix stores raw bytes interpreted according to representation type `T`.
572/// Memory is automatically deallocated when the matrix is dropped.
573#[derive(Debug)]
574pub struct Mat<T: ReprOwned> {
575    ptr: NonNull<u8>,
576    repr: T,
577    _invariant: PhantomData<fn(T) -> T>,
578}
579
580// SAFETY: [`Repr`] is required to propagate its `Send` bound.
581unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
582
583// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
584unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
585
586impl<T: ReprOwned> Mat<T> {
587    /// Create a new matrix using `init` as the initializer.
588    pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
589    where
590        T: NewOwned<U>,
591    {
592        repr.new_owned(init)
593    }
594
595    /// Returns the number of rows (vectors) in the matrix.
596    #[inline]
597    pub fn num_vectors(&self) -> usize {
598        self.repr.nrows()
599    }
600
601    /// Returns a reference to the underlying representation.
602    pub fn repr(&self) -> &T {
603        &self.repr
604    }
605
606    /// Returns the `i`th row if `i < self.num_vectors()`.
607    #[must_use]
608    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
609        if i < self.num_vectors() {
610            // SAFETY: Bounds check passed, and the Mat was constructed
611            // with valid representation and pointer.
612            let row = unsafe { self.get_row_unchecked(i) };
613            Some(row)
614        } else {
615            None
616        }
617    }
618
619    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
620        // SAFETY: Caller must ensure i < self.num_vectors(). The constructors for this type
621        // ensure that `ptr` is compatible with `T`.
622        unsafe { self.repr.get_row(self.ptr, i) }
623    }
624
625    /// Returns the `i`th mutable row if `i < self.num_vectors()`.
626    #[must_use]
627    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
628        if i < self.num_vectors() {
629            // SAFETY: Bounds check passed, and we have exclusive access via &mut self.
630            Some(unsafe { self.get_row_mut_unchecked(i) })
631        } else {
632            None
633        }
634    }
635
636    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
637        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
638        // type ensure that `ptr` is compatible with `T`.
639        unsafe { self.repr.get_row_mut(self.ptr, i) }
640    }
641
642    /// Returns an immutable view of the matrix.
643    #[inline]
644    pub fn as_view(&self) -> MatRef<'_, T> {
645        MatRef {
646            ptr: self.ptr,
647            repr: self.repr,
648            _lifetime: PhantomData,
649        }
650    }
651
652    /// Returns a mutable view of the matrix.
653    #[inline]
654    pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
655        MatMut {
656            ptr: self.ptr,
657            repr: self.repr,
658            _lifetime: PhantomData,
659        }
660    }
661
662    /// Returns an iterator over immutable row references.
663    pub fn rows(&self) -> Rows<'_, T> {
664        Rows::new(self.reborrow())
665    }
666
667    /// Returns an iterator over mutable row references.
668    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
669        RowsMut::new(self.reborrow_mut())
670    }
671
672    /// Construct a new [`Mat`] over the raw pointer and representation without performing
673    /// any validity checks.
674    ///
675    /// # Safety
676    ///
677    /// Argument `ptr` must be:
678    ///
679    /// 1. Point to memory compatible with [`Repr::layout`].
680    /// 2. Be compatible with the drop logic in [`ReprOwned`].
681    pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
682        Self {
683            ptr,
684            repr,
685            _invariant: PhantomData,
686        }
687    }
688
689    /// Return the base pointer for the [`Mat`].
690    pub fn as_raw_ptr(&self) -> *const u8 {
691        self.ptr.as_ptr()
692    }
693
694    /// Return a mutable base pointer for the [`Mat`].
695    pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
696        self.ptr.as_ptr()
697    }
698}
699
700impl<T: ReprOwned> Drop for Mat<T> {
701    fn drop(&mut self) {
702        // SAFETY: `ptr` was correctly initialized according to `layout`
703        // and we are guaranteed exclusive access to the data due to Rust borrow rules.
704        unsafe { self.repr.drop(self.ptr) };
705    }
706}
707
708impl<T: NewCloned> Clone for Mat<T> {
709    fn clone(&self) -> Self {
710        T::new_cloned(self.as_view())
711    }
712}
713
714impl<T: Copy> Mat<Standard<T>> {
715    /// Returns the raw dimension (columns) of the vectors in the matrix.
716    #[inline]
717    pub fn vector_dim(&self) -> usize {
718        self.repr.ncols()
719    }
720
721    /// Return the backing data as a contiguous slice of `T`.
722    ///
723    /// The returned slice has `num_vectors() * vector_dim()` elements in row-major order.
724    #[inline]
725    pub fn as_slice(&self) -> &[T] {
726        self.as_view().as_slice()
727    }
728
729    /// Return a [`MatrixView`] over the backing data.
730    #[inline]
731    pub fn as_matrix_view(&self) -> MatrixView<'_, T> {
732        self.as_view().as_matrix_view()
733    }
734}
735
736////////////
737// MatRef //
738////////////
739
740/// An immutable borrowed view of a matrix.
741///
742/// Provides read-only access to matrix data without ownership. Implements [`Copy`]
743/// and can be freely cloned.
744///
745/// # Type Parameter
746/// - `T`: A [`Repr`] implementation defining the row layout.
747///
748/// # Access
749/// - [`get_row`](Self::get_row): Get an immutable row by index.
750/// - [`rows`](Self::rows): Iterate over all rows.
751#[derive(Debug, Clone, Copy)]
752pub struct MatRef<'a, T: Repr> {
753    ptr: NonNull<u8>,
754    repr: T,
755    /// Marker to tie the lifetime to the borrowed data.
756    _lifetime: PhantomData<&'a T>,
757}
758
759// SAFETY: [`Repr`] is required to propagate its `Send` bound.
760unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
761
762// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
763unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
764
765impl<'a, T: Repr> MatRef<'a, T> {
766    /// Construct a new [`MatRef`] over `data`.
767    pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
768    where
769        T: NewRef<U>,
770    {
771        repr.new_ref(data)
772    }
773
774    /// Returns the number of rows (vectors) in the matrix.
775    #[inline]
776    pub fn num_vectors(&self) -> usize {
777        self.repr.nrows()
778    }
779
780    /// Returns a reference to the underlying representation.
781    pub fn repr(&self) -> &T {
782        &self.repr
783    }
784
785    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
786    #[must_use]
787    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
788        if i < self.num_vectors() {
789            // SAFETY: Bounds check passed, and the MatRef was constructed
790            // with valid representation and pointer.
791            let row = unsafe { self.get_row_unchecked(i) };
792            Some(row)
793        } else {
794            None
795        }
796    }
797
798    /// Returns the i-th row without bounds checking.
799    ///
800    /// # Safety
801    ///
802    /// `i` must be less than `self.num_vectors()`.
803    #[inline]
804    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
805        // SAFETY: Caller must ensure i < self.num_vectors().
806        unsafe { self.repr.get_row(self.ptr, i) }
807    }
808
809    /// Returns an iterator over immutable row references.
810    pub fn rows(&self) -> Rows<'_, T> {
811        Rows::new(*self)
812    }
813
814    /// Return a [`Mat`] with the same contents as `self`.
815    pub fn to_owned(&self) -> Mat<T>
816    where
817        T: NewCloned,
818    {
819        T::new_cloned(*self)
820    }
821
822    /// Construct a new [`MatRef`] over the raw pointer and representation without performing
823    /// any validity checks.
824    ///
825    /// # Safety
826    ///
827    /// Argument `ptr` must point to memory compatible with [`Repr::layout`] and pass any
828    /// validity checks required by `T`.
829    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
830        Self {
831            ptr,
832            repr,
833            _lifetime: PhantomData,
834        }
835    }
836
837    /// Return the base pointer for the [`MatRef`].
838    pub fn as_raw_ptr(&self) -> *const u8 {
839        self.ptr.as_ptr()
840    }
841}
842
843impl<'a, T: Copy> MatRef<'a, Standard<T>> {
844    /// Returns the raw dimension (columns) of the vectors in the matrix.
845    #[inline]
846    pub fn vector_dim(&self) -> usize {
847        self.repr.ncols()
848    }
849
850    /// Return the backing data as a contiguous slice of `T`.
851    ///
852    /// The returned slice has `num_vectors() * vector_dim()` elements in row-major order.
853    #[inline]
854    pub fn as_slice(&self) -> &'a [T] {
855        let len = self.repr.num_elements();
856        // SAFETY: `Standard<T>` guarantees `nrows * ncols` contiguous `T` elements
857        // starting at `self.ptr`. The lifetime `'a` is tied to the original data.
858        unsafe { std::slice::from_raw_parts(self.ptr.as_ptr().cast::<T>(), len) }
859    }
860
861    /// Return a [`MatrixView`] over the backing data.
862    #[allow(clippy::expect_used)]
863    #[inline]
864    pub fn as_matrix_view(&self) -> MatrixView<'a, T> {
865        // `Standard::new` validates that `nrows * ncols` does not overflow,
866        // so `try_from` is infallible here.
867        MatrixView::try_from(self.as_slice(), self.num_vectors(), self.vector_dim())
868            .expect("Standard<T> has valid dimensions")
869    }
870}
871
872// Reborrow: Mat -> MatRef
873impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
874    type Target = MatRef<'this, T>;
875
876    fn reborrow(&'this self) -> Self::Target {
877        self.as_view()
878    }
879}
880
881// ReborrowMut: Mat -> MatMut
882impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
883    type Target = MatMut<'this, T>;
884
885    fn reborrow_mut(&'this mut self) -> Self::Target {
886        self.as_view_mut()
887    }
888}
889
890// Reborrow: MatRef -> MatRef (with shorter lifetime)
891impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
892    type Target = MatRef<'this, T>;
893
894    fn reborrow(&'this self) -> Self::Target {
895        MatRef {
896            ptr: self.ptr,
897            repr: self.repr,
898            _lifetime: PhantomData,
899        }
900    }
901}
902
903////////////
904// MatMut //
905////////////
906
907/// A mutable borrowed view of a matrix.
908///
909/// Provides read-write access to matrix data without ownership.
910///
911/// # Type Parameter
912/// - `T`: A [`ReprMut`] implementation defining the row layout.
913///
914/// # Access
915/// - [`get_row`](Self::get_row): Get an immutable row by index.
916/// - [`get_row_mut`](Self::get_row_mut): Get a mutable row by index.
917/// - [`as_view`](Self::as_view): Reborrow as immutable [`MatRef`].
918/// - [`rows`](Self::rows), [`rows_mut`](Self::rows_mut): Iterate over rows.
919#[derive(Debug)]
920pub struct MatMut<'a, T: ReprMut> {
921    ptr: NonNull<u8>,
922    repr: T,
923    /// Marker to tie the lifetime to the mutably borrowed data.
924    _lifetime: PhantomData<&'a mut T>,
925}
926
927// SAFETY: [`ReprMut`] is required to propagate its `Send` bound.
928unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
929
930// SAFETY: [`ReprMut`] is required to propagate its `Sync` bound.
931unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
932
933impl<'a, T: ReprMut> MatMut<'a, T> {
934    /// Construct a new [`MatMut`] over `data`.
935    pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
936    where
937        T: NewMut<U>,
938    {
939        repr.new_mut(data)
940    }
941
942    /// Returns the number of rows (vectors) in the matrix.
943    #[inline]
944    pub fn num_vectors(&self) -> usize {
945        self.repr.nrows()
946    }
947
948    /// Returns a reference to the underlying representation.
949    pub fn repr(&self) -> &T {
950        &self.repr
951    }
952
953    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
954    #[inline]
955    #[must_use]
956    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
957        if i < self.num_vectors() {
958            // SAFETY: Bounds check passed.
959            Some(unsafe { self.get_row_unchecked(i) })
960        } else {
961            None
962        }
963    }
964
965    /// Returns the i-th row without bounds checking.
966    ///
967    /// # Safety
968    ///
969    /// `i` must be less than `self.num_vectors()`.
970    #[inline]
971    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
972        // SAFETY: Caller must ensure i < self.num_vectors().
973        unsafe { self.repr.get_row(self.ptr, i) }
974    }
975
976    /// Returns a mutable reference to the `i`-th row, or `None` if out of bounds.
977    #[inline]
978    #[must_use]
979    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
980        if i < self.num_vectors() {
981            // SAFETY: Bounds check passed.
982            Some(unsafe { self.get_row_mut_unchecked(i) })
983        } else {
984            None
985        }
986    }
987
988    /// Returns a mutable reference to the i-th row without bounds checking.
989    ///
990    /// # Safety
991    ///
992    /// `i` must be less than [`num_vectors()`](Self::num_vectors).
993    #[inline]
994    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
995        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
996        // type ensure that `ptr` is compatible with `T`.
997        unsafe { self.repr.get_row_mut(self.ptr, i) }
998    }
999
1000    /// Reborrows as an immutable [`MatRef`].
1001    pub fn as_view(&self) -> MatRef<'_, T> {
1002        MatRef {
1003            ptr: self.ptr,
1004            repr: self.repr,
1005            _lifetime: PhantomData,
1006        }
1007    }
1008
1009    /// Returns an iterator over immutable row references.
1010    pub fn rows(&self) -> Rows<'_, T> {
1011        Rows::new(self.reborrow())
1012    }
1013
1014    /// Returns an iterator over mutable row references.
1015    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
1016        RowsMut::new(self.reborrow_mut())
1017    }
1018
1019    /// Return a [`Mat`] with the same contents as `self`.
1020    pub fn to_owned(&self) -> Mat<T>
1021    where
1022        T: NewCloned,
1023    {
1024        T::new_cloned(self.as_view())
1025    }
1026
1027    /// Construct a new [`MatMut`] over the raw pointer and representation without performing
1028    /// any validity checks.
1029    ///
1030    /// # Safety
1031    ///
1032    /// Argument `ptr` must point to memory compatible with [`Repr::layout`].
1033    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
1034        Self {
1035            ptr,
1036            repr,
1037            _lifetime: PhantomData,
1038        }
1039    }
1040
1041    /// Return the base pointer for the [`MatMut`].
1042    pub fn as_raw_ptr(&self) -> *const u8 {
1043        self.ptr.as_ptr()
1044    }
1045
1046    /// Return a mutable base pointer for the [`MatMut`].
1047    pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
1048        self.ptr.as_ptr()
1049    }
1050}
1051
1052// Reborrow: MatMut -> MatRef
1053impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
1054    type Target = MatRef<'this, T>;
1055
1056    fn reborrow(&'this self) -> Self::Target {
1057        self.as_view()
1058    }
1059}
1060
1061// ReborrowMut: MatMut -> MatMut (with shorter lifetime)
1062impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
1063    type Target = MatMut<'this, T>;
1064
1065    fn reborrow_mut(&'this mut self) -> Self::Target {
1066        MatMut {
1067            ptr: self.ptr,
1068            repr: self.repr,
1069            _lifetime: PhantomData,
1070        }
1071    }
1072}
1073
1074impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1075    /// Returns the raw dimension (columns) of the vectors in the matrix.
1076    #[inline]
1077    pub fn vector_dim(&self) -> usize {
1078        self.repr.ncols()
1079    }
1080
1081    /// Return the backing data as a contiguous slice of `T`.
1082    ///
1083    /// The returned slice has `num_vectors() * vector_dim()` elements in row-major order.
1084    #[inline]
1085    pub fn as_slice(&self) -> &[T] {
1086        self.as_view().as_slice()
1087    }
1088
1089    /// Return a [`MatrixView`] over the backing data.
1090    #[inline]
1091    pub fn as_matrix_view(&self) -> MatrixView<'_, T> {
1092        self.as_view().as_matrix_view()
1093    }
1094}
1095
1096//////////
1097// Rows //
1098//////////
1099
1100/// Iterator over immutable row references of a matrix.
1101///
1102/// Created by [`Mat::rows`], [`MatRef::rows`], or [`MatMut::rows`].
1103#[derive(Debug)]
1104pub struct Rows<'a, T: Repr> {
1105    matrix: MatRef<'a, T>,
1106    current: usize,
1107}
1108
1109impl<'a, T> Rows<'a, T>
1110where
1111    T: Repr,
1112{
1113    fn new(matrix: MatRef<'a, T>) -> Self {
1114        Self { matrix, current: 0 }
1115    }
1116}
1117
1118impl<'a, T> Iterator for Rows<'a, T>
1119where
1120    T: Repr + 'a,
1121{
1122    type Item = T::Row<'a>;
1123
1124    fn next(&mut self) -> Option<Self::Item> {
1125        let current = self.current;
1126        if current >= self.matrix.num_vectors() {
1127            None
1128        } else {
1129            self.current += 1;
1130            // SAFETY: We make sure through the above check that
1131            // the access is within bounds.
1132            //
1133            // Extending the lifetime to `'a` is safe because the underlying
1134            // MatRef has lifetime `'a`.
1135            Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1136        }
1137    }
1138
1139    fn size_hint(&self) -> (usize, Option<usize>) {
1140        let remaining = self.matrix.num_vectors() - self.current;
1141        (remaining, Some(remaining))
1142    }
1143}
1144
1145impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1146impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1147
1148/////////////
1149// RowsMut //
1150/////////////
1151
1152/// Iterator over mutable row references of a matrix.
1153///
1154/// Created by [`Mat::rows_mut`] or [`MatMut::rows_mut`].
1155#[derive(Debug)]
1156pub struct RowsMut<'a, T: ReprMut> {
1157    matrix: MatMut<'a, T>,
1158    current: usize,
1159}
1160
1161impl<'a, T> RowsMut<'a, T>
1162where
1163    T: ReprMut,
1164{
1165    fn new(matrix: MatMut<'a, T>) -> Self {
1166        Self { matrix, current: 0 }
1167    }
1168}
1169
1170impl<'a, T> Iterator for RowsMut<'a, T>
1171where
1172    T: ReprMut + 'a,
1173{
1174    type Item = T::RowMut<'a>;
1175
1176    fn next(&mut self) -> Option<Self::Item> {
1177        let current = self.current;
1178        if current >= self.matrix.num_vectors() {
1179            None
1180        } else {
1181            self.current += 1;
1182            // SAFETY: We make sure through the above check that
1183            // the access is within bounds.
1184            //
1185            // Extending the lifetime to `'a` is safe because:
1186            // 1. The underlying MatMut has lifetime `'a`.
1187            // 2. The iterator ensures that the mutable row indices are disjoint, so
1188            //    there is no aliasing as long as the implementation of `ReprMut` ensures
1189            //    there is not mutable sharing of the `RowMut` types.
1190            Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1191        }
1192    }
1193
1194    fn size_hint(&self) -> (usize, Option<usize>) {
1195        let remaining = self.matrix.num_vectors() - self.current;
1196        (remaining, Some(remaining))
1197    }
1198}
1199
1200impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1201impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1202
1203///////////
1204// Tests //
1205///////////
1206
1207#[cfg(test)]
1208mod tests {
1209    use super::*;
1210
1211    use std::fmt::Display;
1212
1213    use diskann_utils::lazy_format;
1214
1215    /// Helper to assert a type is Copy.
1216    fn assert_copy<T: Copy>(_: &T) {}
1217
1218    // ── Variance assertions ──────────────────────────────────────
1219    //
1220    // These functions are never called. The test is that they compile:
1221    // covariant positions must accept subtype coercions.
1222    //
1223    // The negative (invariance) counterparts live in
1224    // `tests/compile-fail/multi/{mat,matmut}_invariant.rs`.
1225
1226    /// `MatRef` is covariant in `'a`: a longer borrow can shorten.
1227    fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
1228        v: MatRef<'long, T>,
1229    ) -> MatRef<'short, T> {
1230        v
1231    }
1232
1233    /// `MatRef` is covariant in `T`: `Standard<&'long u8>` → `Standard<&'short u8>`.
1234    fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
1235        v: MatRef<'a, Standard<&'long u8>>,
1236    ) -> MatRef<'a, Standard<&'short u8>> {
1237        v
1238    }
1239
1240    /// `MatMut` is covariant in `'a`: a longer borrow can shorten.
1241    fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
1242        v: MatMut<'long, T>,
1243    ) -> MatMut<'short, T> {
1244        v
1245    }
1246
1247    fn edge_cases(nrows: usize) -> Vec<usize> {
1248        let max = usize::MAX;
1249
1250        vec![
1251            nrows,
1252            nrows + 1,
1253            nrows + 11,
1254            nrows + 20,
1255            max / 2,
1256            max.div_ceil(2),
1257            max - 1,
1258            max,
1259        ]
1260    }
1261
1262    fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1263        assert_eq!(x.repr(), &repr);
1264        assert_eq!(x.num_vectors(), repr.nrows());
1265        assert_eq!(x.vector_dim(), repr.ncols());
1266
1267        for i in 0..x.num_vectors() {
1268            let row = x.get_row_mut(i).unwrap();
1269            assert_eq!(row.len(), repr.ncols());
1270            row.iter_mut()
1271                .enumerate()
1272                .for_each(|(j, r)| *r = 10 * i + j);
1273        }
1274
1275        for i in edge_cases(repr.nrows()).into_iter() {
1276            assert!(x.get_row_mut(i).is_none());
1277        }
1278    }
1279
1280    fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1281        assert_eq!(x.repr(), &repr);
1282        assert_eq!(x.num_vectors(), repr.nrows());
1283        assert_eq!(x.vector_dim(), repr.ncols());
1284
1285        for i in 0..x.num_vectors() {
1286            let row = x.get_row_mut(i).unwrap();
1287            assert_eq!(row.len(), repr.ncols());
1288
1289            row.iter_mut()
1290                .enumerate()
1291                .for_each(|(j, r)| *r = 10 * i + j);
1292        }
1293
1294        for i in edge_cases(repr.nrows()).into_iter() {
1295            assert!(x.get_row_mut(i).is_none());
1296        }
1297    }
1298
1299    fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1300        assert_eq!(x.len(), repr.nrows());
1301        // Materialize all rows at once.
1302        let mut all_rows: Vec<_> = x.collect();
1303        assert_eq!(all_rows.len(), repr.nrows());
1304        for (i, row) in all_rows.iter_mut().enumerate() {
1305            assert_eq!(row.len(), repr.ncols());
1306            row.iter_mut()
1307                .enumerate()
1308                .for_each(|(j, r)| *r = 10 * i + j);
1309        }
1310    }
1311
1312    fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1313        assert_eq!(x.repr(), &repr);
1314        assert_eq!(x.num_vectors(), repr.nrows());
1315        assert_eq!(x.vector_dim(), repr.ncols());
1316
1317        for i in 0..x.num_vectors() {
1318            let row = x.get_row(i).unwrap();
1319
1320            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1321            row.iter().enumerate().for_each(|(j, r)| {
1322                assert_eq!(
1323                    *r,
1324                    10 * i + j,
1325                    "mismatched entry at row {}, col {} -- ctx: {}",
1326                    i,
1327                    j,
1328                    ctx
1329                )
1330            });
1331        }
1332
1333        for i in edge_cases(repr.nrows()).into_iter() {
1334            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1335        }
1336    }
1337
1338    fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1339        assert_eq!(x.repr(), &repr);
1340        assert_eq!(x.num_vectors(), repr.nrows());
1341        assert_eq!(x.vector_dim(), repr.ncols());
1342
1343        assert_copy(&x);
1344        for i in 0..x.num_vectors() {
1345            let row = x.get_row(i).unwrap();
1346            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1347
1348            row.iter().enumerate().for_each(|(j, r)| {
1349                assert_eq!(
1350                    *r,
1351                    10 * i + j,
1352                    "mismatched entry at row {}, col {} -- ctx: {}",
1353                    i,
1354                    j,
1355                    ctx
1356                )
1357            });
1358        }
1359
1360        for i in edge_cases(repr.nrows()).into_iter() {
1361            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1362        }
1363    }
1364
1365    fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1366        assert_eq!(x.repr(), &repr);
1367        assert_eq!(x.num_vectors(), repr.nrows());
1368        assert_eq!(x.vector_dim(), repr.ncols());
1369
1370        for i in 0..x.num_vectors() {
1371            let row = x.get_row(i).unwrap();
1372            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1373
1374            row.iter().enumerate().for_each(|(j, r)| {
1375                assert_eq!(
1376                    *r,
1377                    10 * i + j,
1378                    "mismatched entry at row {}, col {} -- ctx: {}",
1379                    i,
1380                    j,
1381                    ctx
1382                )
1383            });
1384        }
1385
1386        for i in edge_cases(repr.nrows()).into_iter() {
1387            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1388        }
1389    }
1390
1391    fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1392        assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1393        let all_rows: Vec<_> = x.collect();
1394        assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1395        for (i, row) in all_rows.iter().enumerate() {
1396            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1397            row.iter().enumerate().for_each(|(j, r)| {
1398                assert_eq!(
1399                    *r,
1400                    10 * i + j,
1401                    "mismatched entry at row {}, col {} -- ctx: {}",
1402                    i,
1403                    j,
1404                    ctx
1405                )
1406            });
1407        }
1408    }
1409
1410    //////////////
1411    // Standard //
1412    //////////////
1413
1414    #[test]
1415    fn standard_representation() {
1416        let repr = Standard::<f32>::new(4, 3).unwrap();
1417        assert_eq!(repr.nrows(), 4);
1418        assert_eq!(repr.ncols(), 3);
1419
1420        let layout = repr.layout().unwrap();
1421        assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1422        assert_eq!(layout.align(), std::mem::align_of::<f32>());
1423    }
1424
1425    #[test]
1426    fn standard_zero_dimensions() {
1427        for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1428            let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1429            assert_eq!(repr.nrows(), nrows);
1430            assert_eq!(repr.ncols(), ncols);
1431            let layout = repr.layout().unwrap();
1432            assert_eq!(layout.size(), 0);
1433        }
1434    }
1435
1436    #[test]
1437    fn standard_check_slice() {
1438        let repr = Standard::<u32>::new(3, 4).unwrap();
1439
1440        // Correct length succeeds
1441        let data = vec![0u32; 12];
1442        assert!(repr.check_slice(&data).is_ok());
1443
1444        // Too short fails
1445        let short = vec![0u32; 11];
1446        assert!(matches!(
1447            repr.check_slice(&short),
1448            Err(SliceError::LengthMismatch {
1449                expected: 12,
1450                found: 11
1451            })
1452        ));
1453
1454        // Too long fails
1455        let long = vec![0u32; 13];
1456        assert!(matches!(
1457            repr.check_slice(&long),
1458            Err(SliceError::LengthMismatch {
1459                expected: 12,
1460                found: 13
1461            })
1462        ));
1463
1464        // Overflow case
1465        let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1466        assert!(matches!(overflow_repr, Overflow { .. }));
1467    }
1468
1469    #[test]
1470    fn standard_new_rejects_element_count_overflow() {
1471        // nrows * ncols overflows usize even though per-element size is small.
1472        assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1473        assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1474        assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1475    }
1476
1477    #[test]
1478    fn standard_new_rejects_byte_count_exceeding_isize_max() {
1479        // Element count fits in usize, but total bytes exceed isize::MAX.
1480        let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1481        assert!(Standard::<u64>::new(half, 1).is_err());
1482        assert!(Standard::<u64>::new(1, half).is_err());
1483    }
1484
1485    #[test]
1486    fn standard_new_accepts_boundary_below_isize_max() {
1487        // Largest allocation that still fits in isize::MAX bytes.
1488        let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1489        let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1490        assert_eq!(repr.num_elements(), max_elems);
1491    }
1492
1493    #[test]
1494    fn standard_new_zst_rejects_element_count_overflow() {
1495        // For ZSTs the byte count is always 0, but element-count overflow
1496        // must still be caught so that `num_elements()` never wraps.
1497        assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1498        assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1499    }
1500
1501    #[test]
1502    fn standard_new_zst_accepts_large_non_overflowing() {
1503        // Large-but-valid ZST matrix: element count fits in usize.
1504        let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1505        assert_eq!(repr.num_elements(), usize::MAX);
1506        assert_eq!(repr.layout().unwrap().size(), 0);
1507    }
1508
1509    #[test]
1510    fn standard_new_overflow_error_display() {
1511        let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1512        let msg = err.to_string();
1513        assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1514
1515        let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1516        let zst_msg = zst_err.to_string();
1517        assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1518        assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1519    }
1520
1521    /////////
1522    // Mat //
1523    /////////
1524
1525    #[test]
1526    fn mat_new_and_basic_accessors() {
1527        let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1528        let base: *const u8 = mat.as_raw_ptr();
1529
1530        assert_eq!(mat.num_vectors(), 3);
1531        assert_eq!(mat.vector_dim(), 4);
1532
1533        let repr = mat.repr();
1534        assert_eq!(repr.nrows(), 3);
1535        assert_eq!(repr.ncols(), 4);
1536
1537        for (i, r) in mat.rows().enumerate() {
1538            assert_eq!(r, &[42, 42, 42, 42]);
1539            let ptr = r.as_ptr().cast::<u8>();
1540            assert_eq!(
1541                ptr,
1542                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1543            );
1544        }
1545    }
1546
1547    #[test]
1548    fn mat_new_with_default() {
1549        let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1550        let base: *const u8 = mat.as_raw_ptr();
1551
1552        assert_eq!(mat.num_vectors(), 2);
1553        for (i, row) in mat.rows().enumerate() {
1554            assert!(row.iter().all(|&v| v == 0));
1555
1556            let ptr = row.as_ptr().cast::<u8>();
1557            assert_eq!(
1558                ptr,
1559                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1560            );
1561        }
1562    }
1563
1564    const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1565    const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1566
1567    #[test]
1568    fn test_mat() {
1569        for nrows in ROWS {
1570            for ncols in COLS {
1571                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1572                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1573
1574                // Populate the matrix using `&mut Mat`
1575                {
1576                    let ctx = &lazy_format!("{ctx} - direct");
1577                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1578
1579                    assert_eq!(mat.num_vectors(), *nrows);
1580                    assert_eq!(mat.vector_dim(), *ncols);
1581
1582                    fill_mat(&mut mat, repr);
1583
1584                    check_mat(&mat, repr, ctx);
1585                    check_mat_ref(mat.reborrow(), repr, ctx);
1586                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1587                    check_rows(mat.rows(), repr, ctx);
1588
1589                    // Check reborrow preserves pointers.
1590                    assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1591                    assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1592                }
1593
1594                // Populate the matrix using `MatMut`
1595                {
1596                    let ctx = &lazy_format!("{ctx} - matmut");
1597                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1598                    let matmut = mat.reborrow_mut();
1599
1600                    assert_eq!(matmut.num_vectors(), *nrows);
1601                    assert_eq!(matmut.vector_dim(), *ncols);
1602
1603                    fill_mat_mut(matmut, repr);
1604
1605                    check_mat(&mat, repr, ctx);
1606                    check_mat_ref(mat.reborrow(), repr, ctx);
1607                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1608                    check_rows(mat.rows(), repr, ctx);
1609                }
1610
1611                // Populate the matrix using `RowsMut`
1612                {
1613                    let ctx = &lazy_format!("{ctx} - rows_mut");
1614                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1615                    fill_rows_mut(mat.rows_mut(), repr);
1616
1617                    check_mat(&mat, repr, ctx);
1618                    check_mat_ref(mat.reborrow(), repr, ctx);
1619                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1620                    check_rows(mat.rows(), repr, ctx);
1621                }
1622            }
1623        }
1624    }
1625
1626    #[test]
1627    fn test_mat_clone() {
1628        for nrows in ROWS {
1629            for ncols in COLS {
1630                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1631                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1632
1633                let mut mat = Mat::new(repr, Defaulted).unwrap();
1634                fill_mat(&mut mat, repr);
1635
1636                // Clone via Mat::clone
1637                {
1638                    let ctx = &lazy_format!("{ctx} - Mat::clone");
1639                    let cloned = mat.clone();
1640
1641                    assert_eq!(cloned.num_vectors(), *nrows);
1642                    assert_eq!(cloned.vector_dim(), *ncols);
1643
1644                    check_mat(&cloned, repr, ctx);
1645                    check_mat_ref(cloned.reborrow(), repr, ctx);
1646                    check_rows(cloned.rows(), repr, ctx);
1647
1648                    // Cloned allocation is independent.
1649                    if repr.num_elements() > 0 {
1650                        assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1651                    }
1652                }
1653
1654                // Clone via MatRef::to_owned
1655                {
1656                    let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1657                    let owned = mat.as_view().to_owned();
1658
1659                    check_mat(&owned, repr, ctx);
1660                    check_mat_ref(owned.reborrow(), repr, ctx);
1661                    check_rows(owned.rows(), repr, ctx);
1662
1663                    if repr.num_elements() > 0 {
1664                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1665                    }
1666                }
1667
1668                // Clone via MatMut::to_owned
1669                {
1670                    let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1671                    let owned = mat.as_view_mut().to_owned();
1672
1673                    check_mat(&owned, repr, ctx);
1674                    check_mat_ref(owned.reborrow(), repr, ctx);
1675                    check_rows(owned.rows(), repr, ctx);
1676
1677                    if repr.num_elements() > 0 {
1678                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1679                    }
1680                }
1681            }
1682        }
1683    }
1684
1685    #[test]
1686    fn test_mat_refmut() {
1687        for nrows in ROWS {
1688            for ncols in COLS {
1689                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1690                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1691
1692                // Populate the matrix using `&mut Mat`
1693                {
1694                    let ctx = &lazy_format!("{ctx} - by matmut");
1695                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1696                    let ptr = b.as_ptr().cast::<u8>();
1697                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1698
1699                    assert_eq!(
1700                        ptr,
1701                        matmut.as_raw_ptr(),
1702                        "underlying memory should be preserved",
1703                    );
1704
1705                    fill_mat_mut(matmut.reborrow_mut(), repr);
1706
1707                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1708                    check_mat_ref(matmut.reborrow(), repr, ctx);
1709                    check_rows(matmut.rows(), repr, ctx);
1710                    check_rows(matmut.reborrow().rows(), repr, ctx);
1711
1712                    let matref = MatRef::new(repr, &b).unwrap();
1713                    check_mat_ref(matref, repr, ctx);
1714                    check_mat_ref(matref.reborrow(), repr, ctx);
1715                    check_rows(matref.rows(), repr, ctx);
1716                }
1717
1718                // Populate the matrix using `RowsMut`
1719                {
1720                    let ctx = &lazy_format!("{ctx} - by rows");
1721                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1722                    let ptr = b.as_ptr().cast::<u8>();
1723                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1724
1725                    assert_eq!(
1726                        ptr,
1727                        matmut.as_raw_ptr(),
1728                        "underlying memory should be preserved",
1729                    );
1730
1731                    fill_rows_mut(matmut.rows_mut(), repr);
1732
1733                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1734                    check_mat_ref(matmut.reborrow(), repr, ctx);
1735                    check_rows(matmut.rows(), repr, ctx);
1736                    check_rows(matmut.reborrow().rows(), repr, ctx);
1737
1738                    let matref = MatRef::new(repr, &b).unwrap();
1739                    check_mat_ref(matref, repr, ctx);
1740                    check_mat_ref(matref.reborrow(), repr, ctx);
1741                    check_rows(matref.rows(), repr, ctx);
1742                }
1743            }
1744        }
1745    }
1746
1747    //////////////////
1748    // Constructors //
1749    //////////////////
1750
1751    #[test]
1752    fn test_standard_new_owned() {
1753        let rows = [0, 1, 2, 3, 5, 10];
1754        let cols = [0, 1, 2, 3, 5, 10];
1755
1756        for nrows in rows {
1757            for ncols in cols {
1758                let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1759                let rows_iter = m.rows();
1760                let len = <_ as ExactSizeIterator>::len(&rows_iter);
1761                assert_eq!(len, nrows);
1762                for r in rows_iter {
1763                    assert_eq!(r.len(), ncols);
1764                    assert!(r.iter().all(|i| *i == 1usize));
1765                }
1766            }
1767        }
1768    }
1769
1770    #[test]
1771    fn matref_new_slice_length_error() {
1772        let repr = Standard::<u32>::new(3, 4).unwrap();
1773
1774        // Correct length succeeds
1775        let data = vec![0u32; 12];
1776        assert!(MatRef::new(repr, &data).is_ok());
1777
1778        // Too short fails
1779        let short = vec![0u32; 11];
1780        assert!(matches!(
1781            MatRef::new(repr, &short),
1782            Err(SliceError::LengthMismatch {
1783                expected: 12,
1784                found: 11
1785            })
1786        ));
1787
1788        // Too long fails
1789        let long = vec![0u32; 13];
1790        assert!(matches!(
1791            MatRef::new(repr, &long),
1792            Err(SliceError::LengthMismatch {
1793                expected: 12,
1794                found: 13
1795            })
1796        ));
1797    }
1798
1799    #[test]
1800    fn matmut_new_slice_length_error() {
1801        let repr = Standard::<u32>::new(3, 4).unwrap();
1802
1803        // Correct length succeeds
1804        let mut data = vec![0u32; 12];
1805        assert!(MatMut::new(repr, &mut data).is_ok());
1806
1807        // Too short fails
1808        let mut short = vec![0u32; 11];
1809        assert!(matches!(
1810            MatMut::new(repr, &mut short),
1811            Err(SliceError::LengthMismatch {
1812                expected: 12,
1813                found: 11
1814            })
1815        ));
1816
1817        // Too long fails
1818        let mut long = vec![0u32; 13];
1819        assert!(matches!(
1820            MatMut::new(repr, &mut long),
1821            Err(SliceError::LengthMismatch {
1822                expected: 12,
1823                found: 13
1824            })
1825        ));
1826    }
1827
1828    #[test]
1829    fn as_matrix_view_roundtrip() {
1830        let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1831
1832        // MatRef
1833        let matref = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap();
1834        let view = matref.as_matrix_view();
1835        assert_eq!(view.nrows(), 2);
1836        assert_eq!(view.ncols(), 3);
1837        for row in 0..2 {
1838            for col in 0..3 {
1839                assert_eq!(view[(row, col)], data[row * 3 + col]);
1840            }
1841        }
1842        assert_eq!(matref.as_slice(), &data);
1843
1844        // Mat
1845        let mut mat = Mat::new(Standard::<f32>::new(2, 3).unwrap(), 0.0f32).unwrap();
1846        for i in 0..2 {
1847            let r = mat.get_row_mut(i).unwrap();
1848            for j in 0..3 {
1849                r[j] = data[i * 3 + j];
1850            }
1851        }
1852        let view = mat.as_matrix_view();
1853        assert_eq!(view.nrows(), 2);
1854        assert_eq!(view.ncols(), 3);
1855        for row in 0..2 {
1856            for col in 0..3 {
1857                assert_eq!(view[(row, col)], data[row * 3 + col]);
1858            }
1859        }
1860        assert_eq!(mat.as_slice(), &data);
1861
1862        // MatMut
1863        let mut buf = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1864        let matmut = MatMut::new(Standard::new(2, 3).unwrap(), &mut buf).unwrap();
1865        let view = matmut.as_matrix_view();
1866        assert_eq!(view.nrows(), 2);
1867        assert_eq!(view.ncols(), 3);
1868        for row in 0..2 {
1869            for col in 0..3 {
1870                assert_eq!(view[(row, col)], data[row * 3 + col]);
1871            }
1872        }
1873        assert_eq!(matmut.as_slice(), &data);
1874    }
1875}