#[allow(dead_code)]
pub fn axpy_minus(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus: dst and src length mismatch"
);
struct K<'a> {
neg_alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha,
src,
dst,
} = self;
let neg_a = simd.splat_f64s(neg_alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
for (d, s) in dst_body.iter_mut().zip(src_body) {
*d = simd.mul_add_f64s(neg_a, *s, *d);
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.mul_add_f64s(neg_a, s, d));
}
}
}
pulp::Arch::new().dispatch(K {
neg_alpha: -alpha,
src,
dst,
});
}
#[allow(dead_code)]
pub fn axpy2_minus(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus: dst and src1 length mismatch"
);
struct K<'a> {
neg_alpha0: f64,
neg_alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha0,
neg_alpha1,
src0,
src1,
dst,
} = self;
let na0 = simd.splat_f64s(neg_alpha0);
let na1 = simd.splat_f64s(neg_alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
for ((d, s0), s1) in d_body.iter_mut().zip(s0_body).zip(s1_body) {
let tmp = simd.mul_add_f64s(na0, *s0, *d);
*d = simd.mul_add_f64s(na1, *s1, tmp);
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let tmp = simd.mul_add_f64s(na0, s0v, dv);
let r = simd.mul_add_f64s(na1, s1v, tmp);
simd.partial_store_f64s(d_tail, r);
}
}
}
pulp::Arch::new().dispatch(K {
neg_alpha0: -alpha0,
neg_alpha1: -alpha1,
src0,
src1,
dst,
});
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
pub fn axpy_minus_direct(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus_direct: dst and src length mismatch"
);
struct K<'a> {
neg_alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha,
src,
dst,
} = self;
let neg_a = simd.splat_f64s(neg_alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
for (d, s) in dst_body.iter_mut().zip(src_body) {
*d = simd.mul_add_f64s(neg_a, *s, *d);
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.mul_add_f64s(neg_a, s, d));
}
}
}
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
use pulp::WithSimd;
K {
neg_alpha: -alpha,
src,
dst,
}
.with_simd(NEON);
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
pub fn axpy2_minus_direct(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus_direct: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus_direct: dst and src1 length mismatch"
);
struct K<'a> {
neg_alpha0: f64,
neg_alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha0,
neg_alpha1,
src0,
src1,
dst,
} = self;
let na0 = simd.splat_f64s(neg_alpha0);
let na1 = simd.splat_f64s(neg_alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
for ((d, s0), s1) in d_body.iter_mut().zip(s0_body).zip(s1_body) {
let tmp = simd.mul_add_f64s(na0, *s0, *d);
*d = simd.mul_add_f64s(na1, *s1, tmp);
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let tmp = simd.mul_add_f64s(na0, s0v, dv);
let r = simd.mul_add_f64s(na1, s1v, tmp);
simd.partial_store_f64s(d_tail, r);
}
}
}
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
use pulp::WithSimd;
K {
neg_alpha0: -alpha0,
neg_alpha1: -alpha1,
src0,
src1,
dst,
}
.with_simd(NEON);
}
#[inline(always)]
fn dispatch_fma<K: pulp::WithSimd>(k: K) -> K::Output {
#[cfg(target_arch = "aarch64")]
{
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
k.with_simd(NEON)
}
#[cfg(target_arch = "x86_64")]
{
match pulp::x86::V3::try_new() {
Some(v3) => k.with_simd(v3),
None => pulp::Arch::new().dispatch(k),
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
pulp::Arch::new().dispatch(k)
}
}
#[allow(dead_code)]
pub fn axpy_minus_unroll4(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus_unroll4: dst and src length mismatch"
);
struct K<'a> {
neg_alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha,
src,
dst,
} = self;
let neg_a = simd.splat_f64s(neg_alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = dst_body.chunks_exact_mut(4);
let mut s_chunks = src_body.chunks_exact(4);
for (dc, sc) in (&mut d_chunks).zip(&mut s_chunks) {
let r0 = simd.mul_add_f64s(neg_a, sc[0], dc[0]);
let r1 = simd.mul_add_f64s(neg_a, sc[1], dc[1]);
let r2 = simd.mul_add_f64s(neg_a, sc[2], dc[2]);
let r3 = simd.mul_add_f64s(neg_a, sc[3], dc[3]);
dc[0] = r0;
dc[1] = r1;
dc[2] = r2;
dc[3] = r3;
}
let d_rem = d_chunks.into_remainder();
let s_rem = s_chunks.remainder();
for (d, s) in d_rem.iter_mut().zip(s_rem) {
*d = simd.mul_add_f64s(neg_a, *s, *d);
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.mul_add_f64s(neg_a, s, d));
}
}
}
dispatch_fma(K {
neg_alpha: -alpha,
src,
dst,
});
}
#[allow(dead_code)]
pub fn axpy2_minus_unroll4(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus_unroll4: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus_unroll4: dst and src1 length mismatch"
);
struct K<'a> {
neg_alpha0: f64,
neg_alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha0,
neg_alpha1,
src0,
src1,
dst,
} = self;
let na0 = simd.splat_f64s(neg_alpha0);
let na1 = simd.splat_f64s(neg_alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = d_body.chunks_exact_mut(4);
let mut s0_chunks = s0_body.chunks_exact(4);
let mut s1_chunks = s1_body.chunks_exact(4);
for ((dc, s0c), s1c) in (&mut d_chunks).zip(&mut s0_chunks).zip(&mut s1_chunks) {
let t0 = simd.mul_add_f64s(na0, s0c[0], dc[0]);
let t1 = simd.mul_add_f64s(na0, s0c[1], dc[1]);
let t2 = simd.mul_add_f64s(na0, s0c[2], dc[2]);
let t3 = simd.mul_add_f64s(na0, s0c[3], dc[3]);
let r0 = simd.mul_add_f64s(na1, s1c[0], t0);
let r1 = simd.mul_add_f64s(na1, s1c[1], t1);
let r2 = simd.mul_add_f64s(na1, s1c[2], t2);
let r3 = simd.mul_add_f64s(na1, s1c[3], t3);
dc[0] = r0;
dc[1] = r1;
dc[2] = r2;
dc[3] = r3;
}
let d_rem = d_chunks.into_remainder();
let s0_rem = s0_chunks.remainder();
let s1_rem = s1_chunks.remainder();
for ((d, s0), s1) in d_rem.iter_mut().zip(s0_rem).zip(s1_rem) {
let tmp = simd.mul_add_f64s(na0, *s0, *d);
*d = simd.mul_add_f64s(na1, *s1, tmp);
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let tmp = simd.mul_add_f64s(na0, s0v, dv);
let r = simd.mul_add_f64s(na1, s1v, tmp);
simd.partial_store_f64s(d_tail, r);
}
}
}
dispatch_fma(K {
neg_alpha0: -alpha0,
neg_alpha1: -alpha1,
src0,
src1,
dst,
});
}
#[inline(always)]
fn dispatch_nofma<K: pulp::WithSimd>(k: K) -> K::Output {
#[cfg(target_arch = "aarch64")]
{
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
k.with_simd(NEON)
}
#[cfg(target_arch = "x86_64")]
{
match pulp::x86::V3::try_new() {
Some(v3) => k.with_simd(v3),
None => pulp::Arch::new().dispatch(k),
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
pulp::Arch::new().dispatch(k)
}
}
#[allow(dead_code)]
pub fn axpy_minus_unroll4_nofma(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus_unroll4_nofma: dst and src length mismatch"
);
struct K<'a> {
alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self { alpha, src, dst } = self;
let a = simd.splat_f64s(alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = dst_body.chunks_exact_mut(4);
let mut s_chunks = src_body.chunks_exact(4);
for (dc, sc) in (&mut d_chunks).zip(&mut s_chunks) {
let m0 = simd.mul_f64s(a, sc[0]);
let m1 = simd.mul_f64s(a, sc[1]);
let m2 = simd.mul_f64s(a, sc[2]);
let m3 = simd.mul_f64s(a, sc[3]);
let r0 = simd.sub_f64s(dc[0], m0);
let r1 = simd.sub_f64s(dc[1], m1);
let r2 = simd.sub_f64s(dc[2], m2);
let r3 = simd.sub_f64s(dc[3], m3);
dc[0] = r0;
dc[1] = r1;
dc[2] = r2;
dc[3] = r3;
}
let d_rem = d_chunks.into_remainder();
let s_rem = s_chunks.remainder();
for (d, s) in d_rem.iter_mut().zip(s_rem) {
*d = simd.sub_f64s(*d, simd.mul_f64s(a, *s));
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.sub_f64s(d, simd.mul_f64s(a, s)));
}
}
}
dispatch_nofma(K { alpha, src, dst });
}
#[allow(dead_code)]
pub fn axpy2_minus_unroll4_nofma(
dst: &mut [f64],
src0: &[f64],
alpha0: f64,
src1: &[f64],
alpha1: f64,
) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus_unroll4_nofma: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus_unroll4_nofma: dst and src1 length mismatch"
);
struct K<'a> {
alpha0: f64,
alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
alpha0,
alpha1,
src0,
src1,
dst,
} = self;
let a0 = simd.splat_f64s(alpha0);
let a1 = simd.splat_f64s(alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = d_body.chunks_exact_mut(4);
let mut s0_chunks = s0_body.chunks_exact(4);
let mut s1_chunks = s1_body.chunks_exact(4);
for ((dc, s0c), s1c) in (&mut d_chunks).zip(&mut s0_chunks).zip(&mut s1_chunks) {
let m00 = simd.mul_f64s(a0, s0c[0]);
let m01 = simd.mul_f64s(a0, s0c[1]);
let m02 = simd.mul_f64s(a0, s0c[2]);
let m03 = simd.mul_f64s(a0, s0c[3]);
let m10 = simd.mul_f64s(a1, s1c[0]);
let m11 = simd.mul_f64s(a1, s1c[1]);
let m12 = simd.mul_f64s(a1, s1c[2]);
let m13 = simd.mul_f64s(a1, s1c[3]);
let t0 = simd.add_f64s(m00, m10);
let t1 = simd.add_f64s(m01, m11);
let t2 = simd.add_f64s(m02, m12);
let t3 = simd.add_f64s(m03, m13);
dc[0] = simd.sub_f64s(dc[0], t0);
dc[1] = simd.sub_f64s(dc[1], t1);
dc[2] = simd.sub_f64s(dc[2], t2);
dc[3] = simd.sub_f64s(dc[3], t3);
}
let d_rem = d_chunks.into_remainder();
let s0_rem = s0_chunks.remainder();
let s1_rem = s1_chunks.remainder();
for ((d, s0), s1) in d_rem.iter_mut().zip(s0_rem).zip(s1_rem) {
let m0 = simd.mul_f64s(a0, *s0);
let m1 = simd.mul_f64s(a1, *s1);
*d = simd.sub_f64s(*d, simd.add_f64s(m0, m1));
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let m0 = simd.mul_f64s(a0, s0v);
let m1 = simd.mul_f64s(a1, s1v);
let r = simd.sub_f64s(dv, simd.add_f64s(m0, m1));
simd.partial_store_f64s(d_tail, r);
}
}
}
dispatch_nofma(K {
alpha0,
alpha1,
src0,
src1,
dst,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_nofma_strided(
dst: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
len: usize,
alphas: &[f64],
) {
assert_eq!(
dst.len(),
len,
"schur_panel_minus_nofma_strided: dst.len() must equal len"
);
assert_eq!(
alphas.len(),
n_elim,
"schur_panel_minus_nofma_strided: alphas.len() must equal n_elim"
);
if n_elim == 0 || len == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_nofma_strided: src_block too short ({} < {})",
src_block.len(),
max_idx
);
struct K<'a> {
dst: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
len: usize,
alphas: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset,
len,
alphas,
} = self;
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let body_len = dst_body.len();
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 4;
for chunk_idx in 0..chunks {
let base = chunk_idx * 4;
let mut a0 = dst_body[base];
let mut a1 = dst_body[base + 1];
let mut a2 = dst_body[base + 2];
let mut a3 = dst_body[base + 3];
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let av = simd.splat_f64s(alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let m0 = simd.mul_f64s(av, sb[base]);
let m1 = simd.mul_f64s(av, sb[base + 1]);
let m2 = simd.mul_f64s(av, sb[base + 2]);
let m3 = simd.mul_f64s(av, sb[base + 3]);
a0 = simd.sub_f64s(a0, m0);
a1 = simd.sub_f64s(a1, m1);
a2 = simd.sub_f64s(a2, m2);
a3 = simd.sub_f64s(a3, m3);
}
dst_body[base] = a0;
dst_body[base + 1] = a1;
dst_body[base + 2] = a2;
dst_body[base + 3] = a3;
}
let tail_chunks_start = chunks * 4;
for body_idx in tail_chunks_start..body_len {
let mut acc = dst_body[body_idx];
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let av = simd.splat_f64s(alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let m = simd.mul_f64s(av, sb[body_idx]);
acc = simd.sub_f64s(acc, m);
}
dst_body[body_idx] = acc;
}
if !dst_tail.is_empty() {
let mut acc = simd.partial_load_f64s(dst_tail);
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let av = simd.splat_f64s(alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
let m = simd.mul_f64s(av, s);
acc = simd.sub_f64s(acc, m);
}
simd.partial_store_f64s(dst_tail, acc);
}
}
}
dispatch_nofma(K {
dst,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset,
len,
alphas,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_fma_strided(
dst: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
len: usize,
alphas: &[f64],
) {
assert_eq!(
dst.len(),
len,
"schur_panel_minus_fma_strided: dst.len() must equal len"
);
assert_eq!(
alphas.len(),
n_elim,
"schur_panel_minus_fma_strided: alphas.len() must equal n_elim"
);
if n_elim == 0 || len == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_fma_strided: src_block too short ({} < {})",
src_block.len(),
max_idx
);
struct K<'a> {
dst: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
len: usize,
alphas: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset,
len,
alphas,
} = self;
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let body_len = dst_body.len();
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 4;
for chunk_idx in 0..chunks {
let base = chunk_idx * 4;
let mut a0 = dst_body[base];
let mut a1 = dst_body[base + 1];
let mut a2 = dst_body[base + 2];
let mut a3 = dst_body[base + 3];
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let nav = simd.splat_f64s(-alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
a0 = simd.mul_add_f64s(nav, sb[base], a0);
a1 = simd.mul_add_f64s(nav, sb[base + 1], a1);
a2 = simd.mul_add_f64s(nav, sb[base + 2], a2);
a3 = simd.mul_add_f64s(nav, sb[base + 3], a3);
}
dst_body[base] = a0;
dst_body[base + 1] = a1;
dst_body[base + 2] = a2;
dst_body[base + 3] = a3;
}
let tail_chunks_start = chunks * 4;
for body_idx in tail_chunks_start..body_len {
let mut acc = dst_body[body_idx];
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let nav = simd.splat_f64s(-alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
acc = simd.mul_add_f64s(nav, sb[body_idx], acc);
}
dst_body[body_idx] = acc;
}
if !dst_tail.is_empty() {
let mut acc = simd.partial_load_f64s(dst_tail);
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let nav = simd.splat_f64s(-alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
acc = simd.mul_add_f64s(nav, s, acc);
}
simd.partial_store_f64s(dst_tail, acc);
}
}
}
dispatch_fma(K {
dst,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset,
len,
alphas,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_fma_strided_dual(
dst0: &mut [f64],
dst1: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
alphas0: &[f64],
alphas1: &[f64],
) {
let len0 = dst0.len();
let len1 = dst1.len();
assert_eq!(
len1 + 1,
len0,
"schur_panel_minus_fma_strided_dual: dst1 must be exactly one shorter than dst0 \
(len0={}, len1={})",
len0,
len1
);
assert_eq!(
alphas0.len(),
n_elim,
"schur_panel_minus_fma_strided_dual: alphas0.len() must equal n_elim"
);
assert_eq!(
alphas1.len(),
n_elim,
"schur_panel_minus_fma_strided_dual: alphas1.len() must equal n_elim"
);
if n_elim == 0 || len0 == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len0;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_fma_strided_dual: src_block too short ({} < {})",
src_block.len(),
max_idx
);
for (q, &alpha_q) in alphas0.iter().enumerate().take(n_elim) {
if alpha_q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let s = src_block[col_off];
dst0[0] = (-alpha_q).mul_add(s, dst0[0]);
}
if len1 == 0 {
return;
}
struct K<'a> {
dst0_bulk: &'a mut [f64],
dst1: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset_bulk: usize,
len: usize,
alphas0: &'a [f64],
alphas1: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst0_bulk,
dst1,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk,
len,
alphas0,
alphas1,
} = self;
let (d0_body, d0_tail) = S::as_mut_simd_f64s(dst0_bulk);
let (d1_body, d1_tail) = S::as_mut_simd_f64s(dst1);
let body_len = d0_body.len();
debug_assert_eq!(body_len, d1_body.len());
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 4;
for chunk_idx in 0..chunks {
let base = chunk_idx * 4;
let mut a00 = d0_body[base];
let mut a01 = d0_body[base + 1];
let mut a02 = d0_body[base + 2];
let mut a03 = d0_body[base + 3];
let mut a10 = d1_body[base];
let mut a11 = d1_body[base + 1];
let mut a12 = d1_body[base + 2];
let mut a13 = d1_body[base + 3];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s0v = sb[base];
let s1v = sb[base + 1];
let s2v = sb[base + 2];
let s3v = sb[base + 3];
if a0q != 0.0 {
let nav0 = simd.splat_f64s(-a0q);
a00 = simd.mul_add_f64s(nav0, s0v, a00);
a01 = simd.mul_add_f64s(nav0, s1v, a01);
a02 = simd.mul_add_f64s(nav0, s2v, a02);
a03 = simd.mul_add_f64s(nav0, s3v, a03);
}
if a1q != 0.0 {
let nav1 = simd.splat_f64s(-a1q);
a10 = simd.mul_add_f64s(nav1, s0v, a10);
a11 = simd.mul_add_f64s(nav1, s1v, a11);
a12 = simd.mul_add_f64s(nav1, s2v, a12);
a13 = simd.mul_add_f64s(nav1, s3v, a13);
}
}
d0_body[base] = a00;
d0_body[base + 1] = a01;
d0_body[base + 2] = a02;
d0_body[base + 3] = a03;
d1_body[base] = a10;
d1_body[base + 1] = a11;
d1_body[base + 2] = a12;
d1_body[base + 3] = a13;
}
let tail_chunks_start = chunks * 4;
for body_idx in tail_chunks_start..body_len {
let mut acc0 = d0_body[body_idx];
let mut acc1 = d1_body[body_idx];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s = sb[body_idx];
if a0q != 0.0 {
let nav0 = simd.splat_f64s(-a0q);
acc0 = simd.mul_add_f64s(nav0, s, acc0);
}
if a1q != 0.0 {
let nav1 = simd.splat_f64s(-a1q);
acc1 = simd.mul_add_f64s(nav1, s, acc1);
}
}
d0_body[body_idx] = acc0;
d1_body[body_idx] = acc1;
}
if !d0_tail.is_empty() {
let mut acc0 = simd.partial_load_f64s(d0_tail);
let mut acc1 = simd.partial_load_f64s(d1_tail);
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
if a0q != 0.0 {
let nav0 = simd.splat_f64s(-a0q);
acc0 = simd.mul_add_f64s(nav0, s, acc0);
}
if a1q != 0.0 {
let nav1 = simd.splat_f64s(-a1q);
acc1 = simd.mul_add_f64s(nav1, s, acc1);
}
}
simd.partial_store_f64s(d0_tail, acc0);
simd.partial_store_f64s(d1_tail, acc1);
}
}
}
let (dst0_cap, dst0_bulk) = dst0.split_at_mut(1);
let _ = dst0_cap;
dispatch_fma(K {
dst0_bulk,
dst1,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk: src_row_offset + 1,
len: len1,
alphas0,
alphas1,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_nofma_strided_dual(
dst0: &mut [f64],
dst1: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
alphas0: &[f64],
alphas1: &[f64],
) {
let len0 = dst0.len();
let len1 = dst1.len();
assert_eq!(
len1 + 1,
len0,
"schur_panel_minus_nofma_strided_dual: dst1 must be exactly one shorter than dst0 \
(len0={}, len1={})",
len0,
len1
);
assert_eq!(
alphas0.len(),
n_elim,
"schur_panel_minus_nofma_strided_dual: alphas0.len() must equal n_elim"
);
assert_eq!(
alphas1.len(),
n_elim,
"schur_panel_minus_nofma_strided_dual: alphas1.len() must equal n_elim"
);
if n_elim == 0 || len0 == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len0;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_nofma_strided_dual: src_block too short ({} < {})",
src_block.len(),
max_idx
);
for (q, &alpha_q) in alphas0.iter().enumerate().take(n_elim) {
if alpha_q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let s = src_block[col_off];
dst0[0] -= alpha_q * s;
}
if len1 == 0 {
return;
}
struct K<'a> {
dst0_bulk: &'a mut [f64],
dst1: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset_bulk: usize,
len: usize,
alphas0: &'a [f64],
alphas1: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst0_bulk,
dst1,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk,
len,
alphas0,
alphas1,
} = self;
let (d0_body, d0_tail) = S::as_mut_simd_f64s(dst0_bulk);
let (d1_body, d1_tail) = S::as_mut_simd_f64s(dst1);
let body_len = d0_body.len();
debug_assert_eq!(body_len, d1_body.len());
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 4;
for chunk_idx in 0..chunks {
let base = chunk_idx * 4;
let mut a00 = d0_body[base];
let mut a01 = d0_body[base + 1];
let mut a02 = d0_body[base + 2];
let mut a03 = d0_body[base + 3];
let mut a10 = d1_body[base];
let mut a11 = d1_body[base + 1];
let mut a12 = d1_body[base + 2];
let mut a13 = d1_body[base + 3];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s0v = sb[base];
let s1v = sb[base + 1];
let s2v = sb[base + 2];
let s3v = sb[base + 3];
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
a00 = simd.sub_f64s(a00, simd.mul_f64s(av0, s0v));
a01 = simd.sub_f64s(a01, simd.mul_f64s(av0, s1v));
a02 = simd.sub_f64s(a02, simd.mul_f64s(av0, s2v));
a03 = simd.sub_f64s(a03, simd.mul_f64s(av0, s3v));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
a10 = simd.sub_f64s(a10, simd.mul_f64s(av1, s0v));
a11 = simd.sub_f64s(a11, simd.mul_f64s(av1, s1v));
a12 = simd.sub_f64s(a12, simd.mul_f64s(av1, s2v));
a13 = simd.sub_f64s(a13, simd.mul_f64s(av1, s3v));
}
}
d0_body[base] = a00;
d0_body[base + 1] = a01;
d0_body[base + 2] = a02;
d0_body[base + 3] = a03;
d1_body[base] = a10;
d1_body[base + 1] = a11;
d1_body[base + 2] = a12;
d1_body[base + 3] = a13;
}
let tail_chunks_start = chunks * 4;
for body_idx in tail_chunks_start..body_len {
let mut acc0 = d0_body[body_idx];
let mut acc1 = d1_body[body_idx];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s = sb[body_idx];
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
acc0 = simd.sub_f64s(acc0, simd.mul_f64s(av0, s));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
acc1 = simd.sub_f64s(acc1, simd.mul_f64s(av1, s));
}
}
d0_body[body_idx] = acc0;
d1_body[body_idx] = acc1;
}
if !d0_tail.is_empty() {
let mut acc0 = simd.partial_load_f64s(d0_tail);
let mut acc1 = simd.partial_load_f64s(d1_tail);
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
acc0 = simd.sub_f64s(acc0, simd.mul_f64s(av0, s));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
acc1 = simd.sub_f64s(acc1, simd.mul_f64s(av1, s));
}
}
simd.partial_store_f64s(d0_tail, acc0);
simd.partial_store_f64s(d1_tail, acc1);
}
}
}
let (dst0_cap, dst0_bulk) = dst0.split_at_mut(1);
let _ = dst0_cap;
dispatch_nofma(K {
dst0_bulk,
dst1,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk: src_row_offset + 1,
len: len1,
alphas0,
alphas1,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_nofma_strided_quad(
dst0: &mut [f64],
dst1: &mut [f64],
dst2: &mut [f64],
dst3: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
alphas0: &[f64],
alphas1: &[f64],
alphas2: &[f64],
alphas3: &[f64],
) {
let len0 = dst0.len();
let len1 = dst1.len();
let len2 = dst2.len();
let len3 = dst3.len();
assert_eq!(
len1 + 1,
len0,
"schur_panel_minus_nofma_strided_quad: dst1 must be exactly one shorter than dst0 \
(len0={}, len1={})",
len0,
len1
);
assert_eq!(
len2 + 2,
len0,
"schur_panel_minus_nofma_strided_quad: dst2 must be exactly two shorter than dst0 \
(len0={}, len2={})",
len0,
len2
);
assert_eq!(
len3 + 3,
len0,
"schur_panel_minus_nofma_strided_quad: dst3 must be exactly three shorter than dst0 \
(len0={}, len3={})",
len0,
len3
);
assert_eq!(
alphas0.len(),
n_elim,
"schur_panel_minus_nofma_strided_quad: alphas0.len() must equal n_elim"
);
assert_eq!(
alphas1.len(),
n_elim,
"schur_panel_minus_nofma_strided_quad: alphas1.len() must equal n_elim"
);
assert_eq!(
alphas2.len(),
n_elim,
"schur_panel_minus_nofma_strided_quad: alphas2.len() must equal n_elim"
);
assert_eq!(
alphas3.len(),
n_elim,
"schur_panel_minus_nofma_strided_quad: alphas3.len() must equal n_elim"
);
if n_elim == 0 || len0 == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len0;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_nofma_strided_quad: src_block too short ({} < {})",
src_block.len(),
max_idx
);
for (q, &alpha_q) in alphas0.iter().enumerate().take(n_elim) {
if alpha_q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let s = src_block[col_off];
dst0[0] -= alpha_q * s;
}
if len0 == 1 {
return;
}
for q in 0..n_elim {
let col_off = (src_first_col + q) * col_stride + src_row_offset + 1;
let s = src_block[col_off];
let a0 = alphas0[q];
let a1 = alphas1[q];
if a0 != 0.0 {
dst0[1] -= a0 * s;
}
if a1 != 0.0 {
dst1[0] -= a1 * s;
}
}
if len0 == 2 {
return;
}
for q in 0..n_elim {
let col_off = (src_first_col + q) * col_stride + src_row_offset + 2;
let s = src_block[col_off];
let a0 = alphas0[q];
let a1 = alphas1[q];
let a2 = alphas2[q];
if a0 != 0.0 {
dst0[2] -= a0 * s;
}
if a1 != 0.0 {
dst1[1] -= a1 * s;
}
if a2 != 0.0 {
dst2[0] -= a2 * s;
}
}
if len0 == 3 {
return;
}
struct K<'a> {
dst0_bulk: &'a mut [f64],
dst1_bulk: &'a mut [f64],
dst2_bulk: &'a mut [f64],
dst3_bulk: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset_bulk: usize,
len: usize,
alphas0: &'a [f64],
alphas1: &'a [f64],
alphas2: &'a [f64],
alphas3: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst0_bulk,
dst1_bulk,
dst2_bulk,
dst3_bulk,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk,
len,
alphas0,
alphas1,
alphas2,
alphas3,
} = self;
let (d0_body, d0_tail) = S::as_mut_simd_f64s(dst0_bulk);
let (d1_body, d1_tail) = S::as_mut_simd_f64s(dst1_bulk);
let (d2_body, d2_tail) = S::as_mut_simd_f64s(dst2_bulk);
let (d3_body, d3_tail) = S::as_mut_simd_f64s(dst3_bulk);
let body_len = d0_body.len();
debug_assert_eq!(body_len, d1_body.len());
debug_assert_eq!(body_len, d2_body.len());
debug_assert_eq!(body_len, d3_body.len());
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 2;
for chunk_idx in 0..chunks {
let base = chunk_idx * 2;
let mut a00 = d0_body[base];
let mut a01 = d0_body[base + 1];
let mut a10 = d1_body[base];
let mut a11 = d1_body[base + 1];
let mut a20 = d2_body[base];
let mut a21 = d2_body[base + 1];
let mut a30 = d3_body[base];
let mut a31 = d3_body[base + 1];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
let a2q = alphas2[q];
let a3q = alphas3[q];
if a0q == 0.0 && a1q == 0.0 && a2q == 0.0 && a3q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s0 = sb[base];
let s1 = sb[base + 1];
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
a00 = simd.sub_f64s(a00, simd.mul_f64s(av0, s0));
a01 = simd.sub_f64s(a01, simd.mul_f64s(av0, s1));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
a10 = simd.sub_f64s(a10, simd.mul_f64s(av1, s0));
a11 = simd.sub_f64s(a11, simd.mul_f64s(av1, s1));
}
if a2q != 0.0 {
let av2 = simd.splat_f64s(a2q);
a20 = simd.sub_f64s(a20, simd.mul_f64s(av2, s0));
a21 = simd.sub_f64s(a21, simd.mul_f64s(av2, s1));
}
if a3q != 0.0 {
let av3 = simd.splat_f64s(a3q);
a30 = simd.sub_f64s(a30, simd.mul_f64s(av3, s0));
a31 = simd.sub_f64s(a31, simd.mul_f64s(av3, s1));
}
}
d0_body[base] = a00;
d0_body[base + 1] = a01;
d1_body[base] = a10;
d1_body[base + 1] = a11;
d2_body[base] = a20;
d2_body[base + 1] = a21;
d3_body[base] = a30;
d3_body[base + 1] = a31;
}
let tail_chunks_start = chunks * 2;
for body_idx in tail_chunks_start..body_len {
let mut acc0 = d0_body[body_idx];
let mut acc1 = d1_body[body_idx];
let mut acc2 = d2_body[body_idx];
let mut acc3 = d3_body[body_idx];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
let a2q = alphas2[q];
let a3q = alphas3[q];
if a0q == 0.0 && a1q == 0.0 && a2q == 0.0 && a3q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s = sb[body_idx];
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
acc0 = simd.sub_f64s(acc0, simd.mul_f64s(av0, s));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
acc1 = simd.sub_f64s(acc1, simd.mul_f64s(av1, s));
}
if a2q != 0.0 {
let av2 = simd.splat_f64s(a2q);
acc2 = simd.sub_f64s(acc2, simd.mul_f64s(av2, s));
}
if a3q != 0.0 {
let av3 = simd.splat_f64s(a3q);
acc3 = simd.sub_f64s(acc3, simd.mul_f64s(av3, s));
}
}
d0_body[body_idx] = acc0;
d1_body[body_idx] = acc1;
d2_body[body_idx] = acc2;
d3_body[body_idx] = acc3;
}
if !d0_tail.is_empty() {
let mut acc0 = simd.partial_load_f64s(d0_tail);
let mut acc1 = simd.partial_load_f64s(d1_tail);
let mut acc2 = simd.partial_load_f64s(d2_tail);
let mut acc3 = simd.partial_load_f64s(d3_tail);
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
let a2q = alphas2[q];
let a3q = alphas3[q];
if a0q == 0.0 && a1q == 0.0 && a2q == 0.0 && a3q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
acc0 = simd.sub_f64s(acc0, simd.mul_f64s(av0, s));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
acc1 = simd.sub_f64s(acc1, simd.mul_f64s(av1, s));
}
if a2q != 0.0 {
let av2 = simd.splat_f64s(a2q);
acc2 = simd.sub_f64s(acc2, simd.mul_f64s(av2, s));
}
if a3q != 0.0 {
let av3 = simd.splat_f64s(a3q);
acc3 = simd.sub_f64s(acc3, simd.mul_f64s(av3, s));
}
}
simd.partial_store_f64s(d0_tail, acc0);
simd.partial_store_f64s(d1_tail, acc1);
simd.partial_store_f64s(d2_tail, acc2);
simd.partial_store_f64s(d3_tail, acc3);
}
}
}
let (_d0_cap, d0_bulk) = dst0.split_at_mut(3);
let (_d1_cap, d1_bulk) = dst1.split_at_mut(2);
let (_d2_cap, d2_bulk) = dst2.split_at_mut(1);
dispatch_nofma(K {
dst0_bulk: d0_bulk,
dst1_bulk: d1_bulk,
dst2_bulk: d2_bulk,
dst3_bulk: dst3,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk: src_row_offset + 3,
len: len3,
alphas0,
alphas1,
alphas2,
alphas3,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_fma_strided_quad(
dst0: &mut [f64],
dst1: &mut [f64],
dst2: &mut [f64],
dst3: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
alphas0: &[f64],
alphas1: &[f64],
alphas2: &[f64],
alphas3: &[f64],
) {
let len0 = dst0.len();
let len1 = dst1.len();
let len2 = dst2.len();
let len3 = dst3.len();
assert_eq!(
len1 + 1,
len0,
"schur_panel_minus_fma_strided_quad: dst1 must be exactly one shorter than dst0 \
(len0={}, len1={})",
len0,
len1
);
assert_eq!(
len2 + 2,
len0,
"schur_panel_minus_fma_strided_quad: dst2 must be exactly two shorter than dst0 \
(len0={}, len2={})",
len0,
len2
);
assert_eq!(
len3 + 3,
len0,
"schur_panel_minus_fma_strided_quad: dst3 must be exactly three shorter than dst0 \
(len0={}, len3={})",
len0,
len3
);
assert_eq!(
alphas0.len(),
n_elim,
"schur_panel_minus_fma_strided_quad: alphas0.len() must equal n_elim"
);
assert_eq!(
alphas1.len(),
n_elim,
"schur_panel_minus_fma_strided_quad: alphas1.len() must equal n_elim"
);
assert_eq!(
alphas2.len(),
n_elim,
"schur_panel_minus_fma_strided_quad: alphas2.len() must equal n_elim"
);
assert_eq!(
alphas3.len(),
n_elim,
"schur_panel_minus_fma_strided_quad: alphas3.len() must equal n_elim"
);
if n_elim == 0 || len0 == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len0;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_fma_strided_quad: src_block too short ({} < {})",
src_block.len(),
max_idx
);
for (q, &alpha_q) in alphas0.iter().enumerate().take(n_elim) {
if alpha_q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let s = src_block[col_off];
dst0[0] = (-alpha_q).mul_add(s, dst0[0]);
}
if len0 == 1 {
return;
}
for q in 0..n_elim {
let col_off = (src_first_col + q) * col_stride + src_row_offset + 1;
let s = src_block[col_off];
let a0 = alphas0[q];
let a1 = alphas1[q];
if a0 != 0.0 {
dst0[1] = (-a0).mul_add(s, dst0[1]);
}
if a1 != 0.0 {
dst1[0] = (-a1).mul_add(s, dst1[0]);
}
}
if len0 == 2 {
return;
}
for q in 0..n_elim {
let col_off = (src_first_col + q) * col_stride + src_row_offset + 2;
let s = src_block[col_off];
let a0 = alphas0[q];
let a1 = alphas1[q];
let a2 = alphas2[q];
if a0 != 0.0 {
dst0[2] = (-a0).mul_add(s, dst0[2]);
}
if a1 != 0.0 {
dst1[1] = (-a1).mul_add(s, dst1[1]);
}
if a2 != 0.0 {
dst2[0] = (-a2).mul_add(s, dst2[0]);
}
}
if len0 == 3 {
return;
}
struct K<'a> {
dst0_bulk: &'a mut [f64],
dst1_bulk: &'a mut [f64],
dst2_bulk: &'a mut [f64],
dst3_bulk: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset_bulk: usize,
len: usize,
alphas0: &'a [f64],
alphas1: &'a [f64],
alphas2: &'a [f64],
alphas3: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst0_bulk,
dst1_bulk,
dst2_bulk,
dst3_bulk,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk,
len,
alphas0,
alphas1,
alphas2,
alphas3,
} = self;
let (d0_body, d0_tail) = S::as_mut_simd_f64s(dst0_bulk);
let (d1_body, d1_tail) = S::as_mut_simd_f64s(dst1_bulk);
let (d2_body, d2_tail) = S::as_mut_simd_f64s(dst2_bulk);
let (d3_body, d3_tail) = S::as_mut_simd_f64s(dst3_bulk);
let body_len = d0_body.len();
debug_assert_eq!(body_len, d1_body.len());
debug_assert_eq!(body_len, d2_body.len());
debug_assert_eq!(body_len, d3_body.len());
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 2;
for chunk_idx in 0..chunks {
let base = chunk_idx * 2;
let mut a00 = d0_body[base];
let mut a01 = d0_body[base + 1];
let mut a10 = d1_body[base];
let mut a11 = d1_body[base + 1];
let mut a20 = d2_body[base];
let mut a21 = d2_body[base + 1];
let mut a30 = d3_body[base];
let mut a31 = d3_body[base + 1];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
let a2q = alphas2[q];
let a3q = alphas3[q];
if a0q == 0.0 && a1q == 0.0 && a2q == 0.0 && a3q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s0 = sb[base];
let s1 = sb[base + 1];
if a0q != 0.0 {
let nav0 = simd.splat_f64s(-a0q);
a00 = simd.mul_add_f64s(nav0, s0, a00);
a01 = simd.mul_add_f64s(nav0, s1, a01);
}
if a1q != 0.0 {
let nav1 = simd.splat_f64s(-a1q);
a10 = simd.mul_add_f64s(nav1, s0, a10);
a11 = simd.mul_add_f64s(nav1, s1, a11);
}
if a2q != 0.0 {
let nav2 = simd.splat_f64s(-a2q);
a20 = simd.mul_add_f64s(nav2, s0, a20);
a21 = simd.mul_add_f64s(nav2, s1, a21);
}
if a3q != 0.0 {
let nav3 = simd.splat_f64s(-a3q);
a30 = simd.mul_add_f64s(nav3, s0, a30);
a31 = simd.mul_add_f64s(nav3, s1, a31);
}
}
d0_body[base] = a00;
d0_body[base + 1] = a01;
d1_body[base] = a10;
d1_body[base + 1] = a11;
d2_body[base] = a20;
d2_body[base + 1] = a21;
d3_body[base] = a30;
d3_body[base + 1] = a31;
}
let tail_chunks_start = chunks * 2;
for body_idx in tail_chunks_start..body_len {
let mut acc0 = d0_body[body_idx];
let mut acc1 = d1_body[body_idx];
let mut acc2 = d2_body[body_idx];
let mut acc3 = d3_body[body_idx];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
let a2q = alphas2[q];
let a3q = alphas3[q];
if a0q == 0.0 && a1q == 0.0 && a2q == 0.0 && a3q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s = sb[body_idx];
if a0q != 0.0 {
let nav0 = simd.splat_f64s(-a0q);
acc0 = simd.mul_add_f64s(nav0, s, acc0);
}
if a1q != 0.0 {
let nav1 = simd.splat_f64s(-a1q);
acc1 = simd.mul_add_f64s(nav1, s, acc1);
}
if a2q != 0.0 {
let nav2 = simd.splat_f64s(-a2q);
acc2 = simd.mul_add_f64s(nav2, s, acc2);
}
if a3q != 0.0 {
let nav3 = simd.splat_f64s(-a3q);
acc3 = simd.mul_add_f64s(nav3, s, acc3);
}
}
d0_body[body_idx] = acc0;
d1_body[body_idx] = acc1;
d2_body[body_idx] = acc2;
d3_body[body_idx] = acc3;
}
if !d0_tail.is_empty() {
let mut acc0 = simd.partial_load_f64s(d0_tail);
let mut acc1 = simd.partial_load_f64s(d1_tail);
let mut acc2 = simd.partial_load_f64s(d2_tail);
let mut acc3 = simd.partial_load_f64s(d3_tail);
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
let a2q = alphas2[q];
let a3q = alphas3[q];
if a0q == 0.0 && a1q == 0.0 && a2q == 0.0 && a3q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
if a0q != 0.0 {
let nav0 = simd.splat_f64s(-a0q);
acc0 = simd.mul_add_f64s(nav0, s, acc0);
}
if a1q != 0.0 {
let nav1 = simd.splat_f64s(-a1q);
acc1 = simd.mul_add_f64s(nav1, s, acc1);
}
if a2q != 0.0 {
let nav2 = simd.splat_f64s(-a2q);
acc2 = simd.mul_add_f64s(nav2, s, acc2);
}
if a3q != 0.0 {
let nav3 = simd.splat_f64s(-a3q);
acc3 = simd.mul_add_f64s(nav3, s, acc3);
}
}
simd.partial_store_f64s(d0_tail, acc0);
simd.partial_store_f64s(d1_tail, acc1);
simd.partial_store_f64s(d2_tail, acc2);
simd.partial_store_f64s(d3_tail, acc3);
}
}
}
let (_d0_cap, d0_bulk) = dst0.split_at_mut(3);
let (_d1_cap, d1_bulk) = dst1.split_at_mut(2);
let (_d2_cap, d2_bulk) = dst2.split_at_mut(1);
dispatch_fma(K {
dst0_bulk: d0_bulk,
dst1_bulk: d1_bulk,
dst2_bulk: d2_bulk,
dst3_bulk: dst3,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk: src_row_offset + 3,
len: len3,
alphas0,
alphas1,
alphas2,
alphas3,
});
}
#[cfg(test)]
mod tests {
use super::*;
struct Xorshift64(u64);
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self(if seed == 0 {
0x9E37_79B9_7F4A_7C15
} else {
seed
})
}
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_f64(&mut self) -> f64 {
let bits = (self.next_u64() >> 12) | 0x3FF0_0000_0000_0000;
let x = f64::from_bits(bits) - 1.0; 2.0 * x - 1.0
}
}
fn naive_axpy_minus(dst: &mut [f64], src: &[f64], alpha: f64) {
for i in 0..dst.len() {
let tmp = alpha * src[i];
dst[i] -= tmp;
}
}
fn naive_axpy2_minus(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
for i in 0..dst.len() {
let t0 = alpha0 * src0[i];
let t1 = alpha1 * src1[i];
dst[i] -= t0 + t1;
}
}
const LENGTH_SWEEP: &[usize] = &[
0, 1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256,
257, 511, 512, 513, 1023, 1024,
];
const ULP4: f64 = 4.0 * f64::EPSILON * 2.0;
fn assert_close(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(a.len(), b.len(), "length mismatch in assert_close");
for i in 0..a.len() {
let diff = (a[i] - b[i]).abs();
assert!(
diff <= tol,
"element {}: {} vs {}, diff {:.3e} > {:.3e}",
i,
a[i],
b[i],
diff,
tol
);
}
}
#[test]
fn axpy_minus_zero_length() {
let mut dst: Vec<f64> = vec![];
let src: Vec<f64> = vec![];
axpy_minus(&mut dst, &src, 1.5);
assert!(dst.is_empty());
}
#[test]
fn axpy_minus_length_one() {
let mut dst = vec![5.0];
let src = vec![2.0];
axpy_minus(&mut dst, &src, 0.5);
assert_eq!(dst[0], 4.0);
}
#[test]
fn axpy_minus_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0xFE27_A100_0042_BEEFu64);
for &len in LENGTH_SWEEP {
let src: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy_minus(&mut dst_kernel, &src, alpha);
naive_axpy_minus(&mut dst_ref, &src, alpha);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[test]
#[should_panic(expected = "length mismatch")]
fn axpy_minus_length_mismatch_panics() {
let mut dst = vec![0.0; 4];
let src = vec![0.0; 3];
axpy_minus(&mut dst, &src, 1.0);
}
#[test]
fn axpy2_minus_zero_length() {
let mut dst: Vec<f64> = vec![];
let src0: Vec<f64> = vec![];
let src1: Vec<f64> = vec![];
axpy2_minus(&mut dst, &src0, 1.0, &src1, 2.0);
assert!(dst.is_empty());
}
#[test]
fn axpy2_minus_length_one() {
let mut dst = vec![10.0];
let src0 = vec![2.0];
let src1 = vec![3.0];
axpy2_minus(&mut dst, &src0, 0.5, &src1, 1.0);
assert_eq!(dst[0], 6.0);
}
#[test]
fn axpy2_minus_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0xC0FF_EE00_BAAD_F00Du64);
for &len in LENGTH_SWEEP {
let src0: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let src1: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha0 = rng.next_f64() * 1.5;
let alpha1 = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy2_minus(&mut dst_kernel, &src0, alpha0, &src1, alpha1);
naive_axpy2_minus(&mut dst_ref, &src0, alpha0, &src1, alpha1);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[test]
fn axpy_minus_alpha_zero_is_noop() {
let src = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let dst_init = vec![-3.0, 0.5, 100.0, -7.25, 1e-10, 1e10, -0.0, 42.0];
let mut dst = dst_init.clone();
axpy_minus(&mut dst, &src, 0.0);
assert_eq!(dst, dst_init);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn axpy_minus_unroll4_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0x4E27_A101_00FE_BEEFu64);
for &len in LENGTH_SWEEP {
let src: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy_minus_unroll4(&mut dst_kernel, &src, alpha);
naive_axpy_minus(&mut dst_ref, &src, alpha);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[cfg(target_arch = "aarch64")]
#[test]
fn axpy2_minus_unroll4_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0xC1FF_EE00_BAAD_F00Du64);
for &len in LENGTH_SWEEP {
let src0: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let src1: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha0 = rng.next_f64() * 1.5;
let alpha1 = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy2_minus_unroll4(&mut dst_kernel, &src0, alpha0, &src1, alpha1);
naive_axpy2_minus(&mut dst_ref, &src0, alpha0, &src1, alpha1);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[test]
fn axpy2_minus_alphas_zero_is_noop() {
let src0 = vec![1.0, 2.0, 3.0, 4.0];
let src1 = vec![5.0, 6.0, 7.0, 8.0];
let dst_init = vec![-1.0, 2.5, 3.0, -4.5];
let mut dst = dst_init.clone();
axpy2_minus(&mut dst, &src0, 0.0, &src1, 0.0);
assert_eq!(dst, dst_init);
}
#[test]
fn axpy_minus_unroll4_nofma_is_bit_exact_vs_scalar() {
let mut rng = Xorshift64::new(0xB17E_AC70_0042_F00D_u64);
for &len in LENGTH_SWEEP {
let src: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy_minus_unroll4_nofma(&mut dst_kernel, &src, alpha);
naive_axpy_minus(&mut dst_ref, &src, alpha);
assert_eq!(
dst_kernel, dst_ref,
"non-FMA unroll4 must be bit-exact vs scalar at len={}",
len
);
}
}
#[test]
fn schur_panel_minus_nofma_strided_is_bit_exact_vs_rank1_reference() {
let mut rng = Xorshift64::new(0x517E_3D5C_4242_5042);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len_sweep: &[usize] = &[1, 3, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 256, 257];
for &n_elim in &n_elim_sweep {
for &len in len_sweep {
let col_stride = len + 7 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alphas: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst_ref = dst_init.clone();
for q in 0..n_elim {
let alpha = alphas[q];
if alpha == 0.0 {
continue;
}
let col_off = q * col_stride;
let src_q = &src_block[col_off..col_off + len];
axpy_minus_unroll4_nofma(&mut dst_ref, src_q, alpha);
}
let mut dst_kernel = dst_init.clone();
schur_panel_minus_nofma_strided(
&mut dst_kernel,
&src_block,
0,
n_elim,
col_stride,
0,
len,
&alphas,
);
assert_eq!(
dst_kernel, dst_ref,
"rank-{} accumulator must be bit-exact vs n_elim*rank-1 \
at len={}, col_stride={}",
n_elim, len, col_stride
);
}
}
}
#[test]
fn schur_panel_minus_fma_strided_is_bit_exact_vs_rank1_fma_reference() {
let mut rng = Xorshift64::new(0xF0AA_5C5C_BEEF_F00Du64);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len_sweep: &[usize] = &[1, 3, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 256, 257];
for &n_elim in &n_elim_sweep {
for &len in len_sweep {
let col_stride = len + 7 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alphas: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst_ref = dst_init.clone();
for q in 0..n_elim {
let alpha = alphas[q];
if alpha == 0.0 {
continue;
}
let col_off = q * col_stride;
let src_q = &src_block[col_off..col_off + len];
axpy_minus_unroll4(&mut dst_ref, src_q, alpha);
}
let mut dst_kernel = dst_init.clone();
schur_panel_minus_fma_strided(
&mut dst_kernel,
&src_block,
0,
n_elim,
col_stride,
0,
len,
&alphas,
);
assert_eq!(
dst_kernel, dst_ref,
"FMA rank-{} accumulator must be bit-exact vs n_elim*rank-1-FMA \
at len={}, col_stride={}",
n_elim, len, col_stride
);
}
}
}
#[test]
fn schur_panel_minus_fma_strided_skips_zero_alphas() {
let mut rng = Xorshift64::new(0xABCD_5050_AAAA_BBBBu64);
let n_elim = 4;
let len = 17;
let col_stride = len + 5;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alphas = vec![0.5, 0.0, -0.25, 0.0];
let mut dst_ref = dst_init.clone();
for q in 0..n_elim {
let alpha = alphas[q];
if alpha == 0.0 {
continue;
}
let col_off = q * col_stride;
let src_q = &src_block[col_off..col_off + len];
axpy_minus_unroll4(&mut dst_ref, src_q, alpha);
}
let mut dst_kernel = dst_init.clone();
schur_panel_minus_fma_strided(
&mut dst_kernel,
&src_block,
0,
n_elim,
col_stride,
0,
len,
&alphas,
);
assert_eq!(dst_kernel, dst_ref);
}
#[test]
fn schur_panel_minus_nofma_strided_skips_zero_alphas() {
let mut rng = Xorshift64::new(0xABCD_5050_0042u64);
let n_elim = 4;
let len = 17;
let col_stride = len + 5;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alphas = vec![0.5, 0.0, -0.25, 0.0];
let mut dst_ref = dst_init.clone();
for q in 0..n_elim {
let alpha = alphas[q];
if alpha == 0.0 {
continue;
}
let col_off = q * col_stride;
let src_q = &src_block[col_off..col_off + len];
axpy_minus_unroll4_nofma(&mut dst_ref, src_q, alpha);
}
let mut dst_kernel = dst_init.clone();
schur_panel_minus_nofma_strided(
&mut dst_kernel,
&src_block,
0,
n_elim,
col_stride,
0,
len,
&alphas,
);
assert_eq!(dst_kernel, dst_ref);
}
#[test]
fn schur_panel_minus_nofma_strided_dual_is_bit_exact_vs_two_singles() {
let mut rng = Xorshift64::new(0xD5A1_C0F1_DEAD_BEEFu64);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len0_sweep: &[usize] = &[
1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 257,
];
for &n_elim in &n_elim_sweep {
for &len0 in len0_sweep {
let src_row_offset = 3usize;
let col_stride = src_row_offset + len0 + 5 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let len1 = len0 - 1;
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas0: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
schur_panel_minus_nofma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
if len1 > 0 {
schur_panel_minus_nofma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
}
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
schur_panel_minus_nofma_strided_dual(
&mut dst0_kernel,
&mut dst1_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
assert_eq!(
dst0_kernel, dst0_ref,
"dst0 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst1_kernel, dst1_ref,
"dst1 mismatch at n_elim={}, len0={}",
n_elim, len0
);
}
}
}
#[test]
fn schur_panel_minus_nofma_strided_dual_skips_zero_alphas_independently() {
let mut rng = Xorshift64::new(0xFEED_CAFE_0042_BABEu64);
let n_elim = 5;
let len0 = 33;
let len1 = len0 - 1;
let src_row_offset = 2usize;
let col_stride = src_row_offset + len0 + 4 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas0 = vec![0.5, 0.0, 0.0, -0.25, 0.75];
let alphas1 = vec![0.0, 0.4, 0.0, 0.6, 0.0];
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
schur_panel_minus_nofma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
schur_panel_minus_nofma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
schur_panel_minus_nofma_strided_dual(
&mut dst0_kernel,
&mut dst1_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
assert_eq!(dst0_kernel, dst0_ref);
assert_eq!(dst1_kernel, dst1_ref);
}
#[test]
fn schur_panel_minus_nofma_strided_quad_is_bit_exact_vs_four_singles() {
let mut rng = Xorshift64::new(0x9A77_E11E_2026_0512_u64);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len0_sweep: &[usize] = &[
3, 4, 5, 6, 7, 8, 9, 10, 15, 16, 17, 18, 19, 31, 32, 33, 63, 64, 65, 127, 128, 257,
];
for &n_elim in &n_elim_sweep {
for &len0 in len0_sweep {
let src_row_offset = 3usize;
let col_stride = src_row_offset + len0 + 5 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let len1 = len0 - 1;
let len2 = len0 - 2;
let len3 = len0 - 3;
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let dst2_init: Vec<f64> = (0..len2).map(|_| rng.next_f64()).collect();
let dst3_init: Vec<f64> = (0..len3).map(|_| rng.next_f64()).collect();
let alphas0: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas2: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas3: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
let mut dst2_ref = dst2_init.clone();
let mut dst3_ref = dst3_init.clone();
schur_panel_minus_nofma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
if len1 > 0 {
schur_panel_minus_nofma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
}
if len2 > 0 {
schur_panel_minus_nofma_strided(
&mut dst2_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 2,
len2,
&alphas2,
);
}
if len3 > 0 {
schur_panel_minus_nofma_strided(
&mut dst3_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 3,
len3,
&alphas3,
);
}
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
let mut dst2_kernel = dst2_init.clone();
let mut dst3_kernel = dst3_init.clone();
schur_panel_minus_nofma_strided_quad(
&mut dst0_kernel,
&mut dst1_kernel,
&mut dst2_kernel,
&mut dst3_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
&alphas2,
&alphas3,
);
assert_eq!(
dst0_kernel, dst0_ref,
"dst0 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst1_kernel, dst1_ref,
"dst1 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst2_kernel, dst2_ref,
"dst2 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst3_kernel, dst3_ref,
"dst3 mismatch at n_elim={}, len0={}",
n_elim, len0
);
}
}
}
#[test]
fn schur_panel_minus_nofma_strided_quad_skips_zero_alphas_independently() {
let mut rng = Xorshift64::new(0xCAFE_F00D_2026_0512_u64);
let n_elim = 6;
let len0 = 67;
let len1 = len0 - 1;
let len2 = len0 - 2;
let len3 = len0 - 3;
let src_row_offset = 2usize;
let col_stride = src_row_offset + len0 + 4 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let dst2_init: Vec<f64> = (0..len2).map(|_| rng.next_f64()).collect();
let dst3_init: Vec<f64> = (0..len3).map(|_| rng.next_f64()).collect();
let alphas0 = vec![0.0, 0.5, 0.0, -0.25, 0.75, 1.1];
let alphas1 = vec![0.4, 0.0, 0.0, 0.6, 0.3, -0.9];
let alphas2 = vec![-0.5, 0.0, 0.0, 0.2, 0.0, 0.7];
let alphas3 = vec![0.8, -0.3, 0.0, 0.0, 0.55, 0.4];
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
let mut dst2_ref = dst2_init.clone();
let mut dst3_ref = dst3_init.clone();
schur_panel_minus_nofma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
schur_panel_minus_nofma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
schur_panel_minus_nofma_strided(
&mut dst2_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 2,
len2,
&alphas2,
);
schur_panel_minus_nofma_strided(
&mut dst3_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 3,
len3,
&alphas3,
);
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
let mut dst2_kernel = dst2_init.clone();
let mut dst3_kernel = dst3_init.clone();
schur_panel_minus_nofma_strided_quad(
&mut dst0_kernel,
&mut dst1_kernel,
&mut dst2_kernel,
&mut dst3_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
&alphas2,
&alphas3,
);
assert_eq!(dst0_kernel, dst0_ref);
assert_eq!(dst1_kernel, dst1_ref);
assert_eq!(dst2_kernel, dst2_ref);
assert_eq!(dst3_kernel, dst3_ref);
}
#[test]
fn schur_panel_minus_fma_strided_dual_is_bit_exact_vs_two_fma_singles() {
let mut rng = Xorshift64::new(0x1234_5678_FACE_FEED_u64);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len0_sweep: &[usize] = &[
1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 257,
];
for &n_elim in &n_elim_sweep {
for &len0 in len0_sweep {
let src_row_offset = 3usize;
let col_stride = src_row_offset + len0 + 5 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let len1 = len0 - 1;
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas0: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
schur_panel_minus_fma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
if len1 > 0 {
schur_panel_minus_fma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
}
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
schur_panel_minus_fma_strided_dual(
&mut dst0_kernel,
&mut dst1_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
assert_eq!(
dst0_kernel, dst0_ref,
"FMA dst0 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst1_kernel, dst1_ref,
"FMA dst1 mismatch at n_elim={}, len0={}",
n_elim, len0
);
}
}
}
#[test]
fn schur_panel_minus_fma_strided_dual_skips_zero_alphas_independently() {
let mut rng = Xorshift64::new(0xABCD_1234_F00D_BABE_u64);
let n_elim = 5;
let len0 = 33;
let len1 = len0 - 1;
let src_row_offset = 2usize;
let col_stride = src_row_offset + len0 + 4 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas0 = vec![0.5, 0.0, 0.0, -0.25, 0.75];
let alphas1 = vec![0.0, 0.4, 0.0, 0.6, 0.0];
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
schur_panel_minus_fma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
schur_panel_minus_fma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
schur_panel_minus_fma_strided_dual(
&mut dst0_kernel,
&mut dst1_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
assert_eq!(dst0_kernel, dst0_ref);
assert_eq!(dst1_kernel, dst1_ref);
}
#[test]
fn schur_panel_minus_fma_strided_quad_is_bit_exact_vs_four_fma_singles() {
let mut rng = Xorshift64::new(0xDEAD_BEEF_2026_0513_u64);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len0_sweep: &[usize] = &[
3, 4, 5, 6, 7, 8, 9, 10, 15, 16, 17, 18, 19, 31, 32, 33, 63, 64, 65, 127, 128, 257,
];
for &n_elim in &n_elim_sweep {
for &len0 in len0_sweep {
let src_row_offset = 3usize;
let col_stride = src_row_offset + len0 + 5 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let len1 = len0 - 1;
let len2 = len0 - 2;
let len3 = len0 - 3;
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let dst2_init: Vec<f64> = (0..len2).map(|_| rng.next_f64()).collect();
let dst3_init: Vec<f64> = (0..len3).map(|_| rng.next_f64()).collect();
let alphas0: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas2: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas3: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
let mut dst2_ref = dst2_init.clone();
let mut dst3_ref = dst3_init.clone();
schur_panel_minus_fma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
if len1 > 0 {
schur_panel_minus_fma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
}
if len2 > 0 {
schur_panel_minus_fma_strided(
&mut dst2_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 2,
len2,
&alphas2,
);
}
if len3 > 0 {
schur_panel_minus_fma_strided(
&mut dst3_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 3,
len3,
&alphas3,
);
}
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
let mut dst2_kernel = dst2_init.clone();
let mut dst3_kernel = dst3_init.clone();
schur_panel_minus_fma_strided_quad(
&mut dst0_kernel,
&mut dst1_kernel,
&mut dst2_kernel,
&mut dst3_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
&alphas2,
&alphas3,
);
assert_eq!(
dst0_kernel, dst0_ref,
"FMA dst0 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst1_kernel, dst1_ref,
"FMA dst1 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst2_kernel, dst2_ref,
"FMA dst2 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst3_kernel, dst3_ref,
"FMA dst3 mismatch at n_elim={}, len0={}",
n_elim, len0
);
}
}
}
#[test]
fn schur_panel_minus_fma_strided_quad_skips_zero_alphas_independently() {
let mut rng = Xorshift64::new(0xF00D_FEED_2026_0513_u64);
let n_elim = 6;
let len0 = 67;
let len1 = len0 - 1;
let len2 = len0 - 2;
let len3 = len0 - 3;
let src_row_offset = 2usize;
let col_stride = src_row_offset + len0 + 4 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let dst2_init: Vec<f64> = (0..len2).map(|_| rng.next_f64()).collect();
let dst3_init: Vec<f64> = (0..len3).map(|_| rng.next_f64()).collect();
let alphas0 = vec![0.0, 0.5, 0.0, -0.25, 0.75, 1.1];
let alphas1 = vec![0.4, 0.0, 0.0, 0.6, 0.3, -0.9];
let alphas2 = vec![-0.5, 0.0, 0.0, 0.2, 0.0, 0.7];
let alphas3 = vec![0.8, -0.3, 0.0, 0.0, 0.55, 0.4];
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
let mut dst2_ref = dst2_init.clone();
let mut dst3_ref = dst3_init.clone();
schur_panel_minus_fma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
schur_panel_minus_fma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
schur_panel_minus_fma_strided(
&mut dst2_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 2,
len2,
&alphas2,
);
schur_panel_minus_fma_strided(
&mut dst3_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 3,
len3,
&alphas3,
);
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
let mut dst2_kernel = dst2_init.clone();
let mut dst3_kernel = dst3_init.clone();
schur_panel_minus_fma_strided_quad(
&mut dst0_kernel,
&mut dst1_kernel,
&mut dst2_kernel,
&mut dst3_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
&alphas2,
&alphas3,
);
assert_eq!(dst0_kernel, dst0_ref);
assert_eq!(dst1_kernel, dst1_ref);
assert_eq!(dst2_kernel, dst2_ref);
assert_eq!(dst3_kernel, dst3_ref);
}
#[test]
fn fma_vs_nofma_panel_kernels_within_n_elim_ulps() {
let mut rng = Xorshift64::new(0xC0DE_BABE_A11A_FEED_u64);
let n_elim_sweep = [1usize, 4, 8, 16, 32];
let len0_sweep: &[usize] = &[1, 4, 8, 17, 32, 65, 257];
let close = |fma: &[f64], nofma: &[f64], n_elim: usize, label: &str| {
assert_eq!(fma.len(), nofma.len(), "length mismatch in {}", label);
let tol_factor = (n_elim as f64) * f64::EPSILON;
for i in 0..fma.len() {
let diff = (fma[i] - nofma[i]).abs();
let scale = fma[i].abs().max(nofma[i].abs()).max(1.0);
let tol = tol_factor * scale * 4.0;
assert!(
diff <= tol,
"{}: idx {}: fma={} nofma={} diff={:.3e} > tol={:.3e} (n_elim={})",
label,
i,
fma[i],
nofma[i],
diff,
tol,
n_elim
);
}
};
for &n_elim in &n_elim_sweep {
for &len0 in len0_sweep {
let src_row_offset = 3usize;
let col_stride = src_row_offset + len0 + 5 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let alphas0: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst_fma = dst0_init.clone();
let mut dst_nofma = dst0_init.clone();
schur_panel_minus_fma_strided(
&mut dst_fma,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
schur_panel_minus_nofma_strided(
&mut dst_nofma,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
close(&dst_fma, &dst_nofma, n_elim, "strided");
if len0 >= 1 {
let len1 = len0 - 1;
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut d0_fma = dst0_init.clone();
let mut d1_fma = dst1_init.clone();
let mut d0_nofma = dst0_init.clone();
let mut d1_nofma = dst1_init.clone();
schur_panel_minus_fma_strided_dual(
&mut d0_fma,
&mut d1_fma,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
schur_panel_minus_nofma_strided_dual(
&mut d0_nofma,
&mut d1_nofma,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
close(&d0_fma, &d0_nofma, n_elim, "dual.dst0");
close(&d1_fma, &d1_nofma, n_elim, "dual.dst1");
}
if len0 >= 1 {
let len1 = len0 - 1;
let len2 = len0.saturating_sub(2);
let len3 = len0.saturating_sub(3);
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let dst2_init: Vec<f64> = (0..len2).map(|_| rng.next_f64()).collect();
let dst3_init: Vec<f64> = (0..len3).map(|_| rng.next_f64()).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas2: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas3: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
if len0 >= 1 && len1 + 1 == len0 && len2 + 2 == len0 && len3 + 3 == len0 {
let mut d0_fma = dst0_init.clone();
let mut d1_fma = dst1_init.clone();
let mut d2_fma = dst2_init.clone();
let mut d3_fma = dst3_init.clone();
let mut d0_nofma = dst0_init.clone();
let mut d1_nofma = dst1_init.clone();
let mut d2_nofma = dst2_init.clone();
let mut d3_nofma = dst3_init.clone();
schur_panel_minus_fma_strided_quad(
&mut d0_fma,
&mut d1_fma,
&mut d2_fma,
&mut d3_fma,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
&alphas2,
&alphas3,
);
schur_panel_minus_nofma_strided_quad(
&mut d0_nofma,
&mut d1_nofma,
&mut d2_nofma,
&mut d3_nofma,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
&alphas2,
&alphas3,
);
close(&d0_fma, &d0_nofma, n_elim, "quad.dst0");
close(&d1_fma, &d1_nofma, n_elim, "quad.dst1");
close(&d2_fma, &d2_nofma, n_elim, "quad.dst2");
close(&d3_fma, &d3_nofma, n_elim, "quad.dst3");
}
}
let src: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut d_fma = dst0_init.clone();
let mut d_nofma = dst0_init.clone();
axpy_minus_unroll4(&mut d_fma, &src, alpha);
axpy_minus_unroll4_nofma(&mut d_nofma, &src, alpha);
close(&d_fma, &d_nofma, 1, "axpy_minus_unroll4");
let src_b: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let alpha_b = rng.next_f64() * 1.5;
let mut d2_fma = dst0_init.clone();
let mut d2_nofma = dst0_init.clone();
axpy2_minus_unroll4(&mut d2_fma, &src, alpha, &src_b, alpha_b);
axpy2_minus_unroll4_nofma(&mut d2_nofma, &src, alpha, &src_b, alpha_b);
close(&d2_fma, &d2_nofma, 2, "axpy2_minus_unroll4");
}
}
}
#[test]
fn axpy2_minus_unroll4_nofma_is_bit_exact_vs_scalar() {
let mut rng = Xorshift64::new(0xB17E_AC70_BAAD_F00D_u64);
for &len in LENGTH_SWEEP {
let src0: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let src1: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha0 = rng.next_f64() * 1.5;
let alpha1 = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy2_minus_unroll4_nofma(&mut dst_kernel, &src0, alpha0, &src1, alpha1);
naive_axpy2_minus(&mut dst_ref, &src0, alpha0, &src1, alpha1);
assert_eq!(
dst_kernel, dst_ref,
"non-FMA unroll4 must be bit-exact vs scalar at len={}",
len
);
}
}
}