use math::{
fft::real_u64::{fft4_real, ifft4_real_unreduced},
fields::f64::BaseElement,
FieldElement,
};
const MDS_FREQ_BLOCK_ONE: [i64; 2] = [16, 8];
const MDS_FREQ_BLOCK_TWO: [(i64, i64); 2] = [(8, -4), (-1, 1)];
const MDS_FREQ_BLOCK_THREE: [i64; 2] = [-1, 1];
pub(crate) fn mds_multiply(state: &mut [BaseElement; 8]) {
let mut result = [BaseElement::ZERO; 8];
let mut state_l = [0u64; 8];
let mut state_h = [0u64; 8];
for r in 0..8 {
let s = state[r].inner();
state_h[r] = s >> 32;
state_l[r] = (s as u32) as u64;
}
let state_h = mds_multiply_freq(state_h);
let state_l = mds_multiply_freq(state_l);
for r in 0..8 {
let s = state_l[r] as u128 + ((state_h[r] as u128) << 32);
let s_hi = (s >> 64) as u64;
let s_lo = s as u64;
let z = (s_hi << 32) - s_hi;
let (res, over) = s_lo.overflowing_add(z);
result[r] = BaseElement::from_mont(res.wrapping_add(0u32.wrapping_sub(over as u32) as u64));
}
*state = result;
}
#[inline(always)]
pub(crate) fn mds_multiply_freq(state: [u64; 8]) -> [u64; 8] {
let [s0, s1, s2, s3, s4, s5, s6, s7] = state;
let (u0, u1, u2) = fft4_real([s0, s2, s4, s6]);
let (u4, u5, u6) = fft4_real([s1, s3, s5, s7]);
let [v0, v4] = block1([u0, u4], MDS_FREQ_BLOCK_ONE);
let [v1, v5] = block2([u1, u5], MDS_FREQ_BLOCK_TWO);
let [v2, v6] = block3([u2, u6], MDS_FREQ_BLOCK_THREE);
let [s0, s2, s4, s6] = ifft4_real_unreduced((v0, v1, v2));
let [s1, s3, s5, s7] = ifft4_real_unreduced((v4, v5, v6));
[s0, s1, s2, s3, s4, s5, s6, s7]
}
#[inline(always)]
fn block1(x: [i64; 2], y: [i64; 2]) -> [i64; 2] {
let [x0, x1] = x;
let [y0, y1] = y;
let z0 = x0 * y0 + x1 * y1;
let z1 = x0 * y1 + x1 * y0;
[z0, z1]
}
#[inline(always)]
fn block2(x: [(i64, i64); 2], y: [(i64, i64); 2]) -> [(i64, i64); 2] {
let [(x0r, x0i), (x1r, x1i)] = x;
let [(y0r, y0i), (y1r, y1i)] = y;
let x0s = x0r + x0i;
let x1s = x1r + x1i;
let y0s = y0r + y0i;
let y1s = y1r + y1i;
let m0 = (x0r * y0r, x0i * y0i);
let m1 = (x1r * y1r, x1i * y1i);
let z0r = (m0.0 - m0.1) + (x1s * y1s - m1.0 - m1.1);
let z0i = (x0s * y0s - m0.0 - m0.1) + (-m1.0 + m1.1);
let z0 = (z0r, z0i);
let m0 = (x0r * y1r, x0i * y1i);
let m1 = (x1r * y0r, x1i * y0i);
let z1r = (m0.0 - m0.1) + (m1.0 - m1.1);
let z1i = (x0s * y1s - m0.0 - m0.1) + (x1s * y0s - m1.0 - m1.1);
let z1 = (z1r, z1i);
[z0, z1]
}
#[inline(always)]
fn block3(x: [i64; 2], y: [i64; 2]) -> [i64; 2] {
let [x0, x1] = x;
let [y0, y1] = y;
let z0 = x0 * y0 - x1 * y1;
let z1 = x0 * y1 + x1 * y0;
[z0, z1]
}