use *;
use asm;
use transmute;
use const_assert;
use unroll_for_loops;
use crate GoldilocksField;
use crate Poseidon;
use crate branch_hint;
// ========================================== CONSTANTS ===========================================
const WIDTH: usize = 12;
const EPSILON: u64 = 0xffffffff;
// The round constants to be applied by the second set of full rounds. These are just the usual
// round constants, shifted by one round, with zeros shifted in.
/*
const fn make_final_round_constants() -> [u64; WIDTH * HALF_N_FULL_ROUNDS] {
let mut res = [0; WIDTH * HALF_N_FULL_ROUNDS];
let mut i: usize = 0;
while i < WIDTH * (HALF_N_FULL_ROUNDS - 1) {
res[i] = ALL_ROUND_CONSTANTS[i + WIDTH * (HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + 1)];
i += 1;
}
res
}
const FINAL_ROUND_CONSTANTS: [u64; WIDTH * HALF_N_FULL_ROUNDS] = make_final_round_constants();
*/
// ===================================== COMPILE-TIME CHECKS ======================================
/// The MDS matrix multiplication ASM is specific to the MDS matrix below. We want this file to
/// fail to compile if it has been changed.
const
const_assert!;
/// Ensure that the first WIDTH round constants are in canonical* form. This is required because
/// the first constant layer does not handle double overflow.
/// *: round_const == GoldilocksField::ORDER is safe.
/*
#[allow(dead_code)]
const fn check_round_const_bounds_init() -> bool {
let mut i = 0;
while i < WIDTH {
if ALL_ROUND_CONSTANTS[i] > GoldilocksField::ORDER {
return false;
}
i += 1;
}
true
}
const_assert!(check_round_const_bounds_init());
*/
// ====================================== SCALAR ARITHMETIC =======================================
/// Addition modulo ORDER accounting for wraparound. Correct only when a + b < 2**64 + ORDER.
unsafe
/// Subtraction of a and (b >> 32) modulo ORDER accounting for wraparound.
unsafe
/// Multiplication of the low word (i.e., x as u32) by EPSILON.
unsafe
unsafe
// ==================================== STANDALONE CONST LAYER =====================================
/// Standalone const layer. Run only once, at the start of round 1. Remaining const layers are fused
/// with the preceeding MDS matrix multiplication.
/*
#[inline(always)]
#[unroll_for_loops]
unsafe fn const_layer_full(
mut state: [u64; WIDTH],
round_constants: &[u64; WIDTH],
) -> [u64; WIDTH] {
assert!(WIDTH == 12);
for i in 0..12 {
let rc = round_constants[i];
// add_with_wraparound is safe, because rc is in canonical form.
state[i] = add_with_wraparound(state[i], rc);
}
state
}
*/
// ========================================== FULL ROUNDS ==========================================
/// Full S-box.
unsafe
unsafe
unsafe
// ======================================== PARTIAL ROUNDS =========================================
/*
#[rustfmt::skip]
macro_rules! mds_reduce_asm {
($c0:literal, $c1:literal, $out:literal, $consts:literal) => {
concat!(
// Swizzle
"zip1.2d ", $out, ",", $c0, ",", $c1, "\n", // lo
"zip2.2d ", $c0, ",", $c0, ",", $c1, "\n", // hi
// Reduction from u96
"usra.2d ", $c0, ",", $out, ", #32\n", "sli.2d ", $out, ",", $c0, ", #32\n",
// Extract high 32-bits.
"uzp2.4s ", $c0, ",", $c0, ",", $c0, "\n",
// Multiply by EPSILON and accumulate.
"mov.16b ", $c1, ",", $out, "\n",
"umlal.2d ", $out, ",", $c0, ", ", $consts, "[0]\n",
"cmhi.2d ", $c1, ",", $c1, ",", $out, "\n",
"usra.2d ", $out, ",", $c1, ", #32",
)
};
}
#[inline(always)]
unsafe fn partial_round(
(state_scalar, state_vector): ([u64; WIDTH], [uint64x2_t; 5]),
round_constants: &[u64; WIDTH],
) -> ([u64; WIDTH], [uint64x2_t; 5]) {
// see readme-asm.md
// mds_consts0 == [0xffffffff, 1 << 1, 1 << 3, 1 << 5]
// mds_consts1 == [1 << 8, 1 << 10, 1 << 12, 1 << 16]
let mds_consts0: uint32x4_t = vld1q_u32((&MDS_CONSTS[0..4]).as_ptr().cast::<u32>());
let mds_consts1: uint32x4_t = vld1q_u32((&MDS_CONSTS[4..8]).as_ptr().cast::<u32>());
let res0: u64;
let res1: u64;
let res23: uint64x2_t;
let res45: uint64x2_t;
let res67: uint64x2_t;
let res89: uint64x2_t;
let res1011: uint64x2_t;
let res2_scalar: u64;
let res3_scalar: u64;
let res4_scalar: u64;
let res5_scalar: u64;
let res6_scalar: u64;
let res7_scalar: u64;
let res8_scalar: u64;
let res9_scalar: u64;
let res10_scalar: u64;
let res11_scalar: u64;
asm!(
"ldp d0, d1, [{rc_ptr}, #16]",
"fmov d21, {s1}",
"ldp {lo0}, {lo1}, [{rc_ptr}]",
"umulh {t0}, {s0}, {s0}",
"mul {t1}, {s0}, {s0}",
"subs {t1}, {t1}, {t0}, lsr #32",
"csetm {t2:w}, cc",
"lsl {t3}, {t0}, #32",
"sub {t1}, {t1}, {t2}",
"mov {t0:w}, {t0:w}",
"sub {t0}, {t3}, {t0}",
"adds {t0}, {t1}, {t0}",
"csetm {t1:w}, cs",
"add {t0}, {t0}, {t1}",
"umulh {t1}, {s0}, {t0}",
"umulh {t2}, {t0}, {t0}",
"mul {s0}, {s0}, {t0}",
"mul {t0}, {t0}, {t0}",
"subs {s0}, {s0}, {t1}, lsr #32",
"csetm {t3:w}, cc",
"subs {t0}, {t0}, {t2}, lsr #32",
"csetm {t4:w}, cc",
"lsl {t5}, {t1}, #32",
"lsl {t6}, {t2}, #32",
"sub {s0}, {s0}, {t3}",
"sub {t0}, {t0}, {t4}",
"mov {t1:w}, {t1:w}",
"mov {t2:w}, {t2:w}",
"sub {t1}, {t5}, {t1}",
"ushll.2d v10, v21, #10",
"sub {t2}, {t6}, {t2}",
"ushll.2d v11, v21, #16",
"adds {t1}, {s0}, {t1}",
"uaddw.2d v0, v0, v22",
"csetm {s0:w}, cs",
"umlal.2d v1, v22, v31[1]",
"adds {t2}, {t0}, {t2}",
"uaddw2.2d v10, v10, v22",
"csetm {t0:w}, cs",
"uaddw2.2d v11, v11, v22",
"add {t1}, {t1}, {s0}",
"ldp d2, d3, [{rc_ptr}, #32]",
"add {t2}, {t2}, {t0}",
"ushll.2d v12, v21, #3",
"umulh {s0}, {t1}, {t2}",
"ushll.2d v13, v21, #12",
"mul {t0}, {t1}, {t2}",
"umlal.2d v0, v23, v30[1]",
"add {lo1}, {lo1}, {s1:w}, uxtw",
"uaddw2.2d v10, v10, v23",
"add {lo0}, {lo0}, {s1:w}, uxtw",
"uaddw.2d v11, v11, v23",
"lsr {hi0}, {s1}, #32",
"umlal2.2d v1, v23, v30[1]",
"lsr {t3}, {s2}, #32",
"umlal.2d v2, v22, v31[3]",
"lsr {t4}, {s3}, #32",
"umlal2.2d v12, v22, v31[1]",
"add {hi1}, {hi0}, {t3}",
"umlal.2d v3, v22, v30[2]",
"add {hi0}, {hi0}, {t3}, lsl #1",
"umlal2.2d v13, v22, v31[3]",
"add {lo1}, {lo1}, {s2:w}, uxtw",
"ldp d4, d5, [{rc_ptr}, #48]",
"add {lo0}, {lo0}, {s2:w}, uxtw #1",
"ushll.2d v14, v21, #8",
"lsr {t3}, {s4}, #32",
"ushll.2d v15, v21, #1",
"lsr {t5}, {s5}, #32",
"umlal.2d v0, v24, v30[2]",
"subs {t0}, {t0}, {s0}, lsr #32",
"umlal2.2d v10, v24, v30[3]",
"add {hi1}, {hi1}, {t4}, lsl #1",
"umlal2.2d v11, v24, v30[2]",
"add {t6}, {t3}, {t5}, lsl #3",
"uaddw.2d v1, v1, v24",
"add {t5}, {t3}, {t5}, lsl #2",
"uaddw.2d v2, v2, v23",
"lsr {t3}, {s6}, #32",
"umlal.2d v3, v23, v31[1]",
"lsr {s1}, {s7}, #32",
"uaddw2.2d v12, v12, v23",
"mov {s2:w}, {s4:w}",
"uaddw2.2d v13, v13, v23",
"add {hi0}, {hi0}, {t4}",
"umlal.2d v4, v22, v31[2]",
"add {lo1}, {lo1}, {s3:w}, uxtw #1",
"umlal2.2d v14, v22, v30[2]",
"add {lo0}, {lo0}, {s3:w}, uxtw",
"umlal.2d v5, v22, v31[0]",
"add {t4}, {s2}, {s5:w}, uxtw #3",
"umlal2.2d v15, v22, v31[2]",
"add {s2}, {s2}, {s5:w}, uxtw #2",
"ldp d6, d7, [{rc_ptr}, #64]",
"add {s3}, {s1}, {t3}, lsl #4",
"ushll.2d v16, v21, #5",
"csetm {t1:w}, cc",
"ushll.2d v17, v21, #3",
"add {hi1}, {hi1}, {t6}",
"umlal.2d v0, v25, v30[1]",
"add {hi0}, {hi0}, {t5}, lsl #3",
"umlal2.2d v10, v25, v31[0]",
"mov {t5:w}, {s6:w}",
"umlal.2d v1, v25, v30[3]",
"mov {t6:w}, {s7:w}",
"umlal2.2d v11, v25, v30[1]",
"add {s4}, {t6}, {t5}, lsl #4",
"umlal.2d v2, v24, v30[1]",
"add {t3}, {t3}, {s1}, lsl #7",
"uaddw2.2d v12, v12, v24",
"lsr {s1}, {s8}, #32",
"uaddw.2d v13, v13, v24",
"lsr {s5}, {s9}, #32",
"umlal2.2d v3, v24, v30[1]",
"lsl {t2}, {s0}, #32",
"umlal.2d v4, v23, v31[3]",
"sub {t0}, {t0}, {t1}",
"umlal2.2d v14, v23, v31[1]",
"add {lo1}, {lo1}, {t4}",
"umlal.2d v5, v23, v30[2]",
"add {lo0}, {lo0}, {s2}, lsl #3",
"umlal2.2d v15, v23, v31[3]",
"add {t4}, {t5}, {t6}, lsl #7",
"umlal.2d v6, v22, v30[1]",
"add {hi1}, {hi1}, {s3}, lsl #1",
"umlal2.2d v16, v22, v31[0]",
"add {t5}, {s1}, {s5}, lsl #4",
"umlal.2d v7, v22, v30[3]",
"mov {s0:w}, {s0:w}",
"umlal2.2d v17, v22, v30[1]",
"sub {s0}, {t2}, {s0}",
"ldp d8, d9, [{rc_ptr}, #80]",
"add {lo1}, {lo1}, {s4}, lsl #1",
"ushll.2d v18, v21, #0",
"add {hi0}, {hi0}, {t3}, lsl #1",
"ushll.2d v19, v21, #1",
"mov {t3:w}, {s9:w}",
"umlal.2d v0, v26, v31[2]",
"mov {t6:w}, {s8:w}",
"umlal2.2d v10, v26, v30[2]",
"add {s2}, {t6}, {t3}, lsl #4",
"umlal.2d v1, v26, v31[0]",
"add {s1}, {s5}, {s1}, lsl #9",
"umlal2.2d v11, v26, v31[2]",
"lsr {s3}, {s10}, #32",
"umlal.2d v2, v25, v30[2]",
"lsr {s4}, {s11}, #32",
"umlal2.2d v12, v25, v30[3]",
"adds {s0}, {t0}, {s0}",
"umlal2.2d v13, v25, v30[2]",
"add {lo0}, {lo0}, {t4}, lsl #1",
"uaddw.2d v3, v3, v25",
"add {t3}, {t3}, {t6}, lsl #9",
"uaddw.2d v4, v4, v24",
"add {hi1}, {hi1}, {t5}, lsl #8",
"umlal.2d v5, v24, v31[1]",
"add {t4}, {s3}, {s4}, lsl #13",
"uaddw2.2d v14, v14, v24",
"csetm {t0:w}, cs",
"uaddw2.2d v15, v15, v24",
"add {lo1}, {lo1}, {s2}, lsl #8",
"umlal.2d v6, v23, v31[2]",
"add {hi0}, {hi0}, {s1}, lsl #3",
"umlal2.2d v16, v23, v30[2]",
"mov {t5:w}, {s10:w}",
"umlal.2d v7, v23, v31[0]",
"mov {t6:w}, {s11:w}",
"umlal2.2d v17, v23, v31[2]",
"add {s1}, {t5}, {t6}, lsl #13",
"umlal.2d v8, v22, v30[2]",
"add {s2}, {s4}, {s3}, lsl #6",
"umlal2.2d v18, v22, v30[3]",
"add {s0}, {s0}, {t0}",
"uaddw.2d v9, v9, v22",
"add {lo0}, {lo0}, {t3}, lsl #3",
"umlal2.2d v19, v22, v30[2]",
"add {t3}, {t6}, {t5}, lsl #6",
"add.2d v0, v0, v10",
"add {hi1}, {hi1}, {t4}, lsl #3",
"add.2d v1, v1, v11",
"fmov d20, {s0}",
"umlal.2d v0, v20, v31[3]",
"add {lo1}, {lo1}, {s1}, lsl #3",
"umlal.2d v1, v20, v30[2]",
"add {hi0}, {hi0}, {s2}, lsl #10",
"zip1.2d v22, v0, v1",
"lsr {t4}, {s0}, #32",
"zip2.2d v0, v0, v1",
"add {lo0}, {lo0}, {t3}, lsl #10",
"usra.2d v0, v22, #32",
"add {hi1}, {hi1}, {t4}, lsl #10",
"sli.2d v22, v0, #32",
"mov {t3:w}, {s0:w}",
"uzp2.4s v0, v0, v0",
"add {lo1}, {lo1}, {t3}, lsl #10",
"mov.16b v1, v22",
"add {hi0}, {hi0}, {t4}",
"umlal.2d v22, v0, v30[0]",
"add {lo0}, {lo0}, {t3}",
"cmhi.2d v1, v1, v22",
"lsl {t0}, {hi0}, #32",
"usra.2d v22, v1, #32",
"lsl {t1}, {hi1}, #32",
"fmov {s2}, d22",
"adds {lo0}, {lo0}, {t0}",
"fmov.d {s3}, v22[1]",
"csetm {t0:w}, cs",
"umlal.2d v2, v26, v30[1]",
"adds {lo1}, {lo1}, {t1}",
"umlal2.2d v12, v26, v31[0]",
"csetm {t1:w}, cs",
"umlal.2d v3, v26, v30[3]",
"and {t2}, {hi0}, #0xffffffff00000000",
"umlal2.2d v13, v26, v30[1]",
"and {t3}, {hi1}, #0xffffffff00000000",
"umlal.2d v4, v25, v30[1]",
"lsr {hi0}, {hi0}, #32",
"uaddw2.2d v14, v14, v25",
"lsr {hi1}, {hi1}, #32",
"uaddw.2d v15, v15, v25",
"sub {hi0}, {t2}, {hi0}",
"umlal2.2d v5, v25, v30[1]",
"sub {hi1}, {t3}, {hi1}",
"umlal.2d v6, v24, v31[3]",
"add {lo0}, {lo0}, {t0}",
"umlal2.2d v16, v24, v31[1]",
"add {lo1}, {lo1}, {t1}",
"umlal.2d v7, v24, v30[2]",
"adds {lo0}, {lo0}, {hi0}",
"umlal2.2d v17, v24, v31[3]",
"csetm {t0:w}, cs",
"umlal.2d v8, v23, v30[1]",
"adds {lo1}, {lo1}, {hi1}",
"umlal2.2d v18, v23, v31[0]",
"csetm {t1:w}, cs",
"umlal.2d v9, v23, v30[3]",
"add {s0}, {lo0}, {t0}",
"umlal2.2d v19, v23, v30[1]",
"add {s1}, {lo1}, {t1}",
"add.2d v2, v2, v12",
"add.2d v3, v3, v13",
"umlal.2d v2, v20, v31[2]",
"umlal.2d v3, v20, v31[0]",
mds_reduce_asm!("v2", "v3", "v23", "v30"),
"fmov {s4}, d23",
"fmov.d {s5}, v23[1]",
"umlal.2d v4, v26, v30[2]",
"umlal2.2d v14, v26, v30[3]",
"umlal2.2d v15, v26, v30[2]",
"uaddw.2d v5, v5, v26",
"uaddw.2d v6, v6, v25",
"uaddw2.2d v16, v16, v25",
"uaddw2.2d v17, v17, v25",
"umlal.2d v7, v25, v31[1]",
"umlal.2d v8, v24, v31[2]",
"umlal2.2d v18, v24, v30[2]",
"umlal.2d v9, v24, v31[0]",
"umlal2.2d v19, v24, v31[2]",
"add.2d v4, v4, v14",
"add.2d v5, v5, v15",
"umlal.2d v4, v20, v30[1]",
"umlal.2d v5, v20, v30[3]",
mds_reduce_asm!("v4", "v5", "v24", "v30"),
"fmov {s6}, d24",
"fmov.d {s7}, v24[1]",
"umlal.2d v6, v26, v30[1]",
"uaddw2.2d v16, v16, v26",
"umlal2.2d v17, v26, v30[1]",
"uaddw.2d v7, v7, v26",
"umlal.2d v8, v25, v31[3]",
"umlal2.2d v18, v25, v31[1]",
"umlal.2d v9, v25, v30[2]",
"umlal2.2d v19, v25, v31[3]",
"add.2d v6, v6, v16",
"add.2d v7, v7, v17",
"umlal.2d v6, v20, v30[2]",
"uaddw.2d v7, v7, v20",
mds_reduce_asm!("v6", "v7", "v25", "v30"),
"fmov {s8}, d25",
"fmov.d {s9}, v25[1]",
"uaddw.2d v8, v8, v26",
"uaddw2.2d v18, v18, v26",
"umlal.2d v9, v26, v31[1]",
"uaddw2.2d v19, v19, v26",
"add.2d v8, v8, v18",
"add.2d v9, v9, v19",
"umlal.2d v8, v20, v30[1]",
"uaddw.2d v9, v9, v20",
mds_reduce_asm!("v8", "v9", "v26", "v30"),
"fmov {s10}, d26",
"fmov.d {s11}, v26[1]",
// Scalar inputs/outputs
// s0 is transformed by the S-box
s0 = inout(reg) state_scalar[0] => res0,
// s1-s6 double as scratch in the MDS matrix multiplication
s1 = inout(reg) state_scalar[1] => res1,
// s2-s11 are copied from the vector inputs/outputs
s2 = inout(reg) state_scalar[2] => res2_scalar,
s3 = inout(reg) state_scalar[3] => res3_scalar,
s4 = inout(reg) state_scalar[4] => res4_scalar,
s5 = inout(reg) state_scalar[5] => res5_scalar,
s6 = inout(reg) state_scalar[6] => res6_scalar,
s7 = inout(reg) state_scalar[7] => res7_scalar,
s8 = inout(reg) state_scalar[8] => res8_scalar,
s9 = inout(reg) state_scalar[9] => res9_scalar,
s10 = inout(reg) state_scalar[10] => res10_scalar,
s11 = inout(reg) state_scalar[11] => res11_scalar,
// Pointer to the round constants
rc_ptr = in(reg) round_constants.as_ptr(),
// Scalar MDS multiplication accumulators
lo1 = out(reg) _,
hi1 = out(reg) _,
lo0 = out(reg) _,
hi0 = out(reg) _,
// Scalar scratch registers
// All are used in the scalar S-box
t0 = out(reg) _,
t1 = out(reg) _,
t2 = out(reg) _,
// t3-t6 are used in the scalar MDS matrix multiplication
t3 = out(reg) _,
t4 = out(reg) _,
t5 = out(reg) _,
t6 = out(reg) _,
// Vector MDS multiplication accumulators
// v{n} and v1{n} are accumulators for res[n + 2] (we need two to mask latency)
// The low and high 64-bits are accumulators for the low and high results, respectively
out("v0") _,
out("v1") _,
out("v2") _,
out("v3") _,
out("v4") _,
out("v5") _,
out("v6") _,
out("v7") _,
out("v8") _,
out("v9") _,
out("v10") _,
out("v11") _,
out("v12") _,
out("v13") _,
out("v14") _,
out("v15") _,
out("v16") _,
out("v17") _,
out("v18") _,
out("v19") _,
// Inputs into vector MDS matrix multiplication
// v20 and v21 are sbox(state0) and state1, respectively. They are copied from the scalar
// registers.
out("v20") _,
out("v21") _,
// v22, ..., v26 hold state[2,3], ..., state[10,11]
inout("v22") state_vector[0] => res23,
inout("v23") state_vector[1] => res45,
inout("v24") state_vector[2] => res67,
inout("v25") state_vector[3] => res89,
inout("v26") state_vector[4] => res1011,
// Useful constants
in("v30") mds_consts0,
in("v31") mds_consts1,
options(nostack, pure, readonly),
);
(
[
res0,
res1,
res2_scalar,
res3_scalar,
res4_scalar,
res5_scalar,
res6_scalar,
res7_scalar,
res8_scalar,
res9_scalar,
res10_scalar,
res11_scalar,
],
[res23, res45, res67, res89, res1011],
)
}
*/
// ========================================== GLUE CODE ===========================================
/*
#[inline(always)]
unsafe fn full_round(state: [u64; 12], round_constants: &[u64; WIDTH]) -> [u64; 12] {
let state = sbox_layer_full(state);
mds_layer_full(state, round_constants)
}
#[inline]
unsafe fn full_rounds(
mut state: [u64; 12],
round_constants: &[u64; WIDTH * HALF_N_FULL_ROUNDS],
) -> [u64; 12] {
for round_constants_chunk in round_constants.chunks_exact(WIDTH) {
state = full_round(state, round_constants_chunk.try_into().unwrap());
}
state
}
#[inline(always)]
unsafe fn partial_rounds(
state: [u64; 12],
round_constants: &[u64; WIDTH * N_PARTIAL_ROUNDS],
) -> [u64; 12] {
let mut state = (
state,
[
vcombine_u64(vcreate_u64(state[2]), vcreate_u64(state[3])),
vcombine_u64(vcreate_u64(state[4]), vcreate_u64(state[5])),
vcombine_u64(vcreate_u64(state[6]), vcreate_u64(state[7])),
vcombine_u64(vcreate_u64(state[8]), vcreate_u64(state[9])),
vcombine_u64(vcreate_u64(state[10]), vcreate_u64(state[11])),
],
);
for round_constants_chunk in round_constants.chunks_exact(WIDTH) {
state = partial_round(state, round_constants_chunk.try_into().unwrap());
}
state.0
}
*/
/*
#[inline(always)]
pub unsafe fn poseidon(state: [GoldilocksField; 12]) -> [GoldilocksField; 12] {
let state = unwrap_state(state);
let state = const_layer_full(state, ALL_ROUND_CONSTANTS[0..WIDTH].try_into().unwrap());
let state = full_rounds(
state,
ALL_ROUND_CONSTANTS[WIDTH..WIDTH * (HALF_N_FULL_ROUNDS + 1)]
.try_into()
.unwrap(),
);
let state = partial_rounds(
state,
ALL_ROUND_CONSTANTS
[WIDTH * (HALF_N_FULL_ROUNDS + 1)..WIDTH * (HALF_N_FULL_ROUNDS + N_PARTIAL_ROUNDS + 1)]
.try_into()
.unwrap(),
);
let state = full_rounds(state, &FINAL_ROUND_CONSTANTS);
wrap_state(state)
}
*/
pub unsafe
pub unsafe