use super::utils::{add_asm, mul_add_asm, mul_asm};
#[inline(always)]
pub unsafe fn sbox_s0_asm(state: &mut [u64]) {
unsafe {
let s0 = state[0];
let s0_2 = mul_asm(s0, s0);
let s0_3 = mul_asm(s0_2, s0);
let s0_4 = mul_asm(s0_2, s0_2);
state[0] = mul_asm(s0_3, s0_4);
}
}
#[inline(always)]
pub unsafe fn sbox_s0_dual_asm(state0: &mut [u64], state1: &mut [u64]) {
unsafe {
let a = state0[0];
let b = state1[0];
let a2 = mul_asm(a, a);
let b2 = mul_asm(b, b);
let a3 = mul_asm(a2, a);
let b3 = mul_asm(b2, b);
let a4 = mul_asm(a2, a2);
let b4 = mul_asm(b2, b2);
state0[0] = mul_asm(a3, a4);
state1[0] = mul_asm(b3, b4);
}
}
#[inline(always)]
pub unsafe fn cheap_matmul_asm_w8(state: &mut [u64; 8], first_row: &[u64; 8], v: &[u64; 8]) {
unsafe {
let old_s0 = state[0];
let mut acc = mul_asm(state[0], first_row[0]);
acc = mul_add_asm(state[1], first_row[1], acc);
acc = mul_add_asm(state[2], first_row[2], acc);
acc = mul_add_asm(state[3], first_row[3], acc);
acc = mul_add_asm(state[4], first_row[4], acc);
acc = mul_add_asm(state[5], first_row[5], acc);
acc = mul_add_asm(state[6], first_row[6], acc);
acc = mul_add_asm(state[7], first_row[7], acc);
state[1] = mul_add_asm(old_s0, v[0], state[1]);
state[2] = mul_add_asm(old_s0, v[1], state[2]);
state[3] = mul_add_asm(old_s0, v[2], state[3]);
state[4] = mul_add_asm(old_s0, v[3], state[4]);
state[5] = mul_add_asm(old_s0, v[4], state[5]);
state[6] = mul_add_asm(old_s0, v[5], state[6]);
state[7] = mul_add_asm(old_s0, v[6], state[7]);
state[0] = acc;
}
}
#[inline(always)]
pub unsafe fn cheap_matmul_asm_w12(state: &mut [u64; 12], first_row: &[u64; 12], v: &[u64; 12]) {
unsafe {
let old_s0 = state[0];
let mut acc = mul_asm(state[0], first_row[0]);
acc = mul_add_asm(state[1], first_row[1], acc);
acc = mul_add_asm(state[2], first_row[2], acc);
acc = mul_add_asm(state[3], first_row[3], acc);
acc = mul_add_asm(state[4], first_row[4], acc);
acc = mul_add_asm(state[5], first_row[5], acc);
acc = mul_add_asm(state[6], first_row[6], acc);
acc = mul_add_asm(state[7], first_row[7], acc);
acc = mul_add_asm(state[8], first_row[8], acc);
acc = mul_add_asm(state[9], first_row[9], acc);
acc = mul_add_asm(state[10], first_row[10], acc);
acc = mul_add_asm(state[11], first_row[11], acc);
state[1] = mul_add_asm(old_s0, v[0], state[1]);
state[2] = mul_add_asm(old_s0, v[1], state[2]);
state[3] = mul_add_asm(old_s0, v[2], state[3]);
state[4] = mul_add_asm(old_s0, v[3], state[4]);
state[5] = mul_add_asm(old_s0, v[4], state[5]);
state[6] = mul_add_asm(old_s0, v[5], state[6]);
state[7] = mul_add_asm(old_s0, v[6], state[7]);
state[8] = mul_add_asm(old_s0, v[7], state[8]);
state[9] = mul_add_asm(old_s0, v[8], state[9]);
state[10] = mul_add_asm(old_s0, v[9], state[10]);
state[11] = mul_add_asm(old_s0, v[10], state[11]);
state[0] = acc;
}
}
#[inline(always)]
pub unsafe fn cheap_matmul_dual_asm_w8(
s0: &mut [u64; 8],
s1: &mut [u64; 8],
first_row: &[u64; 8],
v: &[u64; 8],
) {
unsafe {
let old_a = s0[0];
let old_b = s1[0];
let mut acc_a = mul_asm(s0[0], first_row[0]);
let mut acc_b = mul_asm(s1[0], first_row[0]);
acc_a = mul_add_asm(s0[1], first_row[1], acc_a);
acc_b = mul_add_asm(s1[1], first_row[1], acc_b);
acc_a = mul_add_asm(s0[2], first_row[2], acc_a);
acc_b = mul_add_asm(s1[2], first_row[2], acc_b);
acc_a = mul_add_asm(s0[3], first_row[3], acc_a);
acc_b = mul_add_asm(s1[3], first_row[3], acc_b);
acc_a = mul_add_asm(s0[4], first_row[4], acc_a);
acc_b = mul_add_asm(s1[4], first_row[4], acc_b);
acc_a = mul_add_asm(s0[5], first_row[5], acc_a);
acc_b = mul_add_asm(s1[5], first_row[5], acc_b);
acc_a = mul_add_asm(s0[6], first_row[6], acc_a);
acc_b = mul_add_asm(s1[6], first_row[6], acc_b);
acc_a = mul_add_asm(s0[7], first_row[7], acc_a);
acc_b = mul_add_asm(s1[7], first_row[7], acc_b);
s0[1] = mul_add_asm(old_a, v[0], s0[1]);
s1[1] = mul_add_asm(old_b, v[0], s1[1]);
s0[2] = mul_add_asm(old_a, v[1], s0[2]);
s1[2] = mul_add_asm(old_b, v[1], s1[2]);
s0[3] = mul_add_asm(old_a, v[2], s0[3]);
s1[3] = mul_add_asm(old_b, v[2], s1[3]);
s0[4] = mul_add_asm(old_a, v[3], s0[4]);
s1[4] = mul_add_asm(old_b, v[3], s1[4]);
s0[5] = mul_add_asm(old_a, v[4], s0[5]);
s1[5] = mul_add_asm(old_b, v[4], s1[5]);
s0[6] = mul_add_asm(old_a, v[5], s0[6]);
s1[6] = mul_add_asm(old_b, v[5], s1[6]);
s0[7] = mul_add_asm(old_a, v[6], s0[7]);
s1[7] = mul_add_asm(old_b, v[6], s1[7]);
s0[0] = acc_a;
s1[0] = acc_b;
}
}
#[inline(always)]
pub unsafe fn cheap_matmul_dual_asm_w12(
s0: &mut [u64; 12],
s1: &mut [u64; 12],
first_row: &[u64; 12],
v: &[u64; 12],
) {
unsafe {
let old_a = s0[0];
let old_b = s1[0];
let mut acc_a = mul_asm(s0[0], first_row[0]);
let mut acc_b = mul_asm(s1[0], first_row[0]);
for i in 1..12 {
acc_a = mul_add_asm(s0[i], first_row[i], acc_a);
acc_b = mul_add_asm(s1[i], first_row[i], acc_b);
}
for i in 1..12 {
s0[i] = mul_add_asm(old_a, v[i - 1], s0[i]);
s1[i] = mul_add_asm(old_b, v[i - 1], s1[i]);
}
s0[0] = acc_a;
s1[0] = acc_b;
}
}
pub fn dense_matmul_asm_w8(state: &mut [u64; 8], m: &[[u64; 8]; 8]) {
unsafe {
let input = *state;
for i in 0..8 {
let mut acc = mul_asm(input[0], m[i][0]);
for j in 1..8 {
acc = mul_add_asm(input[j], m[i][j], acc);
}
state[i] = acc;
}
}
}
pub fn dense_matmul_asm_w12(state: &mut [u64; 12], m: &[[u64; 12]; 12]) {
unsafe {
let input = *state;
for i in 0..12 {
let mut acc = mul_asm(input[0], m[i][0]);
for j in 1..12 {
acc = mul_add_asm(input[j], m[i][j], acc);
}
state[i] = acc;
}
}
}
pub fn dense_matmul_dual_asm_w8(s0: &mut [u64; 8], s1: &mut [u64; 8], m: &[[u64; 8]; 8]) {
unsafe {
let in0 = *s0;
let in1 = *s1;
for i in 0..8 {
let mut a = mul_asm(in0[0], m[i][0]);
let mut b = mul_asm(in1[0], m[i][0]);
for j in 1..8 {
a = mul_add_asm(in0[j], m[i][j], a);
b = mul_add_asm(in1[j], m[i][j], b);
}
s0[i] = a;
s1[i] = b;
}
}
}
pub fn dense_matmul_dual_asm_w12(s0: &mut [u64; 12], s1: &mut [u64; 12], m: &[[u64; 12]; 12]) {
unsafe {
let in0 = *s0;
let in1 = *s1;
for i in 0..12 {
let mut a = mul_asm(in0[0], m[i][0]);
let mut b = mul_asm(in1[0], m[i][0]);
for j in 1..12 {
a = mul_add_asm(in0[j], m[i][j], a);
b = mul_add_asm(in1[j], m[i][j], b);
}
s0[i] = a;
s1[i] = b;
}
}
}
#[inline(always)]
pub unsafe fn add_rc_asm<const WIDTH: usize>(state: &mut [u64; WIDTH], rc: &[u64; WIDTH]) {
unsafe {
for i in 0..WIDTH {
state[i] = add_asm(state[i], rc[i]);
}
}
}
#[inline(always)]
pub unsafe fn add_rc_dual_asm<const WIDTH: usize>(
s0: &mut [u64; WIDTH],
s1: &mut [u64; WIDTH],
rc: &[u64; WIDTH],
) {
unsafe {
for i in 0..WIDTH {
s0[i] = add_asm(s0[i], rc[i]);
s1[i] = add_asm(s1[i], rc[i]);
}
}
}
#[inline(always)]
pub unsafe fn add_scalar_s0_asm(state: &mut [u64], rc: u64) {
unsafe {
state[0] = add_asm(state[0], rc);
}
}
#[cfg(test)]
mod tests {
use p3_field::PrimeField64;
use proptest::prelude::*;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use super::*;
use crate::aarch64_neon::danger_array;
use crate::{Goldilocks, P};
type F = Goldilocks;
fn canon(x: u64) -> u64 {
F::new(x).as_canonical_u64()
}
proptest! {
#[test]
fn test_sbox_s0_asm(vals in prop::array::uniform8(any::<u64>())) {
let x = F::new(vals[0]);
let x2 = x * x;
let x3 = x2 * x;
let x4 = x2 * x2;
let expected_s0 = (x3 * x4).as_canonical_u64();
let mut state = vals;
unsafe { sbox_s0_asm(&mut state); }
prop_assert_eq!(canon(state[0]), expected_s0);
for i in 1..8 {
prop_assert_eq!(state[i], vals[i]);
}
}
#[test]
fn test_sbox_s0_dual_asm(
vals0 in prop::array::uniform8(any::<u64>()),
vals1 in prop::array::uniform8(any::<u64>()),
) {
let mut ref0 = vals0;
let mut ref1 = vals1;
unsafe {
sbox_s0_asm(&mut ref0);
sbox_s0_asm(&mut ref1);
}
let mut s0 = vals0;
let mut s1 = vals1;
unsafe { sbox_s0_dual_asm(&mut s0, &mut s1); }
prop_assert_eq!(canon(s0[0]), canon(ref0[0]));
prop_assert_eq!(canon(s1[0]), canon(ref1[0]));
for i in 1..8 {
prop_assert_eq!(s0[i], vals0[i]);
prop_assert_eq!(s1[i], vals1[i]);
}
}
#[test]
fn test_add_rc_asm_w8(
vals in prop::array::uniform8(any::<u64>()),
rc in prop::array::uniform8(any::<u64>()),
) {
let expected: [u64; 8] = core::array::from_fn(|i| {
(F::new(vals[i]) + F::new(rc[i])).as_canonical_u64()
});
let mut state = vals;
unsafe { add_rc_asm(&mut state, &rc); }
for i in 0..8 {
prop_assert_eq!(canon(state[i]), expected[i]);
}
}
#[test]
fn test_add_rc_asm_w12(
vals in prop::array::uniform12(any::<u64>()),
rc in prop::array::uniform12(any::<u64>()),
) {
let expected: [u64; 12] = core::array::from_fn(|i| {
(F::new(vals[i]) + F::new(rc[i])).as_canonical_u64()
});
let mut state = vals;
unsafe { add_rc_asm(&mut state, &rc); }
for i in 0..12 {
prop_assert_eq!(canon(state[i]), expected[i]);
}
}
#[test]
fn test_add_rc_dual_asm_w8(
vals0 in prop::array::uniform8(any::<u64>()),
vals1 in prop::array::uniform8(any::<u64>()),
rc in prop::array::uniform8(any::<u64>()),
) {
let mut ref0 = vals0;
let mut ref1 = vals1;
unsafe {
add_rc_asm(&mut ref0, &rc);
add_rc_asm(&mut ref1, &rc);
}
let mut s0 = vals0;
let mut s1 = vals1;
unsafe { add_rc_dual_asm(&mut s0, &mut s1, &rc); }
for i in 0..8 {
prop_assert_eq!(canon(s0[i]), canon(ref0[i]));
prop_assert_eq!(canon(s1[i]), canon(ref1[i]));
}
}
#[test]
fn test_add_rc_dual_asm_w12(
vals0 in prop::array::uniform12(any::<u64>()),
vals1 in prop::array::uniform12(any::<u64>()),
rc in prop::array::uniform12(any::<u64>()),
) {
let mut ref0 = vals0;
let mut ref1 = vals1;
unsafe {
add_rc_asm(&mut ref0, &rc);
add_rc_asm(&mut ref1, &rc);
}
let mut s0 = vals0;
let mut s1 = vals1;
unsafe { add_rc_dual_asm(&mut s0, &mut s1, &rc); }
for i in 0..12 {
prop_assert_eq!(canon(s0[i]), canon(ref0[i]));
prop_assert_eq!(canon(s1[i]), canon(ref1[i]));
}
}
#[test]
fn test_add_scalar_s0_asm(vals in prop::array::uniform8(any::<u64>()), rc: u64) {
let expected_s0 = (F::new(vals[0]) + F::new(rc)).as_canonical_u64();
let mut state = vals;
unsafe { add_scalar_s0_asm(&mut state, rc); }
prop_assert_eq!(canon(state[0]), expected_s0);
for i in 1..8 {
prop_assert_eq!(state[i], vals[i]);
}
}
#[test]
fn test_cheap_matmul_asm_w8(
vals in prop::array::uniform8(any::<u64>()),
first_row in prop::array::uniform8(any::<u64>()),
v in prop::array::uniform8(any::<u64>()),
) {
let f: [F; 8] = vals.map(F::new);
let fr: [F; 8] = first_row.map(F::new);
let fv: [F; 8] = v.map(F::new);
let old_s0 = f[0];
let new_s0: F = (0..8).map(|i| f[i] * fr[i]).sum();
let mut expected = f;
for i in 1..8 {
expected[i] = f[i] + old_s0 * fv[i - 1];
}
expected[0] = new_s0;
let mut state = vals;
unsafe { cheap_matmul_asm_w8(&mut state, &first_row, &v); }
for i in 0..8 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
#[test]
fn test_cheap_matmul_asm_w12(
vals in prop::array::uniform12(any::<u64>()),
first_row in prop::array::uniform12(any::<u64>()),
v in prop::array::uniform12(any::<u64>()),
) {
let f: [F; 12] = vals.map(F::new);
let fr: [F; 12] = first_row.map(F::new);
let fv: [F; 12] = v.map(F::new);
let old_s0 = f[0];
let new_s0: F = (0..12).map(|i| f[i] * fr[i]).sum();
let mut expected = f;
for i in 1..12 {
expected[i] = f[i] + old_s0 * fv[i - 1];
}
expected[0] = new_s0;
let mut state = vals;
unsafe { cheap_matmul_asm_w12(&mut state, &first_row, &v); }
for i in 0..12 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
#[test]
fn test_cheap_matmul_dual_asm_w8(
vals0 in prop::array::uniform8(any::<u64>()),
vals1 in prop::array::uniform8(any::<u64>()),
first_row in prop::array::uniform8(any::<u64>()),
v in prop::array::uniform8(any::<u64>()),
) {
let mut ref0 = vals0;
let mut ref1 = vals1;
unsafe {
cheap_matmul_asm_w8(&mut ref0, &first_row, &v);
cheap_matmul_asm_w8(&mut ref1, &first_row, &v);
}
let mut s0 = vals0;
let mut s1 = vals1;
unsafe { cheap_matmul_dual_asm_w8(&mut s0, &mut s1, &first_row, &v); }
for i in 0..8 {
prop_assert_eq!(canon(s0[i]), canon(ref0[i]));
prop_assert_eq!(canon(s1[i]), canon(ref1[i]));
}
}
#[test]
fn test_cheap_matmul_dual_asm_w12(
vals0 in prop::array::uniform12(any::<u64>()),
vals1 in prop::array::uniform12(any::<u64>()),
first_row in prop::array::uniform12(any::<u64>()),
v in prop::array::uniform12(any::<u64>()),
) {
let mut ref0 = vals0;
let mut ref1 = vals1;
unsafe {
cheap_matmul_asm_w12(&mut ref0, &first_row, &v);
cheap_matmul_asm_w12(&mut ref1, &first_row, &v);
}
let mut s0 = vals0;
let mut s1 = vals1;
unsafe { cheap_matmul_dual_asm_w12(&mut s0, &mut s1, &first_row, &v); }
for i in 0..12 {
prop_assert_eq!(canon(s0[i]), canon(ref0[i]));
prop_assert_eq!(canon(s1[i]), canon(ref1[i]));
}
}
#[test]
fn test_dense_matmul_asm_w8(vals in prop::array::uniform8(any::<u64>())) {
let mut rng = SmallRng::seed_from_u64(42);
let m: [[u64; 8]; 8] = rand::RngExt::random(&mut rng);
let f: [F; 8] = vals.map(F::new);
let expected: [F; 8] = core::array::from_fn(|i| {
(0..8).map(|j| f[j] * F::new(m[i][j])).sum()
});
let mut state = vals;
dense_matmul_asm_w8(&mut state, &m);
for i in 0..8 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
#[test]
fn test_dense_matmul_asm_w12(vals in prop::array::uniform12(any::<u64>())) {
let mut rng = SmallRng::seed_from_u64(43);
let m: [[u64; 12]; 12] = rand::RngExt::random(&mut rng);
let f: [F; 12] = vals.map(F::new);
let expected: [F; 12] = core::array::from_fn(|i| {
(0..12).map(|j| f[j] * F::new(m[i][j])).sum()
});
let mut state = vals;
dense_matmul_asm_w12(&mut state, &m);
for i in 0..12 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
#[test]
fn test_dense_matmul_dual_asm_w8(
vals0 in prop::array::uniform8(any::<u64>()),
vals1 in prop::array::uniform8(any::<u64>()),
) {
let mut rng = SmallRng::seed_from_u64(44);
let m: [[u64; 8]; 8] = rand::RngExt::random(&mut rng);
let mut ref0 = vals0;
let mut ref1 = vals1;
dense_matmul_asm_w8(&mut ref0, &m);
dense_matmul_asm_w8(&mut ref1, &m);
let mut s0 = vals0;
let mut s1 = vals1;
dense_matmul_dual_asm_w8(&mut s0, &mut s1, &m);
for i in 0..8 {
prop_assert_eq!(canon(s0[i]), canon(ref0[i]));
prop_assert_eq!(canon(s1[i]), canon(ref1[i]));
}
}
#[test]
fn test_dense_matmul_dual_asm_w12(
vals0 in prop::array::uniform12(any::<u64>()),
vals1 in prop::array::uniform12(any::<u64>()),
) {
let mut rng = SmallRng::seed_from_u64(45);
let m: [[u64; 12]; 12] = rand::RngExt::random(&mut rng);
let mut ref0 = vals0;
let mut ref1 = vals1;
dense_matmul_asm_w12(&mut ref0, &m);
dense_matmul_asm_w12(&mut ref1, &m);
let mut s0 = vals0;
let mut s1 = vals1;
dense_matmul_dual_asm_w12(&mut s0, &mut s1, &m);
for i in 0..12 {
prop_assert_eq!(canon(s0[i]), canon(ref0[i]));
prop_assert_eq!(canon(s1[i]), canon(ref1[i]));
}
}
}
proptest! {
#[test]
fn test_sbox_s0_asm_danger(vals in danger_array::<8>()) {
let x = F::new(vals[0]);
let x2 = x * x;
let x3 = x2 * x;
let x4 = x2 * x2;
let expected = x3 * x4;
let mut state = vals;
unsafe { sbox_s0_asm(&mut state); }
prop_assert_eq!(canon(state[0]), expected.as_canonical_u64());
}
#[test]
fn test_add_rc_w8_danger(
vals in danger_array::<8>(),
rc in danger_array::<8>(),
) {
let expected: [u64; 8] = core::array::from_fn(|i| {
(F::new(vals[i]) + F::new(rc[i])).as_canonical_u64()
});
let mut state = vals;
unsafe { add_rc_asm(&mut state, &rc); }
for i in 0..8 {
prop_assert_eq!(canon(state[i]), expected[i]);
}
}
#[test]
fn test_cheap_matmul_w8_danger(
vals in danger_array::<8>(),
first_row in danger_array::<8>(),
v in danger_array::<8>(),
) {
let f: [F; 8] = vals.map(F::new);
let fr: [F; 8] = first_row.map(F::new);
let fv: [F; 8] = v.map(F::new);
let old = f[0];
let new0: F = (0..8).map(|i| f[i] * fr[i]).sum();
let mut expected = f;
for i in 1..8 {
expected[i] = f[i] + old * fv[i - 1];
}
expected[0] = new0;
let mut state = vals;
unsafe { cheap_matmul_asm_w8(&mut state, &first_row, &v); }
for i in 0..8 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
#[test]
fn test_cheap_matmul_w12_danger(
vals in danger_array::<12>(),
first_row in danger_array::<12>(),
v in danger_array::<12>(),
) {
let f: [F; 12] = vals.map(F::new);
let fr: [F; 12] = first_row.map(F::new);
let fv: [F; 12] = v.map(F::new);
let old = f[0];
let new0: F = (0..12).map(|i| f[i] * fr[i]).sum();
let mut expected = f;
for i in 1..12 {
expected[i] = f[i] + old * fv[i - 1];
}
expected[0] = new0;
let mut state = vals;
unsafe { cheap_matmul_asm_w12(&mut state, &first_row, &v); }
for i in 0..12 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
#[test]
fn test_dense_matmul_w8_danger(vals in danger_array::<8>()) {
let mut rng = SmallRng::seed_from_u64(101);
let m: [[u64; 8]; 8] = rand::RngExt::random(&mut rng);
let f: [F; 8] = vals.map(F::new);
let expected: [F; 8] = core::array::from_fn(|i| {
(0..8).map(|j| f[j] * F::new(m[i][j])).sum()
});
let mut state = vals;
dense_matmul_asm_w8(&mut state, &m);
for i in 0..8 {
prop_assert_eq!(canon(state[i]), expected[i].as_canonical_u64());
}
}
}
fn adversarial_states_w8() -> [[u64; 8]; 5] {
[
[P - 1; 8],
[u64::MAX; 8],
core::array::from_fn(|i| if i % 2 == 0 { P - 1 } else { u64::MAX }),
core::array::from_fn(|i| P + (i as u64)),
[0; 8],
]
}
fn adversarial_states_w12() -> [[u64; 12]; 5] {
[
[P - 1; 12],
[u64::MAX; 12],
core::array::from_fn(|i| if i % 2 == 0 { P - 1 } else { u64::MAX }),
core::array::from_fn(|i| P + (i as u64)),
[0; 12],
]
}
#[test]
fn test_cheap_matmul_w8_stress() {
let first_row = [P - 1; 8];
let v = [u64::MAX; 8];
for state in adversarial_states_w8() {
let f: [F; 8] = state.map(F::new);
let fr: [F; 8] = first_row.map(F::new);
let fv: [F; 8] = v.map(F::new);
let old = f[0];
let new0: F = (0..8).map(|i| f[i] * fr[i]).sum();
let mut expected = f;
for i in 1..8 {
expected[i] = f[i] + old * fv[i - 1];
}
expected[0] = new0;
let mut got = state;
unsafe {
cheap_matmul_asm_w8(&mut got, &first_row, &v);
}
for i in 0..8 {
assert_eq!(canon(got[i]), expected[i].as_canonical_u64(), "i={i}");
}
}
}
#[test]
fn test_cheap_matmul_w12_stress() {
let first_row = [P - 1; 12];
let v = [u64::MAX; 12];
for state in adversarial_states_w12() {
let f: [F; 12] = state.map(F::new);
let fr: [F; 12] = first_row.map(F::new);
let fv: [F; 12] = v.map(F::new);
let old = f[0];
let new0: F = (0..12).map(|i| f[i] * fr[i]).sum();
let mut expected = f;
for i in 1..12 {
expected[i] = f[i] + old * fv[i - 1];
}
expected[0] = new0;
let mut got = state;
unsafe {
cheap_matmul_asm_w12(&mut got, &first_row, &v);
}
for i in 0..12 {
assert_eq!(canon(got[i]), expected[i].as_canonical_u64(), "i={i}");
}
}
}
#[test]
fn test_add_rc_w8_stress() {
let rc = [u64::MAX; 8];
for state in adversarial_states_w8() {
let expected: [u64; 8] =
core::array::from_fn(|i| (F::new(state[i]) + F::new(rc[i])).as_canonical_u64());
let mut got = state;
unsafe {
add_rc_asm(&mut got, &rc);
}
for i in 0..8 {
assert_eq!(canon(got[i]), expected[i], "i={i}");
}
}
}
}