use super::*;
use crate::{assert, debug_assert, linalg::zip::Diag, utils::thread::join_raw};
#[repr(u8)]
#[derive(Copy, Clone, Debug)]
pub(crate) enum DiagonalKind {
Zero,
Unit,
Generic,
}
unsafe fn copy_lower<E: ComplexField>(
mut dst: MatMut<'_, E>,
src: MatRef<'_, E>,
src_diag: DiagonalKind,
) {
let n = dst.nrows();
debug_assert!(n == dst.nrows());
debug_assert!(n == dst.ncols());
debug_assert!(n == src.nrows());
debug_assert!(n == src.ncols());
for j in 0..n {
for i in 0..j {
dst.write_unchecked(i, j, E::faer_zero());
}
match src_diag {
DiagonalKind::Zero => dst.write_unchecked(j, j, E::faer_zero()),
DiagonalKind::Unit => dst.write_unchecked(j, j, E::faer_one()),
DiagonalKind::Generic => dst.write_unchecked(j, j, src.read(j, j)),
};
for i in j + 1..n {
dst.write_unchecked(i, j, src.read_unchecked(i, j));
}
}
}
unsafe fn accum_lower<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
skip_diag: bool,
alpha: Option<E>,
) {
let n = dst.nrows();
debug_assert!(n == dst.nrows());
debug_assert!(n == dst.ncols());
debug_assert!(n == src.nrows());
debug_assert!(n == src.ncols());
match alpha {
Some(alpha) => {
zipped_rw!(dst, src).for_each_triangular_lower(
if skip_diag { Diag::Skip } else { Diag::Include },
|unzipped!(mut dst, src)| {
dst.write(alpha.faer_mul(dst.read()).faer_add(src.read()))
},
);
}
None => {
zipped_rw!(dst, src).for_each_triangular_lower(
if skip_diag { Diag::Skip } else { Diag::Include },
|unzipped!(mut dst, src)| dst.write(src.read()),
);
}
}
}
#[inline]
unsafe fn copy_upper<E: ComplexField>(
dst: MatMut<'_, E>,
src: MatRef<'_, E>,
src_diag: DiagonalKind,
) {
copy_lower(dst.transpose_mut(), src.transpose(), src_diag)
}
#[inline]
unsafe fn mul<E: ComplexField>(
dst: MatMut<'_, E>,
lhs: MatRef<'_, E>,
rhs: MatRef<'_, E>,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
super::matmul_with_conj(dst, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism);
}
unsafe fn mat_x_lower_into_lower_impl_unchecked<E: ComplexField>(
dst: MatMut<'_, E>,
skip_diag: bool,
lhs: MatRef<'_, E>,
rhs: MatRef<'_, E>,
rhs_diag: DiagonalKind,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
let n = dst.nrows();
debug_assert!(n == dst.nrows());
debug_assert!(n == dst.ncols());
debug_assert!(n == lhs.nrows());
debug_assert!(n == lhs.ncols());
debug_assert!(n == rhs.nrows());
debug_assert!(n == rhs.ncols());
if n <= 16 {
let op = {
#[inline(never)]
|| {
stack_mat!(
[16, 16],
temp_dst,
n,
n,
dst.row_stride(),
dst.col_stride(),
E
);
stack_mat!(
[16, 16],
temp_rhs,
n,
n,
rhs.row_stride(),
rhs.col_stride(),
E
);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
temp_dst.rb_mut(),
lhs,
temp_rhs.rb(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
}
};
op();
} else {
let bs = n / 2;
let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs);
let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
mul(
dst_bot_left.rb_mut(),
lhs_bot_right,
rhs_bot_left,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mat_x_lower_into_lower_impl_unchecked(
dst_bot_right,
skip_diag,
lhs_bot_right,
rhs_bot_right,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mat_x_lower_into_lower_impl_unchecked(
dst_top_left.rb_mut(),
skip_diag,
lhs_top_left,
rhs_top_left,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mat_x_mat_into_lower_impl_unchecked(
dst_top_left,
skip_diag,
lhs_top_right,
rhs_bot_left,
Some(E::faer_one()),
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mat_x_lower_impl_unchecked(
dst_bot_left,
lhs_bot_left,
rhs_top_left,
rhs_diag,
Some(E::faer_one()),
beta,
conj_lhs,
conj_rhs,
parallelism,
);
}
}
unsafe fn mat_x_lower_impl_unchecked<E: ComplexField>(
dst: MatMut<'_, E>,
lhs: MatRef<'_, E>,
rhs: MatRef<'_, E>,
rhs_diag: DiagonalKind,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
let n = rhs.nrows();
let m = lhs.nrows();
debug_assert!(m == lhs.nrows());
debug_assert!(n == lhs.ncols());
debug_assert!(n == rhs.nrows());
debug_assert!(n == rhs.ncols());
debug_assert!(m == dst.nrows());
debug_assert!(n == dst.ncols());
let join_parallelism = if n * n * m < 128 * 128 * 64 {
Parallelism::None
} else {
parallelism
};
if n <= 16 {
let op = {
#[inline(never)]
|| {
stack_mat!(
[16, 16],
temp_rhs,
n,
n,
rhs.row_stride(),
rhs.col_stride(),
E
);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
dst,
lhs,
temp_rhs.rb(),
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
}
};
op();
} else {
let bs = n / 2;
let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
let (lhs_left, lhs_right) = lhs.split_at_col(bs);
let (mut dst_left, mut dst_right) = dst.split_at_col_mut(bs);
join_raw(
|parallelism| {
mat_x_lower_impl_unchecked(
dst_left.rb_mut(),
lhs_left,
rhs_top_left,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
|parallelism| {
mat_x_lower_impl_unchecked(
dst_right.rb_mut(),
lhs_right,
rhs_bot_right,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
join_parallelism,
);
mul(
dst_left,
lhs_right,
rhs_bot_left,
Some(E::faer_one()),
beta,
conj_lhs,
conj_rhs,
parallelism,
);
}
}
unsafe fn lower_x_lower_into_lower_impl_unchecked<E: ComplexField>(
dst: MatMut<'_, E>,
skip_diag: bool,
lhs: MatRef<'_, E>,
lhs_diag: DiagonalKind,
rhs: MatRef<'_, E>,
rhs_diag: DiagonalKind,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
let n = dst.nrows();
debug_assert!(n == lhs.nrows());
debug_assert!(n == lhs.ncols());
debug_assert!(n == rhs.nrows());
debug_assert!(n == rhs.ncols());
debug_assert!(n == dst.nrows());
debug_assert!(n == dst.ncols());
if n <= 16 {
let op = {
#[inline(never)]
|| {
stack_mat!(
[16, 16],
temp_dst,
n,
n,
dst.row_stride(),
dst.col_stride(),
E
);
stack_mat!(
[16, 16],
temp_lhs,
n,
n,
lhs.row_stride(),
lhs.col_stride(),
E
);
stack_mat!(
[16, 16],
temp_rhs,
n,
n,
rhs.row_stride(),
rhs.col_stride(),
E
);
copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
temp_dst.rb_mut(),
temp_lhs.rb(),
temp_rhs.rb(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
}
};
op();
} else {
let bs = n / 2;
let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs);
let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
lower_x_lower_into_lower_impl_unchecked(
dst_top_left,
skip_diag,
lhs_top_left,
lhs_diag,
rhs_top_left,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mat_x_lower_impl_unchecked(
dst_bot_left.rb_mut(),
lhs_bot_left,
rhs_top_left,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
mat_x_lower_impl_unchecked(
dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(),
rhs_bot_left.reverse_rows_and_cols().transpose(),
lhs_bot_right.reverse_rows_and_cols().transpose(),
lhs_diag,
Some(E::faer_one()),
beta,
conj_rhs,
conj_lhs,
parallelism,
);
lower_x_lower_into_lower_impl_unchecked(
dst_bot_right,
skip_diag,
lhs_bot_right,
lhs_diag,
rhs_bot_right,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
}
}
unsafe fn upper_x_lower_impl_unchecked<E: ComplexField>(
dst: MatMut<'_, E>,
lhs: MatRef<'_, E>,
lhs_diag: DiagonalKind,
rhs: MatRef<'_, E>,
rhs_diag: DiagonalKind,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
let n = dst.nrows();
debug_assert!(n == lhs.nrows());
debug_assert!(n == lhs.ncols());
debug_assert!(n == rhs.nrows());
debug_assert!(n == rhs.ncols());
debug_assert!(n == dst.nrows());
debug_assert!(n == dst.ncols());
if n <= 16 {
let op = {
#[inline(never)]
|| {
stack_mat!(
[16, 16],
temp_lhs,
n,
n,
lhs.row_stride(),
lhs.col_stride(),
E
);
stack_mat!(
[16, 16],
temp_rhs,
n,
n,
rhs.row_stride(),
rhs.col_stride(),
E
);
copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
dst,
temp_lhs.rb(),
temp_rhs.rb(),
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
}
};
op();
} else {
let bs = n / 2;
let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) =
dst.split_at_mut(bs, bs);
let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs);
let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
join_raw(
|_| {
mul(
dst_top_left.rb_mut(),
lhs_top_right,
rhs_bot_left,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
upper_x_lower_impl_unchecked(
dst_top_left,
lhs_top_left,
lhs_diag,
rhs_top_left,
rhs_diag,
Some(E::faer_one()),
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
|_| {
join_raw(
|_| {
mat_x_lower_impl_unchecked(
dst_top_right,
lhs_top_right,
rhs_bot_right,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
|_| {
mat_x_lower_impl_unchecked(
dst_bot_left.transpose_mut(),
rhs_bot_left.transpose(),
lhs_bot_right.transpose(),
lhs_diag,
alpha,
beta,
conj_rhs,
conj_lhs,
parallelism,
)
},
parallelism,
);
upper_x_lower_impl_unchecked(
dst_bot_right,
lhs_bot_right,
lhs_diag,
rhs_bot_right,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
parallelism,
);
}
}
unsafe fn upper_x_lower_into_lower_impl_unchecked<E: ComplexField>(
dst: MatMut<'_, E>,
skip_diag: bool,
lhs: MatRef<'_, E>,
lhs_diag: DiagonalKind,
rhs: MatRef<'_, E>,
rhs_diag: DiagonalKind,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
let n = dst.nrows();
debug_assert!(n == lhs.nrows());
debug_assert!(n == lhs.ncols());
debug_assert!(n == rhs.nrows());
debug_assert!(n == rhs.ncols());
debug_assert!(n == dst.nrows());
debug_assert!(n == dst.ncols());
if n <= 16 {
let op = {
#[inline(never)]
|| {
stack_mat!(
[16, 16],
temp_dst,
n,
n,
dst.row_stride(),
dst.col_stride(),
E
);
stack_mat!(
[16, 16],
temp_lhs,
n,
n,
lhs.row_stride(),
lhs.col_stride(),
E
);
stack_mat!(
[16, 16],
temp_rhs,
n,
n,
rhs.row_stride(),
rhs.col_stride(),
E
);
copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag);
copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag);
mul(
temp_dst.rb_mut(),
temp_lhs.rb(),
temp_rhs.rb(),
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
}
};
op();
} else {
let bs = n / 2;
let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs);
let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs);
join_raw(
|_| {
mat_x_mat_into_lower_impl_unchecked(
dst_top_left.rb_mut(),
skip_diag,
lhs_top_right,
rhs_bot_left,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
upper_x_lower_into_lower_impl_unchecked(
dst_top_left,
skip_diag,
lhs_top_left,
lhs_diag,
rhs_top_left,
rhs_diag,
Some(E::faer_one()),
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
|_| {
mat_x_lower_impl_unchecked(
dst_bot_left.transpose_mut(),
rhs_bot_left.transpose(),
lhs_bot_right.transpose(),
lhs_diag,
alpha,
beta,
conj_rhs,
conj_lhs,
parallelism,
);
upper_x_lower_into_lower_impl_unchecked(
dst_bot_right,
skip_diag,
lhs_bot_right,
lhs_diag,
rhs_bot_right,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
parallelism,
);
}
}
unsafe fn mat_x_mat_into_lower_impl_unchecked<E: ComplexField>(
dst: MatMut<'_, E>,
skip_diag: bool,
lhs: MatRef<'_, E>,
rhs: MatRef<'_, E>,
alpha: Option<E>,
beta: E,
conj_lhs: Conj,
conj_rhs: Conj,
parallelism: Parallelism,
) {
debug_assert!(dst.nrows() == dst.ncols());
debug_assert!(dst.nrows() == lhs.nrows());
debug_assert!(dst.ncols() == rhs.ncols());
debug_assert!(lhs.ncols() == rhs.nrows());
let n = dst.nrows();
let k = lhs.ncols();
let join_parallelism = if n * n * k < 128 * 128 * 128 {
Parallelism::None
} else {
parallelism
};
if n <= 16 {
let op = {
#[inline(never)]
|| {
stack_mat!(
[16, 16],
temp_dst,
n,
n,
dst.row_stride(),
dst.col_stride(),
E
);
mul(
temp_dst.rb_mut(),
lhs,
rhs,
None,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
accum_lower(dst, temp_dst.rb(), skip_diag, alpha);
}
};
op();
} else {
let bs = n / 2;
let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs);
let (lhs_top, lhs_bot) = lhs.split_at_row(bs);
let (rhs_left, rhs_right) = rhs.split_at_col(bs);
join_raw(
|_| {
mul(
dst_bot_left,
lhs_bot,
rhs_left,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
|_| {
join_raw(
|_| {
mat_x_mat_into_lower_impl_unchecked(
dst_top_left,
skip_diag,
lhs_top,
rhs_left,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
|_| {
mat_x_mat_into_lower_impl_unchecked(
dst_bot_right,
skip_diag,
lhs_bot,
rhs_right,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
},
join_parallelism,
)
},
join_parallelism,
);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlockStructure {
Rectangular,
TriangularLower,
StrictTriangularLower,
UnitTriangularLower,
TriangularUpper,
StrictTriangularUpper,
UnitTriangularUpper,
}
impl BlockStructure {
#[inline]
pub fn is_dense(self) -> bool {
matches!(self, BlockStructure::Rectangular)
}
#[inline]
pub fn is_lower(self) -> bool {
use BlockStructure::*;
matches!(
self,
TriangularLower | StrictTriangularLower | UnitTriangularLower
)
}
#[inline]
pub fn is_upper(self) -> bool {
use BlockStructure::*;
matches!(
self,
TriangularUpper | StrictTriangularUpper | UnitTriangularUpper
)
}
#[inline]
pub fn transpose(self) -> Self {
use BlockStructure::*;
match self {
Rectangular => Rectangular,
TriangularLower => TriangularUpper,
StrictTriangularLower => StrictTriangularUpper,
UnitTriangularLower => UnitTriangularUpper,
TriangularUpper => TriangularLower,
StrictTriangularUpper => StrictTriangularLower,
UnitTriangularUpper => UnitTriangularLower,
}
}
#[inline]
pub(crate) fn diag_kind(self) -> DiagonalKind {
use BlockStructure::*;
match self {
Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic,
StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero,
UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit,
}
}
}
#[track_caller]
#[inline]
pub fn matmul_with_conj<E: ComplexField>(
acc: impl As2DMut<E>,
acc_structure: BlockStructure,
lhs: impl As2D<E>,
lhs_structure: BlockStructure,
conj_lhs: Conj,
rhs: impl As2D<E>,
rhs_structure: BlockStructure,
conj_rhs: Conj,
alpha: Option<E>,
beta: E,
parallelism: Parallelism,
) {
let mut acc = acc;
let acc = acc.as_2d_mut();
let lhs = lhs.as_2d_ref();
let rhs = rhs.as_2d_ref();
assert!(all(
acc.nrows() == lhs.nrows(),
acc.ncols() == rhs.ncols(),
lhs.ncols() == rhs.nrows(),
));
if !acc_structure.is_dense() {
assert!(acc.nrows() == acc.ncols());
}
if !lhs_structure.is_dense() {
assert!(lhs.nrows() == lhs.ncols());
}
if !rhs_structure.is_dense() {
assert!(rhs.nrows() == rhs.ncols());
}
unsafe {
matmul_unchecked(
acc,
acc_structure,
lhs,
lhs_structure,
conj_lhs,
rhs,
rhs_structure,
conj_rhs,
alpha,
beta,
parallelism,
)
}
}
#[track_caller]
#[inline]
pub fn matmul<E: ComplexField, LhsE: Conjugate<Canonical = E>, RhsE: Conjugate<Canonical = E>>(
acc: impl As2DMut<E>,
acc_structure: BlockStructure,
lhs: impl As2D<LhsE>,
lhs_structure: BlockStructure,
rhs: impl As2D<RhsE>,
rhs_structure: BlockStructure,
alpha: Option<E>,
beta: E,
parallelism: Parallelism,
) {
let mut acc = acc;
let acc = acc.as_2d_mut();
let lhs = lhs.as_2d_ref();
let rhs = rhs.as_2d_ref();
let (lhs, conj_lhs) = lhs.canonicalize();
let (rhs, conj_rhs) = rhs.canonicalize();
matmul_with_conj(
acc,
acc_structure,
lhs,
lhs_structure,
conj_lhs,
rhs,
rhs_structure,
conj_rhs,
alpha,
beta,
parallelism,
);
}
unsafe fn matmul_unchecked<E: ComplexField>(
acc: MatMut<'_, E>,
acc_structure: BlockStructure,
lhs: MatRef<'_, E>,
lhs_structure: BlockStructure,
conj_lhs: Conj,
rhs: MatRef<'_, E>,
rhs_structure: BlockStructure,
conj_rhs: Conj,
alpha: Option<E>,
beta: E,
parallelism: Parallelism,
) {
debug_assert!(acc.nrows() == lhs.nrows());
debug_assert!(acc.ncols() == rhs.ncols());
debug_assert!(lhs.ncols() == rhs.nrows());
if !acc_structure.is_dense() {
debug_assert!(acc.nrows() == acc.ncols());
}
if !lhs_structure.is_dense() {
debug_assert!(lhs.nrows() == lhs.ncols());
}
if !rhs_structure.is_dense() {
debug_assert!(rhs.nrows() == rhs.ncols());
}
let mut acc = acc;
let mut lhs = lhs;
let mut rhs = rhs;
let mut acc_structure = acc_structure;
let mut lhs_structure = lhs_structure;
let mut rhs_structure = rhs_structure;
let mut conj_lhs = conj_lhs;
let mut conj_rhs = conj_rhs;
if rhs_structure.is_lower() {
false
} else if rhs_structure.is_upper() {
acc = acc.reverse_rows_and_cols_mut();
lhs = lhs.reverse_rows_and_cols();
rhs = rhs.reverse_rows_and_cols();
acc_structure = acc_structure.transpose();
lhs_structure = lhs_structure.transpose();
rhs_structure = rhs_structure.transpose();
false
} else if lhs_structure.is_lower() {
acc = acc.reverse_rows_and_cols_mut().transpose_mut();
(lhs, rhs) = (
rhs.reverse_rows_and_cols().transpose(),
lhs.reverse_rows_and_cols().transpose(),
);
(conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
(lhs_structure, rhs_structure) = (rhs_structure, lhs_structure);
true
} else if lhs_structure.is_upper() {
acc_structure = acc_structure.transpose();
acc = acc.transpose_mut();
(lhs, rhs) = (rhs.transpose(), lhs.transpose());
(conj_lhs, conj_rhs) = (conj_rhs, conj_lhs);
(lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose());
true
} else {
false
};
let clear_upper = |acc: MatMut<'_, E>, skip_diag: bool| match &alpha {
&Some(alpha) => zipped_rw!(acc).for_each_triangular_upper(
if skip_diag { Diag::Skip } else { Diag::Include },
|unzipped!(mut acc)| acc.write(alpha.faer_mul(acc.read())),
),
None => zipped_rw!(acc).for_each_triangular_upper(
if skip_diag { Diag::Skip } else { Diag::Include },
|unzipped!(mut acc)| acc.write(E::faer_zero()),
),
};
let skip_diag = matches!(
acc_structure,
BlockStructure::StrictTriangularLower
| BlockStructure::StrictTriangularUpper
| BlockStructure::UnitTriangularLower
| BlockStructure::UnitTriangularUpper
);
let lhs_diag = lhs_structure.diag_kind();
let rhs_diag = rhs_structure.diag_kind();
if acc_structure.is_dense() {
if lhs_structure.is_dense() && rhs_structure.is_dense() {
mul(acc, lhs, rhs, alpha, beta, conj_lhs, conj_rhs, parallelism);
} else {
debug_assert!(rhs_structure.is_lower());
if lhs_structure.is_dense() {
mat_x_lower_impl_unchecked(
acc,
lhs,
rhs,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
} else if lhs_structure.is_lower() {
clear_upper(acc.rb_mut(), true);
lower_x_lower_into_lower_impl_unchecked(
acc,
false,
lhs,
lhs_diag,
rhs,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
} else {
debug_assert!(lhs_structure.is_upper());
upper_x_lower_impl_unchecked(
acc,
lhs,
lhs_diag,
rhs,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
}
}
} else if acc_structure.is_lower() {
if lhs_structure.is_dense() && rhs_structure.is_dense() {
mat_x_mat_into_lower_impl_unchecked(
acc,
skip_diag,
lhs,
rhs,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
} else {
debug_assert!(rhs_structure.is_lower());
if lhs_structure.is_dense() {
mat_x_lower_into_lower_impl_unchecked(
acc,
skip_diag,
lhs,
rhs,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
);
} else if lhs_structure.is_lower() {
lower_x_lower_into_lower_impl_unchecked(
acc,
skip_diag,
lhs,
lhs_diag,
rhs,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
} else {
upper_x_lower_into_lower_impl_unchecked(
acc,
skip_diag,
lhs,
lhs_diag,
rhs,
rhs_diag,
alpha,
beta,
conj_lhs,
conj_rhs,
parallelism,
)
}
}
} else if lhs_structure.is_dense() && rhs_structure.is_dense() {
mat_x_mat_into_lower_impl_unchecked(
acc.transpose_mut(),
skip_diag,
rhs.transpose(),
lhs.transpose(),
alpha,
beta,
conj_rhs,
conj_lhs,
parallelism,
)
} else {
debug_assert!(rhs_structure.is_lower());
if lhs_structure.is_dense() {
upper_x_lower_into_lower_impl_unchecked(
acc.transpose_mut(),
skip_diag,
rhs.transpose(),
rhs_diag,
lhs.transpose(),
lhs_diag,
alpha,
beta,
conj_rhs,
conj_lhs,
parallelism,
)
} else if lhs_structure.is_lower() {
if !skip_diag {
match &alpha {
&Some(alpha) => {
zipped_rw!(
acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
lhs.diagonal().column_vector().as_2d(),
rhs.diagonal().column_vector().as_2d(),
)
.for_each(|unzipped!(mut acc, lhs, rhs)| {
acc.write(
(alpha.faer_mul(acc.read()))
.faer_add(beta.faer_mul(lhs.read().faer_mul(rhs.read()))),
)
});
}
None => {
zipped_rw!(
acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(),
lhs.diagonal().column_vector().as_2d(),
rhs.diagonal().column_vector().as_2d(),
)
.for_each(|unzipped!(mut acc, lhs, rhs)| {
acc.write(beta.faer_mul(lhs.read().faer_mul(rhs.read())))
});
}
}
}
clear_upper(acc.rb_mut(), true);
} else {
debug_assert!(lhs_structure.is_upper());
upper_x_lower_into_lower_impl_unchecked(
acc.transpose_mut(),
skip_diag,
rhs.transpose(),
rhs_diag,
lhs.transpose(),
lhs_diag,
alpha,
beta,
conj_rhs,
conj_lhs,
parallelism,
)
}
}
}