use crate::{assert, col::*, linalg::temp_mat_uninit, mat::*, row::*, *};
use dyn_stack::{PodStack, SizeOverflow, StackReq};
use reborrow::*;
#[track_caller]
#[inline]
pub fn swap_cols<E: ComplexField>(a: ColMut<'_, E>, b: ColMut<'_, E>) {
zipped_rw!(a, b).for_each(|unzipped!(mut a, mut b)| {
let (a_read, b_read) = (a.read(), b.read());
a.write(b_read);
b.write(a_read);
});
}
#[track_caller]
#[inline]
pub fn swap_rows<E: ComplexField>(a: RowMut<'_, E>, b: RowMut<'_, E>) {
swap_cols(a.transpose_mut(), b.transpose_mut())
}
#[track_caller]
#[inline]
pub fn swap_rows_idx<E: ComplexField>(mat: MatMut<'_, E>, a: usize, b: usize) {
if a != b {
let (a, b) = mat.two_rows_mut(a, b);
swap_rows(a, b);
}
}
#[track_caller]
#[inline]
pub fn swap_cols_idx<E: ComplexField>(mat: MatMut<'_, E>, a: usize, b: usize) {
if a != b {
let (a, b) = mat.two_cols_mut(a, b);
swap_cols(a, b);
}
}
mod permown;
mod permref;
pub use permown::Perm;
pub use permref::PermRef;
use self::linalg::temp_mat_req;
#[inline]
#[track_caller]
pub fn permute_cols<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
perm_indices: PermRef<'_, I>,
) {
assert!(all(
src.nrows() == dst.nrows(),
src.ncols() == dst.ncols(),
perm_indices.arrays().0.len() == src.ncols(),
));
permute_rows(
dst.transpose_mut(),
src.transpose(),
perm_indices.canonicalized(),
);
}
#[inline]
#[track_caller]
pub fn permute_rows<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
perm_indices: PermRef<'_, I>,
) {
#[track_caller]
fn implementation<I: Index, E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
perm_indices: PermRef<'_, I>,
) {
assert!(all(
src.nrows() == dst.nrows(),
src.ncols() == dst.ncols(),
perm_indices.len() == src.nrows(),
));
with_dim!(m, src.nrows());
with_dim!(n, src.ncols());
let mut dst = dst.as_shape_mut(m, n);
let src = src.as_shape(m, n);
let perm = perm_indices.as_shape(m).bound_arrays().0;
if dst.rb().row_stride().unsigned_abs() < dst.rb().col_stride().unsigned_abs() {
for j in n.indices() {
for i in m.indices() {
dst.rb_mut().write(i, j, src.read(perm[i].zx(), j));
}
}
} else {
for i in m.indices() {
let src_i = src.row(perm[i].zx());
let mut dst_i = dst.rb_mut().row_mut(i);
dst_i.copy_from(src_i);
}
}
}
implementation(dst, src, perm_indices.canonicalized())
}
pub fn permute_rows_in_place_req<I: Index, E: Entity>(
nrows: usize,
ncols: usize,
) -> Result<StackReq, SizeOverflow> {
temp_mat_req::<E>(nrows, ncols)
}
pub fn permute_cols_in_place_req<I: Index, E: Entity>(
nrows: usize,
ncols: usize,
) -> Result<StackReq, SizeOverflow> {
temp_mat_req::<E>(nrows, ncols)
}
#[inline]
#[track_caller]
pub fn permute_rows_in_place<I: Index, E: ComplexField>(
matrix: MatMut<'_, E>,
perm_indices: PermRef<'_, I>,
stack: &mut PodStack,
) {
#[inline]
#[track_caller]
fn implementation<E: ComplexField, I: Index>(
matrix: MatMut<'_, E>,
perm_indices: PermRef<'_, I>,
stack: &mut PodStack,
) {
let mut matrix = matrix;
let (mut tmp, _) = temp_mat_uninit::<E>(matrix.nrows(), matrix.ncols(), stack);
tmp.rb_mut().copy_from(matrix.rb());
permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices);
}
implementation(matrix, perm_indices.canonicalized(), stack)
}
#[inline]
#[track_caller]
pub fn permute_cols_in_place<I: Index, E: ComplexField>(
matrix: MatMut<'_, E>,
perm_indices: PermRef<'_, I>,
stack: &mut PodStack,
) {
#[inline]
#[track_caller]
fn implementation<I: Index, E: ComplexField>(
matrix: MatMut<'_, E>,
perm_indices: PermRef<'_, I>,
stack: &mut PodStack,
) {
let mut matrix = matrix;
let (mut tmp, _) = temp_mat_uninit::<E>(matrix.nrows(), matrix.ncols(), stack);
tmp.rb_mut().copy_from(matrix.rb());
permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices);
}
implementation(matrix, perm_indices.canonicalized(), stack)
}