use super::*;
use crate::{
assert, debug_assert,
diag::DiagRef,
iter::{self, chunks::ChunkPolicy},
row::RowRef,
Idx, IdxInc, Unbind,
};
#[repr(C)]
pub struct ColRef<'a, E: Entity, R: Shape = usize> {
pub(super) inner: VecImpl<E, R>,
pub(super) __marker: PhantomData<&'a E>,
}
impl<E: Entity, R: Shape> Clone for ColRef<'_, E, R> {
#[inline]
fn clone(&self) -> Self {
*self
}
}
impl<E: Entity, R: Shape> Copy for ColRef<'_, E, R> {}
impl<E: Entity> Default for ColRef<'_, E> {
#[inline]
fn default() -> Self {
from_slice_generic::<E>(E::faer_map(E::UNIT, |()| &[] as &[E::Unit]))
}
}
impl<'short, E: Entity, R: Shape> Reborrow<'short> for ColRef<'_, E, R> {
type Target = ColRef<'short, E, R>;
#[inline]
fn rb(&'short self) -> Self::Target {
*self
}
}
impl<'short, E: Entity, R: Shape> ReborrowMut<'short> for ColRef<'_, E, R> {
type Target = ColRef<'short, E, R>;
#[inline]
fn rb_mut(&'short mut self) -> Self::Target {
*self
}
}
impl<E: Entity, R: Shape> IntoConst for ColRef<'_, E, R> {
type Target = Self;
#[inline]
fn into_const(self) -> Self::Target {
self
}
}
impl<'a, E: Entity, R: Shape> ColRef<'a, E, R> {
#[inline]
pub(crate) unsafe fn __from_raw_parts(ptr: PtrConst<E>, nrows: R, row_stride: isize) -> Self {
Self {
inner: VecImpl {
ptr: into_copy::<E, _>(E::faer_map(
ptr,
#[inline]
|ptr| NonNull::new_unchecked(ptr as *mut E::Unit),
)),
len: nrows,
stride: row_stride,
},
__marker: PhantomData,
}
}
#[inline(always)]
pub fn nrows(&self) -> R {
self.inner.len
}
#[inline(always)]
pub fn ncols(&self) -> usize {
1
}
#[inline(always)]
pub fn as_ptr(self) -> PtrConst<E> {
E::faer_map(
from_copy::<E, _>(self.inner.ptr),
#[inline(always)]
|ptr| ptr.as_ptr() as *const E::Unit,
)
}
#[inline(always)]
pub fn row_stride(&self) -> isize {
self.inner.stride
}
#[inline(always)]
pub fn as_2d(self) -> MatRef<'a, E, R, usize> {
let nrows = self.nrows();
let row_stride = self.row_stride();
unsafe { crate::mat::from_raw_parts(self.as_ptr(), nrows, 1, row_stride, isize::MAX) }
}
#[inline(always)]
pub fn ptr_at(self, row: usize) -> PtrConst<E> {
let offset = (row as isize).wrapping_mul(self.inner.stride);
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| ptr.wrapping_offset(offset),
)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn ptr_at_unchecked(self, row: usize) -> PtrConst<E> {
let offset = crate::utils::unchecked_mul(row, self.inner.stride);
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| ptr.offset(offset),
)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn overflowing_ptr_at(self, row: IdxInc<R>) -> PtrConst<E> {
unsafe {
let cond = row != self.nrows();
let offset = (cond as usize).wrapping_neg() as isize
& (row.unbound() as isize).wrapping_mul(self.inner.stride);
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| ptr.offset(offset),
)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn ptr_inbounds_at(self, row: Idx<R>) -> PtrConst<E> {
debug_assert!(row < self.nrows());
self.ptr_at_unchecked(row.unbound())
}
#[inline]
pub fn as_dyn(self) -> ColRef<'a, E> {
let nrows = self.nrows().unbound();
let row_stride = self.row_stride();
unsafe { from_raw_parts(self.as_ptr(), nrows, row_stride) }
}
#[inline]
pub fn as_shape<V: Shape>(self, nrows: V) -> ColRef<'a, E, V> {
assert!(nrows.unbound() == self.nrows().unbound());
unsafe { from_raw_parts(self.as_ptr(), nrows, self.row_stride()) }
}
#[doc(hidden)]
#[inline(always)]
pub unsafe fn const_cast(self) -> ColMut<'a, E, R> {
ColMut {
inner: self.inner,
__marker: PhantomData,
}
}
#[track_caller]
#[inline(always)]
#[doc(hidden)]
pub fn try_get_contiguous_col(self) -> Slice<'a, E> {
assert!(self.row_stride() == 1);
let m = self.nrows().unbound();
E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| unsafe { core::slice::from_raw_parts(ptr, m) },
)
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_unchecked(
self,
row: IdxInc<R>,
) -> (ColRef<'a, E, usize>, ColRef<'a, E, usize>) {
debug_assert!(row <= self.nrows());
let row_stride = self.row_stride();
unsafe {
let top = self.as_ptr();
let bot = self.overflowing_ptr_at(row);
let row = row.unbound();
let nrows = self.nrows().unbound();
(
ColRef::__from_raw_parts(top, row, row_stride),
ColRef::__from_raw_parts(bot, nrows - row, row_stride),
)
}
}
#[inline(always)]
#[track_caller]
pub fn split_at(self, row: IdxInc<R>) -> (ColRef<'a, E, usize>, ColRef<'a, E, usize>) {
assert!(row <= self.nrows());
unsafe { self.split_at_unchecked(row) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn get_unchecked<RowRange>(
self,
row: RowRange,
) -> <Self as ColIndex<RowRange>>::Target
where
Self: ColIndex<RowRange>,
{
<Self as ColIndex<RowRange>>::get_unchecked(self, row)
}
#[inline(always)]
#[track_caller]
pub fn get<RowRange>(self, row: RowRange) -> <Self as ColIndex<RowRange>>::Target
where
Self: ColIndex<RowRange>,
{
<Self as ColIndex<RowRange>>::get(self, row)
}
#[inline(always)]
#[track_caller]
pub unsafe fn at_unchecked(self, row: Idx<R>) -> Ref<'a, E> {
E::faer_map(
self.ptr_inbounds_at(row),
#[inline(always)]
|ptr| &*ptr,
)
}
#[inline(always)]
#[track_caller]
pub fn at(self, row: Idx<R>) -> Ref<'a, E> {
assert!(row < self.nrows());
unsafe { self.at_unchecked(row) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn read_unchecked(&self, row: Idx<R>) -> E {
E::faer_from_units(E::faer_map(
self.at_unchecked(row),
#[inline(always)]
|ptr| *ptr,
))
}
#[inline(always)]
#[track_caller]
pub fn read(&self, row: Idx<R>) -> E {
E::faer_from_units(E::faer_map(
self.at(row),
#[inline(always)]
|ptr| *ptr,
))
}
#[inline(always)]
#[must_use]
pub fn transpose(self) -> RowRef<'a, E, R> {
unsafe { crate::row::from_raw_parts(self.as_ptr(), self.nrows(), self.row_stride()) }
}
#[inline(always)]
#[must_use]
pub fn conjugate(self) -> ColRef<'a, E::Conj, R>
where
E: Conjugate,
{
unsafe {
super::from_raw_parts(
transmute_unchecked::<
GroupFor<E, *const UnitFor<E>>,
GroupFor<E::Conj, *const UnitFor<E::Conj>>,
>(self.as_ptr()),
self.nrows(),
self.row_stride(),
)
}
}
#[inline(always)]
pub fn adjoint(self) -> RowRef<'a, E::Conj, R>
where
E: Conjugate,
{
self.conjugate().transpose()
}
#[inline(always)]
pub fn canonicalize(self) -> (ColRef<'a, E::Canonical, R>, Conj)
where
E: Conjugate,
{
(
unsafe {
super::from_raw_parts(
transmute_unchecked::<
PtrConst<E>,
GroupFor<E::Canonical, *const UnitFor<E::Canonical>>,
>(self.as_ptr()),
self.nrows(),
self.row_stride(),
)
},
if coe::is_same::<E, E::Canonical>() {
Conj::No
} else {
Conj::Yes
},
)
}
#[inline(always)]
#[must_use]
pub fn reverse_rows(self) -> Self {
let nrows = self.nrows();
let row_stride = self.row_stride().wrapping_neg();
let ptr = unsafe { self.ptr_at_unchecked(nrows.unbound().saturating_sub(1)) };
unsafe { Self::__from_raw_parts(ptr, nrows, row_stride) }
}
#[track_caller]
#[inline(always)]
pub unsafe fn subrows_unchecked<V: Shape>(
self,
row_start: IdxInc<R>,
nrows: V,
) -> ColRef<'a, E, V> {
debug_assert!(all(row_start <= self.nrows()));
{
let nrows = nrows.unbound();
let row_start = row_start.unbound();
debug_assert!(all(nrows <= self.nrows().unbound() - row_start));
}
let row_stride = self.row_stride();
unsafe { ColRef::__from_raw_parts(self.overflowing_ptr_at(row_start), nrows, row_stride) }
}
#[track_caller]
#[inline(always)]
pub fn subrows<V: Shape>(self, row_start: IdxInc<R>, nrows: V) -> ColRef<'a, E, V> {
assert!(all(row_start <= self.nrows()));
{
let nrows = nrows.unbound();
let row_start = row_start.unbound();
assert!(all(nrows <= self.nrows().unbound() - row_start));
}
unsafe { self.subrows_unchecked(row_start, nrows) }
}
#[track_caller]
#[inline(always)]
pub fn column_vector_as_diagonal(self) -> DiagRef<'a, E, R> {
DiagRef { inner: self }
}
#[inline]
pub fn to_owned(&self) -> Col<E::Canonical, R>
where
E: Conjugate,
{
Col::from_fn(
self.nrows(),
#[inline(always)]
|row| unsafe { self.read_unchecked(row).canonicalize() },
)
}
#[inline]
pub fn has_nan(&self) -> bool
where
E: ComplexField,
{
(*self).as_2d().has_nan()
}
#[inline]
pub fn is_all_finite(&self) -> bool
where
E: ComplexField,
{
(*self).rb().as_2d().is_all_finite()
}
#[inline]
pub fn norm_max(&self) -> E::Real
where
E: ComplexField,
{
self.as_2d().norm_max()
}
#[inline]
pub fn norm_l1(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().norm_l1()
}
#[inline]
pub fn norm_l2(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().norm_l2()
}
#[inline]
pub fn squared_norm_l2(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().squared_norm_l2()
}
#[inline]
pub fn sum(&self) -> E
where
E: ComplexField,
{
self.as_2d().sum()
}
#[inline]
#[track_caller]
pub fn kron(&self, rhs: impl As2D<E>) -> Mat<E>
where
E: ComplexField,
{
self.as_2d().kron(rhs)
}
#[inline]
pub fn try_as_slice(self) -> Option<Slice<'a, E>> {
if self.row_stride() == 1 {
let len = self.nrows().unbound();
Some(E::faer_map(
self.as_ptr(),
#[inline(always)]
|ptr| unsafe { core::slice::from_raw_parts(ptr, len) },
))
} else {
None
}
}
#[inline]
pub fn as_ref(&self) -> ColRef<'_, E, R> {
*self
}
#[inline]
pub fn split_first(self) -> Option<(Ref<'a, E>, ColRef<'a, E>)> {
let this = self.as_dyn();
if this.nrows() == 0 {
None
} else {
unsafe {
let (head, tail) = { this.split_at_unchecked(1) };
Some((head.get_unchecked(0), tail))
}
}
}
#[inline]
pub fn split_last(self) -> Option<(Ref<'a, E>, ColRef<'a, E>)> {
let this = self.as_dyn();
if this.nrows() == 0 {
None
} else {
unsafe {
let (head, tail) = { this.split_at_unchecked(this.nrows() - 1) };
Some((tail.get_unchecked(0), head))
}
}
}
#[inline]
pub fn iter(self) -> iter::ElemIter<'a, E> {
iter::ElemIter {
inner: self.as_dyn(),
}
}
#[inline]
#[track_caller]
pub fn chunks(self, chunk_size: usize) -> iter::ColElemChunks<'a, E> {
assert!(chunk_size > 0);
iter::ColElemChunks {
inner: self.as_dyn(),
policy: iter::chunks::ChunkSizePolicy::new(
self.nrows().unbound(),
iter::chunks::ChunkSize(chunk_size),
),
}
}
#[inline]
#[track_caller]
pub fn partition(self, count: usize) -> iter::ColElemPartition<'a, E> {
assert!(count > 0);
iter::ColElemPartition {
inner: self.as_dyn(),
policy: iter::chunks::PartitionCountPolicy::new(
self.nrows().unbound(),
iter::chunks::PartitionCount(count),
),
}
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_chunks(
self,
chunk_size: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, E>> {
use rayon::prelude::*;
self.as_2d()
.par_row_chunks(chunk_size)
.map(|chunk| chunk.col(0))
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_partition(
self,
count: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, E>> {
use rayon::prelude::*;
self.as_2d()
.par_row_partition(count)
.map(|chunk| chunk.col(0))
}
}
#[inline(always)]
pub unsafe fn from_raw_parts<'a, E: Entity, R: Shape>(
ptr: PtrConst<E>,
nrows: R,
row_stride: isize,
) -> ColRef<'a, E, R> {
ColRef::__from_raw_parts(ptr, nrows, row_stride)
}
#[inline(always)]
pub fn from_slice_generic<E: Entity>(slice: Slice<'_, E>) -> ColRef<'_, E> {
let nrows = SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len();
unsafe {
from_raw_parts(
E::faer_map(
slice,
#[inline(always)]
|slice| slice.as_ptr(),
),
nrows,
1,
)
}
}
#[inline(always)]
pub fn from_slice<E: SimpleEntity>(slice: &[E]) -> ColRef<'_, E> {
from_slice_generic(slice)
}
impl<E: Entity, R: Shape> As2D<E> for ColRef<'_, E, R> {
#[inline]
fn as_2d_ref(&self) -> MatRef<'_, E> {
(*self).as_2d().as_dyn()
}
}
impl<E: Entity, R: Shape> AsColRef<E> for ColRef<'_, E, R> {
type R = R;
#[inline]
fn as_col_ref(&self) -> ColRef<'_, E, R> {
*self
}
}
impl<'a, E: Entity, R: Shape> core::fmt::Debug for ColRef<'a, E, R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_list()
.entries(self.iter().map(|x| E::faer_from_units(E::faer_deref(x))))
.finish()
}
}
impl<E: SimpleEntity> core::ops::Index<usize> for ColRef<'_, E> {
type Output = E;
#[inline]
#[track_caller]
fn index(&self, row: usize) -> &E {
self.get(row)
}
}
impl<E: Conjugate> ColBatch<E> for ColRef<'_, E> {
type Owned = Col<E::Canonical>;
#[inline]
#[track_caller]
fn new_owned_zeros(nrows: usize, ncols: usize) -> Self::Owned {
assert!(ncols == 1);
Col::zeros(nrows)
}
#[inline]
fn new_owned_copied(src: &Self) -> Self::Owned {
src.to_owned()
}
#[inline]
#[track_caller]
fn resize_owned(owned: &mut Self::Owned, nrows: usize, ncols: usize) {
<Self::Owned as ColBatch<E::Canonical>>::resize_owned(owned, nrows, ncols)
}
}
#[doc(alias = "broadcast")]
pub fn from_repeated_ref_generic<E: Entity, R: Shape>(
value: Ref<'_, E>,
nrows: R,
) -> ColRef<'_, E, R> {
unsafe { from_raw_parts(E::faer_map(value, |ptr| ptr as *const E::Unit), nrows, 0) }
}
pub fn from_ref_generic<E: Entity>(value: Ref<'_, E>) -> ColRef<'_, E> {
from_repeated_ref_generic(value, 1)
}
#[doc(alias = "broadcast")]
pub fn from_repeated_ref<E: SimpleEntity, R: Shape>(value: &E, nrows: R) -> ColRef<'_, E, R> {
from_repeated_ref_generic(value, nrows)
}
pub fn from_ref<E: SimpleEntity>(value: &E) -> ColRef<'_, E> {
from_ref_generic(value)
}