#![allow(clippy::len_without_is_empty)]
use crate::{
constrained,
inner::{PermMut, PermOwn, PermRef},
seal::Seal,
temp_mat_req, temp_mat_uninit, zipped, ComplexField, Entity, MatMut, MatRef, Matrix,
};
#[cfg(feature = "std")]
use assert2::{assert, debug_assert};
use bytemuck::Pod;
use core::fmt::Debug;
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;
impl Seal for i32 {}
impl Seal for i64 {}
impl Seal for i128 {}
impl Seal for isize {}
impl Seal for u32 {}
impl Seal for u64 {}
impl Seal for u128 {}
impl Seal for usize {}
pub trait Index:
Seal
+ core::fmt::Debug
+ core::ops::Not<Output = Self>
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::AddAssign
+ core::ops::SubAssign
+ Pod
+ Eq
+ Ord
+ Send
+ Sync
{
type FixedWidth: Index;
type Signed: SignedIndex;
#[must_use]
#[inline(always)]
fn truncate(value: usize) -> Self {
Self::from_signed(<Self::Signed as SignedIndex>::truncate(value))
}
#[must_use]
#[inline(always)]
fn zx(self) -> usize {
self.to_signed().zx()
}
#[inline(always)]
fn canonicalize(slice: &[Self]) -> &[Self::FixedWidth] {
bytemuck::cast_slice(slice)
}
#[inline(always)]
fn canonicalize_mut(slice: &mut [Self]) -> &mut [Self::FixedWidth] {
bytemuck::cast_slice_mut(slice)
}
#[inline(always)]
fn from_signed(value: Self::Signed) -> Self {
pulp::cast(value)
}
#[inline(always)]
fn to_signed(self) -> Self::Signed {
pulp::cast(self)
}
#[inline]
fn sum_nonnegative(slice: &[Self]) -> Option<Self> {
Self::Signed::sum_nonnegative(bytemuck::cast_slice(slice)).map(Self::from_signed)
}
}
pub trait SignedIndex:
Seal
+ core::fmt::Debug
+ core::ops::Neg<Output = Self>
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
+ core::ops::AddAssign
+ core::ops::SubAssign
+ Pod
+ Eq
+ Ord
+ Send
+ Sync
{
const MAX: Self;
#[must_use]
fn truncate(value: usize) -> Self;
#[must_use]
fn zx(self) -> usize;
#[must_use]
fn sx(self) -> usize;
fn sum_nonnegative(slice: &[Self]) -> Option<Self> {
let mut acc = Self::zeroed();
for &i in slice {
if Self::MAX - i < acc {
return None;
}
acc += i;
}
Some(acc)
}
}
#[cfg(any(
target_pointer_width = "32",
target_pointer_width = "64",
target_pointer_width = "128",
))]
impl Index for u32 {
type FixedWidth = u32;
type Signed = i32;
}
#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
impl Index for u64 {
type FixedWidth = u64;
type Signed = i64;
}
#[cfg(target_pointer_width = "128")]
impl Index for u128 {
type FixedWidth = u128;
type Signed = i128;
}
impl Index for usize {
#[cfg(target_pointer_width = "32")]
type FixedWidth = u32;
#[cfg(target_pointer_width = "64")]
type FixedWidth = u64;
#[cfg(target_pointer_width = "128")]
type FixedWidth = u128;
type Signed = isize;
}
#[cfg(any(
target_pointer_width = "32",
target_pointer_width = "64",
target_pointer_width = "128",
))]
impl SignedIndex for i32 {
const MAX: Self = Self::MAX;
#[inline(always)]
fn truncate(value: usize) -> Self {
#[allow(clippy::assertions_on_constants)]
const _: () = {
core::assert!(i32::BITS <= usize::BITS);
};
value as isize as Self
}
#[inline(always)]
fn zx(self) -> usize {
self as u32 as usize
}
#[inline(always)]
fn sx(self) -> usize {
self as isize as usize
}
}
#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
impl SignedIndex for i64 {
const MAX: Self = Self::MAX;
#[inline(always)]
fn truncate(value: usize) -> Self {
#[allow(clippy::assertions_on_constants)]
const _: () = {
core::assert!(i64::BITS <= usize::BITS);
};
value as isize as Self
}
#[inline(always)]
fn zx(self) -> usize {
self as u64 as usize
}
#[inline(always)]
fn sx(self) -> usize {
self as isize as usize
}
}
#[cfg(target_pointer_width = "128")]
impl SignedIndex for i128 {
const MAX: Self = Self::MAX;
#[inline(always)]
fn truncate(value: usize) -> Self {
#[allow(clippy::assertions_on_constants)]
const _: () = {
core::assert!(i128::BITS <= usize::BITS);
};
value as isize as Self
}
#[inline(always)]
fn zx(self) -> usize {
self as u128 as usize
}
#[inline(always)]
fn sx(self) -> usize {
self as isize as usize
}
}
impl SignedIndex for isize {
const MAX: Self = Self::MAX;
#[inline(always)]
fn truncate(value: usize) -> Self {
value as isize
}
#[inline(always)]
fn zx(self) -> usize {
self as usize
}
#[inline(always)]
fn sx(self) -> usize {
self as usize
}
}
#[track_caller]
#[inline]
pub fn swap_cols<E: ComplexField>(mat: MatMut<'_, E>, a: usize, b: usize) {
assert!(a < mat.ncols());
assert!(b < mat.ncols());
if a == b {
return;
}
let mat = mat.into_const();
let mat_a = mat.col(a);
let mat_b = mat.col(b);
unsafe { zipped!(mat_a.const_cast(), mat_b.const_cast()) }.for_each(|mut a, mut b| {
let (a_read, b_read) = (a.read(), b.read());
a.write(b_read);
b.write(a_read);
});
}
#[track_caller]
#[inline]
pub fn swap_rows<E: ComplexField>(mat: MatMut<'_, E>, a: usize, b: usize) {
swap_cols(mat.transpose(), a, b)
}
pub type PermutationRef<'a, I, E> = Matrix<PermRef<'a, I, E>>;
pub type PermutationMut<'a, I, E> = Matrix<PermMut<'a, I, E>>;
pub type Permutation<I, E> = Matrix<PermOwn<I, E>>;
impl<I, E: Entity> Permutation<I, E> {
#[inline]
pub fn as_ref(&self) -> PermutationRef<'_, I, E> {
PermutationRef {
inner: PermRef {
forward: &self.inner.forward,
inverse: &self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub fn as_mut(&mut self) -> PermutationMut<'_, I, E> {
PermutationMut {
inner: PermMut {
forward: &mut self.inner.forward,
inverse: &mut self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
}
impl<I: Index, E: Entity> Permutation<I, E> {
#[inline]
#[track_caller]
pub fn new_checked(forward: Box<[I]>, inverse: Box<[I]>) -> Self {
PermutationRef::<'_, I, E>::new_checked(&forward, &inverse);
Self {
inner: PermOwn {
forward,
inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
#[track_caller]
pub unsafe fn new_unchecked(forward: Box<[I]>, inverse: Box<[I]>) -> Self {
let n = forward.len();
assert!(forward.len() == inverse.len());
assert!(n <= I::Signed::MAX.zx());
Self {
inner: PermOwn {
forward,
inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub fn into_arrays(self) -> (Box<[I]>, Box<[I]>) {
(self.inner.forward, self.inner.inverse)
}
#[inline]
pub fn len(&self) -> usize {
debug_assert!(self.inner.inverse.len() == self.inner.forward.len());
self.inner.forward.len()
}
#[inline]
pub fn inverse(self) -> Self {
Self {
inner: PermOwn {
forward: self.inner.inverse,
inverse: self.inner.forward,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub fn cast<T: Entity>(self) -> Permutation<I, T> {
Permutation {
inner: PermOwn {
forward: self.inner.forward,
inverse: self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
}
impl<'a, I: Index, E: Entity> PermutationRef<'a, I, E> {
#[inline]
#[track_caller]
pub fn new_checked(forward: &'a [I], inverse: &'a [I]) -> Self {
#[track_caller]
fn check<I: Index>(forward: &[I], inverse: &[I]) {
let n = forward.len();
assert!(forward.len() == inverse.len());
assert!(n <= I::Signed::MAX.zx());
for (i, &p) in forward.iter().enumerate() {
let p = p.to_signed().zx();
assert!(p < n);
assert!(inverse[p].to_signed().zx() == i);
}
}
check(I::canonicalize(forward), I::canonicalize(inverse));
Self {
inner: PermRef {
forward,
inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
#[track_caller]
pub unsafe fn new_unchecked(forward: &'a [I], inverse: &'a [I]) -> Self {
let n = forward.len();
assert!(forward.len() == inverse.len());
assert!(n <= I::Signed::MAX.zx());
Self {
inner: PermRef {
forward,
inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub fn into_arrays(self) -> (&'a [I], &'a [I]) {
(self.inner.forward, self.inner.inverse)
}
#[inline]
pub fn len(&self) -> usize {
debug_assert!(self.inner.inverse.len() == self.inner.forward.len());
self.inner.forward.len()
}
#[inline]
pub fn inverse(self) -> Self {
Self {
inner: PermRef {
forward: self.inner.inverse,
inverse: self.inner.forward,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub fn cast<T: Entity>(self) -> PermutationRef<'a, I, T> {
PermutationRef {
inner: PermRef {
forward: self.inner.forward,
inverse: self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline(always)]
pub fn canonicalize(self) -> PermutationRef<'a, I::FixedWidth, E> {
PermutationRef {
inner: PermRef {
forward: I::canonicalize(self.inner.forward),
inverse: I::canonicalize(self.inner.inverse),
__marker: core::marker::PhantomData,
},
}
}
#[inline(always)]
pub fn uncanonicalize<J: Index>(self) -> PermutationRef<'a, J, E> {
assert!(core::mem::size_of::<J>() == core::mem::size_of::<I>());
PermutationRef {
inner: PermRef {
forward: bytemuck::cast_slice(self.inner.forward),
inverse: bytemuck::cast_slice(self.inner.inverse),
__marker: core::marker::PhantomData,
},
}
}
}
impl<'a, I: Index, E: Entity> PermutationMut<'a, I, E> {
#[inline]
#[track_caller]
pub fn new_checked(forward: &'a mut [I], inverse: &'a mut [I]) -> Self {
PermutationRef::<'_, I, E>::new_checked(forward, inverse);
Self {
inner: PermMut {
forward,
inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
#[track_caller]
pub unsafe fn new_unchecked(forward: &'a mut [I], inverse: &'a mut [I]) -> Self {
let n = forward.len();
assert!(forward.len() == inverse.len());
assert!(n <= I::Signed::MAX.zx());
Self {
inner: PermMut {
forward,
inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub unsafe fn into_arrays(self) -> (&'a mut [I], &'a mut [I]) {
(self.inner.forward, self.inner.inverse)
}
#[inline]
pub fn len(&self) -> usize {
debug_assert!(self.inner.inverse.len() == self.inner.forward.len());
self.inner.forward.len()
}
#[inline]
pub fn inverse(self) -> Self {
Self {
inner: PermMut {
forward: self.inner.inverse,
inverse: self.inner.forward,
__marker: core::marker::PhantomData,
},
}
}
#[inline]
pub fn cast<T: Entity>(self) -> PermutationMut<'a, I, T> {
PermutationMut {
inner: PermMut {
forward: self.inner.forward,
inverse: self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
#[inline(always)]
pub fn canonicalize(self) -> PermutationMut<'a, I::FixedWidth, E> {
PermutationMut {
inner: PermMut {
forward: I::canonicalize_mut(self.inner.forward),
inverse: I::canonicalize_mut(self.inner.inverse),
__marker: core::marker::PhantomData,
},
}
}
#[inline(always)]
pub fn uncanonicalize<J: Index>(self) -> PermutationMut<'a, J, E> {
assert!(core::mem::size_of::<J>() == core::mem::size_of::<I>());
PermutationMut {
inner: PermMut {
forward: bytemuck::cast_slice_mut(self.inner.forward),
inverse: bytemuck::cast_slice_mut(self.inner.inverse),
__marker: core::marker::PhantomData,
},
}
}
}
impl<'short, 'a, I, E: Entity> Reborrow<'short> for PermutationRef<'a, I, E> {
type Target = PermutationRef<'short, I, E>;
#[inline]
fn rb(&'short self) -> Self::Target {
*self
}
}
impl<'short, 'a, I, E: Entity> ReborrowMut<'short> for PermutationRef<'a, I, E> {
type Target = PermutationRef<'short, I, E>;
#[inline]
fn rb_mut(&'short mut self) -> Self::Target {
*self
}
}
impl<'short, 'a, I, E: Entity> Reborrow<'short> for PermutationMut<'a, I, E> {
type Target = PermutationRef<'short, I, E>;
#[inline]
fn rb(&'short self) -> Self::Target {
PermutationRef {
inner: PermRef {
forward: &*self.inner.forward,
inverse: &*self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
}
impl<'short, 'a, I, E: Entity> ReborrowMut<'short> for PermutationMut<'a, I, E> {
type Target = PermutationMut<'short, I, E>;
#[inline]
fn rb_mut(&'short mut self) -> Self::Target {
PermutationMut {
inner: PermMut {
forward: &mut *self.inner.forward,
inverse: &mut *self.inner.inverse,
__marker: core::marker::PhantomData,
},
}
}
}
impl<'a, I: Debug, E: Entity> Debug for PermutationRef<'a, I, E> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.inner.fmt(f)
}
}
impl<'a, I: Debug, E: Entity> Debug for PermutationMut<'a, I, E> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.rb().fmt(f)
}
}
impl<'a, I: Debug, E: Entity> Debug for Permutation<I, E> {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
self.as_ref().fmt(f)
}
}
#[inline]
#[track_caller]
pub fn permute_cols<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
) {
assert!((src.nrows(), src.ncols()) == (dst.nrows(), dst.ncols()));
assert!(perm_indices.into_arrays().0.len() == src.ncols());
permute_rows(
dst.transpose(),
src.transpose(),
perm_indices.canonicalize(),
);
}
#[inline]
#[track_caller]
pub fn permute_rows<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
) {
#[track_caller]
fn implementation<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
) {
assert!((src.nrows(), src.ncols()) == (dst.nrows(), dst.ncols()));
assert!(perm_indices.into_arrays().0.len() == src.nrows());
constrained::Size::with2(src.nrows(), src.ncols(), |m, n| {
let mut dst = constrained::MatMut::new(dst, m, n);
let src = constrained::MatRef::new(src, m, n);
let perm = constrained::permutation::PermutationRef::new(perm_indices, m)
.into_arrays()
.0;
if dst.rb().into_inner().row_stride().unsigned_abs()
< dst.rb().into_inner().col_stride().unsigned_abs()
{
for j in n.indices() {
for i in m.indices() {
dst.rb_mut().write(i, j, src.read(perm[i].zx(), j));
}
}
} else {
for i in m.indices() {
let src_i = src.into_inner().row(perm[i].zx().into_inner());
let mut dst_i = dst.rb_mut().into_inner().row(i.into_inner());
dst_i.clone_from(src_i);
}
}
});
}
implementation(dst, src, perm_indices.canonicalize())
}
pub fn permute_rows_in_place_req<I: Index, E: Entity>(
nrows: usize,
ncols: usize,
) -> Result<StackReq, SizeOverflow> {
temp_mat_req::<E>(nrows, ncols)
}
pub fn permute_cols_in_place_req<I: Index, E: Entity>(
nrows: usize,
ncols: usize,
) -> Result<StackReq, SizeOverflow> {
temp_mat_req::<E>(nrows, ncols)
}
#[inline]
#[track_caller]
pub fn permute_rows_in_place<I: Index, E: ComplexField>(
matrix: MatMut<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
stack: PodStack<'_>,
) {
#[inline]
#[track_caller]
fn implementation<E: ComplexField, I: Index>(
matrix: MatMut<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
stack: PodStack<'_>,
) {
let mut matrix = matrix;
let (mut tmp, _) = temp_mat_uninit::<E>(matrix.nrows(), matrix.ncols(), stack);
tmp.rb_mut().clone_from(matrix.rb());
permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices);
}
implementation(matrix, perm_indices.canonicalize(), stack)
}
#[inline]
#[track_caller]
pub fn permute_cols_in_place<I: Index, E: ComplexField>(
matrix: MatMut<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
stack: PodStack<'_>,
) {
#[inline]
#[track_caller]
fn implementation<I: Index, E: ComplexField>(
matrix: MatMut<'_, E>,
perm_indices: PermutationRef<'_, I, E>,
stack: PodStack<'_>,
) {
let mut matrix = matrix;
let (mut tmp, _) = temp_mat_uninit::<E>(matrix.nrows(), matrix.ncols(), stack);
tmp.rb_mut().clone_from(matrix.rb());
permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices);
}
implementation(matrix, perm_indices.canonicalize(), stack)
}