Skip to main content

diskann_quantization/multi_vector/
matrix.rs

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT license.
3
4//! Row-major matrix types for multi-vector representations.
5//!
6//! This module provides flexible matrix abstractions that support different underlying
7//! storage formats through the [`Repr`] trait. The primary types are:
8//!
9//! - [`Mat`]: An owning matrix that manages its own memory.
10//! - [`MatRef`]: An immutable borrowed view of matrix data.
11//! - [`MatMut`]: A mutable borrowed view of matrix data.
12//!
13//! # Representations
14//!
15//! Representation types interact with the [`Mat`] family of types using the following traits:
16//!
17//! - [`Repr`]: Read-only matrix representation.
18//! - [`ReprMut`]: Mutable matrix representation.
19//! - [`ReprOwned`]: Owning matrix representation.
20//!
21//! Each trait refinement has a corresponding constructor:
22//!
23//! - [`NewRef`]: Construct a read-only [`MatRef`] view over a slice.
24//! - [`NewMut`]: Construct a mutable [`MatMut`] matrix view over a slice.
25//! - [`NewOwned`]: Construct a new owning [`Mat`].
26//!
27
28use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
29
30use diskann_utils::{Reborrow, ReborrowMut};
31use thiserror::Error;
32
33use crate::utils;
34
35/// Representation trait describing the layout and access patterns for a matrix.
36///
37/// Implementations define how raw bytes are interpreted as typed rows. This enables
38/// matrices over different storage formats (dense, quantized, etc.) using a single
39/// generic [`Mat`] type.
40///
41/// # Associated Types
42///
43/// - `Row<'a>`: The immutable row type (e.g., `&[f32]`, `&[f16]`).
44///
45/// # Safety
46///
47/// Implementations must ensure:
48///
49/// - [`get_row`](Self::get_row) returns valid references for the given row index.
50///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
51///   the contract for the raw pointer.
52///
53/// - The objects implicitly managed by this representation inherit the `Send` and `Sync`
54///   attributes of `Repr`. That is, `Repr: Send` implies that the objects in backing memory
55///   are [`Send`], and likewise with `Sync`. This is necessary to apply [`Send`] and [`Sync`]
56///   bounds to [`Mat`], [`MatRef`], and [`MatMut`].
57pub unsafe trait Repr: Copy {
58    /// Immutable row reference type.
59    type Row<'a>
60    where
61        Self: 'a;
62
63    /// Returns the number of rows in the matrix.
64    ///
65    /// # Safety Contract
66    ///
67    /// This function must be loosely pure in the sense that for any given instance of
68    /// `self`, `self.nrows()` must return the same value.
69    fn nrows(&self) -> usize;
70
71    /// Returns the memory layout for a memory allocation containing [`Repr::nrows`] vectors
72    /// each with vector dimension [`Repr::ncols`].
73    ///
74    /// # Safety Contract
75    ///
76    /// The [`Layout`] returned from this method must be consistent with the contract of
77    /// [`Repr::get_row`].
78    fn layout(&self) -> Result<Layout, LayoutError>;
79
80    /// Returns an immutable reference to the `i`-th row.
81    ///
82    /// # Safety
83    ///
84    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
85    /// - The entire range for this slice must be within a single allocation.
86    /// - `i` must be less than [`Repr::nrows`].
87    /// - The memory referenced by the returned [`Repr::Row`] must not be mutated for the
88    ///   duration of lifetime `'a`.
89    /// - The lifetime for the returned [`Repr::Row`] is inferred from its usage. Correct
90    ///   usage must properly tie the lifetime to a source.
91    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
92}
93
94/// Extension of [`Repr`] that supports mutable row access.
95///
96/// # Associated Types
97///
98/// - `RowMut<'a>`: The mutable row type (e.g., `&mut [f32]`).
99///
100/// # Safety
101///
102/// Implementors must ensure:
103///
104/// - [`get_row_mut`](Self::get_row_mut) returns valid references for the given row index.
105///   This call **must** be memory safe for `i < self.nrows()`, provided the caller upholds
106///   the contract for the raw pointer.
107///
108///   Additionally, since the implementation of the [`RowsMut`] iterator can give out rows
109///   for all `i` in `0..self.nrows()`, the implementation of [`Self::get_row_mut`] must be
110///   such that the result for disjoint `i` must not interfere with one another.
111pub unsafe trait ReprMut: Repr {
112    /// Mutable row reference type.
113    type RowMut<'a>
114    where
115        Self: 'a;
116
117    /// Returns a mutable reference to the i-th row.
118    ///
119    /// # Safety
120    /// - `ptr` must point to a slice with a layout compatible with [`Repr::layout`].
121    /// - The entire range for this slice must be within a single allocation.
122    /// - `i` must be less than `self.nrows()`.
123    /// - The memory referenced by the returned [`ReprMut::RowMut`] must not be accessed
124    ///   through any other reference for the duration of lifetime `'a`.
125    /// - The lifetime for the returned [`ReprMut::RowMut`] is inferred from its usage.
126    ///   Correct usage must properly tie the lifetime to a source.
127    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
128}
129
130/// Extension trait for [`Repr`] that supports deallocation of owned matrices. This is used
131/// in conjunction with [`NewOwned`] to create matrices.
132///
133/// Requires [`ReprMut`] since owned matrices should support mutation.
134///
135/// # Safety
136///
137/// Implementors must ensure that `drop` properly deallocates the memory in a way compatible
138/// with all [`NewOwned`] implementations.
139pub unsafe trait ReprOwned: ReprMut {
140    /// Deallocates memory at `ptr` and drops `self`.
141    ///
142    /// # Safety
143    ///
144    /// - `ptr` must have been obtained via [`NewOwned`] with the same value of `self`.
145    /// - This method may only be called once for such a pointer.
146    /// - After calling this method, the memory behind `ptr` may not be dereferenced at all.
147    unsafe fn drop(self, ptr: NonNull<u8>);
148}
149
150/// A new-type version of `std::alloc::LayoutError` for cleaner error handling.
151///
152/// This is basically the same as [`std::alloc::LayoutError`], but constructible in
153/// use code to allow implementors of [`Repr::layout`] to return it for reasons other than
154/// those derived from `std::alloc::Layout`'s methods.
155#[derive(Debug, Clone, Copy)]
156#[non_exhaustive]
157pub struct LayoutError;
158
159impl LayoutError {
160    /// Construct a new opaque [`LayoutError`].
161    pub fn new() -> Self {
162        Self
163    }
164}
165
166impl Default for LayoutError {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl std::fmt::Display for LayoutError {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        write!(f, "LayoutError")
175    }
176}
177
178impl std::error::Error for LayoutError {}
179
180impl From<std::alloc::LayoutError> for LayoutError {
181    fn from(_: std::alloc::LayoutError) -> Self {
182        LayoutError
183    }
184}
185
186//////////////////
187// Constructors //
188//////////////////
189
190/// Create a new [`MatRef`] over a slice.
191///
192/// # Safety
193///
194/// Implementations must validate the length (and any other requirements) of the provided
195/// slice to ensure it is compatible with the implementation of [`Repr`].
196pub unsafe trait NewRef<T>: Repr {
197    /// Errors that can occur when initializing.
198    type Error;
199
200    /// Create a new [`MatRef`] over `slice`.
201    fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
202}
203
204/// Create a new [`MatMut`] over a slice.
205///
206/// # Safety
207///
208/// Implementations must validate the length (and any other requirements) of the provided
209/// slice to ensure it is compatible with the implementation of [`ReprMut`].
210pub unsafe trait NewMut<T>: ReprMut {
211    /// Errors that can occur when initializing.
212    type Error;
213
214    /// Create a new [`MatMut`] over `slice`.
215    fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
216}
217
218/// Create a new [`Mat`] from an initializer.
219///
220/// # Safety
221///
222/// Implementations must ensure that the returned [`Mat`] is compatible with
223/// `Self`'s implementation of [`ReprOwned`].
224pub unsafe trait NewOwned<T>: ReprOwned {
225    /// Errors that can occur when initializing.
226    type Error;
227
228    /// Create a new [`Mat`] initialized with `init`.
229    fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
230}
231
232/// An initializer argument to [`NewOwned`] that uses a type's [`Default`] implementation
233/// to initialize a matrix.
234///
235/// ```rust
236/// use diskann_quantization::multi_vector::{Mat, Standard, Defaulted};
237/// let mat = Mat::new(Standard::<f32>::new(4, 3).unwrap(), Defaulted).unwrap();
238/// for i in 0..4 {
239///     assert!(mat.get_row(i).unwrap().iter().all(|&x| x == 0.0f32));
240/// }
241/// ```
242#[derive(Debug, Clone, Copy)]
243pub struct Defaulted;
244
245/// Create a new [`Mat`] cloned from a view.
246pub trait NewCloned: ReprOwned {
247    /// Clone the contents behind `v`, returning a new owning [`Mat`].
248    ///
249    /// Implementations should ensure the returned [`Mat`] is "semantically the same" as `v`.
250    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
251}
252
253//////////////
254// Standard //
255//////////////
256
257/// Metadata for dense row-major matrices of `Copy` types.
258///
259/// Rows are stored contiguously as `&[T]` slices. This is the default representation
260/// type for standard floating-point multi-vectors.
261///
262/// # Row Types
263///
264/// - `Row<'a>`: `&'a [T]`
265/// - `RowMut<'a>`: `&'a mut [T]`
266#[derive(Debug, Clone, Copy, PartialEq, Eq)]
267pub struct Standard<T> {
268    nrows: usize,
269    ncols: usize,
270    _elem: PhantomData<T>,
271}
272
273impl<T: Copy> Standard<T> {
274    /// Create a new `Standard` for data of type `T`.
275    ///
276    /// Successful construction requires:
277    ///
278    /// * The total number of elements determined by `nrows * ncols` does not exceed
279    ///   `usize::MAX`.
280    /// * The total memory footprint defined by `ncols * nrows * size_of::<T>()` does not
281    ///   exceed `isize::MAX`.
282    pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
283        Overflow::check::<T>(nrows, ncols)?;
284        Ok(Self {
285            nrows,
286            ncols,
287            _elem: PhantomData,
288        })
289    }
290
291    /// Returns the number of total elements (`rows x cols`) in this matrix.
292    pub fn num_elements(&self) -> usize {
293        // Since we've constructed `self` - we know we cannot overflow.
294        self.nrows() * self.ncols()
295    }
296
297    /// Returns `rows`, the number of rows in this matrix.
298    fn nrows(&self) -> usize {
299        self.nrows
300    }
301
302    /// Returns `ncols`, the number of elements in a row of this matrix.
303    fn ncols(&self) -> usize {
304        self.ncols
305    }
306
307    /// Checks the following:
308    ///
309    /// 1. Computation of the number of elements in `self` does not overflow.
310    /// 2. Argument `slice` has the expected number of elements.
311    fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
312        let len = self.num_elements();
313
314        if slice.len() != len {
315            Err(SliceError::LengthMismatch {
316                expected: len,
317                found: slice.len(),
318            })
319        } else {
320            Ok(())
321        }
322    }
323
324    /// Create a new [`Mat`] around the contents of `b` **without** any checks.
325    ///
326    /// # Safety
327    ///
328    /// The length of `b` must be exactly [`Standard::num_elements`].
329    unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
330        debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
331
332        // SAFETY: Box [guarantees](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.into_raw)
333        // the returned pointer is non-null.
334        let ptr = unsafe { NonNull::new_unchecked(Box::into_raw(b)) }.cast::<u8>();
335
336        // SAFETY: `ptr` is properly aligned and points to a slice of the required length.
337        // Additionally, it is dropped via `Box::from_raw`, which is compatible with obtaining
338        // it from `Box::into_raw`.
339        unsafe { Mat::from_raw_parts(self, ptr) }
340    }
341}
342
343/// Error for [`Standard::new`].
344#[derive(Debug, Clone, Copy)]
345pub struct Overflow {
346    nrows: usize,
347    ncols: usize,
348    elsize: usize,
349}
350
351impl Overflow {
352    fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
353        let elsize = std::mem::size_of::<T>();
354        // Guard the element count itself so that `num_elements()` can never overflow.
355        let elements = nrows.checked_mul(ncols).ok_or(Self {
356            nrows,
357            ncols,
358            elsize,
359        })?;
360
361        let bytes = elsize.saturating_mul(elements);
362        if bytes <= isize::MAX as usize {
363            Ok(())
364        } else {
365            Err(Self {
366                nrows,
367                ncols,
368                elsize,
369            })
370        }
371    }
372}
373
374impl std::fmt::Display for Overflow {
375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
376        if self.elsize == 0 {
377            write!(
378                f,
379                "ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
380                self.nrows, self.ncols,
381            )
382        } else {
383            write!(
384                f,
385                "a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
386                self.nrows, self.ncols, self.elsize,
387            )
388        }
389    }
390}
391
392impl std::error::Error for Overflow {}
393
394/// Error types for [`Standard`].
395#[derive(Debug, Clone, Copy, Error)]
396#[non_exhaustive]
397pub enum SliceError {
398    #[error("Length mismatch: expected {expected}, found {found}")]
399    LengthMismatch { expected: usize, found: usize },
400}
401
402// SAFETY: The implementation correctly computes row offsets as `i * ncols` and
403// constructs valid slices of the appropriate length. The `layout` method correctly
404// reports the memory layout requirements.
405unsafe impl<T: Copy> Repr for Standard<T> {
406    type Row<'a>
407        = &'a [T]
408    where
409        T: 'a;
410
411    fn nrows(&self) -> usize {
412        self.nrows
413    }
414
415    fn layout(&self) -> Result<Layout, LayoutError> {
416        Ok(Layout::array::<T>(self.num_elements())?)
417    }
418
419    unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
420        debug_assert!(ptr.cast::<T>().is_aligned());
421        debug_assert!(i < self.nrows);
422
423        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
424        // audits the constructors for `Mat` and friends, we know that there is room for at
425        // least `self.num_elements()` elements from the base pointer, so this access is safe.
426        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
427
428        // SAFETY: The logic is the same as the previous `unsafe` block.
429        unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
430    }
431}
432
433// SAFETY: The implementation correctly computes row offsets and constructs valid mutable
434// slices.
435unsafe impl<T: Copy> ReprMut for Standard<T> {
436    type RowMut<'a>
437        = &'a mut [T]
438    where
439        T: 'a;
440
441    unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
442        debug_assert!(ptr.cast::<T>().is_aligned());
443        debug_assert!(i < self.nrows);
444
445        // SAFETY: The caller asserts that `i` is less than `self.nrows()`. Since this type
446        // audits the constructors for `Mat` and friends, we know that there is room for at
447        // least `self.num_elements()` elements from the base pointer, so this access is safe.
448        let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
449
450        // SAFETY: The logic is the same as the previous `unsafe` block. Further, the caller
451        // attests that creating a mutable reference is safe.
452        unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
453    }
454}
455
456// SAFETY: The drop implementation correctly reconstructs a Box from the raw pointer
457// using the same length (nrows * ncols) that was used for allocation, allowing Box
458// to properly deallocate the memory.
459unsafe impl<T: Copy> ReprOwned for Standard<T> {
460    unsafe fn drop(self, ptr: NonNull<u8>) {
461        // SAFETY: The caller guarantees that `ptr` was obtained from an implementation of
462        // `NewOwned` for an equivalent instance of `self`.
463        //
464        // We ensure that `NewOwned` goes through boxes, so here we reconstruct a Box to
465        // let it handle deallocation.
466        unsafe {
467            let slice_ptr = std::ptr::slice_from_raw_parts_mut(
468                ptr.cast::<T>().as_ptr(),
469                self.nrows * self.ncols,
470            );
471            let _ = Box::from_raw(slice_ptr);
472        }
473    }
474}
475
476// SAFETY: The implementation uses guarantees from `Box` to ensure that the pointer
477// initialized by it is non-null and properly aligned to the underlying type.
478unsafe impl<T> NewOwned<T> for Standard<T>
479where
480    T: Copy,
481{
482    type Error = crate::error::Infallible;
483    fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
484        let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
485
486        // SAFETY: By construction, `b` has length `self.num_elements()`.
487        Ok(unsafe { self.box_to_mat(b) })
488    }
489}
490
491// SAFETY: This safely reuses `<Self as NewOwned<T>>`.
492unsafe impl<T> NewOwned<Defaulted> for Standard<T>
493where
494    T: Copy + Default,
495{
496    type Error = crate::error::Infallible;
497    fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
498        self.new_owned(T::default())
499    }
500}
501
502// SAFETY: This checks that the slice has the correct length, which is all that is
503// required for [`Repr`].
504unsafe impl<T> NewRef<T> for Standard<T>
505where
506    T: Copy,
507{
508    type Error = SliceError;
509    fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
510        self.check_slice(data)?;
511
512        // SAFETY: The function `check_slice` verifies that `data` is compatible with
513        // the layout requirement of `Standard`.
514        //
515        // We've properly checked that the underlying pointer is okay.
516        Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
517    }
518}
519
520// SAFETY: This checks that the slice has the correct length, which is all that is
521// required for [`ReprMut`].
522unsafe impl<T> NewMut<T> for Standard<T>
523where
524    T: Copy,
525{
526    type Error = SliceError;
527    fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
528        self.check_slice(data)?;
529
530        // SAFETY: The function `check_slice` verifies that `data` is compatible with
531        // the layout requirement of `Standard`.
532        //
533        // We've properly checked that the underlying pointer is okay.
534        Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
535    }
536}
537
538impl<T> NewCloned for Standard<T>
539where
540    T: Copy,
541{
542    fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
543        let b: Box<[T]> = v.rows().flatten().copied().collect();
544
545        // SAFETY: By construction, `b` has length `v.repr().num_elements()`.
546        unsafe { v.repr().box_to_mat(b) }
547    }
548}
549
550/////////
551// Mat //
552/////////
553
554/// An owning matrix that manages its own memory.
555///
556/// The matrix stores raw bytes interpreted according to representation type `T`.
557/// Memory is automatically deallocated when the matrix is dropped.
558#[derive(Debug)]
559pub struct Mat<T: ReprOwned> {
560    ptr: NonNull<u8>,
561    repr: T,
562}
563
564// SAFETY: [`Repr`] is required to propagate its `Send` bound.
565unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
566
567// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
568unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
569
570impl<T: ReprOwned> Mat<T> {
571    /// Create a new matrix using `init` as the initializer.
572    pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
573    where
574        T: NewOwned<U>,
575    {
576        repr.new_owned(init)
577    }
578
579    /// Returns the number of rows (vectors) in the matrix.
580    #[inline]
581    pub fn num_vectors(&self) -> usize {
582        self.repr.nrows()
583    }
584
585    /// Returns a reference to the underlying representation.
586    pub fn repr(&self) -> &T {
587        &self.repr
588    }
589
590    /// Returns the `i`th row if `i < self.num_vectors()`.
591    #[must_use]
592    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
593        if i < self.num_vectors() {
594            // SAFETY: Bounds check passed, and the Mat was constructed
595            // with valid representation and pointer.
596            let row = unsafe { self.get_row_unchecked(i) };
597            Some(row)
598        } else {
599            None
600        }
601    }
602
603    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
604        // SAFETY: Caller must ensure i < self.num_vectors(). The constructors for this type
605        // ensure that `ptr` is compatible with `T`.
606        unsafe { self.repr.get_row(self.ptr, i) }
607    }
608
609    /// Returns the `i`th mutable row if `i < self.num_vectors()`.
610    #[must_use]
611    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
612        if i < self.num_vectors() {
613            // SAFETY: Bounds check passed, and we have exclusive access via &mut self.
614            Some(unsafe { self.get_row_mut_unchecked(i) })
615        } else {
616            None
617        }
618    }
619
620    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
621        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
622        // type ensure that `ptr` is compatible with `T`.
623        unsafe { self.repr.get_row_mut(self.ptr, i) }
624    }
625
626    /// Returns an immutable view of the matrix.
627    #[inline]
628    pub fn as_view(&self) -> MatRef<'_, T> {
629        MatRef {
630            ptr: self.ptr,
631            repr: self.repr,
632            _lifetime: PhantomData,
633        }
634    }
635
636    /// Returns a mutable view of the matrix.
637    #[inline]
638    pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
639        MatMut {
640            ptr: self.ptr,
641            repr: self.repr,
642            _lifetime: PhantomData,
643        }
644    }
645
646    /// Returns an iterator over immutable row references.
647    pub fn rows(&self) -> Rows<'_, T> {
648        Rows::new(self.reborrow())
649    }
650
651    /// Returns an iterator over mutable row references.
652    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
653        RowsMut::new(self.reborrow_mut())
654    }
655
656    /// Construct a new [`Mat`] over the raw pointer and representation without performing
657    /// any validity checks.
658    ///
659    /// # Safety
660    ///
661    /// Argument `ptr` must be:
662    ///
663    /// 1. Point to memory compatible with [`Repr::layout`].
664    /// 2. Be compatible with the drop logic in [`ReprOwned`].
665    pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
666        Self { ptr, repr }
667    }
668
669    /// Return the base pointer for the [`Mat`].
670    pub fn as_raw_ptr(&self) -> *const u8 {
671        self.ptr.as_ptr()
672    }
673}
674
675impl<T: ReprOwned> Drop for Mat<T> {
676    fn drop(&mut self) {
677        // SAFETY: `ptr` was correctly initialized according to `layout`
678        // and we are guaranteed exclusive access to the data due to Rust borrow rules.
679        unsafe { self.repr.drop(self.ptr) };
680    }
681}
682
683impl<T: NewCloned> Clone for Mat<T> {
684    fn clone(&self) -> Self {
685        T::new_cloned(self.as_view())
686    }
687}
688
689impl<T: Copy> Mat<Standard<T>> {
690    /// Returns the raw dimension (columns) of the vectors in the matrix.
691    #[inline]
692    pub fn vector_dim(&self) -> usize {
693        self.repr.ncols()
694    }
695}
696
697////////////
698// MatRef //
699////////////
700
701/// An immutable borrowed view of a matrix.
702///
703/// Provides read-only access to matrix data without ownership. Implements [`Copy`]
704/// and can be freely cloned.
705///
706/// # Type Parameter
707/// - `T`: A [`Repr`] implementation defining the row layout.
708///
709/// # Access
710/// - [`get_row`](Self::get_row): Get an immutable row by index.
711/// - [`rows`](Self::rows): Iterate over all rows.
712#[derive(Debug, Clone, Copy)]
713pub struct MatRef<'a, T: Repr> {
714    pub(crate) ptr: NonNull<u8>,
715    pub(crate) repr: T,
716    /// Marker to tie the lifetime to the borrowed data.
717    pub(crate) _lifetime: PhantomData<&'a [u8]>,
718}
719
720// SAFETY: [`Repr`] is required to propagate its `Send` bound.
721unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
722
723// SAFETY: [`Repr`] is required to propagate its `Sync` bound.
724unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
725
726impl<'a, T: Repr> MatRef<'a, T> {
727    /// Construct a new [`MatRef`] over `data`.
728    pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
729    where
730        T: NewRef<U>,
731    {
732        repr.new_ref(data)
733    }
734
735    /// Returns the number of rows (vectors) in the matrix.
736    #[inline]
737    pub fn num_vectors(&self) -> usize {
738        self.repr.nrows()
739    }
740
741    /// Returns a reference to the underlying representation.
742    pub fn repr(&self) -> &T {
743        &self.repr
744    }
745
746    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
747    #[must_use]
748    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
749        if i < self.num_vectors() {
750            // SAFETY: Bounds check passed, and the MatRef was constructed
751            // with valid representation and pointer.
752            let row = unsafe { self.get_row_unchecked(i) };
753            Some(row)
754        } else {
755            None
756        }
757    }
758
759    /// Returns the i-th row without bounds checking.
760    ///
761    /// # Safety
762    ///
763    /// `i` must be less than `self.num_vectors()`.
764    #[inline]
765    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
766        // SAFETY: Caller must ensure i < self.num_vectors().
767        unsafe { self.repr.get_row(self.ptr, i) }
768    }
769
770    /// Returns an iterator over immutable row references.
771    pub fn rows(&self) -> Rows<'_, T> {
772        Rows::new(*self)
773    }
774
775    /// Return a [`Mat`] with the same contents as `self`.
776    pub fn to_owned(&self) -> Mat<T>
777    where
778        T: NewCloned,
779    {
780        T::new_cloned(*self)
781    }
782
783    /// Construct a new [`MatRef`] over the raw pointer and representation without performing
784    /// any validity checks.
785    ///
786    /// # Safety
787    ///
788    /// Argument `ptr` must point to memory compatible with [`Repr::layout`] and pass any
789    /// validity checks required by `T`.
790    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
791        Self {
792            ptr,
793            repr,
794            _lifetime: PhantomData,
795        }
796    }
797
798    /// Return the base pointer for the [`MatRef`].
799    pub fn as_raw_ptr(&self) -> *const u8 {
800        self.ptr.as_ptr()
801    }
802}
803
804impl<'a, T: Copy> MatRef<'a, Standard<T>> {
805    /// Returns the raw dimension (columns) of the vectors in the matrix.
806    #[inline]
807    pub fn vector_dim(&self) -> usize {
808        self.repr.ncols()
809    }
810}
811
812// Reborrow: Mat -> MatRef
813impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
814    type Target = MatRef<'this, T>;
815
816    fn reborrow(&'this self) -> Self::Target {
817        self.as_view()
818    }
819}
820
821// ReborrowMut: Mat -> MatMut
822impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
823    type Target = MatMut<'this, T>;
824
825    fn reborrow_mut(&'this mut self) -> Self::Target {
826        self.as_view_mut()
827    }
828}
829
830// Reborrow: MatRef -> MatRef (with shorter lifetime)
831impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
832    type Target = MatRef<'this, T>;
833
834    fn reborrow(&'this self) -> Self::Target {
835        MatRef {
836            ptr: self.ptr,
837            repr: self.repr,
838            _lifetime: PhantomData,
839        }
840    }
841}
842
843////////////
844// MatMut //
845////////////
846
847/// A mutable borrowed view of a matrix.
848///
849/// Provides read-write access to matrix data without ownership.
850///
851/// # Type Parameter
852/// - `T`: A [`ReprMut`] implementation defining the row layout.
853///
854/// # Access
855/// - [`get_row`](Self::get_row): Get an immutable row by index.
856/// - [`get_row_mut`](Self::get_row_mut): Get a mutable row by index.
857/// - [`as_view`](Self::as_view): Reborrow as immutable [`MatRef`].
858/// - [`rows`](Self::rows), [`rows_mut`](Self::rows_mut): Iterate over rows.
859#[derive(Debug)]
860pub struct MatMut<'a, T: ReprMut> {
861    pub(crate) ptr: NonNull<u8>,
862    pub(crate) repr: T,
863    /// Marker to tie the lifetime to the mutably borrowed data.
864    pub(crate) _lifetime: PhantomData<&'a mut [u8]>,
865}
866
867// SAFETY: [`ReprMut`] is required to propagate its `Send` bound.
868unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
869
870// SAFETY: [`ReprMut`] is required to propagate its `Sync` bound.
871unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
872
873impl<'a, T: ReprMut> MatMut<'a, T> {
874    /// Construct a new [`MatMut`] over `data`.
875    pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
876    where
877        T: NewMut<U>,
878    {
879        repr.new_mut(data)
880    }
881
882    /// Returns the number of rows (vectors) in the matrix.
883    #[inline]
884    pub fn num_vectors(&self) -> usize {
885        self.repr.nrows()
886    }
887
888    /// Returns a reference to the underlying representation.
889    pub fn repr(&self) -> &T {
890        &self.repr
891    }
892
893    /// Returns an immutable reference to the i-th row, or `None` if out of bounds.
894    #[inline]
895    #[must_use]
896    pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
897        if i < self.num_vectors() {
898            // SAFETY: Bounds check passed.
899            Some(unsafe { self.get_row_unchecked(i) })
900        } else {
901            None
902        }
903    }
904
905    /// Returns the i-th row without bounds checking.
906    ///
907    /// # Safety
908    ///
909    /// `i` must be less than `self.num_vectors()`.
910    #[inline]
911    pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
912        // SAFETY: Caller must ensure i < self.num_vectors().
913        unsafe { self.repr.get_row(self.ptr, i) }
914    }
915
916    /// Returns a mutable reference to the `i`-th row, or `None` if out of bounds.
917    #[inline]
918    #[must_use]
919    pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
920        if i < self.num_vectors() {
921            // SAFETY: Bounds check passed.
922            Some(unsafe { self.get_row_mut_unchecked(i) })
923        } else {
924            None
925        }
926    }
927
928    /// Returns a mutable reference to the i-th row without bounds checking.
929    ///
930    /// # Safety
931    ///
932    /// `i` must be less than [`num_vectors()`](Self::num_vectors).
933    #[inline]
934    pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
935        // SAFETY: Caller asserts that `i < self.num_vectors()`. The constructors for this
936        // type ensure that `ptr` is compatible with `T`.
937        unsafe { self.repr.get_row_mut(self.ptr, i) }
938    }
939
940    /// Reborrows as an immutable [`MatRef`].
941    pub fn as_view(&self) -> MatRef<'_, T> {
942        MatRef {
943            ptr: self.ptr,
944            repr: self.repr,
945            _lifetime: PhantomData,
946        }
947    }
948
949    /// Returns an iterator over immutable row references.
950    pub fn rows(&self) -> Rows<'_, T> {
951        Rows::new(self.reborrow())
952    }
953
954    /// Returns an iterator over mutable row references.
955    pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
956        RowsMut::new(self.reborrow_mut())
957    }
958
959    /// Return a [`Mat`] with the same contents as `self`.
960    pub fn to_owned(&self) -> Mat<T>
961    where
962        T: NewCloned,
963    {
964        T::new_cloned(self.as_view())
965    }
966
967    /// Construct a new [`MatMut`] over the raw pointer and representation without performing
968    /// any validity checks.
969    ///
970    /// # Safety
971    ///
972    /// Argument `ptr` must point to memory compatible with [`Repr::layout`].
973    pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
974        Self {
975            ptr,
976            repr,
977            _lifetime: PhantomData,
978        }
979    }
980
981    /// Return the base pointer for the [`MatMut`].
982    pub fn as_raw_ptr(&self) -> *const u8 {
983        self.ptr.as_ptr()
984    }
985}
986
987// Reborrow: MatMut -> MatRef
988impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
989    type Target = MatRef<'this, T>;
990
991    fn reborrow(&'this self) -> Self::Target {
992        self.as_view()
993    }
994}
995
996// ReborrowMut: MatMut -> MatMut (with shorter lifetime)
997impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
998    type Target = MatMut<'this, T>;
999
1000    fn reborrow_mut(&'this mut self) -> Self::Target {
1001        MatMut {
1002            ptr: self.ptr,
1003            repr: self.repr,
1004            _lifetime: PhantomData,
1005        }
1006    }
1007}
1008
1009impl<'a, T: Copy> MatMut<'a, Standard<T>> {
1010    /// Returns the raw dimension (columns) of the vectors in the matrix.
1011    #[inline]
1012    pub fn vector_dim(&self) -> usize {
1013        self.repr.ncols()
1014    }
1015}
1016
1017//////////
1018// Rows //
1019//////////
1020
1021/// Iterator over immutable row references of a matrix.
1022///
1023/// Created by [`Mat::rows`], [`MatRef::rows`], or [`MatMut::rows`].
1024#[derive(Debug)]
1025pub struct Rows<'a, T: Repr> {
1026    matrix: MatRef<'a, T>,
1027    current: usize,
1028}
1029
1030impl<'a, T> Rows<'a, T>
1031where
1032    T: Repr,
1033{
1034    fn new(matrix: MatRef<'a, T>) -> Self {
1035        Self { matrix, current: 0 }
1036    }
1037}
1038
1039impl<'a, T> Iterator for Rows<'a, T>
1040where
1041    T: Repr + 'a,
1042{
1043    type Item = T::Row<'a>;
1044
1045    fn next(&mut self) -> Option<Self::Item> {
1046        let current = self.current;
1047        if current >= self.matrix.num_vectors() {
1048            None
1049        } else {
1050            self.current += 1;
1051            // SAFETY: We make sure through the above check that
1052            // the access is within bounds.
1053            //
1054            // Extending the lifetime to `'a` is safe because the underlying
1055            // MatRef has lifetime `'a`.
1056            Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
1057        }
1058    }
1059
1060    fn size_hint(&self) -> (usize, Option<usize>) {
1061        let remaining = self.matrix.num_vectors() - self.current;
1062        (remaining, Some(remaining))
1063    }
1064}
1065
1066impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
1067impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
1068
1069/////////////
1070// RowsMut //
1071/////////////
1072
1073/// Iterator over mutable row references of a matrix.
1074///
1075/// Created by [`Mat::rows_mut`] or [`MatMut::rows_mut`].
1076#[derive(Debug)]
1077pub struct RowsMut<'a, T: ReprMut> {
1078    matrix: MatMut<'a, T>,
1079    current: usize,
1080}
1081
1082impl<'a, T> RowsMut<'a, T>
1083where
1084    T: ReprMut,
1085{
1086    fn new(matrix: MatMut<'a, T>) -> Self {
1087        Self { matrix, current: 0 }
1088    }
1089}
1090
1091impl<'a, T> Iterator for RowsMut<'a, T>
1092where
1093    T: ReprMut + 'a,
1094{
1095    type Item = T::RowMut<'a>;
1096
1097    fn next(&mut self) -> Option<Self::Item> {
1098        let current = self.current;
1099        if current >= self.matrix.num_vectors() {
1100            None
1101        } else {
1102            self.current += 1;
1103            // SAFETY: We make sure through the above check that
1104            // the access is within bounds.
1105            //
1106            // Extending the lifetime to `'a` is safe because:
1107            // 1. The underlying MatMut has lifetime `'a`.
1108            // 2. The iterator ensures that the mutable row indices are disjoint, so
1109            //    there is no aliasing as long as the implementation of `ReprMut` ensures
1110            //    there is not mutable sharing of the `RowMut` types.
1111            Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
1112        }
1113    }
1114
1115    fn size_hint(&self) -> (usize, Option<usize>) {
1116        let remaining = self.matrix.num_vectors() - self.current;
1117        (remaining, Some(remaining))
1118    }
1119}
1120
1121impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1122impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
1123
1124///////////
1125// Tests //
1126///////////
1127
1128#[cfg(test)]
1129mod tests {
1130    use super::*;
1131
1132    use std::fmt::Display;
1133
1134    use diskann_utils::lazy_format;
1135
1136    /// Helper to assert a type is Copy.
1137    fn assert_copy<T: Copy>(_: &T) {}
1138
1139    fn edge_cases(nrows: usize) -> Vec<usize> {
1140        let max = usize::MAX;
1141
1142        vec![
1143            nrows,
1144            nrows + 1,
1145            nrows + 11,
1146            nrows + 20,
1147            max / 2,
1148            max.div_ceil(2),
1149            max - 1,
1150            max,
1151        ]
1152    }
1153
1154    fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
1155        assert_eq!(x.repr(), &repr);
1156        assert_eq!(x.num_vectors(), repr.nrows());
1157        assert_eq!(x.vector_dim(), repr.ncols());
1158
1159        for i in 0..x.num_vectors() {
1160            let row = x.get_row_mut(i).unwrap();
1161            assert_eq!(row.len(), repr.ncols());
1162            row.iter_mut()
1163                .enumerate()
1164                .for_each(|(j, r)| *r = 10 * i + j);
1165        }
1166
1167        for i in edge_cases(repr.nrows()).into_iter() {
1168            assert!(x.get_row_mut(i).is_none());
1169        }
1170    }
1171
1172    fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
1173        assert_eq!(x.repr(), &repr);
1174        assert_eq!(x.num_vectors(), repr.nrows());
1175        assert_eq!(x.vector_dim(), repr.ncols());
1176
1177        for i in 0..x.num_vectors() {
1178            let row = x.get_row_mut(i).unwrap();
1179            assert_eq!(row.len(), repr.ncols());
1180
1181            row.iter_mut()
1182                .enumerate()
1183                .for_each(|(j, r)| *r = 10 * i + j);
1184        }
1185
1186        for i in edge_cases(repr.nrows()).into_iter() {
1187            assert!(x.get_row_mut(i).is_none());
1188        }
1189    }
1190
1191    fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
1192        assert_eq!(x.len(), repr.nrows());
1193        // Materialize all rows at once.
1194        let mut all_rows: Vec<_> = x.collect();
1195        assert_eq!(all_rows.len(), repr.nrows());
1196        for (i, row) in all_rows.iter_mut().enumerate() {
1197            assert_eq!(row.len(), repr.ncols());
1198            row.iter_mut()
1199                .enumerate()
1200                .for_each(|(j, r)| *r = 10 * i + j);
1201        }
1202    }
1203
1204    fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1205        assert_eq!(x.repr(), &repr);
1206        assert_eq!(x.num_vectors(), repr.nrows());
1207        assert_eq!(x.vector_dim(), repr.ncols());
1208
1209        for i in 0..x.num_vectors() {
1210            let row = x.get_row(i).unwrap();
1211
1212            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1213            row.iter().enumerate().for_each(|(j, r)| {
1214                assert_eq!(
1215                    *r,
1216                    10 * i + j,
1217                    "mismatched entry at row {}, col {} -- ctx: {}",
1218                    i,
1219                    j,
1220                    ctx
1221                )
1222            });
1223        }
1224
1225        for i in edge_cases(repr.nrows()).into_iter() {
1226            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1227        }
1228    }
1229
1230    fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1231        assert_eq!(x.repr(), &repr);
1232        assert_eq!(x.num_vectors(), repr.nrows());
1233        assert_eq!(x.vector_dim(), repr.ncols());
1234
1235        assert_copy(&x);
1236        for i in 0..x.num_vectors() {
1237            let row = x.get_row(i).unwrap();
1238            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1239
1240            row.iter().enumerate().for_each(|(j, r)| {
1241                assert_eq!(
1242                    *r,
1243                    10 * i + j,
1244                    "mismatched entry at row {}, col {} -- ctx: {}",
1245                    i,
1246                    j,
1247                    ctx
1248                )
1249            });
1250        }
1251
1252        for i in edge_cases(repr.nrows()).into_iter() {
1253            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1254        }
1255    }
1256
1257    fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1258        assert_eq!(x.repr(), &repr);
1259        assert_eq!(x.num_vectors(), repr.nrows());
1260        assert_eq!(x.vector_dim(), repr.ncols());
1261
1262        for i in 0..x.num_vectors() {
1263            let row = x.get_row(i).unwrap();
1264            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1265
1266            row.iter().enumerate().for_each(|(j, r)| {
1267                assert_eq!(
1268                    *r,
1269                    10 * i + j,
1270                    "mismatched entry at row {}, col {} -- ctx: {}",
1271                    i,
1272                    j,
1273                    ctx
1274                )
1275            });
1276        }
1277
1278        for i in edge_cases(repr.nrows()).into_iter() {
1279            assert!(x.get_row(i).is_none(), "ctx: {ctx}");
1280        }
1281    }
1282
1283    fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
1284        assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
1285        let all_rows: Vec<_> = x.collect();
1286        assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
1287        for (i, row) in all_rows.iter().enumerate() {
1288            assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
1289            row.iter().enumerate().for_each(|(j, r)| {
1290                assert_eq!(
1291                    *r,
1292                    10 * i + j,
1293                    "mismatched entry at row {}, col {} -- ctx: {}",
1294                    i,
1295                    j,
1296                    ctx
1297                )
1298            });
1299        }
1300    }
1301
1302    //////////////
1303    // Standard //
1304    //////////////
1305
1306    #[test]
1307    fn standard_representation() {
1308        let repr = Standard::<f32>::new(4, 3).unwrap();
1309        assert_eq!(repr.nrows(), 4);
1310        assert_eq!(repr.ncols(), 3);
1311
1312        let layout = repr.layout().unwrap();
1313        assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
1314        assert_eq!(layout.align(), std::mem::align_of::<f32>());
1315    }
1316
1317    #[test]
1318    fn standard_zero_dimensions() {
1319        for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
1320            let repr = Standard::<u8>::new(nrows, ncols).unwrap();
1321            assert_eq!(repr.nrows(), nrows);
1322            assert_eq!(repr.ncols(), ncols);
1323            let layout = repr.layout().unwrap();
1324            assert_eq!(layout.size(), 0);
1325        }
1326    }
1327
1328    #[test]
1329    fn standard_check_slice() {
1330        let repr = Standard::<u32>::new(3, 4).unwrap();
1331
1332        // Correct length succeeds
1333        let data = vec![0u32; 12];
1334        assert!(repr.check_slice(&data).is_ok());
1335
1336        // Too short fails
1337        let short = vec![0u32; 11];
1338        assert!(matches!(
1339            repr.check_slice(&short),
1340            Err(SliceError::LengthMismatch {
1341                expected: 12,
1342                found: 11
1343            })
1344        ));
1345
1346        // Too long fails
1347        let long = vec![0u32; 13];
1348        assert!(matches!(
1349            repr.check_slice(&long),
1350            Err(SliceError::LengthMismatch {
1351                expected: 12,
1352                found: 13
1353            })
1354        ));
1355
1356        // Overflow case
1357        let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
1358        assert!(matches!(overflow_repr, Overflow { .. }));
1359    }
1360
1361    #[test]
1362    fn standard_new_rejects_element_count_overflow() {
1363        // nrows * ncols overflows usize even though per-element size is small.
1364        assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
1365        assert!(Standard::<u8>::new(2, usize::MAX).is_err());
1366        assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
1367    }
1368
1369    #[test]
1370    fn standard_new_rejects_byte_count_exceeding_isize_max() {
1371        // Element count fits in usize, but total bytes exceed isize::MAX.
1372        let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
1373        assert!(Standard::<u64>::new(half, 1).is_err());
1374        assert!(Standard::<u64>::new(1, half).is_err());
1375    }
1376
1377    #[test]
1378    fn standard_new_accepts_boundary_below_isize_max() {
1379        // Largest allocation that still fits in isize::MAX bytes.
1380        let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
1381        let repr = Standard::<u64>::new(max_elems, 1).unwrap();
1382        assert_eq!(repr.num_elements(), max_elems);
1383    }
1384
1385    #[test]
1386    fn standard_new_zst_rejects_element_count_overflow() {
1387        // For ZSTs the byte count is always 0, but element-count overflow
1388        // must still be caught so that `num_elements()` never wraps.
1389        assert!(Standard::<()>::new(usize::MAX, 2).is_err());
1390        assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
1391    }
1392
1393    #[test]
1394    fn standard_new_zst_accepts_large_non_overflowing() {
1395        // Large-but-valid ZST matrix: element count fits in usize.
1396        let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
1397        assert_eq!(repr.num_elements(), usize::MAX);
1398        assert_eq!(repr.layout().unwrap().size(), 0);
1399    }
1400
1401    #[test]
1402    fn standard_new_overflow_error_display() {
1403        let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
1404        let msg = err.to_string();
1405        assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
1406
1407        let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
1408        let zst_msg = zst_err.to_string();
1409        assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
1410        assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
1411    }
1412
1413    /////////
1414    // Mat //
1415    /////////
1416
1417    #[test]
1418    fn mat_new_and_basic_accessors() {
1419        let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
1420        let base: *const u8 = mat.as_raw_ptr();
1421
1422        assert_eq!(mat.num_vectors(), 3);
1423        assert_eq!(mat.vector_dim(), 4);
1424
1425        let repr = mat.repr();
1426        assert_eq!(repr.nrows(), 3);
1427        assert_eq!(repr.ncols(), 4);
1428
1429        for (i, r) in mat.rows().enumerate() {
1430            assert_eq!(r, &[42, 42, 42, 42]);
1431            let ptr = r.as_ptr().cast::<u8>();
1432            assert_eq!(
1433                ptr,
1434                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1435            );
1436        }
1437    }
1438
1439    #[test]
1440    fn mat_new_with_default() {
1441        let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
1442        let base: *const u8 = mat.as_raw_ptr();
1443
1444        assert_eq!(mat.num_vectors(), 2);
1445        for (i, row) in mat.rows().enumerate() {
1446            assert!(row.iter().all(|&v| v == 0));
1447
1448            let ptr = row.as_ptr().cast::<u8>();
1449            assert_eq!(
1450                ptr,
1451                base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
1452            );
1453        }
1454    }
1455
1456    const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
1457    const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
1458
1459    #[test]
1460    fn test_mat() {
1461        for nrows in ROWS {
1462            for ncols in COLS {
1463                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1464                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1465
1466                // Populate the matrix using `&mut Mat`
1467                {
1468                    let ctx = &lazy_format!("{ctx} - direct");
1469                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1470
1471                    assert_eq!(mat.num_vectors(), *nrows);
1472                    assert_eq!(mat.vector_dim(), *ncols);
1473
1474                    fill_mat(&mut mat, repr);
1475
1476                    check_mat(&mat, repr, ctx);
1477                    check_mat_ref(mat.reborrow(), repr, ctx);
1478                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1479                    check_rows(mat.rows(), repr, ctx);
1480
1481                    // Check reborrow preserves pointers.
1482                    assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
1483                    assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
1484                }
1485
1486                // Populate the matrix using `MatMut`
1487                {
1488                    let ctx = &lazy_format!("{ctx} - matmut");
1489                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1490                    let matmut = mat.reborrow_mut();
1491
1492                    assert_eq!(matmut.num_vectors(), *nrows);
1493                    assert_eq!(matmut.vector_dim(), *ncols);
1494
1495                    fill_mat_mut(matmut, repr);
1496
1497                    check_mat(&mat, repr, ctx);
1498                    check_mat_ref(mat.reborrow(), repr, ctx);
1499                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1500                    check_rows(mat.rows(), repr, ctx);
1501                }
1502
1503                // Populate the matrix using `RowsMut`
1504                {
1505                    let ctx = &lazy_format!("{ctx} - rows_mut");
1506                    let mut mat = Mat::new(repr, Defaulted).unwrap();
1507                    fill_rows_mut(mat.rows_mut(), repr);
1508
1509                    check_mat(&mat, repr, ctx);
1510                    check_mat_ref(mat.reborrow(), repr, ctx);
1511                    check_mat_mut(mat.reborrow_mut(), repr, ctx);
1512                    check_rows(mat.rows(), repr, ctx);
1513                }
1514            }
1515        }
1516    }
1517
1518    #[test]
1519    fn test_mat_clone() {
1520        for nrows in ROWS {
1521            for ncols in COLS {
1522                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1523                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1524
1525                let mut mat = Mat::new(repr, Defaulted).unwrap();
1526                fill_mat(&mut mat, repr);
1527
1528                // Clone via Mat::clone
1529                {
1530                    let ctx = &lazy_format!("{ctx} - Mat::clone");
1531                    let cloned = mat.clone();
1532
1533                    assert_eq!(cloned.num_vectors(), *nrows);
1534                    assert_eq!(cloned.vector_dim(), *ncols);
1535
1536                    check_mat(&cloned, repr, ctx);
1537                    check_mat_ref(cloned.reborrow(), repr, ctx);
1538                    check_rows(cloned.rows(), repr, ctx);
1539
1540                    // Cloned allocation is independent.
1541                    if repr.num_elements() > 0 {
1542                        assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
1543                    }
1544                }
1545
1546                // Clone via MatRef::to_owned
1547                {
1548                    let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
1549                    let owned = mat.as_view().to_owned();
1550
1551                    check_mat(&owned, repr, ctx);
1552                    check_mat_ref(owned.reborrow(), repr, ctx);
1553                    check_rows(owned.rows(), repr, ctx);
1554
1555                    if repr.num_elements() > 0 {
1556                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1557                    }
1558                }
1559
1560                // Clone via MatMut::to_owned
1561                {
1562                    let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
1563                    let owned = mat.as_view_mut().to_owned();
1564
1565                    check_mat(&owned, repr, ctx);
1566                    check_mat_ref(owned.reborrow(), repr, ctx);
1567                    check_rows(owned.rows(), repr, ctx);
1568
1569                    if repr.num_elements() > 0 {
1570                        assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
1571                    }
1572                }
1573            }
1574        }
1575    }
1576
1577    #[test]
1578    fn test_mat_refmut() {
1579        for nrows in ROWS {
1580            for ncols in COLS {
1581                let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
1582                let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
1583
1584                // Populate the matrix using `&mut Mat`
1585                {
1586                    let ctx = &lazy_format!("{ctx} - by matmut");
1587                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1588                    let ptr = b.as_ptr().cast::<u8>();
1589                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1590
1591                    assert_eq!(
1592                        ptr,
1593                        matmut.as_raw_ptr(),
1594                        "underlying memory should be preserved",
1595                    );
1596
1597                    fill_mat_mut(matmut.reborrow_mut(), repr);
1598
1599                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1600                    check_mat_ref(matmut.reborrow(), repr, ctx);
1601                    check_rows(matmut.rows(), repr, ctx);
1602                    check_rows(matmut.reborrow().rows(), repr, ctx);
1603
1604                    let matref = MatRef::new(repr, &b).unwrap();
1605                    check_mat_ref(matref, repr, ctx);
1606                    check_mat_ref(matref.reborrow(), repr, ctx);
1607                    check_rows(matref.rows(), repr, ctx);
1608                }
1609
1610                // Populate the matrix using `RowsMut`
1611                {
1612                    let ctx = &lazy_format!("{ctx} - by rows");
1613                    let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
1614                    let ptr = b.as_ptr().cast::<u8>();
1615                    let mut matmut = MatMut::new(repr, &mut b).unwrap();
1616
1617                    assert_eq!(
1618                        ptr,
1619                        matmut.as_raw_ptr(),
1620                        "underlying memory should be preserved",
1621                    );
1622
1623                    fill_rows_mut(matmut.rows_mut(), repr);
1624
1625                    check_mat_mut(matmut.reborrow_mut(), repr, ctx);
1626                    check_mat_ref(matmut.reborrow(), repr, ctx);
1627                    check_rows(matmut.rows(), repr, ctx);
1628                    check_rows(matmut.reborrow().rows(), repr, ctx);
1629
1630                    let matref = MatRef::new(repr, &b).unwrap();
1631                    check_mat_ref(matref, repr, ctx);
1632                    check_mat_ref(matref.reborrow(), repr, ctx);
1633                    check_rows(matref.rows(), repr, ctx);
1634                }
1635            }
1636        }
1637    }
1638
1639    //////////////////
1640    // Constructors //
1641    //////////////////
1642
1643    #[test]
1644    fn test_standard_new_owned() {
1645        let rows = [0, 1, 2, 3, 5, 10];
1646        let cols = [0, 1, 2, 3, 5, 10];
1647
1648        for nrows in rows {
1649            for ncols in cols {
1650                let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
1651                let rows_iter = m.rows();
1652                let len = <_ as ExactSizeIterator>::len(&rows_iter);
1653                assert_eq!(len, nrows);
1654                for r in rows_iter {
1655                    assert_eq!(r.len(), ncols);
1656                    assert!(r.iter().all(|i| *i == 1usize));
1657                }
1658            }
1659        }
1660    }
1661
1662    #[test]
1663    fn matref_new_slice_length_error() {
1664        let repr = Standard::<u32>::new(3, 4).unwrap();
1665
1666        // Correct length succeeds
1667        let data = vec![0u32; 12];
1668        assert!(MatRef::new(repr, &data).is_ok());
1669
1670        // Too short fails
1671        let short = vec![0u32; 11];
1672        assert!(matches!(
1673            MatRef::new(repr, &short),
1674            Err(SliceError::LengthMismatch {
1675                expected: 12,
1676                found: 11
1677            })
1678        ));
1679
1680        // Too long fails
1681        let long = vec![0u32; 13];
1682        assert!(matches!(
1683            MatRef::new(repr, &long),
1684            Err(SliceError::LengthMismatch {
1685                expected: 12,
1686                found: 13
1687            })
1688        ));
1689    }
1690
1691    #[test]
1692    fn matmut_new_slice_length_error() {
1693        let repr = Standard::<u32>::new(3, 4).unwrap();
1694
1695        // Correct length succeeds
1696        let mut data = vec![0u32; 12];
1697        assert!(MatMut::new(repr, &mut data).is_ok());
1698
1699        // Too short fails
1700        let mut short = vec![0u32; 11];
1701        assert!(matches!(
1702            MatMut::new(repr, &mut short),
1703            Err(SliceError::LengthMismatch {
1704                expected: 12,
1705                found: 11
1706            })
1707        ));
1708
1709        // Too long fails
1710        let mut long = vec![0u32; 13];
1711        assert!(matches!(
1712            MatMut::new(repr, &mut long),
1713            Err(SliceError::LengthMismatch {
1714                expected: 12,
1715                found: 13
1716            })
1717        ));
1718    }
1719}