use std::{alloc::Layout, iter::FusedIterator, marker::PhantomData, ptr::NonNull};
use diskann_utils::{Reborrow, ReborrowMut, views::MatrixView};
use thiserror::Error;
use crate::utils;
pub unsafe trait Repr: Copy {
type Row<'a>
where
Self: 'a;
fn nrows(&self) -> usize;
fn layout(&self) -> Result<Layout, LayoutError>;
unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a>;
}
pub unsafe trait ReprMut: Repr {
type RowMut<'a>
where
Self: 'a;
unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a>;
}
pub unsafe trait ReprOwned: ReprMut {
unsafe fn drop(self, ptr: NonNull<u8>);
}
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct LayoutError;
impl LayoutError {
pub fn new() -> Self {
Self
}
}
impl Default for LayoutError {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for LayoutError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "LayoutError")
}
}
impl std::error::Error for LayoutError {}
impl From<std::alloc::LayoutError> for LayoutError {
fn from(_: std::alloc::LayoutError) -> Self {
LayoutError
}
}
pub unsafe trait NewRef<T>: Repr {
type Error;
fn new_ref(self, slice: &[T]) -> Result<MatRef<'_, Self>, Self::Error>;
}
pub unsafe trait NewMut<T>: ReprMut {
type Error;
fn new_mut(self, slice: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error>;
}
pub unsafe trait NewOwned<T>: ReprOwned {
type Error;
fn new_owned(self, init: T) -> Result<Mat<Self>, Self::Error>;
}
#[derive(Debug, Clone, Copy)]
pub struct Defaulted;
pub trait NewCloned: ReprOwned {
fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Standard<T> {
nrows: usize,
ncols: usize,
_elem: PhantomData<T>,
}
impl<T: Copy> Standard<T> {
pub fn new(nrows: usize, ncols: usize) -> Result<Self, Overflow> {
Overflow::check::<T>(nrows, ncols)?;
Ok(Self {
nrows,
ncols,
_elem: PhantomData,
})
}
pub fn num_elements(&self) -> usize {
self.nrows() * self.ncols()
}
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn check_slice(&self, slice: &[T]) -> Result<(), SliceError> {
let len = self.num_elements();
if slice.len() != len {
Err(SliceError::LengthMismatch {
expected: len,
found: slice.len(),
})
} else {
Ok(())
}
}
unsafe fn box_to_mat(self, b: Box<[T]>) -> Mat<Self> {
debug_assert_eq!(b.len(), self.num_elements(), "safety contract violated");
let ptr = utils::box_into_nonnull(b).cast::<u8>();
unsafe { Mat::from_raw_parts(self, ptr) }
}
}
#[derive(Debug, Clone, Copy)]
pub struct Overflow {
nrows: usize,
ncols: usize,
elsize: usize,
}
impl Overflow {
pub(crate) fn for_type<T>(nrows: usize, ncols: usize) -> Self {
Self {
nrows,
ncols,
elsize: std::mem::size_of::<T>(),
}
}
pub(crate) fn check_byte_budget<T>(
capacity: usize,
nrows: usize,
ncols: usize,
) -> Result<(), Self> {
let bytes = std::mem::size_of::<T>().saturating_mul(capacity);
if bytes <= isize::MAX as usize {
Ok(())
} else {
Err(Self::for_type::<T>(nrows, ncols))
}
}
pub(crate) fn check<T>(nrows: usize, ncols: usize) -> Result<(), Self> {
let capacity = nrows
.checked_mul(ncols)
.ok_or_else(|| Self::for_type::<T>(nrows, ncols))?;
Self::check_byte_budget::<T>(capacity, nrows, ncols)
}
}
impl std::fmt::Display for Overflow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.elsize == 0 {
write!(
f,
"ZST matrix with dimensions {} x {} has more than `usize::MAX` elements",
self.nrows, self.ncols,
)
} else {
write!(
f,
"a matrix of size {} x {} with element size {} would exceed isize::MAX bytes",
self.nrows, self.ncols, self.elsize,
)
}
}
}
impl std::error::Error for Overflow {}
#[derive(Debug, Clone, Copy, Error)]
#[non_exhaustive]
pub enum SliceError {
#[error("Length mismatch: expected {expected}, found {found}")]
LengthMismatch { expected: usize, found: usize },
}
unsafe impl<T: Copy> Repr for Standard<T> {
type Row<'a>
= &'a [T]
where
T: 'a;
fn nrows(&self) -> usize {
self.nrows
}
fn layout(&self) -> Result<Layout, LayoutError> {
Ok(Layout::array::<T>(self.num_elements())?)
}
unsafe fn get_row<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::Row<'a> {
debug_assert!(ptr.cast::<T>().is_aligned());
debug_assert!(i < self.nrows);
let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
unsafe { std::slice::from_raw_parts(row_ptr, self.ncols) }
}
}
unsafe impl<T: Copy> ReprMut for Standard<T> {
type RowMut<'a>
= &'a mut [T]
where
T: 'a;
unsafe fn get_row_mut<'a>(self, ptr: NonNull<u8>, i: usize) -> Self::RowMut<'a> {
debug_assert!(ptr.cast::<T>().is_aligned());
debug_assert!(i < self.nrows);
let row_ptr = unsafe { ptr.as_ptr().cast::<T>().add(i * self.ncols) };
unsafe { std::slice::from_raw_parts_mut(row_ptr, self.ncols) }
}
}
unsafe impl<T: Copy> ReprOwned for Standard<T> {
unsafe fn drop(self, ptr: NonNull<u8>) {
unsafe {
let slice_ptr = std::ptr::slice_from_raw_parts_mut(
ptr.cast::<T>().as_ptr(),
self.nrows * self.ncols,
);
let _ = Box::from_raw(slice_ptr);
}
}
}
unsafe impl<T> NewOwned<T> for Standard<T>
where
T: Copy,
{
type Error = crate::error::Infallible;
fn new_owned(self, value: T) -> Result<Mat<Self>, Self::Error> {
let b: Box<[T]> = (0..self.num_elements()).map(|_| value).collect();
Ok(unsafe { self.box_to_mat(b) })
}
}
unsafe impl<T> NewOwned<Defaulted> for Standard<T>
where
T: Copy + Default,
{
type Error = crate::error::Infallible;
fn new_owned(self, _: Defaulted) -> Result<Mat<Self>, Self::Error> {
self.new_owned(T::default())
}
}
unsafe impl<T> NewRef<T> for Standard<T>
where
T: Copy,
{
type Error = SliceError;
fn new_ref(self, data: &[T]) -> Result<MatRef<'_, Self>, Self::Error> {
self.check_slice(data)?;
Ok(unsafe { MatRef::from_raw_parts(self, utils::as_nonnull(data).cast::<u8>()) })
}
}
unsafe impl<T> NewMut<T> for Standard<T>
where
T: Copy,
{
type Error = SliceError;
fn new_mut(self, data: &mut [T]) -> Result<MatMut<'_, Self>, Self::Error> {
self.check_slice(data)?;
Ok(unsafe { MatMut::from_raw_parts(self, utils::as_nonnull_mut(data).cast::<u8>()) })
}
}
impl<T> NewCloned for Standard<T>
where
T: Copy,
{
fn new_cloned(v: MatRef<'_, Self>) -> Mat<Self> {
let b: Box<[T]> = v.rows().flatten().copied().collect();
unsafe { v.repr().box_to_mat(b) }
}
}
#[derive(Debug)]
pub struct Mat<T: ReprOwned> {
ptr: NonNull<u8>,
repr: T,
_invariant: PhantomData<fn(T) -> T>,
}
unsafe impl<T> Send for Mat<T> where T: ReprOwned + Send {}
unsafe impl<T> Sync for Mat<T> where T: ReprOwned + Sync {}
impl<T: ReprOwned> Mat<T> {
pub fn new<U>(repr: T, init: U) -> Result<Self, <T as NewOwned<U>>::Error>
where
T: NewOwned<U>,
{
repr.new_owned(init)
}
#[inline]
pub fn num_vectors(&self) -> usize {
self.repr.nrows()
}
pub fn repr(&self) -> &T {
&self.repr
}
#[must_use]
pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
if i < self.num_vectors() {
let row = unsafe { self.get_row_unchecked(i) };
Some(row)
} else {
None
}
}
pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
unsafe { self.repr.get_row(self.ptr, i) }
}
#[must_use]
pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
if i < self.num_vectors() {
Some(unsafe { self.get_row_mut_unchecked(i) })
} else {
None
}
}
pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
unsafe { self.repr.get_row_mut(self.ptr, i) }
}
#[inline]
pub fn as_view(&self) -> MatRef<'_, T> {
MatRef {
ptr: self.ptr,
repr: self.repr,
_lifetime: PhantomData,
}
}
#[inline]
pub fn as_view_mut(&mut self) -> MatMut<'_, T> {
MatMut {
ptr: self.ptr,
repr: self.repr,
_lifetime: PhantomData,
}
}
pub fn rows(&self) -> Rows<'_, T> {
Rows::new(self.reborrow())
}
pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
RowsMut::new(self.reborrow_mut())
}
pub(crate) unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
Self {
ptr,
repr,
_invariant: PhantomData,
}
}
pub fn as_raw_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl<T: ReprOwned> Drop for Mat<T> {
fn drop(&mut self) {
unsafe { self.repr.drop(self.ptr) };
}
}
impl<T: NewCloned> Clone for Mat<T> {
fn clone(&self) -> Self {
T::new_cloned(self.as_view())
}
}
impl<T: Copy> Mat<Standard<T>> {
#[inline]
pub fn vector_dim(&self) -> usize {
self.repr.ncols()
}
#[inline]
pub fn as_slice(&self) -> &[T] {
self.as_view().as_slice()
}
#[inline]
pub fn as_matrix_view(&self) -> MatrixView<'_, T> {
self.as_view().as_matrix_view()
}
}
#[derive(Debug, Clone, Copy)]
pub struct MatRef<'a, T: Repr> {
ptr: NonNull<u8>,
repr: T,
_lifetime: PhantomData<&'a T>,
}
unsafe impl<T> Send for MatRef<'_, T> where T: Repr + Send {}
unsafe impl<T> Sync for MatRef<'_, T> where T: Repr + Sync {}
impl<'a, T: Repr> MatRef<'a, T> {
pub fn new<U>(repr: T, data: &'a [U]) -> Result<Self, T::Error>
where
T: NewRef<U>,
{
repr.new_ref(data)
}
#[inline]
pub fn num_vectors(&self) -> usize {
self.repr.nrows()
}
pub fn repr(&self) -> &T {
&self.repr
}
#[must_use]
pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
if i < self.num_vectors() {
let row = unsafe { self.get_row_unchecked(i) };
Some(row)
} else {
None
}
}
#[inline]
pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
unsafe { self.repr.get_row(self.ptr, i) }
}
pub fn rows(&self) -> Rows<'_, T> {
Rows::new(*self)
}
pub fn to_owned(&self) -> Mat<T>
where
T: NewCloned,
{
T::new_cloned(*self)
}
pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
Self {
ptr,
repr,
_lifetime: PhantomData,
}
}
pub fn as_raw_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
}
impl<'a, T: Copy> MatRef<'a, Standard<T>> {
#[inline]
pub fn vector_dim(&self) -> usize {
self.repr.ncols()
}
#[inline]
pub fn as_slice(&self) -> &'a [T] {
let len = self.repr.num_elements();
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr().cast::<T>(), len) }
}
#[allow(clippy::expect_used)]
#[inline]
pub fn as_matrix_view(&self) -> MatrixView<'a, T> {
MatrixView::try_from(self.as_slice(), self.num_vectors(), self.vector_dim())
.expect("Standard<T> has valid dimensions")
}
}
impl<'this, T: ReprOwned> Reborrow<'this> for Mat<T> {
type Target = MatRef<'this, T>;
fn reborrow(&'this self) -> Self::Target {
self.as_view()
}
}
impl<'this, T: ReprOwned> ReborrowMut<'this> for Mat<T> {
type Target = MatMut<'this, T>;
fn reborrow_mut(&'this mut self) -> Self::Target {
self.as_view_mut()
}
}
impl<'this, 'a, T: Repr> Reborrow<'this> for MatRef<'a, T> {
type Target = MatRef<'this, T>;
fn reborrow(&'this self) -> Self::Target {
MatRef {
ptr: self.ptr,
repr: self.repr,
_lifetime: PhantomData,
}
}
}
#[derive(Debug)]
pub struct MatMut<'a, T: ReprMut> {
ptr: NonNull<u8>,
repr: T,
_lifetime: PhantomData<&'a mut T>,
}
unsafe impl<T> Send for MatMut<'_, T> where T: ReprMut + Send {}
unsafe impl<T> Sync for MatMut<'_, T> where T: ReprMut + Sync {}
impl<'a, T: ReprMut> MatMut<'a, T> {
pub fn new<U>(repr: T, data: &'a mut [U]) -> Result<Self, T::Error>
where
T: NewMut<U>,
{
repr.new_mut(data)
}
#[inline]
pub fn num_vectors(&self) -> usize {
self.repr.nrows()
}
pub fn repr(&self) -> &T {
&self.repr
}
#[inline]
#[must_use]
pub fn get_row(&self, i: usize) -> Option<T::Row<'_>> {
if i < self.num_vectors() {
Some(unsafe { self.get_row_unchecked(i) })
} else {
None
}
}
#[inline]
pub(crate) unsafe fn get_row_unchecked(&self, i: usize) -> T::Row<'_> {
unsafe { self.repr.get_row(self.ptr, i) }
}
#[inline]
#[must_use]
pub fn get_row_mut(&mut self, i: usize) -> Option<T::RowMut<'_>> {
if i < self.num_vectors() {
Some(unsafe { self.get_row_mut_unchecked(i) })
} else {
None
}
}
#[inline]
pub(crate) unsafe fn get_row_mut_unchecked(&mut self, i: usize) -> T::RowMut<'_> {
unsafe { self.repr.get_row_mut(self.ptr, i) }
}
pub fn as_view(&self) -> MatRef<'_, T> {
MatRef {
ptr: self.ptr,
repr: self.repr,
_lifetime: PhantomData,
}
}
pub fn rows(&self) -> Rows<'_, T> {
Rows::new(self.reborrow())
}
pub fn rows_mut(&mut self) -> RowsMut<'_, T> {
RowsMut::new(self.reborrow_mut())
}
pub fn to_owned(&self) -> Mat<T>
where
T: NewCloned,
{
T::new_cloned(self.as_view())
}
pub unsafe fn from_raw_parts(repr: T, ptr: NonNull<u8>) -> Self {
Self {
ptr,
repr,
_lifetime: PhantomData,
}
}
pub fn as_raw_ptr(&self) -> *const u8 {
self.ptr.as_ptr()
}
pub(crate) fn as_raw_mut_ptr(&mut self) -> *mut u8 {
self.ptr.as_ptr()
}
}
impl<'this, 'a, T: ReprMut> Reborrow<'this> for MatMut<'a, T> {
type Target = MatRef<'this, T>;
fn reborrow(&'this self) -> Self::Target {
self.as_view()
}
}
impl<'this, 'a, T: ReprMut> ReborrowMut<'this> for MatMut<'a, T> {
type Target = MatMut<'this, T>;
fn reborrow_mut(&'this mut self) -> Self::Target {
MatMut {
ptr: self.ptr,
repr: self.repr,
_lifetime: PhantomData,
}
}
}
impl<'a, T: Copy> MatMut<'a, Standard<T>> {
#[inline]
pub fn vector_dim(&self) -> usize {
self.repr.ncols()
}
#[inline]
pub fn as_slice(&self) -> &[T] {
self.as_view().as_slice()
}
#[inline]
pub fn as_matrix_view(&self) -> MatrixView<'_, T> {
self.as_view().as_matrix_view()
}
}
#[derive(Debug)]
pub struct Rows<'a, T: Repr> {
matrix: MatRef<'a, T>,
current: usize,
}
impl<'a, T> Rows<'a, T>
where
T: Repr,
{
fn new(matrix: MatRef<'a, T>) -> Self {
Self { matrix, current: 0 }
}
}
impl<'a, T> Iterator for Rows<'a, T>
where
T: Repr + 'a,
{
type Item = T::Row<'a>;
fn next(&mut self) -> Option<Self::Item> {
let current = self.current;
if current >= self.matrix.num_vectors() {
None
} else {
self.current += 1;
Some(unsafe { self.matrix.repr.get_row(self.matrix.ptr, current) })
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.matrix.num_vectors() - self.current;
(remaining, Some(remaining))
}
}
impl<'a, T> ExactSizeIterator for Rows<'a, T> where T: Repr + 'a {}
impl<'a, T> FusedIterator for Rows<'a, T> where T: Repr + 'a {}
#[derive(Debug)]
pub struct RowsMut<'a, T: ReprMut> {
matrix: MatMut<'a, T>,
current: usize,
}
impl<'a, T> RowsMut<'a, T>
where
T: ReprMut,
{
fn new(matrix: MatMut<'a, T>) -> Self {
Self { matrix, current: 0 }
}
}
impl<'a, T> Iterator for RowsMut<'a, T>
where
T: ReprMut + 'a,
{
type Item = T::RowMut<'a>;
fn next(&mut self) -> Option<Self::Item> {
let current = self.current;
if current >= self.matrix.num_vectors() {
None
} else {
self.current += 1;
Some(unsafe { self.matrix.repr.get_row_mut(self.matrix.ptr, current) })
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.matrix.num_vectors() - self.current;
(remaining, Some(remaining))
}
}
impl<'a, T> ExactSizeIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
impl<'a, T> FusedIterator for RowsMut<'a, T> where T: ReprMut + 'a {}
#[cfg(test)]
mod tests {
use super::*;
use std::fmt::Display;
use diskann_utils::lazy_format;
fn assert_copy<T: Copy>(_: &T) {}
fn _assert_matref_covariant_lifetime<'long: 'short, 'short, T: Repr>(
v: MatRef<'long, T>,
) -> MatRef<'short, T> {
v
}
fn _assert_matref_covariant_repr<'long: 'short, 'short, 'a>(
v: MatRef<'a, Standard<&'long u8>>,
) -> MatRef<'a, Standard<&'short u8>> {
v
}
fn _assert_matmut_covariant_lifetime<'long: 'short, 'short, T: ReprMut>(
v: MatMut<'long, T>,
) -> MatMut<'short, T> {
v
}
fn edge_cases(nrows: usize) -> Vec<usize> {
let max = usize::MAX;
vec![
nrows,
nrows + 1,
nrows + 11,
nrows + 20,
max / 2,
max.div_ceil(2),
max - 1,
max,
]
}
fn fill_mat(x: &mut Mat<Standard<usize>>, repr: Standard<usize>) {
assert_eq!(x.repr(), &repr);
assert_eq!(x.num_vectors(), repr.nrows());
assert_eq!(x.vector_dim(), repr.ncols());
for i in 0..x.num_vectors() {
let row = x.get_row_mut(i).unwrap();
assert_eq!(row.len(), repr.ncols());
row.iter_mut()
.enumerate()
.for_each(|(j, r)| *r = 10 * i + j);
}
for i in edge_cases(repr.nrows()).into_iter() {
assert!(x.get_row_mut(i).is_none());
}
}
fn fill_mat_mut(mut x: MatMut<'_, Standard<usize>>, repr: Standard<usize>) {
assert_eq!(x.repr(), &repr);
assert_eq!(x.num_vectors(), repr.nrows());
assert_eq!(x.vector_dim(), repr.ncols());
for i in 0..x.num_vectors() {
let row = x.get_row_mut(i).unwrap();
assert_eq!(row.len(), repr.ncols());
row.iter_mut()
.enumerate()
.for_each(|(j, r)| *r = 10 * i + j);
}
for i in edge_cases(repr.nrows()).into_iter() {
assert!(x.get_row_mut(i).is_none());
}
}
fn fill_rows_mut(x: RowsMut<'_, Standard<usize>>, repr: Standard<usize>) {
assert_eq!(x.len(), repr.nrows());
let mut all_rows: Vec<_> = x.collect();
assert_eq!(all_rows.len(), repr.nrows());
for (i, row) in all_rows.iter_mut().enumerate() {
assert_eq!(row.len(), repr.ncols());
row.iter_mut()
.enumerate()
.for_each(|(j, r)| *r = 10 * i + j);
}
}
fn check_mat(x: &Mat<Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
assert_eq!(x.repr(), &repr);
assert_eq!(x.num_vectors(), repr.nrows());
assert_eq!(x.vector_dim(), repr.ncols());
for i in 0..x.num_vectors() {
let row = x.get_row(i).unwrap();
assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
row.iter().enumerate().for_each(|(j, r)| {
assert_eq!(
*r,
10 * i + j,
"mismatched entry at row {}, col {} -- ctx: {}",
i,
j,
ctx
)
});
}
for i in edge_cases(repr.nrows()).into_iter() {
assert!(x.get_row(i).is_none(), "ctx: {ctx}");
}
}
fn check_mat_ref(x: MatRef<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
assert_eq!(x.repr(), &repr);
assert_eq!(x.num_vectors(), repr.nrows());
assert_eq!(x.vector_dim(), repr.ncols());
assert_copy(&x);
for i in 0..x.num_vectors() {
let row = x.get_row(i).unwrap();
assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
row.iter().enumerate().for_each(|(j, r)| {
assert_eq!(
*r,
10 * i + j,
"mismatched entry at row {}, col {} -- ctx: {}",
i,
j,
ctx
)
});
}
for i in edge_cases(repr.nrows()).into_iter() {
assert!(x.get_row(i).is_none(), "ctx: {ctx}");
}
}
fn check_mat_mut(x: MatMut<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
assert_eq!(x.repr(), &repr);
assert_eq!(x.num_vectors(), repr.nrows());
assert_eq!(x.vector_dim(), repr.ncols());
for i in 0..x.num_vectors() {
let row = x.get_row(i).unwrap();
assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
row.iter().enumerate().for_each(|(j, r)| {
assert_eq!(
*r,
10 * i + j,
"mismatched entry at row {}, col {} -- ctx: {}",
i,
j,
ctx
)
});
}
for i in edge_cases(repr.nrows()).into_iter() {
assert!(x.get_row(i).is_none(), "ctx: {ctx}");
}
}
fn check_rows(x: Rows<'_, Standard<usize>>, repr: Standard<usize>, ctx: &dyn Display) {
assert_eq!(x.len(), repr.nrows(), "ctx: {ctx}");
let all_rows: Vec<_> = x.collect();
assert_eq!(all_rows.len(), repr.nrows(), "ctx: {ctx}");
for (i, row) in all_rows.iter().enumerate() {
assert_eq!(row.len(), repr.ncols(), "ctx: {ctx}");
row.iter().enumerate().for_each(|(j, r)| {
assert_eq!(
*r,
10 * i + j,
"mismatched entry at row {}, col {} -- ctx: {}",
i,
j,
ctx
)
});
}
}
#[test]
fn standard_representation() {
let repr = Standard::<f32>::new(4, 3).unwrap();
assert_eq!(repr.nrows(), 4);
assert_eq!(repr.ncols(), 3);
let layout = repr.layout().unwrap();
assert_eq!(layout.size(), 4 * 3 * std::mem::size_of::<f32>());
assert_eq!(layout.align(), std::mem::align_of::<f32>());
}
#[test]
fn standard_zero_dimensions() {
for (nrows, ncols) in [(0, 0), (0, 5), (5, 0)] {
let repr = Standard::<u8>::new(nrows, ncols).unwrap();
assert_eq!(repr.nrows(), nrows);
assert_eq!(repr.ncols(), ncols);
let layout = repr.layout().unwrap();
assert_eq!(layout.size(), 0);
}
}
#[test]
fn standard_check_slice() {
let repr = Standard::<u32>::new(3, 4).unwrap();
let data = vec![0u32; 12];
assert!(repr.check_slice(&data).is_ok());
let short = vec![0u32; 11];
assert!(matches!(
repr.check_slice(&short),
Err(SliceError::LengthMismatch {
expected: 12,
found: 11
})
));
let long = vec![0u32; 13];
assert!(matches!(
repr.check_slice(&long),
Err(SliceError::LengthMismatch {
expected: 12,
found: 13
})
));
let overflow_repr = Standard::<u8>::new(usize::MAX, 2).unwrap_err();
assert!(matches!(overflow_repr, Overflow { .. }));
}
#[test]
fn standard_new_rejects_element_count_overflow() {
assert!(Standard::<u8>::new(usize::MAX, 2).is_err());
assert!(Standard::<u8>::new(2, usize::MAX).is_err());
assert!(Standard::<u8>::new(usize::MAX, usize::MAX).is_err());
}
#[test]
fn standard_new_rejects_byte_count_exceeding_isize_max() {
let half = (isize::MAX as usize / std::mem::size_of::<u64>()) + 1;
assert!(Standard::<u64>::new(half, 1).is_err());
assert!(Standard::<u64>::new(1, half).is_err());
}
#[test]
fn standard_new_accepts_boundary_below_isize_max() {
let max_elems = isize::MAX as usize / std::mem::size_of::<u64>();
let repr = Standard::<u64>::new(max_elems, 1).unwrap();
assert_eq!(repr.num_elements(), max_elems);
}
#[test]
fn standard_new_zst_rejects_element_count_overflow() {
assert!(Standard::<()>::new(usize::MAX, 2).is_err());
assert!(Standard::<()>::new(usize::MAX / 2 + 1, 3).is_err());
}
#[test]
fn standard_new_zst_accepts_large_non_overflowing() {
let repr = Standard::<()>::new(usize::MAX, 1).unwrap();
assert_eq!(repr.num_elements(), usize::MAX);
assert_eq!(repr.layout().unwrap().size(), 0);
}
#[test]
fn standard_new_overflow_error_display() {
let err = Standard::<u32>::new(usize::MAX, 2).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("would exceed isize::MAX bytes"), "{msg}");
let zst_err = Standard::<()>::new(usize::MAX, 2).unwrap_err();
let zst_msg = zst_err.to_string();
assert!(zst_msg.contains("ZST matrix"), "{zst_msg}");
assert!(zst_msg.contains("usize::MAX"), "{zst_msg}");
}
#[test]
fn mat_new_and_basic_accessors() {
let mat = Mat::new(Standard::<usize>::new(3, 4).unwrap(), 42usize).unwrap();
let base: *const u8 = mat.as_raw_ptr();
assert_eq!(mat.num_vectors(), 3);
assert_eq!(mat.vector_dim(), 4);
let repr = mat.repr();
assert_eq!(repr.nrows(), 3);
assert_eq!(repr.ncols(), 4);
for (i, r) in mat.rows().enumerate() {
assert_eq!(r, &[42, 42, 42, 42]);
let ptr = r.as_ptr().cast::<u8>();
assert_eq!(
ptr,
base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
);
}
}
#[test]
fn mat_new_with_default() {
let mat = Mat::new(Standard::<usize>::new(2, 3).unwrap(), Defaulted).unwrap();
let base: *const u8 = mat.as_raw_ptr();
assert_eq!(mat.num_vectors(), 2);
for (i, row) in mat.rows().enumerate() {
assert!(row.iter().all(|&v| v == 0));
let ptr = row.as_ptr().cast::<u8>();
assert_eq!(
ptr,
base.wrapping_add(std::mem::size_of::<usize>() * mat.repr().ncols() * i),
);
}
}
const ROWS: &[usize] = &[0, 1, 2, 3, 5, 10];
const COLS: &[usize] = &[0, 1, 2, 3, 5, 10];
#[test]
fn test_mat() {
for nrows in ROWS {
for ncols in COLS {
let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
{
let ctx = &lazy_format!("{ctx} - direct");
let mut mat = Mat::new(repr, Defaulted).unwrap();
assert_eq!(mat.num_vectors(), *nrows);
assert_eq!(mat.vector_dim(), *ncols);
fill_mat(&mut mat, repr);
check_mat(&mat, repr, ctx);
check_mat_ref(mat.reborrow(), repr, ctx);
check_mat_mut(mat.reborrow_mut(), repr, ctx);
check_rows(mat.rows(), repr, ctx);
assert_eq!(mat.as_raw_ptr(), mat.reborrow().as_raw_ptr());
assert_eq!(mat.as_raw_ptr(), mat.reborrow_mut().as_raw_ptr());
}
{
let ctx = &lazy_format!("{ctx} - matmut");
let mut mat = Mat::new(repr, Defaulted).unwrap();
let matmut = mat.reborrow_mut();
assert_eq!(matmut.num_vectors(), *nrows);
assert_eq!(matmut.vector_dim(), *ncols);
fill_mat_mut(matmut, repr);
check_mat(&mat, repr, ctx);
check_mat_ref(mat.reborrow(), repr, ctx);
check_mat_mut(mat.reborrow_mut(), repr, ctx);
check_rows(mat.rows(), repr, ctx);
}
{
let ctx = &lazy_format!("{ctx} - rows_mut");
let mut mat = Mat::new(repr, Defaulted).unwrap();
fill_rows_mut(mat.rows_mut(), repr);
check_mat(&mat, repr, ctx);
check_mat_ref(mat.reborrow(), repr, ctx);
check_mat_mut(mat.reborrow_mut(), repr, ctx);
check_rows(mat.rows(), repr, ctx);
}
}
}
}
#[test]
fn test_mat_clone() {
for nrows in ROWS {
for ncols in COLS {
let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
let mut mat = Mat::new(repr, Defaulted).unwrap();
fill_mat(&mut mat, repr);
{
let ctx = &lazy_format!("{ctx} - Mat::clone");
let cloned = mat.clone();
assert_eq!(cloned.num_vectors(), *nrows);
assert_eq!(cloned.vector_dim(), *ncols);
check_mat(&cloned, repr, ctx);
check_mat_ref(cloned.reborrow(), repr, ctx);
check_rows(cloned.rows(), repr, ctx);
if repr.num_elements() > 0 {
assert_ne!(mat.as_raw_ptr(), cloned.as_raw_ptr());
}
}
{
let ctx = &lazy_format!("{ctx} - MatRef::to_owned");
let owned = mat.as_view().to_owned();
check_mat(&owned, repr, ctx);
check_mat_ref(owned.reborrow(), repr, ctx);
check_rows(owned.rows(), repr, ctx);
if repr.num_elements() > 0 {
assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
}
}
{
let ctx = &lazy_format!("{ctx} - MatMut::to_owned");
let owned = mat.as_view_mut().to_owned();
check_mat(&owned, repr, ctx);
check_mat_ref(owned.reborrow(), repr, ctx);
check_rows(owned.rows(), repr, ctx);
if repr.num_elements() > 0 {
assert_ne!(mat.as_raw_ptr(), owned.as_raw_ptr());
}
}
}
}
}
#[test]
fn test_mat_refmut() {
for nrows in ROWS {
for ncols in COLS {
let repr = Standard::<usize>::new(*nrows, *ncols).unwrap();
let ctx = &lazy_format!("nrows = {}, ncols = {}", nrows, ncols);
{
let ctx = &lazy_format!("{ctx} - by matmut");
let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
let ptr = b.as_ptr().cast::<u8>();
let mut matmut = MatMut::new(repr, &mut b).unwrap();
assert_eq!(
ptr,
matmut.as_raw_ptr(),
"underlying memory should be preserved",
);
fill_mat_mut(matmut.reborrow_mut(), repr);
check_mat_mut(matmut.reborrow_mut(), repr, ctx);
check_mat_ref(matmut.reborrow(), repr, ctx);
check_rows(matmut.rows(), repr, ctx);
check_rows(matmut.reborrow().rows(), repr, ctx);
let matref = MatRef::new(repr, &b).unwrap();
check_mat_ref(matref, repr, ctx);
check_mat_ref(matref.reborrow(), repr, ctx);
check_rows(matref.rows(), repr, ctx);
}
{
let ctx = &lazy_format!("{ctx} - by rows");
let mut b: Box<[_]> = (0..repr.num_elements()).map(|_| 0usize).collect();
let ptr = b.as_ptr().cast::<u8>();
let mut matmut = MatMut::new(repr, &mut b).unwrap();
assert_eq!(
ptr,
matmut.as_raw_ptr(),
"underlying memory should be preserved",
);
fill_rows_mut(matmut.rows_mut(), repr);
check_mat_mut(matmut.reborrow_mut(), repr, ctx);
check_mat_ref(matmut.reborrow(), repr, ctx);
check_rows(matmut.rows(), repr, ctx);
check_rows(matmut.reborrow().rows(), repr, ctx);
let matref = MatRef::new(repr, &b).unwrap();
check_mat_ref(matref, repr, ctx);
check_mat_ref(matref.reborrow(), repr, ctx);
check_rows(matref.rows(), repr, ctx);
}
}
}
}
#[test]
fn test_standard_new_owned() {
let rows = [0, 1, 2, 3, 5, 10];
let cols = [0, 1, 2, 3, 5, 10];
for nrows in rows {
for ncols in cols {
let m = Mat::new(Standard::new(nrows, ncols).unwrap(), 1usize).unwrap();
let rows_iter = m.rows();
let len = <_ as ExactSizeIterator>::len(&rows_iter);
assert_eq!(len, nrows);
for r in rows_iter {
assert_eq!(r.len(), ncols);
assert!(r.iter().all(|i| *i == 1usize));
}
}
}
}
#[test]
fn matref_new_slice_length_error() {
let repr = Standard::<u32>::new(3, 4).unwrap();
let data = vec![0u32; 12];
assert!(MatRef::new(repr, &data).is_ok());
let short = vec![0u32; 11];
assert!(matches!(
MatRef::new(repr, &short),
Err(SliceError::LengthMismatch {
expected: 12,
found: 11
})
));
let long = vec![0u32; 13];
assert!(matches!(
MatRef::new(repr, &long),
Err(SliceError::LengthMismatch {
expected: 12,
found: 13
})
));
}
#[test]
fn matmut_new_slice_length_error() {
let repr = Standard::<u32>::new(3, 4).unwrap();
let mut data = vec![0u32; 12];
assert!(MatMut::new(repr, &mut data).is_ok());
let mut short = vec![0u32; 11];
assert!(matches!(
MatMut::new(repr, &mut short),
Err(SliceError::LengthMismatch {
expected: 12,
found: 11
})
));
let mut long = vec![0u32; 13];
assert!(matches!(
MatMut::new(repr, &mut long),
Err(SliceError::LengthMismatch {
expected: 12,
found: 13
})
));
}
#[test]
fn as_matrix_view_roundtrip() {
let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let matref = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap();
let view = matref.as_matrix_view();
assert_eq!(view.nrows(), 2);
assert_eq!(view.ncols(), 3);
for row in 0..2 {
for col in 0..3 {
assert_eq!(view[(row, col)], data[row * 3 + col]);
}
}
assert_eq!(matref.as_slice(), &data);
let mut mat = Mat::new(Standard::<f32>::new(2, 3).unwrap(), 0.0f32).unwrap();
for i in 0..2 {
let r = mat.get_row_mut(i).unwrap();
for j in 0..3 {
r[j] = data[i * 3 + j];
}
}
let view = mat.as_matrix_view();
assert_eq!(view.nrows(), 2);
assert_eq!(view.ncols(), 3);
for row in 0..2 {
for col in 0..3 {
assert_eq!(view[(row, col)], data[row * 3 + col]);
}
}
assert_eq!(mat.as_slice(), &data);
let mut buf = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let matmut = MatMut::new(Standard::new(2, 3).unwrap(), &mut buf).unwrap();
let view = matmut.as_matrix_view();
assert_eq!(view.nrows(), 2);
assert_eq!(view.ncols(), 3);
for row in 0..2 {
for col in 0..3 {
assert_eq!(view[(row, col)], data[row * 3 + col]);
}
}
assert_eq!(matmut.as_slice(), &data);
}
}