Skip to main content

diskann_quantization/multi_vector/
matrix.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Row-major matrix types for multi-vector representations.
5//!
6//! This module provides flexible matrix abstractions that support different underlying
7//! storage formats through the [`Repr`] trait. The primary types are:
8//!
9//! - [`Mat`]: An owning matrix that manages its own memory.
10//! - [`MatRef`]: An immutable borrowed view of matrix data.
11//! - [`MatMut`]: A mutable borrowed view of matrix data.
12//!
13//! # Representations
14//!
15//! Representation types interact with the [`Mat`] family of types using the following traits:
16//!
17//! - [`Repr`]: Read-only matrix representation.
18//! - [`ReprMut`]: Mutable matrix representation.
19//! - [`ReprOwned`]: Owning matrix representation.
20//!
21//! Each trait refinement has a corresponding constructor:
22//!
23//! - [`NewRef`]: Construct a read-only [`MatRef`] view over a slice.
24//! - [`NewMut`]: Construct a mutable [`MatMut`] matrix view over a slice.
25//! - [`NewOwned`]: Construct a new owning [`Mat`].
26//!
27
28use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
29
30use diskann_utils::{Reborrow, ReborrowMut};
31use thiserror::Error;
32
33use crate::utils;
34
35/// Representation trait describing the layout and access patterns for a matrix.
36///
37/// Implementations define how raw bytes are interpreted as typed rows. This enables
38/// matrices over different storage formats (dense, quantized, etc.) using a single
39/// generic [`Mat`] type.
40///
41/// # Associated Types
42///
43/// - `Row<'a>`: The immutable row type (e.g., `&[f32]`, `&[f16]`).
44///
45/// # Safety
46///
47/// Implementations must ensure:
48///
49/// - [`get_row`](Self::get_row) returns valid references for the given row index.
50///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
51///   the contract for the raw pointer.
52///
53/// - The objects implicitly managed by this representation inherit the `Send` and `Sync`
54///   attributes of `Repr`. That is, `Repr: Send` implies that the objects in backing memory
55///   are [`Send`], and likewise with `Sync`. This is necessary to apply [`Send`] and [`Sync`]
56///   bounds to [`Mat`], [`MatRef`], and [`MatMut`].
57pub unsafe trait Repr: Copy {
58    /// Immutable row reference type.
59    type Row<'a>
60    where
61        Self: 'a;
62
63    /// Returns the number of rows in the matrix.
64    ///
65    /// # Safety Contract
66    ///
67    /// This function must be loosely pure in the sense that for any given instance of
68    /// `self`, `self.nrows()` must return the same value.
69    fn nrows(&self) -> usize;
70
71    /// Returns the memory layout for a memory allocation containing [`Repr::nrows`] vectors
72    /// each with vector dimension [`Repr::ncols`].
73    ///
74    /// # Safety Contract
75    ///
76    /// The [`Layout`] returned from this method must be consistent with the contract of
77    /// [`Repr::get_row`].
78    fn layout(&self) -> Result<Layout, LayoutError>;
79
80    /// Returns an immutable reference to the `i`-th row.
81    ///
82    /// # Safety
83    ///
84    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
85    /// - The entire range for this slice must be within a single allocation.
86    /// - `i` must be less than [`Repr::nrows`].
87    /// - The memory referenced by the returned [`Repr::Row`] must not be mutated for the
88    ///   duration of lifetime `'a`.
89    /// - The lifetime for the returned [`Repr::Row`] is inferred from its usage. Correct
90    ///   usage must properly tie the lifetime to a source.
91    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
92}
93
94/// Extension of [`Repr`] that supports mutable row access.
95///
96/// # Associated Types
97///
98/// - `RowMut<'a>`: The mutable row type (e.g., `&mut [f32]`).
99///
100/// # Safety
101///
102/// Implementors must ensure:
103///
104/// - [`get_row_mut`](Self::get_row_mut) returns valid references for the given row index.
105///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
106///   the contract for the raw pointer.
107///
108///   Additionally, since the implementation of the [`RowsMut`] iterator can give out rows
109///   for all `i` in `0..self.nrows()`, the implementation of [`Self::get_row_mut`] must be
110///   such that the result for disjoint `i` must not interfere with one another.
111pub unsafe trait ReprMut: Repr {
112    /// Mutable row reference type.
113    type RowMut<'a>
114    where
115        Self: 'a;
116
117    /// Returns a mutable reference to the i-th row.
118    ///
119    /// # Safety
120    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
121    /// - The entire range for this slice must be within a single allocation.
122    /// - `i` must be less than `self.nrows()`.
123    /// - The memory referenced by the returned [`ReprMut::RowMut`] must not be accessed
124    ///   through any other reference for the duration of lifetime `'a`.
125    /// - The lifetime for the returned [`ReprMut::RowMut`] is inferred from its usage.
126    ///   Correct usage must properly tie the lifetime to a source.
127    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
128}
129
130/// Extension trait for [`Repr`] that supports deallocation of owned matrices. This is used
131/// in conjunction with [`NewOwned`] to create matrices.
132///
133/// Requires [`ReprMut`] since owned matrices should support mutation.
134///
135/// # Safety
136///
137/// Implementors must ensure that `drop` properly deallocates the memory in a way compatible
138/// with all [`NewOwned`] implementations.
139pub unsafe trait ReprOwned: ReprMut {
140    /// Deallocates memory at `ptr` and drops `self`.
141    ///
142    /// # Safety
143    ///
144    /// - `ptr` must have been obtained via [`NewOwned`] with the same value of `self`.
145    /// - This method may only be called once for such a pointer.
146    /// - After calling this method, the memory behind `ptr` may not be dereferenced at all.
147    unsafe fn drop(self, ptr: NonNull<u8>);
148}
149
150/// A new-type version of `std::alloc::LayoutError` for cleaner error handling.
151///
152/// This is basically the same as [`std::alloc::LayoutError`], but constructible in
153/// use code to allow implementors of [`Repr::layout`] to return it for reasons other than
154/// those derived from `std::alloc::Layout`'s methods.
155#[derive(Debug, Clone, Copy)]
156#[non_exhaustive]
157pub struct LayoutError;
158
159impl LayoutError {
160    /// Construct a new opaque [`LayoutError`].
161    pub fn new() -> Self {
162        Self
163    }
164}
165
166impl Default for LayoutError {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl std::fmt::Display for LayoutError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        write!(f, "LayoutError")
175    }
176}
177
178impl std::error::Error for LayoutError {}
179
180impl From<std::alloc::LayoutError> for LayoutError {
181    fn from(_: std::alloc::LayoutError) -> Self {
182        LayoutError
183    }
184}
185
186//////////////////
187// Constructors //
188//////////////////
189
190/// Create a new [`MatRef`] over a slice.
191///
192/// # Safety
193///
194/// Implementations must validate the length (and any other requirements) of the provided
195/// slice to ensure it is compatible with the implementation of [`Repr`].
196pub unsafe trait NewRef<T>: Repr {
197    /// Errors that can occur when initializing.
198    type Error;
199
200    /// Create a new [`MatRef`] over `slice`.
201    fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
202}
203
204/// Create a new [`MatMut`] over a slice.
205///
206/// # Safety
207///
208/// Implementations must validate the length (and any other requirements) of the provided
209/// slice to ensure it is compatible with the implementation of [`ReprMut`].
210pub unsafe trait NewMut<T>: ReprMut {
211    /// Errors that can occur when initializing.
212    type Error;
213
214    /// Create a new [`MatMut`] over `slice`.
215    fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
216}
217
218/// Create a new [`Mat`] from an initializer.
219///
220/// # Safety
221///
222/// Implementations must ensure that the returned [`Mat`] is compatible with
223/// `Self`'s implementation of [`ReprOwned`].
224pub unsafe trait NewOwned<T>: ReprOwned {
225    /// Errors that can occur when initializing.
226    type Error;
227
228    /// Create a new [`Mat`] initialized with `init`.
229    fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
230}
231
232/// An initializer argument to [`NewOwned`] that uses a type's [`Default`] implementation
233/// to initialize a matrix.
234///
235/// ```rust
236/// use diskann_quantization::multi_vector::{Mat, Standard, Defaulted};
237/// let mat = Mat::new(Standard::<f32>::new(4, 3), Defaulted).unwrap();
238/// for i in 0..4 {
239///     assert!(mat.get_row(i).unwrap().iter().all(|&x| x == 0.0f32));
240/// }
241/// ```
242#[derive(Debug, Clone, Copy)]
243pub struct Defaulted;
244
245//////////////
246// Standard //
247//////////////
248
249/// Metadata for dense row-major matrices of `Copy` types.
250///
251/// Rows are stored contiguously as `&[T]` slices. This is the default representation
252/// type for standard floating-point multi-vectors.
253///
254/// # Row Types
255///
256/// - `Row<'a>`: `&'a [T]`
257/// - `RowMut<'a>`: `&'a mut [T]`
258#[derive(Debug, Clone, Copy, PartialEq, Eq)]
259pub struct Standard<T> {
260    nrows: usize,
261    ncols: usize,
262    _elem: PhantomData<T>,
263}
264
265impl<T: Copy> Standard<T> {
266    /// Create a new `Standard` for data of type `T`.
267    pub fn new(nrows: usize, ncols: usize) -> Self {
268        Self {
269            nrows,
270            ncols,
271            _elem: PhantomData,
272        }
273    }
274
275    /// Returns the number of total elements (`rows x cols`) in this matrix, returning `None`
276    /// if this computation overflows.
277    pub fn num_elements(&self) -> Option<usize> {
278        self.nrows.checked_mul(self.ncols())
279    }
280
281    /// Returns `ncols`, the number of elements in a row of this matrix.
282    fn ncols(&self) -> usize {
283        self.ncols
284    }
285
286    /// Checks the following:
287    ///
288    /// 1. Computation of the number of elements in `self` does not overflow.
289    /// 2. Argument `slice` has the expected number of elements.
290    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
291        let len = self.num_elements().ok_or(SliceError::Overflow)?;
292
293        if slice.len() != len {
294            Err(SliceError::LengthMismatch {
295                expected: len,
296                found: slice.len(),
297            })
298        } else {
299            Ok(())
300        }
301    }
302}
303
304/// Error types for [`Standard`].
305#[derive(Debug, Clone, Copy, Error)]
306#[non_exhaustive]
307pub enum SliceError {
308    #[error("Length mismatch: expected {expected}, found {found}")]
309    LengthMismatch { expected: usize, found: usize },
310    #[error("Computing slice length overflowed.")]
311    Overflow,
312}
313
314// SAFETY: The implementation correctly computes row offsets as `i * ncols` and
315// constructs valid slices of the appropriate length. The `layout` method correctly
316// reports the memory layout requirements.
317unsafe impl<T: Copy> Repr for Standard<T> {
318    type Row<'a>
319        = &'a [T]
320    where
321        T: 'a;
322
323    fn nrows(&self) -> usize {
324        self.nrows
325    }
326
327    fn layout(&self) -> Result<Layout, LayoutError> {
328        let elements = self.num_elements().ok_or(LayoutError::new())?;
329        Ok(Layout::array::<T>(elements)?)
330    }
331
332    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
333        debug_assert!(ptr.cast::<T>().is_aligned());
334        debug_assert!(i < self.nrows);
335
336        let row_ptr = ptr.as_ptr().cast::<T>().add(i * self.ncols);
337        std::slice::from_raw_parts(row_ptr, self.ncols)
338    }
339}
340
341// SAFETY: The implementation correctly computes row offsets and constructs valid mutable
342// slices.
343unsafe impl<T: Copy> ReprMut for Standard<T> {
344    type RowMut<'a>
345        = &'a mut [T]
346    where
347        T: 'a;
348
349    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
350        debug_assert!(ptr.cast::<T>().is_aligned());
351        debug_assert!(i < self.nrows);
352
353        let row_ptr = ptr.as_ptr().cast::<T>().add(i * self.ncols);
354        std::slice::from_raw_parts_mut(row_ptr, self.ncols)
355    }
356}
357
358// SAFETY: The drop implementation correctly reconstructs a Box from the raw pointer
359// using the same length (nrows * ncols) that was used for allocation, allowing Box
360// to properly deallocate the memory.
361unsafe impl<T: Copy> ReprOwned for Standard<T> {
362    unsafe fn drop(self, ptr: NonNull<u8>) {
363        // SAFETY: The caller guarantees that `ptr` was obtained from an implementation of
364        // `NewOwned` for an equivalent instance of `self`.
365        //
366        // We ensure that `NewOwned` goes through boxes, so here we reconstruct a Box to
367        // let it handle deallocation.
368        unsafe {
369            let slice_ptr = std::ptr::slice_from_raw_parts_mut(
370                ptr.cast::<T>().as_ptr(),
371                self.nrows * self.ncols,
372            );
373            let _ = Box::from_raw(slice_ptr);
374        }
375    }
376}
377
378// SAFETY: The implementation uses guarantees from `Box` to ensure that the pointer
379// initialized by it is non-null and properly aligned to the underlying type.
380unsafe impl<T> NewOwned<T> for Standard<T>
381where
382    T: Copy,
383{
384    type Error = crate::error::Infallible;
385    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
386        let b: Box<[T]> = (0..self.nrows() * self.ncols()).map(|_| value).collect();
387        // SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw)
388        // the returned pointer is non-null.
389        let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
390
391        // SAFETY: `ptr` is properly aligned and points to a slice of the required length.
392        // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
393        // it from `Box::into_raw`.
394        Ok(unsafe { Mat::from_raw_parts(self, ptr) })
395    }
396}
397
398// SAFETY: This safely reuses `<Self as NewOwned<T>>`.
399unsafe impl<T> NewOwned<Defaulted> for Standard<T>
400where
401    T: Copy + Default,
402{
403    type Error = crate::error::Infallible;
404    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
405        self.new_owned(T::default())
406    }
407}
408
409// SAFETY: This checks that the slice has the correct length, which is all that is
410// required for [`Repr`].
411unsafe impl<T> NewRef<T> for Standard<T>
412where
413    T: Copy,
414{
415    type Error = SliceError;
416    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
417        self.check_slice(data)?;
418
419        // SAFETY: The function `check_slice` verifies that `data` is compatible with
420        // the layout requirement of `Standard`.
421        //
422        // We've properly checked that the underlying pointer is okay.
423        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
424    }
425}
426
427// SAFETY: This checks that the slice has the correct length, which is all that is
428// required for [`ReprMut`].
429unsafe impl<T> NewMut<T> for Standard<T>
430where
431    T: Copy,
432{
433    type Error = SliceError;
434    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
435        self.check_slice(data)?;
436
437        // SAFETY: The function `check_slice` verifies that `data` is compatible with
438        // the layout requirement of `Standard`.
439        //
440        // We've properly checked that the underlying pointer is okay.
441        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
442    }
443}
444
445/////////
446// Mat //
447/////////
448
449/// An owning matrix that manages its own memory.
450///
451/// The matrix stores raw bytes interpreted according to representation type `T`.
452/// Memory is automatically deallocated when the matrix is dropped.
453#[derive(Debug)]
454pub struct Mat<T: ReprOwned> {
455    ptr: NonNull<u8>,
456    repr: T,
457}
458
459// SAFETY: [`Repr`] is required to propagate its `Send` bound.
460unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
461
462// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
463unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
464
465impl<T: ReprOwned> Mat<T> {
466    /// Create a new matrix using `init` as the initializer.
467    pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
468    where
469        T: NewOwned<U>,
470    {
471        repr.new_owned(init)
472    }
473
474    /// Returns the number of rows (vectors) in the matrix.
475    #[inline]
476    pub fn num_vectors(&self) -> usize {
477        self.repr.nrows()
478    }
479
480    /// Returns a reference to the underlying representation.
481    pub fn repr(&self) -> &T {
482        &self.repr
483    }
484
485    /// Returns the `i`th row if `i < self.num_vectors()`.
486    #[must_use]
487    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
488        if i < self.num_vectors() {
489            // SAFETY: Bounds check passed, and the Mat was constructed
490            // with valid representation and pointer.
491            let row = unsafe { self.get_row_unchecked(i) };
492            Some(row)
493        } else {
494            None
495        }
496    }
497
498    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
499        // SAFETY: Caller must ensure i < self.num_vectors(). The constructors for this type
500        // ensure that `ptr` is compatible with `T`.
501        unsafe { self.repr.get_row(self.ptr, i) }
502    }
503
504    /// Returns the `i`th mutable row if `i < self.num_vectors()`.
505    #[must_use]
506    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
507        if i < self.num_vectors() {
508            // SAFETY: Bounds check passed, and we have exclusive access via &mut self.
509            Some(unsafe { self.get_row_mut_unchecked(i) })
510        } else {
511            None
512        }
513    }
514
515    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
516        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
517        // type ensure that `ptr` is compatible with `T`.
518        unsafe { self.repr.get_row_mut(self.ptr, i) }
519    }
520
521    /// Returns an immutable view of the matrix.
522    #[inline]
523    pub fn as_view(&self) -> MatRef<'_, T> {
524        MatRef {
525            ptr: self.ptr,
526            repr: self.repr,
527            _lifetime: PhantomData,
528        }
529    }
530
531    /// Returns a mutable view of the matrix.
532    #[inline]
533    pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
534        MatMut {
535            ptr: self.ptr,
536            repr: self.repr,
537            _lifetime: PhantomData,
538        }
539    }
540
541    /// Returns an iterator over immutable row references.
542    pub fn rows(&self) -> Rows<'_, T> {
543        Rows::new(self.reborrow())
544    }
545
546    /// Returns an iterator over mutable row references.
547    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
548        RowsMut::new(self.reborrow_mut())
549    }
550
551    /// Construct a new [`Mat`] over the raw pointer and representation without performing
552    /// any validity checks.
553    ///
554    /// # Safety
555    ///
556    /// Argument `ptr` must be:
557    ///
558    /// 1. Point to memory compatible with [`Repr::layout`].
559    /// 2. Be compatible with the drop logic in [`ReprOwned`].
560    pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
561        Self { ptr, repr }
562    }
563
564    #[cfg(test)]
565    fn as_ptr(&self) -> NonNull<u8> {
566        self.ptr
567    }
568}
569
570impl<T: ReprOwned> Drop for Mat<T> {
571    fn drop(&mut self) {
572        // SAFETY: `ptr` was correctly initialized according to `layout`
573        // and we are guaranteed exclusive access to the data due to Rust borrow rules.
574        unsafe { self.repr.drop(self.ptr) };
575    }
576}
577
578impl<T: Copy> Mat<Standard<T>> {
579    /// Returns the raw dimension (columns) of the vectors in the matrix.
580    #[inline]
581    pub fn vector_dim(&self) -> usize {
582        self.repr.ncols()
583    }
584}
585
586////////////
587// MatRef //
588////////////
589
590/// An immutable borrowed view of a matrix.
591///
592/// Provides read-only access to matrix data without ownership. Implements [`Copy`]
593/// and can be freely cloned.
594///
595/// # Type Parameter
596/// - `T`: A [`Repr`] implementation defining the row layout.
597///
598/// # Access
599/// - [`get_row`](Self::get_row): Get an immutable row by index.
600/// - [`rows`](Self::rows): Iterate over all rows.
601#[derive(Debug, Clone, Copy)]
602pub struct MatRef<'a, T: Repr> {
603    pub(crate) ptr: NonNull<u8>,
604    pub(crate) repr: T,
605    /// Marker to tie the lifetime to the borrowed data.
606    pub(crate) _lifetime: PhantomData<&'a [u8]>,
607}
608
609// SAFETY: [`Repr`] is required to propagate its `Send` bound.
610unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
611
612// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
613unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
614
615impl<'a, T: Repr> MatRef<'a, T> {
616    /// Construct a new [`MatRef`] over `data`.
617    pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
618    where
619        T: NewRef<U>,
620    {
621        repr.new_ref(data)
622    }
623
624    /// Returns the number of rows (vectors) in the matrix.
625    #[inline]
626    pub fn num_vectors(&self) -> usize {
627        self.repr.nrows()
628    }
629
630    /// Returns a reference to the underlying representation.
631    pub fn repr(&self) -> &T {
632        &self.repr
633    }
634
635    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
636    #[must_use]
637    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
638        if i < self.num_vectors() {
639            // SAFETY: Bounds check passed, and the MatRef was constructed
640            // with valid representation and pointer.
641            let row = unsafe { self.get_row_unchecked(i) };
642            Some(row)
643        } else {
644            None
645        }
646    }
647
648    /// Returns the i-th row without bounds checking.
649    ///
650    /// # Safety
651    ///
652    /// `i` must be less than `self.num_vectors()`.
653    #[inline]
654    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
655        // SAFETY: Caller must ensure i < self.num_vectors().
656        unsafe { self.repr.get_row(self.ptr, i) }
657    }
658
659    /// Returns an iterator over immutable row references.
660    pub fn rows(&self) -> Rows<'_, T> {
661        Rows::new(*self)
662    }
663
664    /// Construct a new [`MatRef`] over the raw pointer and representation without performing
665    /// any validity checks.
666    ///
667    /// # Safety
668    ///
669    /// Argument `ptr` must point to memory compatible with [`Repr::layout`] and pass any
670    /// validity checks required by `T`.
671    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
672        Self {
673            ptr,
674            repr,
675            _lifetime: PhantomData,
676        }
677    }
678}
679
680impl<'a, T: Copy> MatRef<'a, Standard<T>> {
681    /// Returns the raw dimension (columns) of the vectors in the matrix.
682    #[inline]
683    pub fn vector_dim(&self) -> usize {
684        self.repr.ncols()
685    }
686}
687
688// Reborrow: Mat -> MatRef
689impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
690    type Target = MatRef<'this, T>;
691
692    fn reborrow(&'this self) -> Self::Target {
693        self.as_view()
694    }
695}
696
697// ReborrowMut: Mat -> MatMut
698impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
699    type Target = MatMut<'this, T>;
700
701    fn reborrow_mut(&'this mut self) -> Self::Target {
702        self.as_view_mut()
703    }
704}
705
706// Reborrow: MatRef -> MatRef (with shorter lifetime)
707impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
708    type Target = MatRef<'this, T>;
709
710    fn reborrow(&'this self) -> Self::Target {
711        MatRef {
712            ptr: self.ptr,
713            repr: self.repr,
714            _lifetime: PhantomData,
715        }
716    }
717}
718
719////////////
720// MatMut //
721////////////
722
723/// A mutable borrowed view of a matrix.
724///
725/// Provides read-write access to matrix data without ownership.
726///
727/// # Type Parameter
728/// - `T`: A [`ReprMut`] implementation defining the row layout.
729///
730/// # Access
731/// - [`get_row`](Self::get_row): Get an immutable row by index.
732/// - [`get_row_mut`](Self::get_row_mut): Get a mutable row by index.
733/// - [`as_view`](Self::as_view): Reborrow as immutable [`MatRef`].
734/// - [`rows`](Self::rows), [`rows_mut`](Self::rows_mut): Iterate over rows.
735#[derive(Debug)]
736pub struct MatMut<'a, T: ReprMut> {
737    pub(crate) ptr: NonNull<u8>,
738    pub(crate) repr: T,
739    /// Marker to tie the lifetime to the mutably borrowed data.
740    pub(crate) _lifetime: PhantomData<&'a mut [u8]>,
741}
742
743// SAFETY: [`ReprMut`] is required to propagate its `Send` bound.
744unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
745
746// SAFETY: [`ReprMut`] is required to propagate its `Sync` bound.
747unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
748
749impl<'a, T: ReprMut> MatMut<'a, T> {
750    /// Construct a new [`MatMut`] over `data`.
751    pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
752    where
753        T: NewMut<U>,
754    {
755        repr.new_mut(data)
756    }
757
758    /// Returns the number of rows (vectors) in the matrix.
759    #[inline]
760    pub fn num_vectors(&self) -> usize {
761        self.repr.nrows()
762    }
763
764    /// Returns a reference to the underlying representation.
765    pub fn repr(&self) -> &T {
766        &self.repr
767    }
768
769    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
770    #[inline]
771    #[must_use]
772    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
773        if i < self.num_vectors() {
774            // SAFETY: Bounds check passed.
775            Some(unsafe { self.get_row_unchecked(i) })
776        } else {
777            None
778        }
779    }
780
781    /// Returns the i-th row without bounds checking.
782    ///
783    /// # Safety
784    ///
785    /// `i` must be less than `self.num_vectors()`.
786    #[inline]
787    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
788        // SAFETY: Caller must ensure i < self.num_vectors().
789        unsafe { self.repr.get_row(self.ptr, i) }
790    }
791
792    /// Returns a mutable reference to the `i`-th row, or `None` if out of bounds.
793    #[inline]
794    #[must_use]
795    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
796        if i < self.num_vectors() {
797            // SAFETY: Bounds check passed.
798            Some(unsafe { self.get_row_mut_unchecked(i) })
799        } else {
800            None
801        }
802    }
803
804    /// Returns a mutable reference to the i-th row without bounds checking.
805    ///
806    /// # Safety
807    ///
808    /// `i` must be less than [`num_vectors()`](Self::num_vectors).
809    #[inline]
810    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
811        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
812        // type ensure that `ptr` is compatible with `T`.
813        unsafe { self.repr.get_row_mut(self.ptr, i) }
814    }
815
816    /// Reborrows as an immutable [`MatRef`].
817    pub fn as_view(&self) -> MatRef<'_, T> {
818        MatRef {
819            ptr: self.ptr,
820            repr: self.repr,
821            _lifetime: PhantomData,
822        }
823    }
824
825    /// Returns an iterator over immutable row references.
826    pub fn rows(&self) -> Rows<'_, T> {
827        Rows::new(self.reborrow())
828    }
829
830    /// Returns an iterator over mutable row references.
831    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
832        RowsMut::new(self.reborrow_mut())
833    }
834
835    /// Construct a new [`MatMut`] over the raw pointer and representation without performing
836    /// any validity checks.
837    ///
838    /// # Safety
839    ///
840    /// Argument `ptr` must point to memory compatible with [`Repr::layout`].
841    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
842        Self {
843            ptr,
844            repr,
845            _lifetime: PhantomData,
846        }
847    }
848}
849
850// Reborrow: MatMut -> MatRef
851impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
852    type Target = MatRef<'this, T>;
853
854    fn reborrow(&'this self) -> Self::Target {
855        self.as_view()
856    }
857}
858
859// ReborrowMut: MatMut -> MatMut (with shorter lifetime)
860impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
861    type Target = MatMut<'this, T>;
862
863    fn reborrow_mut(&'this mut self) -> Self::Target {
864        MatMut {
865            ptr: self.ptr,
866            repr: self.repr,
867            _lifetime: PhantomData,
868        }
869    }
870}
871
872impl<'a, T: Copy> MatMut<'a, Standard<T>> {
873    /// Returns the raw dimension (columns) of the vectors in the matrix.
874    #[inline]
875    pub fn vector_dim(&self) -> usize {
876        self.repr.ncols()
877    }
878}
879
880//////////
881// Rows //
882//////////
883
884/// Iterator over immutable row references of a matrix.
885///
886/// Created by [`Mat::rows`], [`MatRef::rows`], or [`MatMut::rows`].
887#[derive(Debug)]
888pub struct Rows<'a, T: Repr> {
889    matrix: MatRef<'a, T>,
890    current: usize,
891}
892
893impl<'a, T> Rows<'a, T>
894where
895    T: Repr,
896{
897    fn new(matrix: MatRef<'a, T>) -> Self {
898        Self { matrix, current: 0 }
899    }
900}
901
902impl<'a, T> Iterator for Rows<'a, T>
903where
904    T: Repr + 'a,
905{
906    type Item = T::Row<'a>;
907
908    fn next(&mut self) -> Option<Self::Item> {
909        let current = self.current;
910        if current >= self.matrix.num_vectors() {
911            None
912        } else {
913            self.current += 1;
914            // SAFETY: We make sure through the above check that
915            // the access is within bounds.
916            //
917            // Extending the lifetime to `'a` is safe because the underlying
918            // MatRef has lifetime `'a`.
919            Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
920        }
921    }
922
923    fn size_hint(&self) -> (usize, Option<usize>) {
924        let remaining = self.matrix.num_vectors() - self.current;
925        (remaining, Some(remaining))
926    }
927}
928
929impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
930impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
931
932/////////////
933// RowsMut //
934/////////////
935
936/// Iterator over mutable row references of a matrix.
937///
938/// Created by [`Mat::rows_mut`] or [`MatMut::rows_mut`].
939#[derive(Debug)]
940pub struct RowsMut<'a, T: ReprMut> {
941    matrix: MatMut<'a, T>,
942    current: usize,
943}
944
945impl<'a, T> RowsMut<'a, T>
946where
947    T: ReprMut,
948{
949    fn new(matrix: MatMut<'a, T>) -> Self {
950        Self { matrix, current: 0 }
951    }
952}
953
954impl<'a, T> Iterator for RowsMut<'a, T>
955where
956    T: ReprMut + 'a,
957{
958    type Item = T::RowMut<'a>;
959
960    fn next(&mut self) -> Option<Self::Item> {
961        let current = self.current;
962        if current >= self.matrix.num_vectors() {
963            None
964        } else {
965            self.current += 1;
966            // SAFETY: We make sure through the above check that
967            // the access is within bounds.
968            //
969            // Extending the lifetime to `'a` is safe because:
970            // 1. The underlying MatMut has lifetime `'a`.
971            // 2. The iterator ensures that the mutable row indices are disjoint, so
972            //    there is no aliasing as long as the implementation of `ReprMut` ensures
973            //    there is not mutable sharing of the `RowMut` types.
974            Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
975        }
976    }
977
978    fn size_hint(&self) -> (usize, Option<usize>) {
979        let remaining = self.matrix.num_vectors() - self.current;
980        (remaining, Some(remaining))
981    }
982}
983
984impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
985impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
986
987///////////
988// Tests //
989///////////
990
991#[cfg(test)]
992mod tests {
993    use super::*;
994
995    use std::fmt::Display;
996
997    use diskann_utils::lazy_format;
998
999    /// Helper to assert a type is Copy.
1000    fn assert_copy<T: Copy>(_: &T) {}
1001
1002    fn edge_cases(nrows: usize) -> Vec<usize> {
1003        let max = usize::MAX;
1004
1005        vec![
1006            nrows,
1007            nrows + 1,
1008            nrows + 11,
1009            nrows + 20,
1010            max / 2,
1011            max.div_ceil(2),
1012            max - 1,
1013            max,
1014        ]
1015    }
1016
1017    fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1018        assert_eq!(x.repr(), &repr);
1019        assert_eq!(x.num_vectors(), repr.nrows());
1020        assert_eq!(x.vector_dim(), repr.ncols());
1021
1022        for i in 0..x.num_vectors() {
1023            let row = x.get_row_mut(i).unwrap();
1024            assert_eq!(row.len(), repr.ncols());
1025            row.iter_mut()
1026                .enumerate()
1027                .for_each(|(j, r)| *r = 10 * i + j);
1028        }
1029
1030        for i in edge_cases(repr.nrows()).into_iter() {
1031            assert!(x.get_row_mut(i).is_none());
1032        }
1033    }
1034
1035    fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1036        assert_eq!(x.repr(), &repr);
1037        assert_eq!(x.num_vectors(), repr.nrows());
1038        assert_eq!(x.vector_dim(), repr.ncols());
1039
1040        for i in 0..x.num_vectors() {
1041            let row = x.get_row_mut(i).unwrap();
1042            assert_eq!(row.len(), repr.ncols());
1043
1044            row.iter_mut()
1045                .enumerate()
1046                .for_each(|(j, r)| *r = 10 * i + j);
1047        }
1048
1049        for i in edge_cases(repr.nrows()).into_iter() {
1050            assert!(x.get_row_mut(i).is_none());
1051        }
1052    }
1053
1054    fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1055        assert_eq!(x.len(), repr.nrows());
1056        // Materialize all rows at once.
1057        let mut all_rows: Vec<_> = x.collect();
1058        assert_eq!(all_rows.len(), repr.nrows());
1059        for (i, row) in all_rows.iter_mut().enumerate() {
1060            assert_eq!(row.len(), repr.ncols());
1061            row.iter_mut()
1062                .enumerate()
1063                .for_each(|(j, r)| *r = 10 * i + j);
1064        }
1065    }
1066
1067    fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1068        assert_eq!(x.repr(), &repr);
1069        assert_eq!(x.num_vectors(), repr.nrows());
1070        assert_eq!(x.vector_dim(), repr.ncols());
1071
1072        for i in 0..x.num_vectors() {
1073            let row = x.get_row(i).unwrap();
1074
1075            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1076            row.iter().enumerate().for_each(|(j, r)| {
1077                assert_eq!(
1078                    *r,
1079                    10 * i + j,
1080                    "mismatched entry at row {}, col {} -- ctx: {}",
1081                    i,
1082                    j,
1083                    ctx
1084                )
1085            });
1086        }
1087
1088        for i in edge_cases(repr.nrows()).into_iter() {
1089            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1090        }
1091    }
1092
1093    fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1094        assert_eq!(x.repr(), &repr);
1095        assert_eq!(x.num_vectors(), repr.nrows());
1096        assert_eq!(x.vector_dim(), repr.ncols());
1097
1098        assert_copy(&x);
1099        for i in 0..x.num_vectors() {
1100            let row = x.get_row(i).unwrap();
1101            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1102
1103            row.iter().enumerate().for_each(|(j, r)| {
1104                assert_eq!(
1105                    *r,
1106                    10 * i + j,
1107                    "mismatched entry at row {}, col {} -- ctx: {}",
1108                    i,
1109                    j,
1110                    ctx
1111                )
1112            });
1113        }
1114
1115        for i in edge_cases(repr.nrows()).into_iter() {
1116            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1117        }
1118    }
1119
1120    fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1121        assert_eq!(x.repr(), &repr);
1122        assert_eq!(x.num_vectors(), repr.nrows());
1123        assert_eq!(x.vector_dim(), repr.ncols());
1124
1125        for i in 0..x.num_vectors() {
1126            let row = x.get_row(i).unwrap();
1127            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1128
1129            row.iter().enumerate().for_each(|(j, r)| {
1130                assert_eq!(
1131                    *r,
1132                    10 * i + j,
1133                    "mismatched entry at row {}, col {} -- ctx: {}",
1134                    i,
1135                    j,
1136                    ctx
1137                )
1138            });
1139        }
1140
1141        for i in edge_cases(repr.nrows()).into_iter() {
1142            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1143        }
1144    }
1145
1146    fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1147        assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1148        let all_rows: Vec<_> = x.collect();
1149        assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1150        for (i, row) in all_rows.iter().enumerate() {
1151            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1152            row.iter().enumerate().for_each(|(j, r)| {
1153                assert_eq!(
1154                    *r,
1155                    10 * i + j,
1156                    "mismatched entry at row {}, col {} -- ctx: {}",
1157                    i,
1158                    j,
1159                    ctx
1160                )
1161            });
1162        }
1163    }
1164
1165    //////////////
1166    // Standard //
1167    //////////////
1168
1169    #[test]
1170    fn standard_representation() {
1171        let repr = Standard::<f32>::new(4, 3);
1172        assert_eq!(repr.nrows(), 4);
1173        assert_eq!(repr.ncols(), 3);
1174
1175        let layout = repr.layout().unwrap();
1176        assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1177        assert_eq!(layout.align(), std::mem::align_of::<f32>());
1178    }
1179
1180    #[test]
1181    fn standard_zero_dimensions() {
1182        for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1183            let repr = Standard::<u8>::new(nrows, ncols);
1184            assert_eq!(repr.nrows(), nrows);
1185            assert_eq!(repr.ncols(), ncols);
1186            let layout = repr.layout().unwrap();
1187            assert_eq!(layout.size(), 0);
1188        }
1189    }
1190
1191    #[test]
1192    fn standard_check_slice() {
1193        let repr = Standard::<u32>::new(3, 4);
1194
1195        // Correct length succeeds
1196        let data = vec![0u32; 12];
1197        assert!(repr.check_slice(&data).is_ok());
1198
1199        // Too short fails
1200        let short = vec![0u32; 11];
1201        assert!(matches!(
1202            repr.check_slice(&short),
1203            Err(SliceError::LengthMismatch {
1204                expected: 12,
1205                found: 11
1206            })
1207        ));
1208
1209        // Too long fails
1210        let long = vec![0u32; 13];
1211        assert!(matches!(
1212            repr.check_slice(&long),
1213            Err(SliceError::LengthMismatch {
1214                expected: 12,
1215                found: 13
1216            })
1217        ));
1218
1219        // Overflow case
1220        let overflow_repr = Standard::<u8>::new(usize::MAX, 2);
1221        assert!(matches!(
1222            overflow_repr.check_slice(&[]),
1223            Err(SliceError::Overflow)
1224        ));
1225    }
1226
1227    #[test]
1228    fn standard_layout_errors() {
1229        // Error path 1: num_elements() overflows (nrows * ncols > usize::MAX)
1230        let overflow_repr = Standard::<u8>::new(usize::MAX, 2);
1231        assert!(overflow_repr.layout().is_err());
1232
1233        // Error path 2: Layout::array fails (total byte size overflows)
1234        // For a u64, we need elements * 8 > isize::MAX to trigger Layout::array error
1235        // Using isize::MAX / 4 elements of u64 (8 bytes each) will overflow
1236        let large_repr = Standard::<u64>::new(isize::MAX as usize / 4, 2);
1237        assert!(large_repr.layout().is_err());
1238    }
1239
1240    /////////
1241    // Mat //
1242    /////////
1243
1244    #[test]
1245    fn mat_new_and_basic_accessors() {
1246        let mat = Mat::new(Standard::<usize>::new(3, 4), 42usize).unwrap();
1247        let base: *const u8 = mat.as_ptr().as_ptr();
1248
1249        assert_eq!(mat.num_vectors(), 3);
1250        assert_eq!(mat.vector_dim(), 4);
1251
1252        let repr = mat.repr();
1253        assert_eq!(repr.nrows(), 3);
1254        assert_eq!(repr.ncols(), 4);
1255
1256        for (i, r) in mat.rows().enumerate() {
1257            assert_eq!(r, &[42, 42, 42, 42]);
1258            let ptr = r.as_ptr().cast::<u8>();
1259            assert_eq!(
1260                ptr,
1261                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1262            );
1263        }
1264    }
1265
1266    #[test]
1267    fn mat_new_with_default() {
1268        let mat = Mat::new(Standard::<usize>::new(2, 3), Defaulted).unwrap();
1269        let base: *const u8 = mat.as_ptr().as_ptr();
1270
1271        assert_eq!(mat.num_vectors(), 2);
1272        for (i, row) in mat.rows().enumerate() {
1273            assert!(row.iter().all(|&v| v == 0));
1274
1275            let ptr = row.as_ptr().cast::<u8>();
1276            assert_eq!(
1277                ptr,
1278                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1279            );
1280        }
1281    }
1282
1283    const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1284    const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1285
1286    #[test]
1287    fn test_mat() {
1288        for nrows in ROWS {
1289            for ncols in COLS {
1290                let repr = Standard::<usize>::new(*nrows, *ncols);
1291                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1292
1293                // Populate the matrix using `&mut Mat`
1294                {
1295                    let ctx = &lazy_format!("{ctx} - direct");
1296                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1297
1298                    assert_eq!(mat.num_vectors(), *nrows);
1299                    assert_eq!(mat.vector_dim(), *ncols);
1300
1301                    fill_mat(&mut mat, repr);
1302
1303                    check_mat(&mat, repr, ctx);
1304                    check_mat_ref(mat.reborrow(), repr, ctx);
1305                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1306                    check_rows(mat.rows(), repr, ctx);
1307                }
1308
1309                // Populate the matrix using `MatMut`
1310                {
1311                    let ctx = &lazy_format!("{ctx} - matmut");
1312                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1313                    let matmut = mat.reborrow_mut();
1314
1315                    assert_eq!(matmut.num_vectors(), *nrows);
1316                    assert_eq!(matmut.vector_dim(), *ncols);
1317
1318                    fill_mat_mut(matmut, repr);
1319
1320                    check_mat(&mat, repr, ctx);
1321                    check_mat_ref(mat.reborrow(), repr, ctx);
1322                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1323                    check_rows(mat.rows(), repr, ctx);
1324                }
1325
1326                // Populate the matrix using `RowsMut`
1327                {
1328                    let ctx = &lazy_format!("{ctx} - rows_mut");
1329                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1330                    fill_rows_mut(mat.rows_mut(), repr);
1331
1332                    check_mat(&mat, repr, ctx);
1333                    check_mat_ref(mat.reborrow(), repr, ctx);
1334                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1335                    check_rows(mat.rows(), repr, ctx);
1336                }
1337            }
1338        }
1339    }
1340
1341    #[test]
1342    fn test_mat_refmut() {
1343        for nrows in ROWS {
1344            for ncols in COLS {
1345                let repr = Standard::<usize>::new(*nrows, *ncols);
1346                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1347
1348                // Populate the matrix using `&mut Mat`
1349                {
1350                    let ctx = &lazy_format!("{ctx} - by matmut");
1351                    let mut b: Box<[_]> =
1352                        (0..repr.num_elements().unwrap()).map(|_| 0usize).collect();
1353                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1354
1355                    fill_mat_mut(matmut.reborrow_mut(), repr);
1356
1357                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1358                    check_mat_ref(matmut.reborrow(), repr, ctx);
1359                    check_rows(matmut.rows(), repr, ctx);
1360                    check_rows(matmut.reborrow().rows(), repr, ctx);
1361
1362                    let matref = MatRef::new(repr, &b).unwrap();
1363                    check_mat_ref(matref, repr, ctx);
1364                    check_mat_ref(matref.reborrow(), repr, ctx);
1365                    check_rows(matref.rows(), repr, ctx);
1366                }
1367
1368                // Populate the matrix using `RowsMut`
1369                {
1370                    let ctx = &lazy_format!("{ctx} - by rows");
1371                    let mut b: Box<[_]> =
1372                        (0..repr.num_elements().unwrap()).map(|_| 0usize).collect();
1373                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1374
1375                    fill_rows_mut(matmut.rows_mut(), repr);
1376
1377                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1378                    check_mat_ref(matmut.reborrow(), repr, ctx);
1379                    check_rows(matmut.rows(), repr, ctx);
1380                    check_rows(matmut.reborrow().rows(), repr, ctx);
1381
1382                    let matref = MatRef::new(repr, &b).unwrap();
1383                    check_mat_ref(matref, repr, ctx);
1384                    check_mat_ref(matref.reborrow(), repr, ctx);
1385                    check_rows(matref.rows(), repr, ctx);
1386                }
1387            }
1388        }
1389    }
1390
1391    //////////////////
1392    // Constructors //
1393    //////////////////
1394
1395    #[test]
1396    fn test_standard_new_owned() {
1397        let rows = [0, 1, 2, 3, 5, 10];
1398        let cols = [0, 1, 2, 3, 5, 10];
1399
1400        for nrows in rows {
1401            for ncols in cols {
1402                let m = Mat::new(Standard::new(nrows, ncols), 1usize).unwrap();
1403                let rows_iter = m.rows();
1404                let len = <_ as ExactSizeIterator>::len(&rows_iter);
1405                assert_eq!(len, nrows);
1406                for r in rows_iter {
1407                    assert_eq!(r.len(), ncols);
1408                    assert!(r.iter().all(|i| *i == 1usize));
1409                }
1410            }
1411        }
1412    }
1413
1414    #[test]
1415    fn matref_new_slice_length_error() {
1416        let repr = Standard::<u32>::new(3, 4);
1417
1418        // Correct length succeeds
1419        let data = vec![0u32; 12];
1420        assert!(MatRef::new(repr, &data).is_ok());
1421
1422        // Too short fails
1423        let short = vec![0u32; 11];
1424        assert!(matches!(
1425            MatRef::new(repr, &short),
1426            Err(SliceError::LengthMismatch {
1427                expected: 12,
1428                found: 11
1429            })
1430        ));
1431
1432        // Too long fails
1433        let long = vec![0u32; 13];
1434        assert!(matches!(
1435            MatRef::new(repr, &long),
1436            Err(SliceError::LengthMismatch {
1437                expected: 12,
1438                found: 13
1439            })
1440        ));
1441    }
1442
1443    #[test]
1444    fn matmut_new_slice_length_error() {
1445        let repr = Standard::<u32>::new(3, 4);
1446
1447        // Correct length succeeds
1448        let mut data = vec![0u32; 12];
1449        assert!(MatMut::new(repr, &mut data).is_ok());
1450
1451        // Too short fails
1452        let mut short = vec![0u32; 11];
1453        assert!(matches!(
1454            MatMut::new(repr, &mut short),
1455            Err(SliceError::LengthMismatch {
1456                expected: 12,
1457                found: 11
1458            })
1459        ));
1460
1461        // Too long fails
1462        let mut long = vec![0u32; 13];
1463        assert!(matches!(
1464            MatMut::new(repr, &mut long),
1465            Err(SliceError::LengthMismatch {
1466                expected: 12,
1467                found: 13
1468            })
1469        ));
1470    }
1471}