Skip to main content

diskann_quantization/multi_vector/
matrix.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Row-major matrix types for multi-vector representations.
5//!
6//! This module provides flexible matrix abstractions that support different underlying
7//! storage formats through the [`Repr`] trait. The primary types are:
8//!
9//! - [`Mat`]: An owning matrix that manages its own memory.
10//! - [`MatRef`]: An immutable borrowed view of matrix data.
11//! - [`MatMut`]: A mutable borrowed view of matrix data.
12//!
13//! # Representations
14//!
15//! Representation types interact with the [`Mat`] family of types using the following traits:
16//!
17//! - [`Repr`]: Read-only matrix representation.
18//! - [`ReprMut`]: Mutable matrix representation.
19//! - [`ReprOwned`]: Owning matrix representation.
20//!
21//! Each trait refinement has a corresponding constructor:
22//!
23//! - [`NewRef`]: Construct a read-only [`MatRef`] view over a slice.
24//! - [`NewMut`]: Construct a mutable [`MatMut`] matrix view over a slice.
25//! - [`NewOwned`]: Construct a new owning [`Mat`].
26//!
27
28use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
29
30use diskann_utils::{Reborrow, ReborrowMut};
31use thiserror::Error;
32
33use crate::utils;
34
35/// Representation trait describing the layout and access patterns for a matrix.
36///
37/// Implementations define how raw bytes are interpreted as typed rows. This enables
38/// matrices over different storage formats (dense, quantized, etc.) using a single
39/// generic [`Mat`] type.
40///
41/// # Associated Types
42///
43/// - `Row<'a>`: The immutable row type (e.g., `&[f32]`, `&[f16]`).
44///
45/// # Safety
46///
47/// Implementations must ensure:
48///
49/// - [`get_row`](Self::get_row) returns valid references for the given row index.
50///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
51///   the contract for the raw pointer.
52///
53/// - The objects implicitly managed by this representation inherit the `Send` and `Sync`
54///   attributes of `Repr`. That is, `Repr: Send` implies that the objects in backing memory
55///   are [`Send`], and likewise with `Sync`. This is necessary to apply [`Send`] and [`Sync`]
56///   bounds to [`Mat`], [`MatRef`], and [`MatMut`].
57pub unsafe trait Repr: Copy {
58    /// Immutable row reference type.
59    type Row<'a>
60    where
61        Self: 'a;
62
63    /// Returns the number of rows in the matrix.
64    ///
65    /// # Safety Contract
66    ///
67    /// This function must be loosely pure in the sense that for any given instance of
68    /// `self`, `self.nrows()` must return the same value.
69    fn nrows(&self) -> usize;
70
71    /// Returns the memory layout for a memory allocation containing [`Repr::nrows`] vectors
72    /// each with vector dimension [`Repr::ncols`].
73    ///
74    /// # Safety Contract
75    ///
76    /// The [`Layout`] returned from this method must be consistent with the contract of
77    /// [`Repr::get_row`].
78    fn layout(&self) -> Result<Layout, LayoutError>;
79
80    /// Returns an immutable reference to the `i`-th row.
81    ///
82    /// # Safety
83    ///
84    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
85    /// - The entire range for this slice must be within a single allocation.
86    /// - `i` must be less than [`Repr::nrows`].
87    /// - The memory referenced by the returned [`Repr::Row`] must not be mutated for the
88    ///   duration of lifetime `'a`.
89    /// - The lifetime for the returned [`Repr::Row`] is inferred from its usage. Correct
90    ///   usage must properly tie the lifetime to a source.
91    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
92}
93
94/// Extension of [`Repr`] that supports mutable row access.
95///
96/// # Associated Types
97///
98/// - `RowMut<'a>`: The mutable row type (e.g., `&mut [f32]`).
99///
100/// # Safety
101///
102/// Implementors must ensure:
103///
104/// - [`get_row_mut`](Self::get_row_mut) returns valid references for the given row index.
105///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
106///   the contract for the raw pointer.
107///
108///   Additionally, since the implementation of the [`RowsMut`] iterator can give out rows
109///   for all `i` in `0..self.nrows()`, the implementation of [`Self::get_row_mut`] must be
110///   such that the result for disjoint `i` must not interfere with one another.
111pub unsafe trait ReprMut: Repr {
112    /// Mutable row reference type.
113    type RowMut<'a>
114    where
115        Self: 'a;
116
117    /// Returns a mutable reference to the i-th row.
118    ///
119    /// # Safety
120    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
121    /// - The entire range for this slice must be within a single allocation.
122    /// - `i` must be less than `self.nrows()`.
123    /// - The memory referenced by the returned [`ReprMut::RowMut`] must not be accessed
124    ///   through any other reference for the duration of lifetime `'a`.
125    /// - The lifetime for the returned [`ReprMut::RowMut`] is inferred from its usage.
126    ///   Correct usage must properly tie the lifetime to a source.
127    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
128}
129
130/// Extension trait for [`Repr`] that supports deallocation of owned matrices. This is used
131/// in conjunction with [`NewOwned`] to create matrices.
132///
133/// Requires [`ReprMut`] since owned matrices should support mutation.
134///
135/// # Safety
136///
137/// Implementors must ensure that `drop` properly deallocates the memory in a way compatible
138/// with all [`NewOwned`] implementations.
139pub unsafe trait ReprOwned: ReprMut {
140    /// Deallocates memory at `ptr` and drops `self`.
141    ///
142    /// # Safety
143    ///
144    /// - `ptr` must have been obtained via [`NewOwned`] with the same value of `self`.
145    /// - This method may only be called once for such a pointer.
146    /// - After calling this method, the memory behind `ptr` may not be dereferenced at all.
147    unsafe fn drop(self, ptr: NonNull<u8>);
148}
149
150/// A new-type version of `std::alloc::LayoutError` for cleaner error handling.
151///
152/// This is basically the same as [`std::alloc::LayoutError`], but constructible in
153/// use code to allow implementors of [`Repr::layout`] to return it for reasons other than
154/// those derived from `std::alloc::Layout`'s methods.
155#[derive(Debug, Clone, Copy)]
156#[non_exhaustive]
157pub struct LayoutError;
158
159impl LayoutError {
160    /// Construct a new opaque [`LayoutError`].
161    pub fn new() -> Self {
162        Self
163    }
164}
165
166impl Default for LayoutError {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl std::fmt::Display for LayoutError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        write!(f, "LayoutError")
175    }
176}
177
178impl std::error::Error for LayoutError {}
179
180impl From<std::alloc::LayoutError> for LayoutError {
181    fn from(_: std::alloc::LayoutError) -> Self {
182        LayoutError
183    }
184}
185
186//////////////////
187// Constructors //
188//////////////////
189
190/// Create a new [`MatRef`] over a slice.
191///
192/// # Safety
193///
194/// Implementations must validate the length (and any other requirements) of the provided
195/// slice to ensure it is compatible with the implementation of [`Repr`].
196pub unsafe trait NewRef<T>: Repr {
197    /// Errors that can occur when initializing.
198    type Error;
199
200    /// Create a new [`MatRef`] over `slice`.
201    fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
202}
203
204/// Create a new [`MatMut`] over a slice.
205///
206/// # Safety
207///
208/// Implementations must validate the length (and any other requirements) of the provided
209/// slice to ensure it is compatible with the implementation of [`ReprMut`].
210pub unsafe trait NewMut<T>: ReprMut {
211    /// Errors that can occur when initializing.
212    type Error;
213
214    /// Create a new [`MatMut`] over `slice`.
215    fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
216}
217
218/// Create a new [`Mat`] from an initializer.
219///
220/// # Safety
221///
222/// Implementations must ensure that the returned [`Mat`] is compatible with
223/// `Self`'s implementation of [`ReprOwned`].
224pub unsafe trait NewOwned<T>: ReprOwned {
225    /// Errors that can occur when initializing.
226    type Error;
227
228    /// Create a new [`Mat`] initialized with `init`.
229    fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
230}
231
232/// An initializer argument to [`NewOwned`] that uses a type's [`Default`] implementation
233/// to initialize a matrix.
234///
235/// ```rust
236/// use diskann_quantization::multi_vector::{Mat, Standard, Defaulted};
237/// let mat = Mat::new(Standard::<f32>::new(4, 3).unwrap(), Defaulted).unwrap();
238/// for i in 0..4 {
239///     assert!(mat.get_row(i).unwrap().iter().all(|&x| x == 0.0f32));
240/// }
241/// ```
242#[derive(Debug, Clone, Copy)]
243pub struct Defaulted;
244
245/// Create a new [`Mat`] cloned from a view.
246pub trait NewCloned: ReprOwned {
247    /// Clone the contents behind `v`, returning a new owning [`Mat`].
248    ///
249    /// Implementations should ensure the returned [`Mat`] is "semantically the same" as `v`.
250    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
251}
252
253//////////////
254// Standard //
255//////////////
256
257/// Metadata for dense row-major matrices of `Copy` types.
258///
259/// Rows are stored contiguously as `&[T]` slices. This is the default representation
260/// type for standard floating-point multi-vectors.
261///
262/// # Row Types
263///
264/// - `Row<'a>`: `&'a [T]`
265/// - `RowMut<'a>`: `&'a mut [T]`
266#[derive(Debug, Clone, Copy, PartialEq, Eq)]
267pub struct Standard<T> {
268    nrows: usize,
269    ncols: usize,
270    _elem: PhantomData<T>,
271}
272
273impl<T: Copy> Standard<T> {
274    /// Create a new `Standard` for data of type `T`.
275    ///
276    /// Successful construction requires:
277    ///
278    /// * The total number of elements determined by `nrows * ncols` does not exceed
279    ///   `usize::MAX`.
280    /// * The total memory footprint defined by `ncols * nrows * size_of::<T>()` does not
281    ///   exceed `isize::MAX`.
282    pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
283        Overflow::check::<T>(nrows, ncols)?;
284        Ok(Self {
285            nrows,
286            ncols,
287            _elem: PhantomData,
288        })
289    }
290
291    /// Returns the number of total elements (`rows x cols`) in this matrix.
292    pub fn num_elements(&self) -> usize {
293        // Since we've constructed `self` - we know we cannot overflow.
294        self.nrows() * self.ncols()
295    }
296
297    /// Returns `rows`, the number of rows in this matrix.
298    fn nrows(&self) -> usize {
299        self.nrows
300    }
301
302    /// Returns `ncols`, the number of elements in a row of this matrix.
303    fn ncols(&self) -> usize {
304        self.ncols
305    }
306
307    /// Checks the following:
308    ///
309    /// 1. Computation of the number of elements in `self` does not overflow.
310    /// 2. Argument `slice` has the expected number of elements.
311    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
312        let len = self.num_elements();
313
314        if slice.len() != len {
315            Err(SliceError::LengthMismatch {
316                expected: len,
317                found: slice.len(),
318            })
319        } else {
320            Ok(())
321        }
322    }
323
324    /// Create a new [`Mat`] around the contents of `b` **without** any checks.
325    ///
326    /// # Safety
327    ///
328    /// The length of `b` must be exactly [`Standard::num_elements`].
329    unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
330        debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
331
332        // SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw)
333        // the returned pointer is non-null.
334        let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
335
336        // SAFETY: `ptr` is properly aligned and points to a slice of the required length.
337        // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
338        // it from `Box::into_raw`.
339        unsafe { Mat::from_raw_parts(self, ptr) }
340    }
341}
342
343/// Error for [`Standard::new`].
344#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346    nrows: usize,
347    ncols: usize,
348    elsize: usize,
349}
350
351impl Overflow {
352    fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
353        let elsize = std::mem::size_of::<T>();
354        // Guard the element count itself so that `num_elements()` can never overflow.
355        let elements = nrows.checked_mul(ncols).ok_or(Self {
356            nrows,
357            ncols,
358            elsize,
359        })?;
360
361        let bytes = elsize.saturating_mul(elements);
362        if bytes <= isize::MAX as usize {
363            Ok(())
364        } else {
365            Err(Self {
366                nrows,
367                ncols,
368                elsize,
369            })
370        }
371    }
372}
373
374impl std::fmt::Display for Overflow {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        if self.elsize == 0 {
377            write!(
378                f,
379                "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
380                self.nrows, self.ncols,
381            )
382        } else {
383            write!(
384                f,
385                "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
386                self.nrows, self.ncols, self.elsize,
387            )
388        }
389    }
390}
391
392impl std::error::Error for Overflow {}
393
394/// Error types for [`Standard`].
395#[derive(Debug, Clone, Copy, Error)]
396#[non_exhaustive]
397pub enum SliceError {
398    #[error("Length mismatch: expected {expected}, found {found}")]
399    LengthMismatch { expected: usize, found: usize },
400}
401
402// SAFETY: The implementation correctly computes row offsets as `i * ncols` and
403// constructs valid slices of the appropriate length. The `layout` method correctly
404// reports the memory layout requirements.
405unsafe impl<T: Copy> Repr for Standard<T> {
406    type Row<'a>
407        = &'a [T]
408    where
409        T: 'a;
410
411    fn nrows(&self) -> usize {
412        self.nrows
413    }
414
415    fn layout(&self) -> Result<Layout, LayoutError> {
416        Ok(Layout::array::<T>(self.num_elements())?)
417    }
418
419    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
420        debug_assert!(ptr.cast::<T>().is_aligned());
421        debug_assert!(i < self.nrows);
422
423        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
424        // audits the constructors for `Mat` and friends, we know that there is room for at
425        // least `self.num_elements()` elements from the base pointer, so this access is safe.
426        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
427
428        // SAFETY: The logic is the same as the previous `unsafe` block.
429        unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
430    }
431}
432
433// SAFETY: The implementation correctly computes row offsets and constructs valid mutable
434// slices.
435unsafe impl<T: Copy> ReprMut for Standard<T> {
436    type RowMut<'a>
437        = &'a mut [T]
438    where
439        T: 'a;
440
441    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
442        debug_assert!(ptr.cast::<T>().is_aligned());
443        debug_assert!(i < self.nrows);
444
445        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
446        // audits the constructors for `Mat` and friends, we know that there is room for at
447        // least `self.num_elements()` elements from the base pointer, so this access is safe.
448        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
449
450        // SAFETY: The logic is the same as the previous `unsafe` block. Further, the caller
451        // attests that creating a mutable reference is safe.
452        unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
453    }
454}
455
456// SAFETY: The drop implementation correctly reconstructs a Box from the raw pointer
457// using the same length (nrows * ncols) that was used for allocation, allowing Box
458// to properly deallocate the memory.
459unsafe impl<T: Copy> ReprOwned for Standard<T> {
460    unsafe fn drop(self, ptr: NonNull<u8>) {
461        // SAFETY: The caller guarantees that `ptr` was obtained from an implementation of
462        // `NewOwned` for an equivalent instance of `self`.
463        //
464        // We ensure that `NewOwned` goes through boxes, so here we reconstruct a Box to
465        // let it handle deallocation.
466        unsafe {
467            let slice_ptr = std::ptr::slice_from_raw_parts_mut(
468                ptr.cast::<T>().as_ptr(),
469                self.nrows * self.ncols,
470            );
471            let _ = Box::from_raw(slice_ptr);
472        }
473    }
474}
475
476// SAFETY: The implementation uses guarantees from `Box` to ensure that the pointer
477// initialized by it is non-null and properly aligned to the underlying type.
478unsafe impl<T> NewOwned<T> for Standard<T>
479where
480    T: Copy,
481{
482    type Error = crate::error::Infallible;
483    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
484        let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
485
486        // SAFETY: By construction, `b` has length `self.num_elements()`.
487        Ok(unsafe { self.box_to_mat(b) })
488    }
489}
490
491// SAFETY: This safely reuses `<Self as NewOwned<T>>`.
492unsafe impl<T> NewOwned<Defaulted> for Standard<T>
493where
494    T: Copy + Default,
495{
496    type Error = crate::error::Infallible;
497    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
498        self.new_owned(T::default())
499    }
500}
501
502// SAFETY: This checks that the slice has the correct length, which is all that is
503// required for [`Repr`].
504unsafe impl<T> NewRef<T> for Standard<T>
505where
506    T: Copy,
507{
508    type Error = SliceError;
509    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
510        self.check_slice(data)?;
511
512        // SAFETY: The function `check_slice` verifies that `data` is compatible with
513        // the layout requirement of `Standard`.
514        //
515        // We've properly checked that the underlying pointer is okay.
516        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
517    }
518}
519
520// SAFETY: This checks that the slice has the correct length, which is all that is
521// required for [`ReprMut`].
522unsafe impl<T> NewMut<T> for Standard<T>
523where
524    T: Copy,
525{
526    type Error = SliceError;
527    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
528        self.check_slice(data)?;
529
530        // SAFETY: The function `check_slice` verifies that `data` is compatible with
531        // the layout requirement of `Standard`.
532        //
533        // We've properly checked that the underlying pointer is okay.
534        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
535    }
536}
537
538impl<T> NewCloned for Standard<T>
539where
540    T: Copy,
541{
542    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
543        let b: Box<[T]> = v.rows().flatten().copied().collect();
544
545        // SAFETY: By construction, `b` has length `v.repr().num_elements()`.
546        unsafe { v.repr().box_to_mat(b) }
547    }
548}
549
550/////////
551// Mat //
552/////////
553
554/// An owning matrix that manages its own memory.
555///
556/// The matrix stores raw bytes interpreted according to representation type `T`.
557/// Memory is automatically deallocated when the matrix is dropped.
558#[derive(Debug)]
559pub struct Mat<T: ReprOwned> {
560    ptr: NonNull<u8>,
561    repr: T,
562    _invariant: PhantomData<fn(T) -> T>,
563}
564
565// SAFETY: [`Repr`] is required to propagate its `Send` bound.
566unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
567
568// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
569unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
570
571impl<T: ReprOwned> Mat<T> {
572    /// Create a new matrix using `init` as the initializer.
573    pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
574    where
575        T: NewOwned<U>,
576    {
577        repr.new_owned(init)
578    }
579
580    /// Returns the number of rows (vectors) in the matrix.
581    #[inline]
582    pub fn num_vectors(&self) -> usize {
583        self.repr.nrows()
584    }
585
586    /// Returns a reference to the underlying representation.
587    pub fn repr(&self) -> &T {
588        &self.repr
589    }
590
591    /// Returns the `i`th row if `i < self.num_vectors()`.
592    #[must_use]
593    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
594        if i < self.num_vectors() {
595            // SAFETY: Bounds check passed, and the Mat was constructed
596            // with valid representation and pointer.
597            let row = unsafe { self.get_row_unchecked(i) };
598            Some(row)
599        } else {
600            None
601        }
602    }
603
604    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
605        // SAFETY: Caller must ensure i < self.num_vectors(). The constructors for this type
606        // ensure that `ptr` is compatible with `T`.
607        unsafe { self.repr.get_row(self.ptr, i) }
608    }
609
610    /// Returns the `i`th mutable row if `i < self.num_vectors()`.
611    #[must_use]
612    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
613        if i < self.num_vectors() {
614            // SAFETY: Bounds check passed, and we have exclusive access via &mut self.
615            Some(unsafe { self.get_row_mut_unchecked(i) })
616        } else {
617            None
618        }
619    }
620
621    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
622        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
623        // type ensure that `ptr` is compatible with `T`.
624        unsafe { self.repr.get_row_mut(self.ptr, i) }
625    }
626
627    /// Returns an immutable view of the matrix.
628    #[inline]
629    pub fn as_view(&self) -> MatRef<'_, T> {
630        MatRef {
631            ptr: self.ptr,
632            repr: self.repr,
633            _lifetime: PhantomData,
634        }
635    }
636
637    /// Returns a mutable view of the matrix.
638    #[inline]
639    pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
640        MatMut {
641            ptr: self.ptr,
642            repr: self.repr,
643            _lifetime: PhantomData,
644        }
645    }
646
647    /// Returns an iterator over immutable row references.
648    pub fn rows(&self) -> Rows<'_, T> {
649        Rows::new(self.reborrow())
650    }
651
652    /// Returns an iterator over mutable row references.
653    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
654        RowsMut::new(self.reborrow_mut())
655    }
656
657    /// Construct a new [`Mat`] over the raw pointer and representation without performing
658    /// any validity checks.
659    ///
660    /// # Safety
661    ///
662    /// Argument `ptr` must be:
663    ///
664    /// 1. Point to memory compatible with [`Repr::layout`].
665    /// 2. Be compatible with the drop logic in [`ReprOwned`].
666    pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
667        Self {
668            ptr,
669            repr,
670            _invariant: PhantomData,
671        }
672    }
673
674    /// Return the base pointer for the [`Mat`].
675    pub fn as_raw_ptr(&self) -> *const u8 {
676        self.ptr.as_ptr()
677    }
678}
679
680impl<T: ReprOwned> Drop for Mat<T> {
681    fn drop(&mut self) {
682        // SAFETY: `ptr` was correctly initialized according to `layout`
683        // and we are guaranteed exclusive access to the data due to Rust borrow rules.
684        unsafe { self.repr.drop(self.ptr) };
685    }
686}
687
688impl<T: NewCloned> Clone for Mat<T> {
689    fn clone(&self) -> Self {
690        T::new_cloned(self.as_view())
691    }
692}
693
694impl<T: Copy> Mat<Standard<T>> {
695    /// Returns the raw dimension (columns) of the vectors in the matrix.
696    #[inline]
697    pub fn vector_dim(&self) -> usize {
698        self.repr.ncols()
699    }
700}
701
702////////////
703// MatRef //
704////////////
705
706/// An immutable borrowed view of a matrix.
707///
708/// Provides read-only access to matrix data without ownership. Implements [`Copy`]
709/// and can be freely cloned.
710///
711/// # Type Parameter
712/// - `T`: A [`Repr`] implementation defining the row layout.
713///
714/// # Access
715/// - [`get_row`](Self::get_row): Get an immutable row by index.
716/// - [`rows`](Self::rows): Iterate over all rows.
717#[derive(Debug, Clone, Copy)]
718pub struct MatRef<'a, T: Repr> {
719    pub(crate) ptr: NonNull<u8>,
720    pub(crate) repr: T,
721    /// Marker to tie the lifetime to the borrowed data.
722    pub(crate) _lifetime: PhantomData<&'a T>,
723}
724
725// SAFETY: [`Repr`] is required to propagate its `Send` bound.
726unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
727
728// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
729unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
730
731impl<'a, T: Repr> MatRef<'a, T> {
732    /// Construct a new [`MatRef`] over `data`.
733    pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
734    where
735        T: NewRef<U>,
736    {
737        repr.new_ref(data)
738    }
739
740    /// Returns the number of rows (vectors) in the matrix.
741    #[inline]
742    pub fn num_vectors(&self) -> usize {
743        self.repr.nrows()
744    }
745
746    /// Returns a reference to the underlying representation.
747    pub fn repr(&self) -> &T {
748        &self.repr
749    }
750
751    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
752    #[must_use]
753    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
754        if i < self.num_vectors() {
755            // SAFETY: Bounds check passed, and the MatRef was constructed
756            // with valid representation and pointer.
757            let row = unsafe { self.get_row_unchecked(i) };
758            Some(row)
759        } else {
760            None
761        }
762    }
763
764    /// Returns the i-th row without bounds checking.
765    ///
766    /// # Safety
767    ///
768    /// `i` must be less than `self.num_vectors()`.
769    #[inline]
770    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
771        // SAFETY: Caller must ensure i < self.num_vectors().
772        unsafe { self.repr.get_row(self.ptr, i) }
773    }
774
775    /// Returns an iterator over immutable row references.
776    pub fn rows(&self) -> Rows<'_, T> {
777        Rows::new(*self)
778    }
779
780    /// Return a [`Mat`] with the same contents as `self`.
781    pub fn to_owned(&self) -> Mat<T>
782    where
783        T: NewCloned,
784    {
785        T::new_cloned(*self)
786    }
787
788    /// Construct a new [`MatRef`] over the raw pointer and representation without performing
789    /// any validity checks.
790    ///
791    /// # Safety
792    ///
793    /// Argument `ptr` must point to memory compatible with [`Repr::layout`] and pass any
794    /// validity checks required by `T`.
795    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
796        Self {
797            ptr,
798            repr,
799            _lifetime: PhantomData,
800        }
801    }
802
803    /// Return the base pointer for the [`MatRef`].
804    pub fn as_raw_ptr(&self) -> *const u8 {
805        self.ptr.as_ptr()
806    }
807}
808
809impl<'a, T: Copy> MatRef<'a, Standard<T>> {
810    /// Returns the raw dimension (columns) of the vectors in the matrix.
811    #[inline]
812    pub fn vector_dim(&self) -> usize {
813        self.repr.ncols()
814    }
815}
816
817// Reborrow: Mat -> MatRef
818impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
819    type Target = MatRef<'this, T>;
820
821    fn reborrow(&'this self) -> Self::Target {
822        self.as_view()
823    }
824}
825
826// ReborrowMut: Mat -> MatMut
827impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
828    type Target = MatMut<'this, T>;
829
830    fn reborrow_mut(&'this mut self) -> Self::Target {
831        self.as_view_mut()
832    }
833}
834
835// Reborrow: MatRef -> MatRef (with shorter lifetime)
836impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
837    type Target = MatRef<'this, T>;
838
839    fn reborrow(&'this self) -> Self::Target {
840        MatRef {
841            ptr: self.ptr,
842            repr: self.repr,
843            _lifetime: PhantomData,
844        }
845    }
846}
847
848////////////
849// MatMut //
850////////////
851
852/// A mutable borrowed view of a matrix.
853///
854/// Provides read-write access to matrix data without ownership.
855///
856/// # Type Parameter
857/// - `T`: A [`ReprMut`] implementation defining the row layout.
858///
859/// # Access
860/// - [`get_row`](Self::get_row): Get an immutable row by index.
861/// - [`get_row_mut`](Self::get_row_mut): Get a mutable row by index.
862/// - [`as_view`](Self::as_view): Reborrow as immutable [`MatRef`].
863/// - [`rows`](Self::rows), [`rows_mut`](Self::rows_mut): Iterate over rows.
864#[derive(Debug)]
865pub struct MatMut<'a, T: ReprMut> {
866    pub(crate) ptr: NonNull<u8>,
867    pub(crate) repr: T,
868    /// Marker to tie the lifetime to the mutably borrowed data.
869    pub(crate) _lifetime: PhantomData<&'a mut T>,
870}
871
872// SAFETY: [`ReprMut`] is required to propagate its `Send` bound.
873unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
874
875// SAFETY: [`ReprMut`] is required to propagate its `Sync` bound.
876unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
877
878impl<'a, T: ReprMut> MatMut<'a, T> {
879    /// Construct a new [`MatMut`] over `data`.
880    pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
881    where
882        T: NewMut<U>,
883    {
884        repr.new_mut(data)
885    }
886
887    /// Returns the number of rows (vectors) in the matrix.
888    #[inline]
889    pub fn num_vectors(&self) -> usize {
890        self.repr.nrows()
891    }
892
893    /// Returns a reference to the underlying representation.
894    pub fn repr(&self) -> &T {
895        &self.repr
896    }
897
898    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
899    #[inline]
900    #[must_use]
901    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
902        if i < self.num_vectors() {
903            // SAFETY: Bounds check passed.
904            Some(unsafe { self.get_row_unchecked(i) })
905        } else {
906            None
907        }
908    }
909
910    /// Returns the i-th row without bounds checking.
911    ///
912    /// # Safety
913    ///
914    /// `i` must be less than `self.num_vectors()`.
915    #[inline]
916    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
917        // SAFETY: Caller must ensure i < self.num_vectors().
918        unsafe { self.repr.get_row(self.ptr, i) }
919    }
920
921    /// Returns a mutable reference to the `i`-th row, or `None` if out of bounds.
922    #[inline]
923    #[must_use]
924    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
925        if i < self.num_vectors() {
926            // SAFETY: Bounds check passed.
927            Some(unsafe { self.get_row_mut_unchecked(i) })
928        } else {
929            None
930        }
931    }
932
933    /// Returns a mutable reference to the i-th row without bounds checking.
934    ///
935    /// # Safety
936    ///
937    /// `i` must be less than [`num_vectors()`](Self::num_vectors).
938    #[inline]
939    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
940        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
941        // type ensure that `ptr` is compatible with `T`.
942        unsafe { self.repr.get_row_mut(self.ptr, i) }
943    }
944
945    /// Reborrows as an immutable [`MatRef`].
946    pub fn as_view(&self) -> MatRef<'_, T> {
947        MatRef {
948            ptr: self.ptr,
949            repr: self.repr,
950            _lifetime: PhantomData,
951        }
952    }
953
954    /// Returns an iterator over immutable row references.
955    pub fn rows(&self) -> Rows<'_, T> {
956        Rows::new(self.reborrow())
957    }
958
959    /// Returns an iterator over mutable row references.
960    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
961        RowsMut::new(self.reborrow_mut())
962    }
963
964    /// Return a [`Mat`] with the same contents as `self`.
965    pub fn to_owned(&self) -> Mat<T>
966    where
967        T: NewCloned,
968    {
969        T::new_cloned(self.as_view())
970    }
971
972    /// Construct a new [`MatMut`] over the raw pointer and representation without performing
973    /// any validity checks.
974    ///
975    /// # Safety
976    ///
977    /// Argument `ptr` must point to memory compatible with [`Repr::layout`].
978    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
979        Self {
980            ptr,
981            repr,
982            _lifetime: PhantomData,
983        }
984    }
985
986    /// Return the base pointer for the [`MatMut`].
987    pub fn as_raw_ptr(&self) -> *const u8 {
988        self.ptr.as_ptr()
989    }
990}
991
992// Reborrow: MatMut -> MatRef
993impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
994    type Target = MatRef<'this, T>;
995
996    fn reborrow(&'this self) -> Self::Target {
997        self.as_view()
998    }
999}
1000
1001// ReborrowMut: MatMut -> MatMut (with shorter lifetime)
1002impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
1003    type Target = MatMut<'this, T>;
1004
1005    fn reborrow_mut(&'this mut self) -> Self::Target {
1006        MatMut {
1007            ptr: self.ptr,
1008            repr: self.repr,
1009            _lifetime: PhantomData,
1010        }
1011    }
1012}
1013
1014impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1015    /// Returns the raw dimension (columns) of the vectors in the matrix.
1016    #[inline]
1017    pub fn vector_dim(&self) -> usize {
1018        self.repr.ncols()
1019    }
1020}
1021
1022//////////
1023// Rows //
1024//////////
1025
1026/// Iterator over immutable row references of a matrix.
1027///
1028/// Created by [`Mat::rows`], [`MatRef::rows`], or [`MatMut::rows`].
1029#[derive(Debug)]
1030pub struct Rows<'a, T: Repr> {
1031    matrix: MatRef<'a, T>,
1032    current: usize,
1033}
1034
1035impl<'a, T> Rows<'a, T>
1036where
1037    T: Repr,
1038{
1039    fn new(matrix: MatRef<'a, T>) -> Self {
1040        Self { matrix, current: 0 }
1041    }
1042}
1043
1044impl<'a, T> Iterator for Rows<'a, T>
1045where
1046    T: Repr + 'a,
1047{
1048    type Item = T::Row<'a>;
1049
1050    fn next(&mut self) -> Option<Self::Item> {
1051        let current = self.current;
1052        if current >= self.matrix.num_vectors() {
1053            None
1054        } else {
1055            self.current += 1;
1056            // SAFETY: We make sure through the above check that
1057            // the access is within bounds.
1058            //
1059            // Extending the lifetime to `'a` is safe because the underlying
1060            // MatRef has lifetime `'a`.
1061            Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1062        }
1063    }
1064
1065    fn size_hint(&self) -> (usize, Option<usize>) {
1066        let remaining = self.matrix.num_vectors() - self.current;
1067        (remaining, Some(remaining))
1068    }
1069}
1070
1071impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1072impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1073
1074/////////////
1075// RowsMut //
1076/////////////
1077
1078/// Iterator over mutable row references of a matrix.
1079///
1080/// Created by [`Mat::rows_mut`] or [`MatMut::rows_mut`].
1081#[derive(Debug)]
1082pub struct RowsMut<'a, T: ReprMut> {
1083    matrix: MatMut<'a, T>,
1084    current: usize,
1085}
1086
1087impl<'a, T> RowsMut<'a, T>
1088where
1089    T: ReprMut,
1090{
1091    fn new(matrix: MatMut<'a, T>) -> Self {
1092        Self { matrix, current: 0 }
1093    }
1094}
1095
1096impl<'a, T> Iterator for RowsMut<'a, T>
1097where
1098    T: ReprMut + 'a,
1099{
1100    type Item = T::RowMut<'a>;
1101
1102    fn next(&mut self) -> Option<Self::Item> {
1103        let current = self.current;
1104        if current >= self.matrix.num_vectors() {
1105            None
1106        } else {
1107            self.current += 1;
1108            // SAFETY: We make sure through the above check that
1109            // the access is within bounds.
1110            //
1111            // Extending the lifetime to `'a` is safe because:
1112            // 1. The underlying MatMut has lifetime `'a`.
1113            // 2. The iterator ensures that the mutable row indices are disjoint, so
1114            //    there is no aliasing as long as the implementation of `ReprMut` ensures
1115            //    there is not mutable sharing of the `RowMut` types.
1116            Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1117        }
1118    }
1119
1120    fn size_hint(&self) -> (usize, Option<usize>) {
1121        let remaining = self.matrix.num_vectors() - self.current;
1122        (remaining, Some(remaining))
1123    }
1124}
1125
1126impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1127impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1128
1129///////////
1130// Tests //
1131///////////
1132
1133#[cfg(test)]
1134mod tests {
1135    use super::*;
1136
1137    use std::fmt::Display;
1138
1139    use diskann_utils::lazy_format;
1140
1141    /// Helper to assert a type is Copy.
1142    fn assert_copy<T: Copy>(_: &T) {}
1143
1144    // ── Variance assertions ──────────────────────────────────────
1145    //
1146    // These functions are never called. The test is that they compile:
1147    // covariant positions must accept subtype coercions.
1148    //
1149    // The negative (invariance) counterparts live in
1150    // `tests/compile-fail/multi/{mat,matmut}_invariant.rs`.
1151
1152    /// `MatRef` is covariant in `'a`: a longer borrow can shorten.
1153    fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
1154        v: MatRef<'long, T>,
1155    ) -> MatRef<'short, T> {
1156        v
1157    }
1158
1159    /// `MatRef` is covariant in `T`: `Standard<&'long u8>` → `Standard<&'short u8>`.
1160    fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
1161        v: MatRef<'a, Standard<&'long u8>>,
1162    ) -> MatRef<'a, Standard<&'short u8>> {
1163        v
1164    }
1165
1166    /// `MatMut` is covariant in `'a`: a longer borrow can shorten.
1167    fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
1168        v: MatMut<'long, T>,
1169    ) -> MatMut<'short, T> {
1170        v
1171    }
1172
1173    fn edge_cases(nrows: usize) -> Vec<usize> {
1174        let max = usize::MAX;
1175
1176        vec![
1177            nrows,
1178            nrows + 1,
1179            nrows + 11,
1180            nrows + 20,
1181            max / 2,
1182            max.div_ceil(2),
1183            max - 1,
1184            max,
1185        ]
1186    }
1187
1188    fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1189        assert_eq!(x.repr(), &repr);
1190        assert_eq!(x.num_vectors(), repr.nrows());
1191        assert_eq!(x.vector_dim(), repr.ncols());
1192
1193        for i in 0..x.num_vectors() {
1194            let row = x.get_row_mut(i).unwrap();
1195            assert_eq!(row.len(), repr.ncols());
1196            row.iter_mut()
1197                .enumerate()
1198                .for_each(|(j, r)| *r = 10 * i + j);
1199        }
1200
1201        for i in edge_cases(repr.nrows()).into_iter() {
1202            assert!(x.get_row_mut(i).is_none());
1203        }
1204    }
1205
1206    fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1207        assert_eq!(x.repr(), &repr);
1208        assert_eq!(x.num_vectors(), repr.nrows());
1209        assert_eq!(x.vector_dim(), repr.ncols());
1210
1211        for i in 0..x.num_vectors() {
1212            let row = x.get_row_mut(i).unwrap();
1213            assert_eq!(row.len(), repr.ncols());
1214
1215            row.iter_mut()
1216                .enumerate()
1217                .for_each(|(j, r)| *r = 10 * i + j);
1218        }
1219
1220        for i in edge_cases(repr.nrows()).into_iter() {
1221            assert!(x.get_row_mut(i).is_none());
1222        }
1223    }
1224
1225    fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1226        assert_eq!(x.len(), repr.nrows());
1227        // Materialize all rows at once.
1228        let mut all_rows: Vec<_> = x.collect();
1229        assert_eq!(all_rows.len(), repr.nrows());
1230        for (i, row) in all_rows.iter_mut().enumerate() {
1231            assert_eq!(row.len(), repr.ncols());
1232            row.iter_mut()
1233                .enumerate()
1234                .for_each(|(j, r)| *r = 10 * i + j);
1235        }
1236    }
1237
1238    fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1239        assert_eq!(x.repr(), &repr);
1240        assert_eq!(x.num_vectors(), repr.nrows());
1241        assert_eq!(x.vector_dim(), repr.ncols());
1242
1243        for i in 0..x.num_vectors() {
1244            let row = x.get_row(i).unwrap();
1245
1246            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1247            row.iter().enumerate().for_each(|(j, r)| {
1248                assert_eq!(
1249                    *r,
1250                    10 * i + j,
1251                    "mismatched entry at row {}, col {} -- ctx: {}",
1252                    i,
1253                    j,
1254                    ctx
1255                )
1256            });
1257        }
1258
1259        for i in edge_cases(repr.nrows()).into_iter() {
1260            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1261        }
1262    }
1263
1264    fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1265        assert_eq!(x.repr(), &repr);
1266        assert_eq!(x.num_vectors(), repr.nrows());
1267        assert_eq!(x.vector_dim(), repr.ncols());
1268
1269        assert_copy(&x);
1270        for i in 0..x.num_vectors() {
1271            let row = x.get_row(i).unwrap();
1272            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1273
1274            row.iter().enumerate().for_each(|(j, r)| {
1275                assert_eq!(
1276                    *r,
1277                    10 * i + j,
1278                    "mismatched entry at row {}, col {} -- ctx: {}",
1279                    i,
1280                    j,
1281                    ctx
1282                )
1283            });
1284        }
1285
1286        for i in edge_cases(repr.nrows()).into_iter() {
1287            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1288        }
1289    }
1290
1291    fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1292        assert_eq!(x.repr(), &repr);
1293        assert_eq!(x.num_vectors(), repr.nrows());
1294        assert_eq!(x.vector_dim(), repr.ncols());
1295
1296        for i in 0..x.num_vectors() {
1297            let row = x.get_row(i).unwrap();
1298            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1299
1300            row.iter().enumerate().for_each(|(j, r)| {
1301                assert_eq!(
1302                    *r,
1303                    10 * i + j,
1304                    "mismatched entry at row {}, col {} -- ctx: {}",
1305                    i,
1306                    j,
1307                    ctx
1308                )
1309            });
1310        }
1311
1312        for i in edge_cases(repr.nrows()).into_iter() {
1313            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1314        }
1315    }
1316
1317    fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1318        assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1319        let all_rows: Vec<_> = x.collect();
1320        assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1321        for (i, row) in all_rows.iter().enumerate() {
1322            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1323            row.iter().enumerate().for_each(|(j, r)| {
1324                assert_eq!(
1325                    *r,
1326                    10 * i + j,
1327                    "mismatched entry at row {}, col {} -- ctx: {}",
1328                    i,
1329                    j,
1330                    ctx
1331                )
1332            });
1333        }
1334    }
1335
1336    //////////////
1337    // Standard //
1338    //////////////
1339
1340    #[test]
1341    fn standard_representation() {
1342        let repr = Standard::<f32>::new(4, 3).unwrap();
1343        assert_eq!(repr.nrows(), 4);
1344        assert_eq!(repr.ncols(), 3);
1345
1346        let layout = repr.layout().unwrap();
1347        assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1348        assert_eq!(layout.align(), std::mem::align_of::<f32>());
1349    }
1350
1351    #[test]
1352    fn standard_zero_dimensions() {
1353        for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1354            let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1355            assert_eq!(repr.nrows(), nrows);
1356            assert_eq!(repr.ncols(), ncols);
1357            let layout = repr.layout().unwrap();
1358            assert_eq!(layout.size(), 0);
1359        }
1360    }
1361
1362    #[test]
1363    fn standard_check_slice() {
1364        let repr = Standard::<u32>::new(3, 4).unwrap();
1365
1366        // Correct length succeeds
1367        let data = vec![0u32; 12];
1368        assert!(repr.check_slice(&data).is_ok());
1369
1370        // Too short fails
1371        let short = vec![0u32; 11];
1372        assert!(matches!(
1373            repr.check_slice(&short),
1374            Err(SliceError::LengthMismatch {
1375                expected: 12,
1376                found: 11
1377            })
1378        ));
1379
1380        // Too long fails
1381        let long = vec![0u32; 13];
1382        assert!(matches!(
1383            repr.check_slice(&long),
1384            Err(SliceError::LengthMismatch {
1385                expected: 12,
1386                found: 13
1387            })
1388        ));
1389
1390        // Overflow case
1391        let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1392        assert!(matches!(overflow_repr, Overflow { .. }));
1393    }
1394
1395    #[test]
1396    fn standard_new_rejects_element_count_overflow() {
1397        // nrows * ncols overflows usize even though per-element size is small.
1398        assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1399        assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1400        assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1401    }
1402
1403    #[test]
1404    fn standard_new_rejects_byte_count_exceeding_isize_max() {
1405        // Element count fits in usize, but total bytes exceed isize::MAX.
1406        let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1407        assert!(Standard::<u64>::new(half, 1).is_err());
1408        assert!(Standard::<u64>::new(1, half).is_err());
1409    }
1410
1411    #[test]
1412    fn standard_new_accepts_boundary_below_isize_max() {
1413        // Largest allocation that still fits in isize::MAX bytes.
1414        let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1415        let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1416        assert_eq!(repr.num_elements(), max_elems);
1417    }
1418
1419    #[test]
1420    fn standard_new_zst_rejects_element_count_overflow() {
1421        // For ZSTs the byte count is always 0, but element-count overflow
1422        // must still be caught so that `num_elements()` never wraps.
1423        assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1424        assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1425    }
1426
1427    #[test]
1428    fn standard_new_zst_accepts_large_non_overflowing() {
1429        // Large-but-valid ZST matrix: element count fits in usize.
1430        let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1431        assert_eq!(repr.num_elements(), usize::MAX);
1432        assert_eq!(repr.layout().unwrap().size(), 0);
1433    }
1434
1435    #[test]
1436    fn standard_new_overflow_error_display() {
1437        let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1438        let msg = err.to_string();
1439        assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1440
1441        let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1442        let zst_msg = zst_err.to_string();
1443        assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1444        assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1445    }
1446
1447    /////////
1448    // Mat //
1449    /////////
1450
1451    #[test]
1452    fn mat_new_and_basic_accessors() {
1453        let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1454        let base: *const u8 = mat.as_raw_ptr();
1455
1456        assert_eq!(mat.num_vectors(), 3);
1457        assert_eq!(mat.vector_dim(), 4);
1458
1459        let repr = mat.repr();
1460        assert_eq!(repr.nrows(), 3);
1461        assert_eq!(repr.ncols(), 4);
1462
1463        for (i, r) in mat.rows().enumerate() {
1464            assert_eq!(r, &[42, 42, 42, 42]);
1465            let ptr = r.as_ptr().cast::<u8>();
1466            assert_eq!(
1467                ptr,
1468                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1469            );
1470        }
1471    }
1472
1473    #[test]
1474    fn mat_new_with_default() {
1475        let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1476        let base: *const u8 = mat.as_raw_ptr();
1477
1478        assert_eq!(mat.num_vectors(), 2);
1479        for (i, row) in mat.rows().enumerate() {
1480            assert!(row.iter().all(|&v| v == 0));
1481
1482            let ptr = row.as_ptr().cast::<u8>();
1483            assert_eq!(
1484                ptr,
1485                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1486            );
1487        }
1488    }
1489
1490    const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1491    const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1492
1493    #[test]
1494    fn test_mat() {
1495        for nrows in ROWS {
1496            for ncols in COLS {
1497                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1498                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1499
1500                // Populate the matrix using `&mut Mat`
1501                {
1502                    let ctx = &lazy_format!("{ctx} - direct");
1503                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1504
1505                    assert_eq!(mat.num_vectors(), *nrows);
1506                    assert_eq!(mat.vector_dim(), *ncols);
1507
1508                    fill_mat(&mut mat, repr);
1509
1510                    check_mat(&mat, repr, ctx);
1511                    check_mat_ref(mat.reborrow(), repr, ctx);
1512                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1513                    check_rows(mat.rows(), repr, ctx);
1514
1515                    // Check reborrow preserves pointers.
1516                    assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1517                    assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1518                }
1519
1520                // Populate the matrix using `MatMut`
1521                {
1522                    let ctx = &lazy_format!("{ctx} - matmut");
1523                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1524                    let matmut = mat.reborrow_mut();
1525
1526                    assert_eq!(matmut.num_vectors(), *nrows);
1527                    assert_eq!(matmut.vector_dim(), *ncols);
1528
1529                    fill_mat_mut(matmut, repr);
1530
1531                    check_mat(&mat, repr, ctx);
1532                    check_mat_ref(mat.reborrow(), repr, ctx);
1533                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1534                    check_rows(mat.rows(), repr, ctx);
1535                }
1536
1537                // Populate the matrix using `RowsMut`
1538                {
1539                    let ctx = &lazy_format!("{ctx} - rows_mut");
1540                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1541                    fill_rows_mut(mat.rows_mut(), repr);
1542
1543                    check_mat(&mat, repr, ctx);
1544                    check_mat_ref(mat.reborrow(), repr, ctx);
1545                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1546                    check_rows(mat.rows(), repr, ctx);
1547                }
1548            }
1549        }
1550    }
1551
1552    #[test]
1553    fn test_mat_clone() {
1554        for nrows in ROWS {
1555            for ncols in COLS {
1556                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1557                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1558
1559                let mut mat = Mat::new(repr, Defaulted).unwrap();
1560                fill_mat(&mut mat, repr);
1561
1562                // Clone via Mat::clone
1563                {
1564                    let ctx = &lazy_format!("{ctx} - Mat::clone");
1565                    let cloned = mat.clone();
1566
1567                    assert_eq!(cloned.num_vectors(), *nrows);
1568                    assert_eq!(cloned.vector_dim(), *ncols);
1569
1570                    check_mat(&cloned, repr, ctx);
1571                    check_mat_ref(cloned.reborrow(), repr, ctx);
1572                    check_rows(cloned.rows(), repr, ctx);
1573
1574                    // Cloned allocation is independent.
1575                    if repr.num_elements() > 0 {
1576                        assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1577                    }
1578                }
1579
1580                // Clone via MatRef::to_owned
1581                {
1582                    let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1583                    let owned = mat.as_view().to_owned();
1584
1585                    check_mat(&owned, repr, ctx);
1586                    check_mat_ref(owned.reborrow(), repr, ctx);
1587                    check_rows(owned.rows(), repr, ctx);
1588
1589                    if repr.num_elements() > 0 {
1590                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1591                    }
1592                }
1593
1594                // Clone via MatMut::to_owned
1595                {
1596                    let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1597                    let owned = mat.as_view_mut().to_owned();
1598
1599                    check_mat(&owned, repr, ctx);
1600                    check_mat_ref(owned.reborrow(), repr, ctx);
1601                    check_rows(owned.rows(), repr, ctx);
1602
1603                    if repr.num_elements() > 0 {
1604                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1605                    }
1606                }
1607            }
1608        }
1609    }
1610
1611    #[test]
1612    fn test_mat_refmut() {
1613        for nrows in ROWS {
1614            for ncols in COLS {
1615                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1616                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1617
1618                // Populate the matrix using `&mut Mat`
1619                {
1620                    let ctx = &lazy_format!("{ctx} - by matmut");
1621                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1622                    let ptr = b.as_ptr().cast::<u8>();
1623                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1624
1625                    assert_eq!(
1626                        ptr,
1627                        matmut.as_raw_ptr(),
1628                        "underlying memory should be preserved",
1629                    );
1630
1631                    fill_mat_mut(matmut.reborrow_mut(), repr);
1632
1633                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1634                    check_mat_ref(matmut.reborrow(), repr, ctx);
1635                    check_rows(matmut.rows(), repr, ctx);
1636                    check_rows(matmut.reborrow().rows(), repr, ctx);
1637
1638                    let matref = MatRef::new(repr, &b).unwrap();
1639                    check_mat_ref(matref, repr, ctx);
1640                    check_mat_ref(matref.reborrow(), repr, ctx);
1641                    check_rows(matref.rows(), repr, ctx);
1642                }
1643
1644                // Populate the matrix using `RowsMut`
1645                {
1646                    let ctx = &lazy_format!("{ctx} - by rows");
1647                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1648                    let ptr = b.as_ptr().cast::<u8>();
1649                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1650
1651                    assert_eq!(
1652                        ptr,
1653                        matmut.as_raw_ptr(),
1654                        "underlying memory should be preserved",
1655                    );
1656
1657                    fill_rows_mut(matmut.rows_mut(), repr);
1658
1659                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1660                    check_mat_ref(matmut.reborrow(), repr, ctx);
1661                    check_rows(matmut.rows(), repr, ctx);
1662                    check_rows(matmut.reborrow().rows(), repr, ctx);
1663
1664                    let matref = MatRef::new(repr, &b).unwrap();
1665                    check_mat_ref(matref, repr, ctx);
1666                    check_mat_ref(matref.reborrow(), repr, ctx);
1667                    check_rows(matref.rows(), repr, ctx);
1668                }
1669            }
1670        }
1671    }
1672
1673    //////////////////
1674    // Constructors //
1675    //////////////////
1676
1677    #[test]
1678    fn test_standard_new_owned() {
1679        let rows = [0, 1, 2, 3, 5, 10];
1680        let cols = [0, 1, 2, 3, 5, 10];
1681
1682        for nrows in rows {
1683            for ncols in cols {
1684                let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1685                let rows_iter = m.rows();
1686                let len = <_ as ExactSizeIterator>::len(&rows_iter);
1687                assert_eq!(len, nrows);
1688                for r in rows_iter {
1689                    assert_eq!(r.len(), ncols);
1690                    assert!(r.iter().all(|i| *i == 1usize));
1691                }
1692            }
1693        }
1694    }
1695
1696    #[test]
1697    fn matref_new_slice_length_error() {
1698        let repr = Standard::<u32>::new(3, 4).unwrap();
1699
1700        // Correct length succeeds
1701        let data = vec![0u32; 12];
1702        assert!(MatRef::new(repr, &data).is_ok());
1703
1704        // Too short fails
1705        let short = vec![0u32; 11];
1706        assert!(matches!(
1707            MatRef::new(repr, &short),
1708            Err(SliceError::LengthMismatch {
1709                expected: 12,
1710                found: 11
1711            })
1712        ));
1713
1714        // Too long fails
1715        let long = vec![0u32; 13];
1716        assert!(matches!(
1717            MatRef::new(repr, &long),
1718            Err(SliceError::LengthMismatch {
1719                expected: 12,
1720                found: 13
1721            })
1722        ));
1723    }
1724
1725    #[test]
1726    fn matmut_new_slice_length_error() {
1727        let repr = Standard::<u32>::new(3, 4).unwrap();
1728
1729        // Correct length succeeds
1730        let mut data = vec![0u32; 12];
1731        assert!(MatMut::new(repr, &mut data).is_ok());
1732
1733        // Too short fails
1734        let mut short = vec![0u32; 11];
1735        assert!(matches!(
1736            MatMut::new(repr, &mut short),
1737            Err(SliceError::LengthMismatch {
1738                expected: 12,
1739                found: 11
1740            })
1741        ));
1742
1743        // Too long fails
1744        let mut long = vec![0u32; 13];
1745        assert!(matches!(
1746            MatMut::new(repr, &mut long),
1747            Err(SliceError::LengthMismatch {
1748                expected: 12,
1749                found: 13
1750            })
1751        ));
1752    }
1753}