use super::*;
use crate::{
diag::{DiagMut, DiagRef},
iter,
iter::chunks::ChunkPolicy,
row::{RowMut, RowRef},
unzipped, zipped_rw, Idx, IdxInc, Unbind,
};
#[repr(C)]
pub struct ColMut<'a, E: Entity, R: Shape = usize> {
pub(super) inner: VecImpl<E, R>,
pub(super) __marker: PhantomData<&'a E>,
}
impl<E: Entity> Default for ColMut<'_, E> {
#[inline]
fn default() -> Self {
from_slice_mut_generic::<E>(E::faer_map(E::UNIT, |()| &mut [] as &mut [E::Unit]))
}
}
impl<'short, E: Entity, R: Shape> Reborrow<'short> for ColMut<'_, E, R> {
type Target = ColRef<'short, E, R>;
#[inline]
fn rb(&'short self) -> Self::Target {
ColRef {
inner: self.inner,
__marker: PhantomData,
}
}
}
impl<'short, E: Entity, R: Shape> ReborrowMut<'short> for ColMut<'_, E, R> {
type Target = ColMut<'short, E, R>;
#[inline]
fn rb_mut(&'short mut self) -> Self::Target {
ColMut {
inner: self.inner,
__marker: PhantomData,
}
}
}
impl<'a, E: Entity, R: Shape> IntoConst for ColMut<'a, E, R> {
type Target = ColRef<'a, E, R>;
#[inline]
fn into_const(self) -> Self::Target {
ColRef {
inner: self.inner,
__marker: PhantomData,
}
}
}
impl<'a, E: Entity, R: Shape> ColMut<'a, E, R> {
#[inline]
pub(crate) unsafe fn __from_raw_parts(ptr: PtrMut<E>, nrows: R, row_stride: isize) -> Self {
Self {
inner: VecImpl {
ptr: into_copy::<E, _>(E::faer_map(
ptr,
#[inline]
|ptr| NonNull::new_unchecked(ptr),
)),
len: nrows,
stride: row_stride,
},
__marker: PhantomData,
}
}
#[inline(always)]
pub fn nrows(&self) -> R {
self.inner.len
}
#[inline(always)]
pub fn ncols(&self) -> usize {
1
}
#[inline(always)]
pub fn as_ptr(self) -> PtrConst<E> {
self.into_const().as_ptr()
}
#[inline(always)]
pub fn as_ptr_mut(self) -> PtrMut<E> {
E::faer_map(
from_copy::<E, _>(self.inner.ptr),
#[inline(always)]
|ptr| ptr.as_ptr() as *mut E::Unit,
)
}
#[inline(always)]
pub fn row_stride(&self) -> isize {
self.inner.stride
}
#[inline(always)]
pub fn as_2d(self) -> MatRef<'a, E, R> {
self.into_const().as_2d()
}
#[inline(always)]
pub fn as_2d_mut(self) -> MatMut<'a, E, R> {
unsafe { self.into_const().as_2d().const_cast() }
}
#[inline(always)]
pub fn ptr_at(self, row: usize) -> PtrConst<E> {
self.into_const().ptr_at(row)
}
#[inline(always)]
pub fn ptr_at_mut(self, row: usize) -> PtrMut<E> {
let offset = (row as isize).wrapping_mul(self.inner.stride);
E::faer_map(
self.as_ptr_mut(),
#[inline(always)]
|ptr| ptr.wrapping_offset(offset),
)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn ptr_at_unchecked(self, row: usize) -> PtrConst<E> {
self.into_const().ptr_at_unchecked(row)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn ptr_at_mut_unchecked(self, row: usize) -> PtrMut<E> {
let offset = crate::utils::unchecked_mul(row, self.inner.stride);
E::faer_map(
self.as_ptr_mut(),
#[inline(always)]
|ptr| ptr.offset(offset),
)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn overflowing_ptr_at(self, row: IdxInc<R>) -> PtrConst<E> {
self.into_const().overflowing_ptr_at(row)
}
#[inline(always)]
#[doc(hidden)]
pub unsafe fn overflowing_ptr_at_mut(self, row: IdxInc<R>) -> PtrMut<E> {
unsafe {
let cond = row != self.nrows();
let offset = (cond as usize).wrapping_neg() as isize
& (row.unbound() as isize).wrapping_mul(self.inner.stride);
E::faer_map(
self.as_ptr_mut(),
#[inline(always)]
|ptr| ptr.offset(offset),
)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn ptr_inbounds_at(self, row: Idx<R>) -> PtrConst<E> {
self.into_const().ptr_inbounds_at(row)
}
#[inline(always)]
#[track_caller]
pub unsafe fn ptr_inbounds_at_mut(self, row: Idx<R>) -> PtrMut<E> {
debug_assert!(row < self.nrows());
self.ptr_at_mut_unchecked(row.unbound())
}
#[inline]
pub fn as_dyn(self) -> ColRef<'a, E> {
let nrows = self.nrows().unbound();
let row_stride = self.row_stride();
unsafe { from_raw_parts(self.as_ptr(), nrows, row_stride) }
}
#[inline]
pub fn as_dyn_mut(self) -> ColMut<'a, E> {
let nrows = self.nrows().unbound();
let row_stride = self.row_stride();
unsafe { from_raw_parts_mut(self.as_ptr_mut(), nrows, row_stride) }
}
#[inline]
pub fn as_shape<V: Shape>(self, nrows: V) -> ColRef<'a, E, V> {
self.into_const().as_shape(nrows)
}
#[inline]
pub fn as_shape_mut<V: Shape>(self, nrows: V) -> ColMut<'a, E, V> {
unsafe { self.into_const().as_shape(nrows).const_cast() }
}
#[track_caller]
#[inline(always)]
#[doc(hidden)]
pub fn try_get_contiguous_col(self) -> Slice<'a, E> {
self.into_const().try_get_contiguous_col()
}
#[track_caller]
#[inline(always)]
#[doc(hidden)]
pub fn try_get_contiguous_col_mut(self) -> SliceMut<'a, E> {
assert!(self.row_stride() == 1);
let m = self.nrows().unbound();
E::faer_map(
self.as_ptr_mut(),
#[inline(always)]
|ptr| unsafe { core::slice::from_raw_parts_mut(ptr, m) },
)
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_unchecked(self, row: IdxInc<R>) -> (ColRef<'a, E>, ColRef<'a, E>) {
self.into_const().split_at_unchecked(row)
}
#[inline(always)]
#[track_caller]
pub unsafe fn split_at_mut_unchecked(self, row: IdxInc<R>) -> (ColMut<'a, E>, ColMut<'a, E>) {
let (top, bot) = self.into_const().split_at_unchecked(row);
unsafe { (top.const_cast(), bot.const_cast()) }
}
#[inline(always)]
#[track_caller]
pub fn split_at(self, row: IdxInc<R>) -> (ColRef<'a, E>, ColRef<'a, E>) {
self.into_const().split_at(row)
}
#[inline(always)]
#[track_caller]
pub fn split_at_mut(self, row: IdxInc<R>) -> (ColMut<'a, E>, ColMut<'a, E>) {
assert!(row <= self.nrows());
unsafe { self.split_at_mut_unchecked(row) }
}
#[inline(always)]
#[track_caller]
pub unsafe fn get_unchecked<RowRange>(
self,
row: RowRange,
) -> <ColRef<'a, E, R> as ColIndex<RowRange>>::Target
where
ColRef<'a, E, R>: ColIndex<RowRange>,
{
self.into_const().get_unchecked(row)
}
#[inline(always)]
#[track_caller]
pub unsafe fn get_mut_unchecked<RowRange>(
self,
row: RowRange,
) -> <Self as ColIndex<RowRange>>::Target
where
Self: ColIndex<RowRange>,
{
<Self as ColIndex<RowRange>>::get_unchecked(self, row)
}
#[inline(always)]
#[track_caller]
pub fn get<RowRange>(self, row: RowRange) -> <ColRef<'a, E, R> as ColIndex<RowRange>>::Target
where
ColRef<'a, E, R>: ColIndex<RowRange>,
{
self.into_const().get(row)
}
#[inline(always)]
#[track_caller]
pub fn get_mut<RowRange>(self, row: RowRange) -> <Self as ColIndex<RowRange>>::Target
where
Self: ColIndex<RowRange>,
{
<Self as ColIndex<RowRange>>::get(self, row)
}
#[inline(always)]
#[track_caller]
pub fn at(self, row: Idx<R>) -> Ref<'a, E> {
self.into_const().at(row)
}
#[inline(always)]
#[track_caller]
pub fn at_mut(self, row: Idx<R>) -> Mut<'a, E> {
assert!(row < self.nrows());
unsafe {
E::faer_map(
self.ptr_inbounds_at_mut(row),
#[inline(always)]
|ptr| &mut *ptr,
)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn at_unchecked(self, row: Idx<R>) -> Ref<'a, E> {
self.into_const().at_unchecked(row)
}
#[inline(always)]
#[track_caller]
pub unsafe fn at_mut_unchecked(self, row: Idx<R>) -> Mut<'a, E> {
unsafe {
E::faer_map(
self.ptr_inbounds_at_mut(row),
#[inline(always)]
|ptr| &mut *ptr,
)
}
}
#[inline(always)]
#[track_caller]
pub unsafe fn read_unchecked(&self, row: Idx<R>) -> E {
self.rb().read_unchecked(row)
}
#[inline(always)]
#[track_caller]
pub fn read(&self, row: Idx<R>) -> E {
self.rb().read(row)
}
#[inline(always)]
#[track_caller]
pub unsafe fn write_unchecked(&mut self, row: Idx<R>, value: E) {
let units = value.faer_into_units();
let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(row));
E::faer_map(
zipped,
#[inline(always)]
|(unit, ptr)| *ptr = unit,
);
}
#[inline(always)]
#[track_caller]
pub fn write(&mut self, row: Idx<R>, value: E) {
assert!(row < self.nrows());
unsafe { self.write_unchecked(row, value) };
}
#[track_caller]
pub fn copy_from<ViewE: Conjugate<Canonical = E>>(
&mut self,
other: impl AsColRef<ViewE, R = R>,
) {
#[track_caller]
#[inline(always)]
fn implementation<R: Shape, E: Entity, ViewE: Conjugate<Canonical = E>>(
this: ColMut<'_, E, R>,
other: ColRef<'_, ViewE, R>,
) {
zipped_rw!(this, other)
.for_each(|unzipped!(mut dst, src)| dst.write(src.read().canonicalize()));
}
implementation(self.rb_mut(), other.as_col_ref())
}
#[track_caller]
pub fn fill_zero(&mut self)
where
E: ComplexField,
{
zipped_rw!(self.rb_mut()).for_each(
#[inline(always)]
|unzipped!(mut x)| x.write(E::faer_zero()),
);
}
#[track_caller]
pub fn fill(&mut self, constant: E) {
zipped_rw!((*self).rb_mut()).for_each(
#[inline(always)]
|unzipped!(mut x)| x.write(constant),
);
}
#[inline(always)]
#[must_use]
pub fn transpose(self) -> RowRef<'a, E, R> {
self.into_const().transpose()
}
#[inline(always)]
#[must_use]
pub fn transpose_mut(self) -> RowMut<'a, E, R> {
unsafe { self.into_const().transpose().const_cast() }
}
#[inline(always)]
#[must_use]
pub fn conjugate(self) -> ColRef<'a, E::Conj, R>
where
E: Conjugate,
{
self.into_const().conjugate()
}
#[inline(always)]
#[must_use]
pub fn conjugate_mut(self) -> ColMut<'a, E::Conj, R>
where
E: Conjugate,
{
unsafe { self.into_const().conjugate().const_cast() }
}
#[inline(always)]
pub fn adjoint(self) -> RowRef<'a, E::Conj, R>
where
E: Conjugate,
{
self.into_const().adjoint()
}
#[inline(always)]
pub fn adjoint_mut(self) -> RowMut<'a, E::Conj, R>
where
E: Conjugate,
{
self.conjugate_mut().transpose_mut()
}
#[inline(always)]
pub fn canonicalize(self) -> (ColRef<'a, E::Canonical, R>, Conj)
where
E: Conjugate,
{
self.into_const().canonicalize()
}
#[inline(always)]
pub fn canonicalize_mut(self) -> (ColMut<'a, E::Canonical, R>, Conj)
where
E: Conjugate,
{
let (canon, conj) = self.into_const().canonicalize();
unsafe { (canon.const_cast(), conj) }
}
#[inline(always)]
#[must_use]
pub fn reverse_rows(self) -> ColRef<'a, E, R> {
self.into_const().reverse_rows()
}
#[inline(always)]
#[must_use]
pub fn reverse_rows_mut(self) -> Self {
unsafe { self.into_const().reverse_rows().const_cast() }
}
#[track_caller]
#[inline(always)]
pub unsafe fn subrows_unchecked<V: Shape>(
self,
row_start: IdxInc<R>,
nrows: V,
) -> ColRef<'a, E, V> {
self.into_const().subrows_unchecked(row_start, nrows)
}
#[track_caller]
#[inline(always)]
pub unsafe fn subrows_mut_unchecked<V: Shape>(
self,
row_start: IdxInc<R>,
nrows: V,
) -> ColMut<'a, E, V> {
self.into_const()
.subrows_unchecked(row_start, nrows)
.const_cast()
}
#[track_caller]
#[inline(always)]
pub fn subrows<V: Shape>(self, row_start: IdxInc<R>, nrows: V) -> ColRef<'a, E, V> {
self.into_const().subrows(row_start, nrows)
}
#[track_caller]
#[inline(always)]
pub fn subrows_mut<V: Shape>(self, row_start: IdxInc<R>, nrows: V) -> ColMut<'a, E, V> {
unsafe { self.into_const().subrows(row_start, nrows).const_cast() }
}
#[track_caller]
#[inline(always)]
pub fn column_vector_as_diagonal(self) -> DiagRef<'a, E, R> {
self.into_const().column_vector_as_diagonal()
}
#[track_caller]
#[inline(always)]
pub fn column_vector_as_diagonal_mut(self) -> DiagMut<'a, E, R> {
DiagMut { inner: self }
}
#[inline]
pub fn to_owned(&self) -> Col<E::Canonical, R>
where
E: Conjugate,
{
(*self).rb().to_owned()
}
#[inline]
pub fn has_nan(&self) -> bool
where
E: ComplexField,
{
(*self).rb().as_2d().has_nan()
}
#[inline]
pub fn is_all_finite(&self) -> bool
where
E: ComplexField,
{
(*self).rb().as_2d().is_all_finite()
}
#[inline]
pub fn norm_max(&self) -> E::Real
where
E: ComplexField,
{
self.rb().as_2d().norm_max()
}
#[inline]
pub fn norm_l1(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().norm_l1()
}
#[inline]
pub fn norm_l2(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().norm_l2()
}
#[inline]
pub fn squared_norm_l2(&self) -> E::Real
where
E: ComplexField,
{
self.as_ref().as_2d().squared_norm_l2()
}
#[inline]
pub fn sum(&self) -> E
where
E: ComplexField,
{
self.rb().as_2d().sum()
}
#[inline]
#[track_caller]
pub fn kron(&self, rhs: impl As2D<E>) -> Mat<E>
where
E: ComplexField,
{
self.as_ref().kron(rhs)
}
#[inline]
pub fn try_as_slice(self) -> Option<Slice<'a, E>> {
self.into_const().try_as_slice()
}
#[inline]
pub fn try_as_slice_mut(self) -> Option<SliceMut<'a, E>> {
if self.row_stride() == 1 {
let len = self.nrows().unbound();
Some(E::faer_map(
self.as_ptr_mut(),
#[inline(always)]
|ptr| unsafe { core::slice::from_raw_parts_mut(ptr, len) },
))
} else {
None
}
}
pub unsafe fn try_as_uninit_slice_mut(self) -> Option<UninitSliceMut<'a, E>> {
if self.row_stride() == 1 {
let len = self.nrows().unbound();
Some(E::faer_map(
self.as_ptr_mut(),
#[inline(always)]
|ptr| unsafe { core::slice::from_raw_parts_mut(ptr as _, len) },
))
} else {
None
}
}
#[inline]
pub fn as_ref(&self) -> ColRef<'_, E, R> {
(*self).rb()
}
#[inline]
pub fn as_mut(&mut self) -> ColMut<'_, E, R> {
(*self).rb_mut()
}
#[inline]
pub fn split_first(self) -> Option<(Ref<'a, E>, ColRef<'a, E>)> {
self.into_const().split_first()
}
#[inline]
pub fn split_last(self) -> Option<(Ref<'a, E>, ColRef<'a, E>)> {
self.into_const().split_last()
}
#[inline]
pub fn split_first_mut(self) -> Option<(Mut<'a, E>, ColMut<'a, E>)> {
let this = self.as_dyn_mut();
if this.nrows() == 0 {
None
} else {
unsafe {
let (head, tail) = { this.split_at_mut_unchecked(1) };
Some((head.get_mut_unchecked(0), tail))
}
}
}
#[inline]
pub fn split_last_mut(self) -> Option<(Mut<'a, E>, ColMut<'a, E>)> {
let this = self.as_dyn_mut();
if this.nrows() == 0 {
None
} else {
let nrows = this.nrows();
unsafe {
let (head, tail) = { this.split_at_mut_unchecked(nrows - 1) };
Some((tail.get_mut_unchecked(0), head))
}
}
}
#[inline]
pub fn iter(self) -> iter::ElemIter<'a, E> {
iter::ElemIter {
inner: self.as_dyn(),
}
}
#[inline]
pub fn iter_mut(self) -> iter::ElemIterMut<'a, E> {
iter::ElemIterMut {
inner: self.as_dyn_mut(),
}
}
#[inline]
#[track_caller]
pub fn chunks(self, chunk_size: usize) -> iter::ColElemChunks<'a, E> {
self.into_const().chunks(chunk_size)
}
#[inline]
#[track_caller]
pub fn partition(self, count: usize) -> iter::ColElemPartition<'a, E> {
self.into_const().partition(count)
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_chunks(
self,
chunk_size: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, E>> {
self.into_const().par_chunks(chunk_size)
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_partition(
self,
count: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColRef<'a, E>> {
self.into_const().par_partition(count)
}
#[inline]
#[track_caller]
pub fn chunks_mut(self, chunk_size: usize) -> iter::ColElemChunksMut<'a, E> {
assert!(chunk_size > 0);
let nrows = self.nrows().unbound();
iter::ColElemChunksMut {
inner: self.as_dyn_mut(),
policy: iter::chunks::ChunkSizePolicy::new(nrows, iter::chunks::ChunkSize(chunk_size)),
}
}
#[inline]
#[track_caller]
pub fn partition_mut(self, count: usize) -> iter::ColElemPartitionMut<'a, E> {
assert!(count > 0);
let nrows = self.nrows();
iter::ColElemPartitionMut {
inner: self.as_dyn_mut(),
policy: iter::chunks::PartitionCountPolicy::new(
nrows.unbound(),
iter::chunks::PartitionCount(count),
),
}
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_chunks_mut(
self,
chunk_size: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColMut<'a, E>> {
use rayon::prelude::*;
self.into_const()
.par_chunks(chunk_size)
.map(|x| unsafe { x.const_cast() })
}
#[cfg(feature = "rayon")]
#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))]
#[inline]
#[track_caller]
pub fn par_partition_mut(
self,
count: usize,
) -> impl 'a + rayon::iter::IndexedParallelIterator<Item = ColMut<'a, E>> {
use rayon::prelude::*;
self.into_const()
.par_partition(count)
.map(|x| unsafe { x.const_cast() })
}
#[doc(hidden)]
#[inline(always)]
pub unsafe fn const_cast(self) -> ColMut<'a, E, R> {
self
}
}
#[inline(always)]
pub unsafe fn from_raw_parts_mut<'a, E: Entity, R: Shape>(
ptr: PtrMut<E>,
nrows: R,
row_stride: isize,
) -> ColMut<'a, E, R> {
ColMut::__from_raw_parts(ptr, nrows, row_stride)
}
#[inline(always)]
pub fn from_slice_mut_generic<E: Entity>(slice: SliceMut<'_, E>) -> ColMut<'_, E> {
let nrows = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len();
unsafe {
from_raw_parts_mut(
E::faer_map(
slice,
#[inline(always)]
|slice| slice.as_mut_ptr(),
),
nrows,
1,
)
}
}
#[inline(always)]
pub fn from_slice_mut<E: SimpleEntity>(slice: &mut [E]) -> ColMut<'_, E> {
from_slice_mut_generic(slice)
}
impl<E: Entity, R: Shape> As2D<E> for ColMut<'_, E, R> {
#[inline]
fn as_2d_ref(&self) -> MatRef<'_, E> {
(*self).rb().as_2d().as_dyn()
}
}
impl<E: Entity, R: Shape> As2DMut<E> for ColMut<'_, E, R> {
#[inline]
fn as_2d_mut(&mut self) -> MatMut<'_, E> {
(*self).rb_mut().as_2d_mut().as_dyn_mut()
}
}
impl<E: Entity, R: Shape> AsColRef<E> for ColMut<'_, E, R> {
type R = R;
#[inline]
fn as_col_ref(&self) -> ColRef<'_, E, R> {
(*self).rb()
}
}
impl<E: Entity, R: Shape> AsColMut<E> for ColMut<'_, E, R> {
#[inline]
fn as_col_mut(&mut self) -> ColMut<'_, E, R> {
(*self).rb_mut()
}
}
impl<'a, E: Entity, R: Shape> core::fmt::Debug for ColMut<'a, E, R> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.rb().fmt(f)
}
}
impl<E: SimpleEntity> core::ops::Index<usize> for ColMut<'_, E> {
type Output = E;
#[inline]
#[track_caller]
fn index(&self, row: usize) -> &E {
(*self).rb().get(row)
}
}
impl<E: SimpleEntity> core::ops::IndexMut<usize> for ColMut<'_, E> {
#[inline]
#[track_caller]
fn index_mut(&mut self, row: usize) -> &mut E {
(*self).rb_mut().get_mut(row)
}
}
impl<E: Conjugate> ColBatch<E> for ColMut<'_, E> {
type Owned = Col<E::Canonical>;
#[inline]
#[track_caller]
fn new_owned_zeros(nrows: usize, ncols: usize) -> Self::Owned {
assert!(ncols == 1);
Col::zeros(nrows)
}
#[inline]
fn new_owned_copied(src: &Self) -> Self::Owned {
src.to_owned()
}
#[inline]
#[track_caller]
fn resize_owned(owned: &mut Self::Owned, nrows: usize, ncols: usize) {
Self::Owned::resize_owned(owned, nrows, ncols)
}
}
impl<E: Conjugate> ColBatchMut<E> for ColMut<'_, E> {}
pub fn from_mut<E: SimpleEntity>(value: &mut E) -> ColMut<'_, E> {
from_mut_generic(value)
}
pub fn from_mut_generic<E: Entity>(value: Mut<'_, E>) -> ColMut<'_, E> {
unsafe { from_raw_parts_mut(E::faer_map(value, |ptr| ptr as *mut E::Unit), 1, 1) }
}