use super::*;
use crate::{
assert, debug_assert, diag::DiagRef, iter, iter::chunks::ChunkPolicy, unzipped,
utils::bound::*, zipped_rw, Idx, IdxInc, Shape, Unbind,
};
use core::ops::Range;
use generativity::make_guard;
#[repr(C)]
pub struct MatRef<'a, E: Entity, R: Shape = usize, C: Shape = usize> {
pub(super) inner: MatImpl<E, R, C>,
pub(super) __marker: PhantomData<&'a E>,
}
impl<E: Entity, R: Shape, C: Shape> Clone for MatRef<'_, E, R, C> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<E: Entity, R: Shape, C: Shape> Copy for MatRef<'_, E, R, C> {}
impl<E: Entity> Default for MatRef<'_, E> {
#[inline]
fn default() -> Self {
from_column_major_slice_generic(map!(E, E::UNIT, |(())| { &[] as &[E::Unit] }), 0, 0)
}
}
impl<'short, E: Entity, R: Shape, C: Shape> Reborrow<'short> for MatRef<'_, E, R, C> {
type Target = MatRef<'short, E, R, C>;
#[inline]
fn rb(&'short self) -> Self::Target {
*self
}
}
impl<'short, E: Entity, R: Shape, C: Shape> ReborrowMut<'short> for MatRef<'_, E, R, C> {
type Target = MatRef<'short, E, R, C>;
#[inline]
fn rb_mut(&'short mut self) -> Self::Target {
*self
}
}
impl<E: Entity, R: Shape, C: Shape> IntoConst for MatRef<'_, E, R, C> {
type Target = Self;
#[inline]
fn into_const(self) -> Self::Target {
self
}
}
impl<'a, E: Entity, R: Shape, C: Shape> MatRef<'a, E, R, C> {
#[inline]
pub(crate) unsafe fn __from_raw_parts(
ptr: PtrConst<E>,
nrows: R,
ncols: C,
row_stride: isize,
col_stride: isize,
) -> Self {
Self {
inner: MatImpl {
ptr: into_copy::<E, _>(map!(E, ptr, |(ptr)| {
NonNull::new_unchecked(ptr as *mut E::Unit)
},)),
nrows,
ncols,
row_stride,
col_stride,
},
__marker: PhantomData,
}
}
#[inline(always)]
pub fn as_ptr(self) -> PtrConst<E> {
map!(E, from_copy::<E, _>(self.inner.ptr), |(ptr)| {
ptr.as_ptr() as *const E::Unit
},)
}
#[inline]
pub fn nrows(&self) -> R {
self.inner.nrows
}
#[inline]
pub fn ncols(&self) -> C {
self.inner.ncols
}
#[inline]
pub fn shape(&self) -> (R, C) {
(self.nrows(), self.ncols())
}
#[inline]
pub fn row_stride(&self) -> isize {
self.inner.row_stride
}
#[inline]
pub fn col_stride(&self) -> isize {
self.inner.col_stride
}
#[inline(always)]
pub fn ptr_at(self, row: usize, col: usize) -> PtrConst<E> {
let offset = ((row as isize).wrapping_mul(self.inner.row_stride))
.wrapping_add((col as isize).wrapping_mul(self.inner.col_stride));
map!(E, self.as_ptr(), |(ptr)| { ptr.wrapping_offset(offset) },)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn ptr_at_unchecked(self, row: usize, col: usize) -> PtrConst<E> {
let offset = crate::utils::unchecked_add(
crate::utils::unchecked_mul(row, self.inner.row_stride),
crate::utils::unchecked_mul(col, self.inner.col_stride),
);
map!(E, self.as_ptr(), |(ptr)| { ptr.offset(offset) },)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn overflowing_ptr_at(self, row: IdxInc<R>, col: IdxInc<C>) -> PtrConst<E> {
unsafe {
let cond = (row != self.nrows()) & (col != self.ncols());
let offset = (cond as usize).wrapping_neg() as isize
& (isize::wrapping_add(
(row.unbound() as isize).wrapping_mul(self.inner.row_stride),
(col.unbound() as isize).wrapping_mul(self.inner.col_stride),
));
map!(E, self.as_ptr(), |(ptr)| { ptr.offset(offset) },)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn ptr_inbounds_at(self, row: Idx<R>, col: Idx<C>) -> PtrConst<E> {
debug_assert!(all(row < self.nrows(), col < self.ncols()));
self.ptr_at_unchecked(row.unbound(), col.unbound())
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_unchecked(
self,
row: IdxInc<R>,
col: IdxInc<C>,
) -> (
MatRef<'a, E, usize, usize>,
MatRef<'a, E, usize, usize>,
MatRef<'a, E, usize, usize>,
MatRef<'a, E, usize, usize>,
) {
debug_assert!(all(row <= self.nrows(), col <= self.ncols()));
let row_stride = self.row_stride();
let col_stride = self.col_stride();
let nrows = self.nrows();
let ncols = self.ncols();
unsafe {
let top_left = self.overflowing_ptr_at(R::start(), C::start());
let top_right = self.overflowing_ptr_at(R::start(), col);
let bot_left = self.overflowing_ptr_at(row, C::start());
let bot_right = self.overflowing_ptr_at(row, col);
let row = row.unbound();
let nrows = nrows.unbound();
let col = col.unbound();
let ncols = ncols.unbound();
(
MatRef::__from_raw_parts(top_left, row, col, row_stride, col_stride),
MatRef::__from_raw_parts(top_right, row, ncols - col, row_stride, col_stride),
MatRef::__from_raw_parts(bot_left, nrows - row, col, row_stride, col_stride),
MatRef::__from_raw_parts(
bot_right,
nrows - row,
ncols - col,
row_stride,
col_stride,
),
)
}
}
#[inline(always)]
#[track_caller]
pub fn split_at(
self,
row: IdxInc<R>,
col: IdxInc<C>,
) -> (
MatRef<'a, E, usize, usize>,
MatRef<'a, E, usize, usize>,
MatRef<'a, E, usize, usize>,
MatRef<'a, E, usize, usize>,
) {
assert!(all(row <= self.nrows(), col <= self.ncols()));
unsafe { self.split_at_unchecked(row, col) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_row_unchecked(
self,
row: IdxInc<R>,
) -> (MatRef<'a, E, usize, C>, MatRef<'a, E, usize, C>) {
debug_assert!(row <= self.nrows());
let row_stride = self.row_stride();
let col_stride = self.col_stride();
let nrows = self.nrows();
let ncols = self.ncols();
unsafe {
let top_right = self.overflowing_ptr_at(R::start(), C::start());
let bot_right = self.overflowing_ptr_at(row, C::start());
let row = row.unbound();
let nrows = nrows.unbound();
(
MatRef::__from_raw_parts(top_right, row, ncols, row_stride, col_stride),
MatRef::__from_raw_parts(bot_right, nrows - row, ncols, row_stride, col_stride),
)
}
}
#[inline(always)]
#[track_caller]
pub fn split_at_row(
self,
row: IdxInc<R>,
) -> (MatRef<'a, E, usize, C>, MatRef<'a, E, usize, C>) {
assert!(row <= self.nrows());
unsafe { self.split_at_row_unchecked(row) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_col_unchecked(
self,
col: IdxInc<C>,
) -> (MatRef<'a, E, R, usize>, MatRef<'a, E, R, usize>) {
debug_assert!(col <= self.ncols());
let row_stride = self.row_stride();
let col_stride = self.col_stride();
let nrows = self.nrows();
let ncols = self.ncols();
unsafe {
let bot_left = self.overflowing_ptr_at(R::start(), C::start());
let bot_right = self.overflowing_ptr_at(R::start(), col);
let col = col.unbound();
let ncols = ncols.unbound();
(
MatRef::__from_raw_parts(bot_left, nrows, col, row_stride, col_stride),
MatRef::__from_raw_parts(bot_right, nrows, ncols - col, row_stride, col_stride),
)
}
}
#[inline(always)]
#[track_caller]
pub fn split_at_col(
self,
col: IdxInc<C>,
) -> (MatRef<'a, E, R, usize>, MatRef<'a, E, R, usize>) {
assert!(col <= self.ncols());
unsafe { self.split_at_col_unchecked(col) }
}
#[inline(always)]
#[must_use]
pub fn transpose(self) -> MatRef<'a, E, C, R> {
unsafe {
MatRef::__from_raw_parts(
self.as_ptr(),
self.ncols(),
self.nrows(),
self.col_stride(),
self.row_stride(),
)
}
}
#[inline(always)]
#[must_use]
pub fn conjugate(self) -> MatRef<'a, E::Conj, R, C>
where
E: Conjugate,
{
unsafe {
MatRef::__from_raw_parts(
transmute_unchecked::<
GroupFor<E, *const UnitFor<E>>,
GroupFor<E::Conj, *const UnitFor<E::Conj>>,
>(self.as_ptr()),
self.nrows(),
self.ncols(),
self.row_stride(),
self.col_stride(),
)
}
}
#[inline(always)]
#[must_use]
pub fn adjoint(self) -> MatRef<'a, E::Conj, C, R>
where
E: Conjugate,
{
self.transpose().conjugate()
}
#[inline(always)]
#[must_use]
pub fn canonicalize(self) -> (MatRef<'a, E::Canonical, R, C>, Conj)
where
E: Conjugate,
{
(
unsafe {
MatRef::__from_raw_parts(
transmute_unchecked::<
PtrConst<E>,
GroupFor<E::Canonical, *const UnitFor<E::Canonical>>,
>(self.as_ptr()),
self.nrows(),
self.ncols(),
self.row_stride(),
self.col_stride(),
)
},
if E::IS_CANONICAL { Conj::No } else { Conj::Yes },
)
}
#[inline(always)]
#[track_caller]
pub unsafe fn at_unchecked(self, row: Idx<R>, col: Idx<C>) -> Ref<'a, E> {
unsafe { map!(E, self.ptr_inbounds_at(row, col), |(ptr)| &*ptr) }
}
#[inline(always)]
#[track_caller]
pub fn at(self, row: Idx<R>, col: Idx<C>) -> Ref<'a, E> {
assert!(all(row < self.nrows(), col < self.ncols()));
unsafe { map!(E, self.ptr_inbounds_at(row, col), |(ptr)| &*ptr) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn read_unchecked(&self, row: Idx<R>, col: Idx<C>) -> E {
E::faer_from_units(map!(E, self.at_unchecked(row, col), |(ptr)| { *ptr },))
}
#[inline(always)]
#[track_caller]
pub fn read(&self, row: Idx<R>, col: Idx<C>) -> E {
E::faer_from_units(map!(E, self.at(row, col), |(ptr)| { *ptr },))
}
#[inline(always)]
#[must_use]
pub fn reverse_rows(self) -> Self {
let nrows = self.nrows();
let ncols = self.ncols();
let row_stride = self.row_stride().wrapping_neg();
let col_stride = self.col_stride();
let ptr = unsafe { self.ptr_at_unchecked(nrows.unbound().saturating_sub(1), 0) };
unsafe { Self::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) }
}
#[inline(always)]
#[must_use]
pub fn reverse_cols(self) -> Self {
let nrows = self.nrows();
let ncols = self.ncols();
let row_stride = self.row_stride();
let col_stride = self.col_stride().wrapping_neg();
let ptr = unsafe { self.ptr_at_unchecked(0, ncols.unbound().saturating_sub(1)) };
unsafe { Self::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) }
}
#[inline(always)]
#[must_use]
pub fn reverse_rows_and_cols(self) -> Self {
let nrows = self.nrows();
let ncols = self.ncols();
let row_stride = -self.row_stride();
let col_stride = -self.col_stride();
let ptr = unsafe {
self.ptr_at_unchecked(
nrows.unbound().saturating_sub(1),
ncols.unbound().saturating_sub(1),
)
};
unsafe { Self::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn submatrix_unchecked<V: Shape, H: Shape>(
self,
row_start: IdxInc<R>,
col_start: IdxInc<C>,
nrows: V,
ncols: H,
) -> MatRef<'a, E, V, H> {
debug_assert!(all(row_start <= self.nrows(), col_start <= self.ncols()));
{
let nrows = nrows.unbound();
let row_start = row_start.unbound();
let ncols = ncols.unbound();
let col_start = col_start.unbound();
debug_assert!(all(
nrows <= self.nrows().unbound() - row_start,
ncols <= self.ncols().unbound() - col_start,
));
}
let row_stride = self.row_stride();
let col_stride = self.col_stride();
unsafe {
MatRef::__from_raw_parts(
self.overflowing_ptr_at(row_start, col_start),
nrows,
ncols,
row_stride,
col_stride,
)
}
}
#[track_caller]
#[inline(always)]
pub fn submatrix<V: Shape, H: Shape>(
self,
row_start: IdxInc<R>,
col_start: IdxInc<C>,
nrows: V,
ncols: H,
) -> MatRef<'a, E, V, H> {
assert!(all(row_start <= self.nrows(), col_start <= self.ncols()));
{
let nrows = nrows.unbound();
let row_start = row_start.unbound();
let ncols = ncols.unbound();
let col_start = col_start.unbound();
assert!(all(
nrows <= self.nrows().unbound() - row_start,
ncols <= self.ncols().unbound() - col_start,
));
}
unsafe { self.submatrix_unchecked(row_start, col_start, nrows, ncols) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn subrows_unchecked<V: Shape>(
self,
row_start: IdxInc<R>,
nrows: V,
) -> MatRef<'a, E, V, C> {
debug_assert!(row_start <= self.nrows());
{
let nrows = nrows.unbound();
let row_start = row_start.unbound();
debug_assert!(nrows <= self.nrows().unbound() - row_start);
}
let row_stride = self.row_stride();
let col_stride = self.col_stride();
unsafe {
MatRef::__from_raw_parts(
self.overflowing_ptr_at(row_start, C::start()),
nrows,
self.ncols(),
row_stride,
col_stride,
)
}
}
#[track_caller]
#[inline(always)]
pub fn subrows<V: Shape>(self, row_start: IdxInc<R>, nrows: V) -> MatRef<'a, E, V, C> {
assert!(row_start <= self.nrows());
{
let nrows = nrows.unbound();
let row_start = row_start.unbound();
assert!(nrows <= self.nrows().unbound() - row_start);
}
unsafe { self.subrows_unchecked(row_start, nrows) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn subcols_unchecked<H: Shape>(
self,
col_start: IdxInc<C>,
ncols: H,
) -> MatRef<'a, E, R, H> {
debug_assert!(col_start <= self.ncols());
{
let ncols = ncols.unbound();
let col_start = col_start.unbound();
debug_assert!(ncols <= self.ncols().unbound() - col_start);
}
let row_stride = self.row_stride();
let col_stride = self.col_stride();
unsafe {
MatRef::__from_raw_parts(
self.overflowing_ptr_at(R::start(), col_start),
self.nrows(),
ncols,
row_stride,
col_stride,
)
}
}
#[track_caller]
#[inline(always)]
pub fn subcols<H: Shape>(self, col_start: IdxInc<C>, ncols: H) -> MatRef<'a, E, R, H> {
assert!(col_start <= self.ncols());
{
let ncols = ncols.unbound();
let col_start = col_start.unbound();
assert!(ncols <= self.ncols().unbound() - col_start);
}
unsafe { self.subcols_unchecked(col_start, ncols) }
}
#[track_caller]
#[inline(always)]
#[doc(hidden)]
pub fn subcols_range(self, cols: Range<IdxInc<C>>) -> MatRef<'a, E, R, usize> {
assert!(all(cols.start <= self.ncols(), cols.end <= self.ncols()));
let ncols = cols.end.unbound().saturating_sub(cols.start.unbound());
unsafe { self.subcols_unchecked(cols.start, ncols) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn row_unchecked(self, row_idx: Idx<R>) -> RowRef<'a, E, C> {
debug_assert!(row_idx < self.nrows());
unsafe {
RowRef::__from_raw_parts(
self.overflowing_ptr_at(row_idx.into(), C::start()),
self.ncols(),
self.col_stride(),
)
}
}
#[track_caller]
#[inline(always)]
pub fn row(self, row_idx: Idx<R>) -> RowRef<'a, E, C> {
assert!(row_idx < self.nrows());
unsafe { self.row_unchecked(row_idx) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn col_unchecked(self, col_idx: Idx<C>) -> ColRef<'a, E, R> {
debug_assert!(col_idx < self.ncols());
unsafe {
ColRef::__from_raw_parts(
self.overflowing_ptr_at(R::start(), col_idx.into()),
self.nrows(),
self.row_stride(),
)
}
}
#[track_caller]
#[inline(always)]
pub fn col(self, col_idx: Idx<C>) -> ColRef<'a, E, R> {
assert!(col_idx < self.ncols());
unsafe { self.col_unchecked(col_idx) }
}
#[inline]
pub fn to_owned(&self) -> Mat<E::Canonical>
where
E: Conjugate,
{
let mut mat = Mat::new();
mat.resize_with(
self.nrows().unbound(),
self.ncols().unbound(),
#[inline(always)]
|row, col| unsafe {
self.read_unchecked(Idx::<R>::new_unbound(row), Idx::<C>::new_unbound(col))
.canonicalize()
},
);
mat
}
#[doc(hidden)]
#[inline(always)]
pub unsafe fn const_cast(self) -> MatMut<'a, E, R, C> {
MatMut {
inner: self.inner,
__marker: PhantomData,
}
}
#[inline]
pub fn as_dyn(self) -> MatRef<'a, E> {
unsafe {
from_raw_parts(
self.as_ptr(),
self.nrows().unbound(),
self.ncols().unbound(),
self.row_stride(),
self.col_stride(),
)
}
}
#[inline]
pub fn as_shape<V: Shape, H: Shape>(self, nrows: V, ncols: H) -> MatRef<'a, E, V, H> {
assert!(all(
nrows.unbound() == self.nrows().unbound(),
ncols.unbound() == self.ncols().unbound(),
));
unsafe {
from_raw_parts(
self.as_ptr(),
nrows,
ncols,
self.row_stride(),
self.col_stride(),
)
}
}
#[track_caller]
#[inline(always)]
#[doc(hidden)]
pub fn try_get_contiguous_col(self, j: Idx<C>) -> Slice<'a, E> {
assert!(self.row_stride() == 1);
let col = self.col(j);
let m = col.nrows().unbound();
if m == 0 {
map!(E, E::UNIT, |(())| { &[] as &[E::Unit] },)
} else {
map!(E, col.as_ptr(), |(ptr)| {
unsafe { core::slice::from_raw_parts(ptr, m) }
},)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn get_unchecked<RowRange, ColRange>(
self,
row: RowRange,
col: ColRange,
) -> <Self as MatIndex<RowRange, ColRange>>::Target
where
Self: MatIndex<RowRange, ColRange>,
{
<Self as MatIndex<RowRange, ColRange>>::get_unchecked(self, row, col)
}
#[inline(always)]
#[track_caller]
pub fn get<RowRange, ColRange>(
self,
row: RowRange,
col: ColRange,
) -> <Self as MatIndex<RowRange, ColRange>>::Target
where
Self: MatIndex<RowRange, ColRange>,
{
<Self as MatIndex<RowRange, ColRange>>::get(self, row, col)
}
#[inline]
pub fn has_nan(&self) -> bool
where
E: ComplexField,
{
let mut found_nan = false;
zipped_rw!(*self).for_each(|unzipped!(x)| {
found_nan |= x.read().faer_is_nan();
});
found_nan
}
#[inline]
pub fn is_all_finite(&self) -> bool
where
E: ComplexField,
{
let mut all_finite = true;
zipped_rw!(*self).for_each(|unzipped!(x)| {
all_finite &= x.read().faer_is_finite();
});
all_finite
}
#[inline]
pub fn norm_max(&self) -> E::Real
where
E: ComplexField,
{
crate::linalg::reductions::norm_max::norm_max(self.as_dyn())
}
#[inline]
pub fn norm_l1(&self) -> E::Real
where
E: ComplexField,
{
crate::linalg::reductions::norm_l1::norm_l1(self.as_dyn())
}
#[inline]
pub fn norm_l2(&self) -> E::Real
where
E: ComplexField,
{
crate::linalg::reductions::norm_l2::norm_l2(self.as_dyn())
}
#[inline]
pub fn squared_norm_l2(&self) -> E::Real
where
E: ComplexField,
{
let norm = crate::linalg::reductions::norm_l2::norm_l2(self.as_dyn());
norm.faer_mul(norm)
}
#[inline]
pub fn sum(&self) -> E
where
E: ComplexField,
{
crate::linalg::reductions::sum::sum(self.as_dyn())
}
#[inline]
#[track_caller]
pub fn kron(&self, rhs: impl As2D<E>) -> Mat<E>
where
E: ComplexField,
{
let lhs = self.as_dyn();
let rhs = rhs.as_2d_ref();
let mut dst = Mat::new();
dst.resize_with(
lhs.nrows() * rhs.nrows(),
lhs.ncols() * rhs.ncols(),
|_, _| E::zeroed(),
);
crate::linalg::kron(dst.as_mut(), lhs, rhs);
dst
}
#[inline]
pub fn as_ref(&self) -> MatRef<'_, E, R, C> {
*self
}
#[inline]
pub fn split_first_col(self) -> Option<(ColRef<'a, E, R>, MatRef<'a, E, R, usize>)> {
if self.ncols().unbound() == 0 {
None
} else {
unsafe {
let (head, tail) =
{ self.split_at_col_unchecked(self.ncols().unchecked_idx_inc(1)) };
Some((head.col_unchecked(0), tail))
}
}
}
#[inline]
pub fn split_last_col(self) -> Option<(ColRef<'a, E, R>, MatRef<'a, E, R, usize>)> {
let ncols = self.ncols().unbound();
if ncols == 0 {
None
} else {
unsafe {
let (head, tail) =
{ self.split_at_col_unchecked(self.ncols().unchecked_idx_inc(ncols - 1)) };
Some((tail.col_unchecked(0), head))
}
}
}
#[inline]
pub fn split_first_row(self) -> Option<(RowRef<'a, E, C>, MatRef<'a, E, usize, C>)> {
if self.nrows().unbound() == 0 {
None
} else {
unsafe {
let (head, tail) =
{ self.split_at_row_unchecked(self.nrows().unchecked_idx_inc(1)) };
Some((head.row_unchecked(0), tail))
}
}
}
#[inline]
pub fn split_last_row(self) -> Option<(RowRef<'a, E, C>, MatRef<'a, E, usize, C>)> {
let nrows = self.nrows().unbound();
if nrows == 0 {
None
} else {
unsafe {
let (head, tail) =
{ self.split_at_row_unchecked(self.nrows().unchecked_idx_inc(nrows - 1)) };
Some((tail.row_unchecked(0), head))
}
}
}
#[inline]
pub fn col_iter(self) -> iter::ColIter<'a, E, R> {
let nrows = self.nrows();
let ncols = self.ncols();
iter::ColIter {
inner: self.as_shape(nrows, ncols.unbound()),
}
}
#[inline]
pub fn row_iter(self) -> iter::RowIter<'a, E> {
iter::RowIter {
inner: self.as_dyn(),
}
}
#[inline]
#[track_caller]
pub fn col_chunks(self, chunk_size: usize) -> iter::ColChunks<'a, E> {
assert!(chunk_size > 0);
let this = self.as_dyn();
let ncols = this.ncols();
iter::ColChunks {
inner: this,
policy: iter::chunks::ChunkSizePolicy::new(ncols, iter::chunks::ChunkSize(chunk_size)),
}
}
#[inline]
#[track_caller]
pub fn col_partition(self, count: usize) -> iter::ColPartition<'a, E> {
assert!(count > 0);
let this = self.as_dyn();
let ncols = this.ncols();
iter::ColPartition {
inner: this,
policy: iter::chunks::PartitionCountPolicy::new(
ncols,
iter::chunks::PartitionCount(count),
),
}
}
#[inline]
#[track_caller]
pub fn row_chunks(self, chunk_size: usize) -> iter::RowChunks<'a, E> {
assert!(chunk_size > 0);
let this = self.as_dyn();
let nrows = this.nrows();
iter::RowChunks {
inner: this,
policy: iter::chunks::ChunkSizePolicy::new(nrows, iter::chunks::ChunkSize(chunk_size)),
}
}
#[inline]
#[track_caller]
pub fn row_partition(self, count: usize) -> iter::RowPartition<'a, E> {
assert!(count > 0);
let this = self.as_dyn();
let nrows = this.nrows();
iter::RowPartition {
inner: this,
policy: iter::chunks::PartitionCountPolicy::new(
nrows,
iter::chunks::PartitionCount(count),
),
}
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_col_chunks(
self,
chunk_size: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, E, R, usize>> {
use rayon::prelude::*;
let this = self.as_dyn();
assert!(chunk_size > 0);
let chunk_count = this.ncols().div_ceil(chunk_size);
(0..chunk_count).into_par_iter().map(move |chunk_idx| {
let pos = chunk_size * chunk_idx;
let out = this.subcols(pos, Ord::min(chunk_size, this.ncols() - pos));
out.submatrix(0, 0, unsafe { R::new_unbound(out.nrows()) }, out.ncols())
})
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_col_partition(
self,
count: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, E, R, usize>> {
use rayon::prelude::*;
let this = self.as_dyn();
assert!(count > 0);
(0..count).into_par_iter().map(move |chunk_idx| {
let (start, len) =
crate::utils::thread::par_split_indices(this.ncols(), chunk_idx, count);
let out = this.subcols(start, len);
out.submatrix(0, 0, unsafe { R::new_unbound(out.nrows()) }, out.ncols())
})
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_row_chunks(
self,
chunk_size: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, E, usize, C>> {
use rayon::prelude::*;
self.transpose()
.par_col_chunks(chunk_size)
.map(|chunk| chunk.transpose())
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_row_partition(
self,
count: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = MatRef<'a, E, usize, C>> {
use rayon::prelude::*;
self.transpose()
.par_col_partition(count)
.map(|chunk| chunk.transpose())
}
#[track_caller]
#[inline(always)]
pub fn column_vector_as_diagonal(self) -> DiagRef<'a, E, R> {
assert!(self.ncols().unbound() == 1);
DiagRef {
inner: self.col(unsafe { Idx::<C>::new_unbound(0) }),
}
}
}
impl<'a, E: Entity, N: Shape> MatRef<'a, E, N, N> {
#[inline(always)]
pub fn diagonal(self) -> DiagRef<'a, E, N> {
let size = self.nrows().min(self.ncols());
let row_stride = self.row_stride();
let col_stride = self.col_stride();
unsafe {
DiagRef {
inner: crate::col::from_raw_parts(self.as_ptr(), size, row_stride + col_stride),
}
}
}
}
impl<'a, E: RealField, R: Shape, C: Shape> MatRef<'a, num_complex::Complex<E>, R, C> {
#[inline(always)]
pub fn real_imag(self) -> num_complex::Complex<MatRef<'a, E, R, C>> {
let row_stride = self.row_stride();
let col_stride = self.col_stride();
let nrows = self.nrows();
let ncols = self.ncols();
let num_complex::Complex { re, im } = self.as_ptr();
unsafe {
num_complex::Complex {
re: super::from_raw_parts(re, nrows, ncols, row_stride, col_stride),
im: super::from_raw_parts(im, nrows, ncols, row_stride, col_stride),
}
}
}
}
impl<E: Entity, R: Shape, C: Shape> AsMatRef<E> for MatRef<'_, E, R, C> {
type R = R;
type C = C;
#[inline]
fn as_mat_ref(&self) -> MatRef<'_, E, R, C> {
*self
}
}
impl<E: Entity> As2D<E> for MatRef<'_, E> {
#[inline]
fn as_2d_ref(&self) -> MatRef<'_, E> {
*self
}
}
#[inline(always)]
pub unsafe fn from_raw_parts<'a, E: Entity, R: Shape, C: Shape>(
ptr: PtrConst<E>,
nrows: R,
ncols: C,
row_stride: isize,
col_stride: isize,
) -> MatRef<'a, E, R, C> {
MatRef::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride)
}
#[track_caller]
#[inline(always)]
pub fn from_column_major_slice_generic<E: Entity, R: Shape, C: Shape>(
slice: Slice<'_, E>,
nrows: R,
ncols: C,
) -> MatRef<'_, E, R, C> {
from_slice_assert(
nrows.unbound(),
ncols.unbound(),
SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(),
);
unsafe {
from_raw_parts(
map!(E, slice, |(slice)| { slice.as_ptr() },),
nrows,
ncols,
1,
nrows.unbound() as isize,
)
}
}
#[track_caller]
#[inline(always)]
pub fn from_column_major_slice<E: SimpleEntity, R: Shape, C: Shape>(
slice: &[E],
nrows: R,
ncols: C,
) -> MatRef<'_, E, R, C> {
from_column_major_slice_generic(slice, nrows, ncols)
}
#[track_caller]
#[inline(always)]
pub fn from_row_major_slice_generic<E: Entity, R: Shape, C: Shape>(
slice: Slice<'_, E>,
nrows: R,
ncols: C,
) -> MatRef<'_, E, R, C> {
from_column_major_slice_generic(slice, ncols, nrows).transpose()
}
#[track_caller]
#[inline(always)]
pub fn from_row_major_slice<E: SimpleEntity, R: Shape, C: Shape>(
slice: &[E],
nrows: R,
ncols: C,
) -> MatRef<'_, E, R, C> {
from_column_major_slice_generic(slice, ncols, nrows).transpose()
}
#[track_caller]
pub fn from_column_major_slice_with_stride_generic<E: Entity, R: Shape, C: Shape>(
slice: Slice<'_, E>,
nrows: R,
ncols: C,
col_stride: usize,
) -> MatRef<'_, E, R, C> {
from_strided_column_major_slice_assert(
nrows.unbound(),
ncols.unbound(),
col_stride,
SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(),
);
unsafe {
from_raw_parts(
map!(E, slice, |(slice)| { slice.as_ptr() },),
nrows,
ncols,
1,
col_stride.unbound() as isize,
)
}
}
#[track_caller]
pub fn from_row_major_slice_with_stride_generic<E: Entity, R: Shape, C: Shape>(
slice: Slice<'_, E>,
nrows: R,
ncols: C,
row_stride: usize,
) -> MatRef<'_, E, R, C> {
from_column_major_slice_with_stride_generic(slice, ncols, nrows, row_stride).transpose()
}
#[track_caller]
pub fn from_column_major_slice_with_stride<E: SimpleEntity, R: Shape, C: Shape>(
slice: &[E],
nrows: R,
ncols: C,
col_stride: usize,
) -> MatRef<'_, E, R, C> {
from_column_major_slice_with_stride_generic(slice, nrows, ncols, col_stride)
}
#[track_caller]
pub fn from_row_major_slice_with_stride<E: SimpleEntity, R: Shape, C: Shape>(
slice: &[E],
nrows: R,
ncols: C,
row_stride: usize,
) -> MatRef<'_, E, R, C> {
from_row_major_slice_with_stride_generic(slice, nrows, ncols, row_stride)
}
impl<'a, T: Entity, R: Shape, C: Shape> core::fmt::Debug for MatRef<'a, T, R, C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
make_guard!(M);
make_guard!(N);
let M = self.nrows().bind(M);
let N = self.ncols().bind(N);
let this = self.as_shape(M, N);
fn imp<'M, 'N, T: Entity>(
this: MatRef<'_, T, Dim<'M>, Dim<'N>>,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
struct DebugRow<'a, 'N, T: Entity>(RowRef<'a, T, Dim<'N>>);
impl<'N, T: Entity> core::fmt::Debug for DebugRow<'_, 'N, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_list()
.entries(
self.0
.ncols()
.indices()
.map(|j| (T::faer_from_units(T::faer_deref(self.0.at(j))))),
)
.finish()
}
}
writeln!(f, "[")?;
for i in this.nrows().indices() {
let row = this.row(i);
DebugRow(row).fmt(f)?;
f.write_str(",\n")?;
}
write!(f, "]")
}
imp(this, f)
}
}
impl<E: SimpleEntity> core::ops::Index<(usize, usize)> for MatRef<'_, E> {
type Output = E;
#[inline]
#[track_caller]
fn index(&self, (row, col): (usize, usize)) -> &E {
self.get(row, col)
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<E: Entity> matrixcompare_core::Matrix<E> for MatRef<'_, E> {
#[inline]
fn rows(&self) -> usize {
self.nrows()
}
#[inline]
fn cols(&self) -> usize {
self.ncols()
}
#[inline]
fn access(&self) -> matrixcompare_core::Access<'_, E> {
matrixcompare_core::Access::Dense(self)
}
}
#[cfg(feature = "std")]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
impl<E: Entity> matrixcompare_core::DenseAccess<E> for MatRef<'_, E> {
#[inline]
fn fetch_single(&self, row: usize, col: usize) -> E {
self.read(row, col)
}
}
impl<E: Conjugate> ColBatch<E> for MatRef<'_, E> {
type Owned = Mat<E::Canonical>;
#[inline]
#[track_caller]
fn new_owned_zeros(nrows: usize, ncols: usize) -> Self::Owned {
Mat::zeros(nrows, ncols)
}
#[inline]
fn new_owned_copied(src: &Self) -> Self::Owned {
src.to_owned()
}
#[inline]
#[track_caller]
fn resize_owned(owned: &mut Self::Owned, nrows: usize, ncols: usize) {
<Self::Owned as ColBatch<E::Canonical>>::resize_owned(owned, nrows, ncols)
}
}
impl<E: Conjugate> RowBatch<E> for MatRef<'_, E> {
type Owned = Mat<E::Canonical>;
#[inline]
#[track_caller]
fn new_owned_zeros(nrows: usize, ncols: usize) -> Self::Owned {
Mat::zeros(nrows, ncols)
}
#[inline]
fn new_owned_copied(src: &Self) -> Self::Owned {
src.to_owned()
}
#[inline]
#[track_caller]
fn resize_owned(owned: &mut Self::Owned, nrows: usize, ncols: usize) {
<Self::Owned as RowBatch<E::Canonical>>::resize_owned(owned, nrows, ncols)
}
}
#[doc(alias = "broadcast")]
pub fn from_repeated_ref_generic<E: Entity, R: Shape, C: Shape>(
value: Ref<'_, E>,
nrows: R,
ncols: C,
) -> MatRef<'_, E, R, C> {
unsafe {
from_raw_parts(
map!(E, value, |(ptr)| { ptr as *const E::Unit }),
nrows,
ncols,
0,
0,
)
}
}
#[doc(alias = "broadcast")]
pub fn from_repeated_ref<E: SimpleEntity, R: Shape, C: Shape>(
value: &E,
nrows: R,
ncols: C,
) -> MatRef<'_, E, R, C> {
from_repeated_ref_generic(value, nrows, ncols)
}
#[doc(alias = "broadcast")]
pub fn from_repeated_col<E: Entity, C: Shape>(
col: ColRef<'_, E>,
ncols: C,
) -> MatRef<'_, E, usize, C> {
unsafe { from_raw_parts(col.as_ptr(), col.nrows(), ncols, col.row_stride(), 0) }
}
#[doc(alias = "broadcast")]
pub fn from_repeated_row<E: Entity, R: Shape>(
row: RowRef<'_, E>,
nrows: R,
) -> MatRef<'_, E, R, usize> {
unsafe { from_raw_parts(row.as_ptr(), nrows, row.ncols(), 0, row.col_stride()) }
}
pub fn from_ref<E: SimpleEntity>(value: &E) -> MatRef<'_, E> {
from_ref_generic(value)
}
pub fn from_ref_generic<E: Entity>(value: Ref<'_, E>) -> MatRef<'_, E> {
from_repeated_ref_generic(value, 1, 1)
}