use std::ops::{Add, AddAssign, Mul, MulAssign};
use std::simd::{Simd, SimdElement};
use std::simd::num::SimdFloat;
use crate::fused::ftrmsv::ftrmsv_n;
use crate::traits::Fma;
use crate::assert_length_eq_n_cols;
use crate::fused::ftrmsv::N_ROWS_PER_CHUNK;
use crate::l1::axpy::axpy;
use crate::l1::dot::{dot, LANES};
use crate::types::{MatRef, VecRef, VecMut};
use crate::types::{Triangular, Transpose};
pub(crate) const N_COLS_PER_CHUNK: usize = 4;
pub(crate) fn trmv_ln<T>(
a: MatRef<'_, T>,
mut x: VecMut<'_, T>,
)
where
T: Copy
+ AddAssign
+ MulAssign
+ Mul<Output=T>
+ Add<Output=T>
+ Fma
+ SimdElement,
Simd<T, N_ROWS_PER_CHUNK>: SimdFloat<Scalar=T>
+ AddAssign
+ Mul<Output = Simd<T, N_ROWS_PER_CHUNK>>
+ Fma,
{
let (n_rows, n_cols) = a.dimension();
assert_eq!(n_rows, n_cols, "matrix a must be square nxn to be triangular");
assert_length_eq_n_cols!(a, x);
let n = n_rows;
let a_slice = a.as_slice();
let x_slice = x.as_slice_mut();
for (cols, a_panel) in a.col_panels(N_COLS_PER_CHUNK).rev() {
if cols.end - cols.start != N_COLS_PER_CHUNK {
for j in cols.rev() {
let xj = x_slice[j];
let a_col = VecRef::new(&a_slice[j * n + j + 1..(j + 1) * n]);
let x_col = VecMut::new(&mut x_slice[j + 1..n]);
axpy(xj, a_col, x_col);
x_slice[j] *= a_slice[j * (n + 1)];
}
} else {
let j0 = cols.start;
let j1 = cols.end;
let col0 = a_panel.col(0);
let col1 = a_panel.col(1);
let col2 = a_panel.col(2);
let col3 = a_panel.col(3);
let c0 = &col0.as_slice()[j0 + N_COLS_PER_CHUNK..];
let c1 = &col1.as_slice()[j0 + N_COLS_PER_CHUNK..];
let c2 = &col2.as_slice()[j0 + N_COLS_PER_CHUNK..];
let c3 = &col3.as_slice()[j0 + N_COLS_PER_CHUNK..];
let x0 = x_slice[j0];
let x1 = x_slice[j0 + 1];
let x2 = x_slice[j0 + 2];
let x3 = x_slice[j0 + 3];
ftrmsv_n(
c0, c1, c2, c3,
x0, x1, x2, x3,
&mut x_slice[j1..n],
);
for j in cols.rev() {
let xj = x_slice[j];
let a_col = VecRef::new(&a_slice[j * n + j + 1..j * n + j1]);
let x_col = VecMut::new(&mut x_slice[j + 1..j1]);
axpy(xj, a_col, x_col);
x_slice[j] *= a_slice[j * (n + 1)];
}
}
}
}
pub(crate) fn trmv_un<T>(
a: MatRef<'_, T>,
mut x: VecMut<'_, T>,
)
where
T: Copy
+ AddAssign
+ MulAssign
+ Mul<Output=T>
+ Add<Output=T>
+ Fma
+ SimdElement,
Simd<T, N_ROWS_PER_CHUNK>: SimdFloat<Scalar=T>
+ AddAssign
+ Mul<Output = Simd<T, N_ROWS_PER_CHUNK>>
+ Fma,
{
let (n_rows, n_cols) = a.dimension();
assert_eq!(n_rows, n_cols, "matrix a must be square nxn to be triangular");
assert_length_eq_n_cols!(a, x);
let n = n_cols;
let a_slice = a.as_slice();
let x_slice = x.as_slice_mut();
for (cols, a_panel) in a.col_panels(N_COLS_PER_CHUNK) {
if cols.end - cols.start != N_COLS_PER_CHUNK {
for j in cols {
let xj = x_slice[j];
let acol = &a_slice[j * n..j * (n + 1)];
let xcol = &mut x_slice[..j];
axpy(
xj,
VecRef::new(acol),
VecMut::new(xcol),
);
x_slice[j] *= a_slice[j * (n + 1)];
}
} else {
let j0 = cols.start;
let col0 = a_panel.col(0);
let col1 = a_panel.col(1);
let col2 = a_panel.col(2);
let col3 = a_panel.col(3);
let c0 = &col0.as_slice()[..j0];
let c1 = &col1.as_slice()[..j0];
let c2 = &col2.as_slice()[..j0];
let c3 = &col3.as_slice()[..j0];
let x0 = x_slice[j0];
let x1 = x_slice[j0 + 1];
let x2 = x_slice[j0 + 2];
let x3 = x_slice[j0 + 3];
ftrmsv_n(
c0, c1, c2, c3,
x0, x1, x2, x3,
&mut x_slice[..j0],
);
for j in cols {
let xj = x_slice[j];
let acol = &a.as_slice()[j * n + j0..j * (n + 1)];
let xcol = &mut x_slice[j0..j];
axpy(
xj,
VecRef::new(acol),
VecMut::new(xcol),
);
x_slice[j] *= a_slice[j * (n + 1)];
}
}
}
}
pub(crate) fn trmv_lt<T>(
a: MatRef<'_, T>,
mut x: VecMut<'_, T>,
)
where
T: SimdElement
+ Copy
+ Default
+ AddAssign
+ Mul<Output=T>
+ Add<Output=T>
+ Fma,
Simd<T, LANES>: SimdFloat<Scalar=T>
+ AddAssign
+ Fma,
{
let (n_rows, n_cols) = a.dimension();
assert_eq!(n_rows, n_cols, "matrix a must be square nxn to be triangular");
assert_length_eq_n_cols!(a, x);
let n = n_cols;
let a_slice = a.as_slice();
let x_slice = x.as_slice_mut();
for j in 0..n {
let xv = &x_slice[j..];
let av = &a_slice[j * (n + 1)..(j + 1) * n];
x_slice[j] = dot(
VecRef::new(xv),
VecRef::new(av),
);
}
}
pub(crate) fn trmv_ut<T>(
a: MatRef<'_, T>,
mut x: VecMut<'_, T>,
)
where
T: SimdElement
+ Copy
+ Default
+ AddAssign
+ Mul<Output=T>
+ Add<Output=T>
+ Fma,
Simd<T, LANES>: SimdFloat<Scalar=T>
+ AddAssign
+ Fma,
{
let (n_rows, n_cols) = a.dimension();
assert_eq!(n_rows, n_cols, "matrix a must be square nxn to be triangular");
assert_length_eq_n_cols!(a, x);
let n = n_rows;
let a_slice = a.as_slice();
let x_slice = x.as_slice_mut();
for j in (0..n).rev() {
let xv = &x_slice[..j + 1];
let av = &a_slice[j * n..j * (n + 1) + 1];
x_slice[j] = dot(
VecRef::new(xv),
VecRef::new(av),
);
}
}
pub fn trmv<T>(
uplo: Triangular,
trans: Transpose,
a: MatRef<'_, T>,
x: VecMut<'_, T>,
)
where
T: Copy
+ AddAssign
+ MulAssign
+ Mul<Output=T>
+ Add<Output=T>
+ Fma
+ SimdElement
+ Default,
Simd<T, N_ROWS_PER_CHUNK>: SimdFloat<Scalar=T>
+ AddAssign
+ Mul<Output = Simd<T, N_ROWS_PER_CHUNK>>
+ Fma,
Simd<T, LANES>: SimdFloat<Scalar=T>
+ AddAssign
+ Fma,
{
match uplo {
Triangular::Upper => {
match trans {
Transpose::NoTranspose => trmv_un(a, x),
Transpose::Transpose => trmv_ut(a, x),
}
},
Triangular::Lower => {
match trans {
Transpose::NoTranspose => trmv_ln(a, x),
Transpose::Transpose => trmv_lt(a, x),
}
}
}
}