use super::*;
use crate::{
assert,
col::ColRef,
debug_assert,
iter::{self, chunks::ChunkPolicy},
Idx, IdxInc, Shape, Unbind,
};
#[repr(C)]
pub struct RowRef<'a, E: Entity, C: Shape = usize> {
pub(super) inner: VecImpl<E, C>,
pub(super) __marker: PhantomData<&'a E>,
}
impl<E: Entity, C: Shape> Clone for RowRef<'_, E, C> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<E: Entity, C: Shape> Copy for RowRef<'_, E, C> {}
impl<E: Entity> Default for RowRef<'_, E> {
#[inline]
fn default() -> Self {
from_slice_generic::<E>(E::faer_map(E::UNIT, |()| &[] as &[E::Unit]))
}
}
impl<'short, E: Entity, C: Shape> Reborrow<'short> for RowRef<'_, E, C> {
type Target = RowRef<'short, E, C>;
#[inline]
fn rb(&'short self) -> Self::Target {
*self
}
}
impl<'short, E: Entity, C: Shape> ReborrowMut<'short> for RowRef<'_, E, C> {
type Target = RowRef<'short, E, C>;
#[inline]
fn rb_mut(&'short mut self) -> Self::Target {
*self
}
}
impl<E: Entity, C: Shape> IntoConst for RowRef<'_, E, C> {
type Target = Self;
#[inline]
fn into_const(self) -> Self::Target {
self
}
}
impl<'a, E: Entity, C: Shape> RowRef<'a, E, C> {
pub(crate) unsafe fn __from_raw_parts(ptr: PtrConst<E>, ncols: C, col_stride: isize) -> Self {
Self {
inner: VecImpl {
ptr: into_copy::<E, _>(E::faer_map(
ptr,
#[inline]
|ptr| NonNull::new_unchecked(ptr as *mut E::Unit),
)),
len: ncols,
stride: col_stride,
},
__marker: PhantomData,
}
}
#[inline(always)]
pub fn nrows(&self) -> usize {
1
}
#[inline(always)]
pub fn ncols(&self) -> C {
self.inner.len
}
#[inline(always)]
pub fn as_ptr(self) -> PtrConst<E> {
E::faer_map(
from_copy::<E, _>(self.inner.ptr),
#[inline(always)]
|ptr| ptr.as_ptr() as *const E::Unit,
)
}
#[inline(always)]
pub fn col_stride(&self) -> isize {
self.inner.stride
}
#[inline(always)]
pub fn as_2d(self) -> MatRef<'a, E, usize, C> {
let ncols = self.ncols();
let col_stride = self.col_stride();
unsafe { crate::mat::from_raw_parts(self.as_ptr(), 1, ncols, isize::MAX, col_stride) }
}
#[inline(always)]
pub fn ptr_at(self, col: usize) -> PtrConst<E> {
let offset = (col as isize).wrapping_mul(self.inner.stride);
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| ptr.wrapping_offset(offset),
)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn ptr_at_unchecked(self, col: usize) -> PtrConst<E> {
let offset = crate::utils::unchecked_mul(col, self.inner.stride);
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| ptr.offset(offset),
)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn overflowing_ptr_at(self, col: IdxInc<C>) -> PtrConst<E> {
unsafe {
let cond = col != self.ncols();
let offset = (cond as usize).wrapping_neg() as isize
& (col.unbound() as isize).wrapping_mul(self.inner.stride);
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| ptr.offset(offset),
)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn ptr_inbounds_at(self, col: Idx<C>) -> PtrConst<E> {
debug_assert!(col < self.ncols());
self.ptr_at_unchecked(col.unbound())
}
#[inline]
pub fn as_dyn(self) -> RowRef<'a, E> {
let ncols = self.ncols().unbound();
let col_stride = self.col_stride();
unsafe { from_raw_parts(self.as_ptr(), ncols, col_stride) }
}
#[inline]
pub fn as_shape<H: Shape>(self, ncols: H) -> RowRef<'a, E, H> {
assert!(ncols.unbound() == self.ncols().unbound());
unsafe { from_raw_parts(self.as_ptr(), ncols, self.col_stride()) }
}
#[doc(hidden)]
#[inline(always)]
pub unsafe fn const_cast(self) -> RowMut<'a, E, C> {
RowMut {
inner: self.inner,
__marker: PhantomData,
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_unchecked(
self,
col: IdxInc<C>,
) -> (RowRef<'a, E, usize>, RowRef<'a, E, usize>) {
debug_assert!(col <= self.ncols());
let col_stride = self.col_stride();
let ncols = self.ncols().unbound();
unsafe {
let top = self.as_ptr();
let bot = self.overflowing_ptr_at(col);
let col = col.unbound();
(
RowRef::__from_raw_parts(top, col, col_stride),
RowRef::__from_raw_parts(bot, ncols - col, col_stride),
)
}
}
#[inline(always)]
#[track_caller]
pub fn split_at(self, col: IdxInc<C>) -> (RowRef<'a, E, usize>, RowRef<'a, E, usize>) {
assert!(col <= self.ncols());
unsafe { self.split_at_unchecked(col) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn get_unchecked<ColRange>(
self,
col: ColRange,
) -> <Self as RowIndex<ColRange>>::Target
where
Self: RowIndex<ColRange>,
{
<Self as RowIndex<ColRange>>::get_unchecked(self, col)
}
#[inline(always)]
#[track_caller]
pub fn get<ColRange>(self, col: ColRange) -> <Self as RowIndex<ColRange>>::Target
where
Self: RowIndex<ColRange>,
{
<Self as RowIndex<ColRange>>::get(self, col)
}
#[inline(always)]
#[track_caller]
pub unsafe fn at_unchecked(self, col: Idx<C>) -> Ref<'a, E> {
self.transpose().at_unchecked(col)
}
#[inline(always)]
#[track_caller]
pub fn at(self, col: Idx<C>) -> Ref<'a, E> {
self.transpose().at(col)
}
#[inline(always)]
#[track_caller]
pub unsafe fn read_unchecked(&self, col: Idx<C>) -> E {
E::faer_from_units(E::faer_map(
self.at_unchecked(col),
#[inline(always)]
|ptr| *ptr,
))
}
#[inline(always)]
#[track_caller]
pub fn read(&self, col: Idx<C>) -> E {
E::faer_from_units(E::faer_map(
self.at(col),
#[inline(always)]
|ptr| *ptr,
))
}
#[inline(always)]
#[must_use]
pub fn transpose(self) -> ColRef<'a, E, C> {
unsafe { ColRef::__from_raw_parts(self.as_ptr(), self.ncols(), self.col_stride()) }
}
#[inline(always)]
#[must_use]
pub fn conjugate(self) -> RowRef<'a, E::Conj, C>
where
E: Conjugate,
{
unsafe {
super::from_raw_parts(
transmute_unchecked::<
GroupFor<E, *const UnitFor<E>>,
GroupFor<E::Conj, *const UnitFor<E::Conj>>,
>(self.as_ptr()),
self.ncols(),
self.col_stride(),
)
}
}
#[inline(always)]
pub fn adjoint(self) -> ColRef<'a, E::Conj, C>
where
E: Conjugate,
{
self.conjugate().transpose()
}
#[inline(always)]
pub fn canonicalize(self) -> (RowRef<'a, E::Canonical, C>, Conj)
where
E: Conjugate,
{
(
unsafe {
super::from_raw_parts(
transmute_unchecked::<
PtrConst<E>,
GroupFor<E::Canonical, *const UnitFor<E::Canonical>>,
>(self.as_ptr()),
self.ncols(),
self.col_stride(),
)
},
if coe::is_same::<E, E::Canonical>() {
Conj::No
} else {
Conj::Yes
},
)
}
#[inline(always)]
#[must_use]
pub fn reverse_cols(self) -> Self {
let ncols = self.ncols();
let col_stride = self.col_stride().wrapping_neg();
let ptr = unsafe { self.ptr_at_unchecked(ncols.unbound().saturating_sub(1)) };
unsafe { Self::__from_raw_parts(ptr, ncols, col_stride) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn subcols_unchecked<H: Shape>(
self,
col_start: IdxInc<C>,
ncols: H,
) -> RowRef<'a, E, H> {
debug_assert!(col_start <= self.ncols());
{
let ncols = ncols.unbound();
let col_start = col_start.unbound();
debug_assert!(ncols <= self.ncols().unbound() - col_start);
}
let col_stride = self.col_stride();
unsafe { RowRef::__from_raw_parts(self.overflowing_ptr_at(col_start), ncols, col_stride) }
}
#[track_caller]
#[inline(always)]
pub fn subcols<H: Shape>(self, col_start: IdxInc<C>, ncols: H) -> RowRef<'a, E, H> {
assert!(col_start <= self.ncols());
{
let ncols = ncols.unbound();
let col_start = col_start.unbound();
assert!(ncols <= self.ncols().unbound() - col_start);
}
unsafe { self.subcols_unchecked(col_start, ncols) }
}
#[inline]
pub fn to_owned(&self) -> Row<E::Canonical, C>
where
E: Conjugate,
{
Row::from_fn(
self.ncols(),
#[inline(always)]
|i| unsafe { self.read_unchecked(i) }.canonicalize(),
)
}
#[inline]
pub fn has_nan(&self) -> bool
where
E: ComplexField,
{
(*self).rb().as_2d().has_nan()
}
#[inline]
pub fn is_all_finite(&self) -> bool
where
E: ComplexField,
{
(*self).rb().as_2d().is_all_finite()
}
#[inline]
pub fn norm_max(&self) -> E::Real
where
E: ComplexField,
{
self.as_2d().norm_max()
}
#[inline]
pub fn norm_l1(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().norm_l1()
}
#[inline]
pub fn norm_l2(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().norm_l2()
}
#[inline]
pub fn squared_norm_l2(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().squared_norm_l2()
}
#[inline]
pub fn sum(&self) -> E
where
E: ComplexField,
{
self.as_2d().sum()
}
#[inline]
#[track_caller]
pub fn kron(&self, rhs: impl As2D<E>) -> Mat<E>
where
E: ComplexField,
{
self.as_2d().kron(rhs)
}
#[inline]
pub fn try_as_slice(self) -> Option<Slice<'a, E>> {
if self.col_stride() == 1 {
let len = self.ncols().unbound();
Some(E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| unsafe { core::slice::from_raw_parts(ptr, len) },
))
} else {
None
}
}
#[inline]
pub fn as_ref(&self) -> RowRef<'_, E, C> {
*self
}
#[inline]
pub fn split_first(self) -> Option<(Ref<'a, E>, RowRef<'a, E>)> {
let this = self.as_dyn();
if this.ncols() == 0 {
None
} else {
unsafe {
let (head, tail) = { this.split_at_unchecked(1) };
Some((head.get_unchecked(0), tail))
}
}
}
#[inline]
pub fn split_last(self) -> Option<(Ref<'a, E>, RowRef<'a, E>)> {
let this = self.as_dyn();
if this.ncols() == 0 {
None
} else {
unsafe {
let (head, tail) = { this.split_at_unchecked(this.ncols() - 1) };
Some((tail.get_unchecked(0), head))
}
}
}
#[inline]
pub fn iter(self) -> iter::ElemIter<'a, E> {
iter::ElemIter {
inner: self.transpose().as_dyn(),
}
}
#[inline]
#[track_caller]
pub fn chunks(self, chunk_size: usize) -> iter::RowElemChunks<'a, E> {
assert!(chunk_size > 0);
iter::RowElemChunks {
inner: self.as_dyn(),
policy: iter::chunks::ChunkSizePolicy::new(
self.ncols().unbound(),
iter::chunks::ChunkSize(chunk_size),
),
}
}
#[inline]
#[track_caller]
pub fn partition(self, count: usize) -> iter::RowElemPartition<'a, E> {
assert!(count > 0);
iter::RowElemPartition {
inner: self.as_dyn(),
policy: iter::chunks::PartitionCountPolicy::new(
self.ncols().unbound(),
iter::chunks::PartitionCount(count),
),
}
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_chunks(
self,
chunk_size: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, E>> {
use rayon::prelude::*;
self.transpose()
.par_chunks(chunk_size)
.map(|x| x.transpose())
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_partition(
self,
count: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = RowRef<'a, E>> {
use rayon::prelude::*;
self.transpose().par_partition(count).map(|x| x.transpose())
}
}
#[inline(always)]
pub unsafe fn from_raw_parts<'a, E: Entity, C: Shape>(
ptr: PtrConst<E>,
ncols: C,
col_stride: isize,
) -> RowRef<'a, E, C> {
RowRef::__from_raw_parts(ptr, ncols, col_stride)
}
#[inline(always)]
pub fn from_slice_generic<E: Entity>(slice: Slice<'_, E>) -> RowRef<'_, E> {
let nrows = SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len();
unsafe {
from_raw_parts(
E::faer_map(
slice,
#[inline(always)]
|slice| slice.as_ptr(),
),
nrows,
1,
)
}
}
#[inline(always)]
pub fn from_slice<E: SimpleEntity>(slice: &[E]) -> RowRef<'_, E> {
from_slice_generic(slice)
}
impl<E: Entity, C: Shape> As2D<E> for RowRef<'_, E, C> {
#[inline]
fn as_2d_ref(&self) -> MatRef<'_, E> {
(*self).as_2d().as_dyn()
}
}
impl<E: Entity, C: Shape> AsRowRef<E> for RowRef<'_, E, C> {
type C = C;
#[inline]
fn as_row_ref(&self) -> RowRef<'_, E, C> {
*self
}
}
impl<'a, E: Entity, C: Shape> core::fmt::Debug for RowRef<'a, E, C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.transpose().fmt(f)
}
}
impl<E: SimpleEntity, C: Shape> core::ops::Index<Idx<C>> for RowRef<'_, E, C> {
type Output = E;
#[inline]
#[track_caller]
fn index(&self, col: Idx<C>) -> &E {
self.at(col)
}
}
impl<E: Conjugate> RowBatch<E> for RowRef<'_, E> {
type Owned = Row<E::Canonical>;
#[inline]
#[track_caller]
fn new_owned_zeros(nrows: usize, ncols: usize) -> Self::Owned {
assert!(nrows == 1);
Row::zeros(ncols)
}
#[inline]
fn new_owned_copied(src: &Self) -> Self::Owned {
src.to_owned()
}
#[inline]
fn resize_owned(owned: &mut Self::Owned, nrows: usize, ncols: usize) {
<Self::Owned as RowBatch<E::Canonical>>::resize_owned(owned, nrows, ncols)
}
}
#[doc(alias = "broadcast")]
pub fn from_repeated_ref<E: SimpleEntity>(value: &E, ncols: usize) -> RowRef<'_, E> {
unsafe { from_raw_parts(E::faer_map(value, |ptr| ptr as *const E::Unit), ncols, 0) }
}
pub fn from_ref<E: SimpleEntity>(value: &E) -> RowRef<'_, E> {
from_ref_generic(value)
}
#[doc(alias = "broadcast")]
pub fn from_repeated_ref_generic<E: Entity>(value: Ref<'_, E>, ncols: usize) -> RowRef<'_, E> {
unsafe { from_raw_parts(E::faer_map(value, |ptr| ptr as *const E::Unit), ncols, 0) }
}
pub fn from_ref_generic<E: Entity>(value: Ref<'_, E>) -> RowRef<'_, E> {
from_repeated_ref_generic(value, 1)
}