#[allow(dead_code)]
pub fn axpy_minus(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus: dst and src length mismatch"
);
struct K<'a> {
neg_alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha,
src,
dst,
} = self;
let neg_a = simd.splat_f64s(neg_alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
for (d, s) in dst_body.iter_mut().zip(src_body) {
*d = simd.mul_add_f64s(neg_a, *s, *d);
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.mul_add_f64s(neg_a, s, d));
}
}
}
pulp::Arch::new().dispatch(K {
neg_alpha: -alpha,
src,
dst,
});
}
#[allow(dead_code)]
pub fn axpy2_minus(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus: dst and src1 length mismatch"
);
struct K<'a> {
neg_alpha0: f64,
neg_alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha0,
neg_alpha1,
src0,
src1,
dst,
} = self;
let na0 = simd.splat_f64s(neg_alpha0);
let na1 = simd.splat_f64s(neg_alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
for ((d, s0), s1) in d_body.iter_mut().zip(s0_body).zip(s1_body) {
let tmp = simd.mul_add_f64s(na0, *s0, *d);
*d = simd.mul_add_f64s(na1, *s1, tmp);
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let tmp = simd.mul_add_f64s(na0, s0v, dv);
let r = simd.mul_add_f64s(na1, s1v, tmp);
simd.partial_store_f64s(d_tail, r);
}
}
}
pulp::Arch::new().dispatch(K {
neg_alpha0: -alpha0,
neg_alpha1: -alpha1,
src0,
src1,
dst,
});
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
pub fn axpy_minus_direct(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus_direct: dst and src length mismatch"
);
struct K<'a> {
neg_alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha,
src,
dst,
} = self;
let neg_a = simd.splat_f64s(neg_alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
for (d, s) in dst_body.iter_mut().zip(src_body) {
*d = simd.mul_add_f64s(neg_a, *s, *d);
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.mul_add_f64s(neg_a, s, d));
}
}
}
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
use pulp::WithSimd;
K {
neg_alpha: -alpha,
src,
dst,
}
.with_simd(NEON);
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
pub fn axpy2_minus_direct(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus_direct: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus_direct: dst and src1 length mismatch"
);
struct K<'a> {
neg_alpha0: f64,
neg_alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha0,
neg_alpha1,
src0,
src1,
dst,
} = self;
let na0 = simd.splat_f64s(neg_alpha0);
let na1 = simd.splat_f64s(neg_alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
for ((d, s0), s1) in d_body.iter_mut().zip(s0_body).zip(s1_body) {
let tmp = simd.mul_add_f64s(na0, *s0, *d);
*d = simd.mul_add_f64s(na1, *s1, tmp);
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let tmp = simd.mul_add_f64s(na0, s0v, dv);
let r = simd.mul_add_f64s(na1, s1v, tmp);
simd.partial_store_f64s(d_tail, r);
}
}
}
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
use pulp::WithSimd;
K {
neg_alpha0: -alpha0,
neg_alpha1: -alpha1,
src0,
src1,
dst,
}
.with_simd(NEON);
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
pub fn axpy_minus_unroll4(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus_unroll4: dst and src length mismatch"
);
struct K<'a> {
neg_alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha,
src,
dst,
} = self;
let neg_a = simd.splat_f64s(neg_alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = dst_body.chunks_exact_mut(4);
let mut s_chunks = src_body.chunks_exact(4);
for (dc, sc) in (&mut d_chunks).zip(&mut s_chunks) {
let r0 = simd.mul_add_f64s(neg_a, sc[0], dc[0]);
let r1 = simd.mul_add_f64s(neg_a, sc[1], dc[1]);
let r2 = simd.mul_add_f64s(neg_a, sc[2], dc[2]);
let r3 = simd.mul_add_f64s(neg_a, sc[3], dc[3]);
dc[0] = r0;
dc[1] = r1;
dc[2] = r2;
dc[3] = r3;
}
let d_rem = d_chunks.into_remainder();
let s_rem = s_chunks.remainder();
for (d, s) in d_rem.iter_mut().zip(s_rem) {
*d = simd.mul_add_f64s(neg_a, *s, *d);
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.mul_add_f64s(neg_a, s, d));
}
}
}
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
use pulp::WithSimd;
K {
neg_alpha: -alpha,
src,
dst,
}
.with_simd(NEON);
}
#[cfg(target_arch = "aarch64")]
#[allow(dead_code)]
pub fn axpy2_minus_unroll4(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus_unroll4: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus_unroll4: dst and src1 length mismatch"
);
struct K<'a> {
neg_alpha0: f64,
neg_alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
neg_alpha0,
neg_alpha1,
src0,
src1,
dst,
} = self;
let na0 = simd.splat_f64s(neg_alpha0);
let na1 = simd.splat_f64s(neg_alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = d_body.chunks_exact_mut(4);
let mut s0_chunks = s0_body.chunks_exact(4);
let mut s1_chunks = s1_body.chunks_exact(4);
for ((dc, s0c), s1c) in (&mut d_chunks).zip(&mut s0_chunks).zip(&mut s1_chunks) {
let t0 = simd.mul_add_f64s(na0, s0c[0], dc[0]);
let t1 = simd.mul_add_f64s(na0, s0c[1], dc[1]);
let t2 = simd.mul_add_f64s(na0, s0c[2], dc[2]);
let t3 = simd.mul_add_f64s(na0, s0c[3], dc[3]);
let r0 = simd.mul_add_f64s(na1, s1c[0], t0);
let r1 = simd.mul_add_f64s(na1, s1c[1], t1);
let r2 = simd.mul_add_f64s(na1, s1c[2], t2);
let r3 = simd.mul_add_f64s(na1, s1c[3], t3);
dc[0] = r0;
dc[1] = r1;
dc[2] = r2;
dc[3] = r3;
}
let d_rem = d_chunks.into_remainder();
let s0_rem = s0_chunks.remainder();
let s1_rem = s1_chunks.remainder();
for ((d, s0), s1) in d_rem.iter_mut().zip(s0_rem).zip(s1_rem) {
let tmp = simd.mul_add_f64s(na0, *s0, *d);
*d = simd.mul_add_f64s(na1, *s1, tmp);
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let tmp = simd.mul_add_f64s(na0, s0v, dv);
let r = simd.mul_add_f64s(na1, s1v, tmp);
simd.partial_store_f64s(d_tail, r);
}
}
}
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
use pulp::WithSimd;
K {
neg_alpha0: -alpha0,
neg_alpha1: -alpha1,
src0,
src1,
dst,
}
.with_simd(NEON);
}
#[inline(always)]
fn dispatch_nofma<K: pulp::WithSimd>(k: K) -> K::Output {
#[cfg(target_arch = "aarch64")]
{
const NEON: pulp::aarch64::Neon = unsafe { pulp::aarch64::Neon::new_unchecked() };
k.with_simd(NEON)
}
#[cfg(target_arch = "x86_64")]
{
match pulp::x86::V3::try_new() {
Some(v3) => k.with_simd(v3),
None => pulp::Arch::new().dispatch(k),
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
pulp::Arch::new().dispatch(k)
}
}
#[allow(dead_code)]
pub fn axpy_minus_unroll4_nofma(dst: &mut [f64], src: &[f64], alpha: f64) {
assert_eq!(
dst.len(),
src.len(),
"axpy_minus_unroll4_nofma: dst and src length mismatch"
);
struct K<'a> {
alpha: f64,
src: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self { alpha, src, dst } = self;
let a = simd.splat_f64s(alpha);
let (src_body, src_tail) = S::as_simd_f64s(src);
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = dst_body.chunks_exact_mut(4);
let mut s_chunks = src_body.chunks_exact(4);
for (dc, sc) in (&mut d_chunks).zip(&mut s_chunks) {
let m0 = simd.mul_f64s(a, sc[0]);
let m1 = simd.mul_f64s(a, sc[1]);
let m2 = simd.mul_f64s(a, sc[2]);
let m3 = simd.mul_f64s(a, sc[3]);
let r0 = simd.sub_f64s(dc[0], m0);
let r1 = simd.sub_f64s(dc[1], m1);
let r2 = simd.sub_f64s(dc[2], m2);
let r3 = simd.sub_f64s(dc[3], m3);
dc[0] = r0;
dc[1] = r1;
dc[2] = r2;
dc[3] = r3;
}
let d_rem = d_chunks.into_remainder();
let s_rem = s_chunks.remainder();
for (d, s) in d_rem.iter_mut().zip(s_rem) {
*d = simd.sub_f64s(*d, simd.mul_f64s(a, *s));
}
if !src_tail.is_empty() {
let s = simd.partial_load_f64s(src_tail);
let d = simd.partial_load_f64s(dst_tail);
simd.partial_store_f64s(dst_tail, simd.sub_f64s(d, simd.mul_f64s(a, s)));
}
}
}
dispatch_nofma(K { alpha, src, dst });
}
#[allow(dead_code)]
pub fn axpy2_minus_unroll4_nofma(
dst: &mut [f64],
src0: &[f64],
alpha0: f64,
src1: &[f64],
alpha1: f64,
) {
assert_eq!(
dst.len(),
src0.len(),
"axpy2_minus_unroll4_nofma: dst and src0 length mismatch"
);
assert_eq!(
dst.len(),
src1.len(),
"axpy2_minus_unroll4_nofma: dst and src1 length mismatch"
);
struct K<'a> {
alpha0: f64,
alpha1: f64,
src0: &'a [f64],
src1: &'a [f64],
dst: &'a mut [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
alpha0,
alpha1,
src0,
src1,
dst,
} = self;
let a0 = simd.splat_f64s(alpha0);
let a1 = simd.splat_f64s(alpha1);
let (s0_body, s0_tail) = S::as_simd_f64s(src0);
let (s1_body, s1_tail) = S::as_simd_f64s(src1);
let (d_body, d_tail) = S::as_mut_simd_f64s(dst);
let mut d_chunks = d_body.chunks_exact_mut(4);
let mut s0_chunks = s0_body.chunks_exact(4);
let mut s1_chunks = s1_body.chunks_exact(4);
for ((dc, s0c), s1c) in (&mut d_chunks).zip(&mut s0_chunks).zip(&mut s1_chunks) {
let m00 = simd.mul_f64s(a0, s0c[0]);
let m01 = simd.mul_f64s(a0, s0c[1]);
let m02 = simd.mul_f64s(a0, s0c[2]);
let m03 = simd.mul_f64s(a0, s0c[3]);
let m10 = simd.mul_f64s(a1, s1c[0]);
let m11 = simd.mul_f64s(a1, s1c[1]);
let m12 = simd.mul_f64s(a1, s1c[2]);
let m13 = simd.mul_f64s(a1, s1c[3]);
let t0 = simd.add_f64s(m00, m10);
let t1 = simd.add_f64s(m01, m11);
let t2 = simd.add_f64s(m02, m12);
let t3 = simd.add_f64s(m03, m13);
dc[0] = simd.sub_f64s(dc[0], t0);
dc[1] = simd.sub_f64s(dc[1], t1);
dc[2] = simd.sub_f64s(dc[2], t2);
dc[3] = simd.sub_f64s(dc[3], t3);
}
let d_rem = d_chunks.into_remainder();
let s0_rem = s0_chunks.remainder();
let s1_rem = s1_chunks.remainder();
for ((d, s0), s1) in d_rem.iter_mut().zip(s0_rem).zip(s1_rem) {
let m0 = simd.mul_f64s(a0, *s0);
let m1 = simd.mul_f64s(a1, *s1);
*d = simd.sub_f64s(*d, simd.add_f64s(m0, m1));
}
if !s0_tail.is_empty() {
let s0v = simd.partial_load_f64s(s0_tail);
let s1v = simd.partial_load_f64s(s1_tail);
let dv = simd.partial_load_f64s(d_tail);
let m0 = simd.mul_f64s(a0, s0v);
let m1 = simd.mul_f64s(a1, s1v);
let r = simd.sub_f64s(dv, simd.add_f64s(m0, m1));
simd.partial_store_f64s(d_tail, r);
}
}
}
dispatch_nofma(K {
alpha0,
alpha1,
src0,
src1,
dst,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_nofma_strided(
dst: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
len: usize,
alphas: &[f64],
) {
assert_eq!(
dst.len(),
len,
"schur_panel_minus_nofma_strided: dst.len() must equal len"
);
assert_eq!(
alphas.len(),
n_elim,
"schur_panel_minus_nofma_strided: alphas.len() must equal n_elim"
);
if n_elim == 0 || len == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_nofma_strided: src_block too short ({} < {})",
src_block.len(),
max_idx
);
struct K<'a> {
dst: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
len: usize,
alphas: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset,
len,
alphas,
} = self;
let (dst_body, dst_tail) = S::as_mut_simd_f64s(dst);
let body_len = dst_body.len();
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 4;
for chunk_idx in 0..chunks {
let base = chunk_idx * 4;
let mut a0 = dst_body[base];
let mut a1 = dst_body[base + 1];
let mut a2 = dst_body[base + 2];
let mut a3 = dst_body[base + 3];
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let av = simd.splat_f64s(alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let m0 = simd.mul_f64s(av, sb[base]);
let m1 = simd.mul_f64s(av, sb[base + 1]);
let m2 = simd.mul_f64s(av, sb[base + 2]);
let m3 = simd.mul_f64s(av, sb[base + 3]);
a0 = simd.sub_f64s(a0, m0);
a1 = simd.sub_f64s(a1, m1);
a2 = simd.sub_f64s(a2, m2);
a3 = simd.sub_f64s(a3, m3);
}
dst_body[base] = a0;
dst_body[base + 1] = a1;
dst_body[base + 2] = a2;
dst_body[base + 3] = a3;
}
let tail_chunks_start = chunks * 4;
for body_idx in tail_chunks_start..body_len {
let mut acc = dst_body[body_idx];
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let av = simd.splat_f64s(alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let m = simd.mul_f64s(av, sb[body_idx]);
acc = simd.sub_f64s(acc, m);
}
dst_body[body_idx] = acc;
}
if !dst_tail.is_empty() {
let mut acc = simd.partial_load_f64s(dst_tail);
for q in 0..n_elim {
let alpha_q = alphas[q];
if alpha_q == 0.0 {
continue;
}
let av = simd.splat_f64s(alpha_q);
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
let m = simd.mul_f64s(av, s);
acc = simd.sub_f64s(acc, m);
}
simd.partial_store_f64s(dst_tail, acc);
}
}
}
dispatch_nofma(K {
dst,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset,
len,
alphas,
});
}
#[allow(dead_code, clippy::too_many_arguments)]
pub fn schur_panel_minus_nofma_strided_dual(
dst0: &mut [f64],
dst1: &mut [f64],
src_block: &[f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset: usize,
alphas0: &[f64],
alphas1: &[f64],
) {
let len0 = dst0.len();
let len1 = dst1.len();
assert_eq!(
len1 + 1,
len0,
"schur_panel_minus_nofma_strided_dual: dst1 must be exactly one shorter than dst0 \
(len0={}, len1={})",
len0,
len1
);
assert_eq!(
alphas0.len(),
n_elim,
"schur_panel_minus_nofma_strided_dual: alphas0.len() must equal n_elim"
);
assert_eq!(
alphas1.len(),
n_elim,
"schur_panel_minus_nofma_strided_dual: alphas1.len() must equal n_elim"
);
if n_elim == 0 || len0 == 0 {
return;
}
let last_q = n_elim - 1;
let max_idx = (src_first_col + last_q) * col_stride + src_row_offset + len0;
assert!(
src_block.len() >= max_idx,
"schur_panel_minus_nofma_strided_dual: src_block too short ({} < {})",
src_block.len(),
max_idx
);
for (q, &alpha_q) in alphas0.iter().enumerate().take(n_elim) {
if alpha_q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset;
let s = src_block[col_off];
dst0[0] -= alpha_q * s;
}
if len1 == 0 {
return;
}
struct K<'a> {
dst0_bulk: &'a mut [f64],
dst1: &'a mut [f64],
src_block: &'a [f64],
src_first_col: usize,
n_elim: usize,
col_stride: usize,
src_row_offset_bulk: usize,
len: usize,
alphas0: &'a [f64],
alphas1: &'a [f64],
}
impl pulp::WithSimd for K<'_> {
type Output = ();
#[allow(clippy::needless_range_loop)]
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) {
let Self {
dst0_bulk,
dst1,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk,
len,
alphas0,
alphas1,
} = self;
let (d0_body, d0_tail) = S::as_mut_simd_f64s(dst0_bulk);
let (d1_body, d1_tail) = S::as_mut_simd_f64s(dst1);
let body_len = d0_body.len();
debug_assert_eq!(body_len, d1_body.len());
let tail_off = body_len * S::F64_LANES;
let chunks = body_len / 4;
for chunk_idx in 0..chunks {
let base = chunk_idx * 4;
let mut a00 = d0_body[base];
let mut a01 = d0_body[base + 1];
let mut a02 = d0_body[base + 2];
let mut a03 = d0_body[base + 3];
let mut a10 = d1_body[base];
let mut a11 = d1_body[base + 1];
let mut a12 = d1_body[base + 2];
let mut a13 = d1_body[base + 3];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s0v = sb[base];
let s1v = sb[base + 1];
let s2v = sb[base + 2];
let s3v = sb[base + 3];
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
a00 = simd.sub_f64s(a00, simd.mul_f64s(av0, s0v));
a01 = simd.sub_f64s(a01, simd.mul_f64s(av0, s1v));
a02 = simd.sub_f64s(a02, simd.mul_f64s(av0, s2v));
a03 = simd.sub_f64s(a03, simd.mul_f64s(av0, s3v));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
a10 = simd.sub_f64s(a10, simd.mul_f64s(av1, s0v));
a11 = simd.sub_f64s(a11, simd.mul_f64s(av1, s1v));
a12 = simd.sub_f64s(a12, simd.mul_f64s(av1, s2v));
a13 = simd.sub_f64s(a13, simd.mul_f64s(av1, s3v));
}
}
d0_body[base] = a00;
d0_body[base + 1] = a01;
d0_body[base + 2] = a02;
d0_body[base + 3] = a03;
d1_body[base] = a10;
d1_body[base + 1] = a11;
d1_body[base + 2] = a12;
d1_body[base + 3] = a13;
}
let tail_chunks_start = chunks * 4;
for body_idx in tail_chunks_start..body_len {
let mut acc0 = d0_body[body_idx];
let mut acc1 = d1_body[body_idx];
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let (sb, _st) = S::as_simd_f64s(src_q);
let s = sb[body_idx];
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
acc0 = simd.sub_f64s(acc0, simd.mul_f64s(av0, s));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
acc1 = simd.sub_f64s(acc1, simd.mul_f64s(av1, s));
}
}
d0_body[body_idx] = acc0;
d1_body[body_idx] = acc1;
}
if !d0_tail.is_empty() {
let mut acc0 = simd.partial_load_f64s(d0_tail);
let mut acc1 = simd.partial_load_f64s(d1_tail);
for q in 0..n_elim {
let a0q = alphas0[q];
let a1q = alphas1[q];
if a0q == 0.0 && a1q == 0.0 {
continue;
}
let col_off = (src_first_col + q) * col_stride + src_row_offset_bulk;
let src_q = &src_block[col_off..col_off + len];
let src_q_tail = &src_q[tail_off..];
let s = simd.partial_load_f64s(src_q_tail);
if a0q != 0.0 {
let av0 = simd.splat_f64s(a0q);
acc0 = simd.sub_f64s(acc0, simd.mul_f64s(av0, s));
}
if a1q != 0.0 {
let av1 = simd.splat_f64s(a1q);
acc1 = simd.sub_f64s(acc1, simd.mul_f64s(av1, s));
}
}
simd.partial_store_f64s(d0_tail, acc0);
simd.partial_store_f64s(d1_tail, acc1);
}
}
}
let (dst0_cap, dst0_bulk) = dst0.split_at_mut(1);
let _ = dst0_cap;
dispatch_nofma(K {
dst0_bulk,
dst1,
src_block,
src_first_col,
n_elim,
col_stride,
src_row_offset_bulk: src_row_offset + 1,
len: len1,
alphas0,
alphas1,
});
}
#[cfg(test)]
mod tests {
use super::*;
struct Xorshift64(u64);
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self(if seed == 0 {
0x9E37_79B9_7F4A_7C15
} else {
seed
})
}
fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_f64(&mut self) -> f64 {
let bits = (self.next_u64() >> 12) | 0x3FF0_0000_0000_0000;
let x = f64::from_bits(bits) - 1.0; 2.0 * x - 1.0
}
}
fn naive_axpy_minus(dst: &mut [f64], src: &[f64], alpha: f64) {
for i in 0..dst.len() {
let tmp = alpha * src[i];
dst[i] -= tmp;
}
}
fn naive_axpy2_minus(dst: &mut [f64], src0: &[f64], alpha0: f64, src1: &[f64], alpha1: f64) {
for i in 0..dst.len() {
let t0 = alpha0 * src0[i];
let t1 = alpha1 * src1[i];
dst[i] -= t0 + t1;
}
}
const LENGTH_SWEEP: &[usize] = &[
0, 1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256,
257, 511, 512, 513, 1023, 1024,
];
const ULP4: f64 = 4.0 * f64::EPSILON * 2.0;
fn assert_close(a: &[f64], b: &[f64], tol: f64) {
assert_eq!(a.len(), b.len(), "length mismatch in assert_close");
for i in 0..a.len() {
let diff = (a[i] - b[i]).abs();
assert!(
diff <= tol,
"element {}: {} vs {}, diff {:.3e} > {:.3e}",
i,
a[i],
b[i],
diff,
tol
);
}
}
#[test]
fn axpy_minus_zero_length() {
let mut dst: Vec<f64> = vec![];
let src: Vec<f64> = vec![];
axpy_minus(&mut dst, &src, 1.5);
assert!(dst.is_empty());
}
#[test]
fn axpy_minus_length_one() {
let mut dst = vec![5.0];
let src = vec![2.0];
axpy_minus(&mut dst, &src, 0.5);
assert_eq!(dst[0], 4.0);
}
#[test]
fn axpy_minus_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0xFE27_A100_0042_BEEFu64);
for &len in LENGTH_SWEEP {
let src: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy_minus(&mut dst_kernel, &src, alpha);
naive_axpy_minus(&mut dst_ref, &src, alpha);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[test]
#[should_panic(expected = "length mismatch")]
fn axpy_minus_length_mismatch_panics() {
let mut dst = vec![0.0; 4];
let src = vec![0.0; 3];
axpy_minus(&mut dst, &src, 1.0);
}
#[test]
fn axpy2_minus_zero_length() {
let mut dst: Vec<f64> = vec![];
let src0: Vec<f64> = vec![];
let src1: Vec<f64> = vec![];
axpy2_minus(&mut dst, &src0, 1.0, &src1, 2.0);
assert!(dst.is_empty());
}
#[test]
fn axpy2_minus_length_one() {
let mut dst = vec![10.0];
let src0 = vec![2.0];
let src1 = vec![3.0];
axpy2_minus(&mut dst, &src0, 0.5, &src1, 1.0);
assert_eq!(dst[0], 6.0);
}
#[test]
fn axpy2_minus_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0xC0FF_EE00_BAAD_F00Du64);
for &len in LENGTH_SWEEP {
let src0: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let src1: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha0 = rng.next_f64() * 1.5;
let alpha1 = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy2_minus(&mut dst_kernel, &src0, alpha0, &src1, alpha1);
naive_axpy2_minus(&mut dst_ref, &src0, alpha0, &src1, alpha1);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[test]
fn axpy_minus_alpha_zero_is_noop() {
let src = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let dst_init = vec![-3.0, 0.5, 100.0, -7.25, 1e-10, 1e10, -0.0, 42.0];
let mut dst = dst_init.clone();
axpy_minus(&mut dst, &src, 0.0);
assert_eq!(dst, dst_init);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn axpy_minus_unroll4_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0x4E27_A101_00FE_BEEFu64);
for &len in LENGTH_SWEEP {
let src: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy_minus_unroll4(&mut dst_kernel, &src, alpha);
naive_axpy_minus(&mut dst_ref, &src, alpha);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[cfg(target_arch = "aarch64")]
#[test]
fn axpy2_minus_unroll4_matches_reference_across_length_sweep() {
let mut rng = Xorshift64::new(0xC1FF_EE00_BAAD_F00Du64);
for &len in LENGTH_SWEEP {
let src0: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let src1: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha0 = rng.next_f64() * 1.5;
let alpha1 = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy2_minus_unroll4(&mut dst_kernel, &src0, alpha0, &src1, alpha1);
naive_axpy2_minus(&mut dst_ref, &src0, alpha0, &src1, alpha1);
assert_close(&dst_kernel, &dst_ref, ULP4);
}
}
#[test]
fn axpy2_minus_alphas_zero_is_noop() {
let src0 = vec![1.0, 2.0, 3.0, 4.0];
let src1 = vec![5.0, 6.0, 7.0, 8.0];
let dst_init = vec![-1.0, 2.5, 3.0, -4.5];
let mut dst = dst_init.clone();
axpy2_minus(&mut dst, &src0, 0.0, &src1, 0.0);
assert_eq!(dst, dst_init);
}
#[test]
fn axpy_minus_unroll4_nofma_is_bit_exact_vs_scalar() {
let mut rng = Xorshift64::new(0xB17_EAC70_0042_F00Du64);
for &len in LENGTH_SWEEP {
let src: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy_minus_unroll4_nofma(&mut dst_kernel, &src, alpha);
naive_axpy_minus(&mut dst_ref, &src, alpha);
assert_eq!(
dst_kernel, dst_ref,
"non-FMA unroll4 must be bit-exact vs scalar at len={}",
len
);
}
}
#[test]
fn schur_panel_minus_nofma_strided_is_bit_exact_vs_rank1_reference() {
let mut rng = Xorshift64::new(0x517E_3D5C_4242_5042);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len_sweep: &[usize] = &[1, 3, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 256, 257];
for &n_elim in &n_elim_sweep {
for &len in len_sweep {
let col_stride = len + 7 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alphas: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst_ref = dst_init.clone();
for q in 0..n_elim {
let alpha = alphas[q];
if alpha == 0.0 {
continue;
}
let col_off = q * col_stride;
let src_q = &src_block[col_off..col_off + len];
axpy_minus_unroll4_nofma(&mut dst_ref, src_q, alpha);
}
let mut dst_kernel = dst_init.clone();
schur_panel_minus_nofma_strided(
&mut dst_kernel,
&src_block,
0,
n_elim,
col_stride,
0,
len,
&alphas,
);
assert_eq!(
dst_kernel, dst_ref,
"rank-{} accumulator must be bit-exact vs n_elim*rank-1 \
at len={}, col_stride={}",
n_elim, len, col_stride
);
}
}
}
#[test]
fn schur_panel_minus_nofma_strided_skips_zero_alphas() {
let mut rng = Xorshift64::new(0xABCD_5050_0042u64);
let n_elim = 4;
let len = 17;
let col_stride = len + 5;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alphas = vec![0.5, 0.0, -0.25, 0.0];
let mut dst_ref = dst_init.clone();
for q in 0..n_elim {
let alpha = alphas[q];
if alpha == 0.0 {
continue;
}
let col_off = q * col_stride;
let src_q = &src_block[col_off..col_off + len];
axpy_minus_unroll4_nofma(&mut dst_ref, src_q, alpha);
}
let mut dst_kernel = dst_init.clone();
schur_panel_minus_nofma_strided(
&mut dst_kernel,
&src_block,
0,
n_elim,
col_stride,
0,
len,
&alphas,
);
assert_eq!(dst_kernel, dst_ref);
}
#[test]
fn schur_panel_minus_nofma_strided_dual_is_bit_exact_vs_two_singles() {
let mut rng = Xorshift64::new(0xD5A1_C0F1_DEAD_BEEFu64);
let n_elim_sweep = [1usize, 2, 4, 7, 8, 16, 31, 32];
let len0_sweep: &[usize] = &[
1, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 257,
];
for &n_elim in &n_elim_sweep {
for &len0 in len0_sweep {
let src_row_offset = 3usize;
let col_stride = src_row_offset + len0 + 5 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let len1 = len0 - 1;
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas0: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let alphas1: Vec<f64> = (0..n_elim).map(|_| rng.next_f64() * 1.5).collect();
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
schur_panel_minus_nofma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
if len1 > 0 {
schur_panel_minus_nofma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
}
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
schur_panel_minus_nofma_strided_dual(
&mut dst0_kernel,
&mut dst1_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
assert_eq!(
dst0_kernel, dst0_ref,
"dst0 mismatch at n_elim={}, len0={}",
n_elim, len0
);
assert_eq!(
dst1_kernel, dst1_ref,
"dst1 mismatch at n_elim={}, len0={}",
n_elim, len0
);
}
}
}
#[test]
fn schur_panel_minus_nofma_strided_dual_skips_zero_alphas_independently() {
let mut rng = Xorshift64::new(0xFEED_CAFE_0042_BABEu64);
let n_elim = 5;
let len0 = 33;
let len1 = len0 - 1;
let src_row_offset = 2usize;
let col_stride = src_row_offset + len0 + 4 + n_elim;
let total = n_elim * col_stride;
let src_block: Vec<f64> = (0..total).map(|_| rng.next_f64()).collect();
let dst0_init: Vec<f64> = (0..len0).map(|_| rng.next_f64()).collect();
let dst1_init: Vec<f64> = (0..len1).map(|_| rng.next_f64()).collect();
let alphas0 = vec![0.5, 0.0, 0.0, -0.25, 0.75];
let alphas1 = vec![0.0, 0.4, 0.0, 0.6, 0.0];
let mut dst0_ref = dst0_init.clone();
let mut dst1_ref = dst1_init.clone();
schur_panel_minus_nofma_strided(
&mut dst0_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
len0,
&alphas0,
);
schur_panel_minus_nofma_strided(
&mut dst1_ref,
&src_block,
0,
n_elim,
col_stride,
src_row_offset + 1,
len1,
&alphas1,
);
let mut dst0_kernel = dst0_init.clone();
let mut dst1_kernel = dst1_init.clone();
schur_panel_minus_nofma_strided_dual(
&mut dst0_kernel,
&mut dst1_kernel,
&src_block,
0,
n_elim,
col_stride,
src_row_offset,
&alphas0,
&alphas1,
);
assert_eq!(dst0_kernel, dst0_ref);
assert_eq!(dst1_kernel, dst1_ref);
}
#[test]
fn axpy2_minus_unroll4_nofma_is_bit_exact_vs_scalar() {
let mut rng = Xorshift64::new(0xB17_EAC70_BAAD_F00Du64);
for &len in LENGTH_SWEEP {
let src0: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let src1: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let dst_init: Vec<f64> = (0..len).map(|_| rng.next_f64()).collect();
let alpha0 = rng.next_f64() * 1.5;
let alpha1 = rng.next_f64() * 1.5;
let mut dst_kernel = dst_init.clone();
let mut dst_ref = dst_init.clone();
axpy2_minus_unroll4_nofma(&mut dst_kernel, &src0, alpha0, &src1, alpha1);
naive_axpy2_minus(&mut dst_ref, &src0, alpha0, &src1, alpha1);
assert_eq!(
dst_kernel, dst_ref,
"non-FMA unroll4 must be bit-exact vs scalar at len={}",
len
);
}
}
}