use itertools::Itertools;
use num_traits::{One, Zero};
use serde::{Deserialize, Serialize};
use crate::shared_math::b_field_element::BFieldElement;
use crate::shared_math::traits::FiniteField;
use super::rescue_prime_digest::DIGEST_LENGTH;
pub const STATE_SIZE: usize = 16;
pub const CAPACITY: usize = 6;
pub const RATE: usize = 10;
pub const NUM_ROUNDS: usize = 8;
pub const ALPHA: u64 = 7;
pub const ALPHA_INV: u64 = 10540996611094048183;
pub const ROUND_CONSTANTS: [u64; NUM_ROUNDS * STATE_SIZE * 2] = [
3006656781416918236,
4369161505641058227,
6684374425476535479,
15779820574306927140,
9604497860052635077,
6451419160553310210,
16926195364602274076,
6738541355147603274,
13653823767463659393,
16331310420018519380,
10921208506902903237,
5856388654420905056,
180518533287168595,
6394055120127805757,
4624620449883041133,
4245779370310492662,
11436753067664141475,
9565904130524743243,
1795462928700216574,
6069083569854718822,
16847768509740167846,
4958030292488314453,
6638656158077421079,
7387994719600814898,
1380138540257684527,
2756275326704598308,
6162254851582803897,
4357202747710082448,
12150731779910470904,
3121517886069239079,
14951334357190345445,
11174705360936334066,
17619090104023680035,
9879300494565649603,
6833140673689496042,
8026685634318089317,
6481786893261067369,
15148392398843394510,
11231860157121869734,
2645253741394956018,
15345701758979398253,
1715545688795694261,
3419893440622363282,
12314745080283886274,
16173382637268011204,
2012426895438224656,
6886681868854518019,
9323151312904004776,
14061124303940833928,
14720644192628944300,
3643016909963520634,
15164487940674916922,
18095609311840631082,
17450128049477479068,
13770238146408051799,
959547712344137104,
12896174981045071755,
15673600445734665670,
5421724936277706559,
15147580014608980436,
10475549030802107253,
9781768648599053415,
12208559126136453589,
14883846462224929329,
4104889747365723917,
748723978556009523,
1227256388689532469,
5479813539795083611,
8771502115864637772,
16732275956403307541,
4416407293527364014,
828170020209737786,
12657110237330569793,
6054985640939410036,
4339925773473390539,
12523290846763939879,
6515670251745069817,
3304839395869669984,
13139364704983394567,
7310284340158351735,
10864373318031796808,
17752126773383161797,
1934077736434853411,
12181011551355087129,
16512655861290250275,
17788869165454339633,
12226346139665475316,
521307319751404755,
18194723210928015140,
11017703779172233841,
15109417014344088693,
16118100307150379696,
16104548432406078622,
10637262801060241057,
10146828954247700859,
14927431817078997000,
8849391379213793752,
14873391436448856814,
15301636286727658488,
14600930856978269524,
14900320206081752612,
9439125422122803926,
17731778886181971775,
11364016993846997841,
11610707911054206249,
16438527050768899002,
1230592087960588528,
11390503834342845303,
10608561066917009324,
5454068995870010477,
13783920070953012756,
10807833173700567220,
8597517374132535250,
17631206339728520236,
8083932512125088346,
10460229397140806011,
16904442127403184100,
15806582425540851960,
8002674967888750145,
7088508235236416142,
2774873684607752403,
11519427263507311324,
14949623981479468161,
18169367272402768616,
13279771425489376175,
3437101568566296039,
11820510872362664493,
13649520728248893918,
13432595021904865723,
12153175375751103391,
16459175915481931891,
14698099486055505377,
14962427686967561007,
10825731681832829214,
12562849212348892143,
18054851842681741827,
16866664833727482321,
10485994783891875256,
8074668712578030015,
7502837771635714611,
8326381174040960025,
1299216707593490898,
12092900834113479279,
10147133736028577997,
12103660182675227350,
16088613802080804964,
10323305955081440356,
12814564542614394316,
9653856919559060601,
10390420172371317530,
7831993942325060892,
9568326819852151217,
6299791178740935792,
12692828392357621723,
10331476541693143830,
3115340436782501075,
17456578083689713056,
12924575652913558388,
14365487216177868031,
7211834371191912632,
17610068359394967554,
646302646073569086,
12437378932700222679,
2758591586601041336,
10952396165876183059,
8827205511644136726,
17572216767879446421,
12516044823385174395,
6380048472179557105,
1959389938825200414,
257915527015303758,
4942451629986849727,
1698530521870297461,
1802136667015215029,
6353258543636931941,
13791525219506237119,
7093082295632492630,
15409842367405634814,
2090232819855225051,
13926160661036606054,
389467431021126699,
4736917413147385608,
6217341363393311211,
4366302820407593918,
12748238635329332117,
7671680179984682360,
17998193362025085453,
432899318054332645,
1973816396170253277,
607886411884636526,
15080416519109365682,
13607062276466651973,
2458254972975404730,
15323169029557757131,
10953434699543086460,
13995946730291266219,
12803971247555868632,
3974568790603251423,
10629169239281589943,
2058261494620094806,
15905212873859894286,
11221574225004694137,
15430295276730781380,
10448646831319611878,
7559293484620816204,
15679753002507105741,
6043747003590355195,
3404573815097301491,
13392826344874185313,
6464466389567159772,
8932733991045074013,
6565970376680631168,
7050411859293315754,
9763347751680159247,
3140014248604700259,
5621238883761074228,
12664766603293629079,
6533276137502482405,
914829860407409680,
14599697497440353734,
16400390478099648992,
1619185634767959932,
16420198681440130663,
1331388886719756999,
1430143015191336857,
14618841684410509097,
1870494251298489312,
3783117677312763499,
16164771504475705474,
6996935044500625689,
4356994160244918010,
13579982029281680908,
8835524728424198741,
13281017722683773148,
2669924686363521592,
15020410046647566094,
9534143832529454683,
156263138519279564,
17421879327900831752,
9524879102847422379,
5120021146470638642,
9588770058331935449,
1501841070476096181,
5687728871183511192,
16091855309800405887,
17307425956518746505,
1162636238106302518,
8756478993690213481,
6898084027896327288,
8485261637658061794,
4169208979833913382,
7776158701576840241,
13861841831073878156,
4896983281306117497,
6056805506026814259,
15706891000994288769,
];
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RescuePrimeOptimizedState {
pub state: [BFieldElement; STATE_SIZE],
}
impl RescuePrimeOptimizedState {
fn new() -> RescuePrimeOptimizedState {
RescuePrimeOptimizedState {
state: [BFieldElement::zero(); STATE_SIZE],
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct RescuePrimeOptimized {}
impl RescuePrimeOptimized {
#[inline]
fn batch_square(array: &mut [BFieldElement; STATE_SIZE]) {
for a in array.iter_mut() {
*a = a.square();
}
}
#[inline]
fn batch_square_n<const N: usize>(array: &mut [BFieldElement; STATE_SIZE]) {
for _ in 0..N {
Self::batch_square(array);
}
}
#[inline]
fn batch_mul_into(
array: &mut [BFieldElement; STATE_SIZE],
operand: [BFieldElement; STATE_SIZE],
) {
for (a, b) in array.iter_mut().zip_eq(operand.iter()) {
*a *= *b;
}
}
#[inline]
fn batch_mod_pow_alpha_inv(array: [BFieldElement; STATE_SIZE]) -> [BFieldElement; STATE_SIZE] {
let mut p1 = array;
Self::batch_square(&mut p1);
let mut p2 = p1;
Self::batch_square(&mut p2);
let mut p3 = p2;
Self::batch_square_n::<3>(&mut p3);
Self::batch_mul_into(&mut p3, p2);
let mut p4 = p3;
Self::batch_square_n::<6>(&mut p4);
Self::batch_mul_into(&mut p4, p3);
let mut p5 = p4;
Self::batch_square_n::<12>(&mut p5);
Self::batch_mul_into(&mut p5, p4);
let mut p6 = p5;
Self::batch_square_n::<6>(&mut p6);
Self::batch_mul_into(&mut p6, p3);
let mut p7 = p6;
Self::batch_square_n::<31>(&mut p7);
Self::batch_mul_into(&mut p7, p6);
let mut result = p7;
Self::batch_square(&mut result);
Self::batch_mul_into(&mut result, p6);
Self::batch_square_n::<2>(&mut result);
Self::batch_mul_into(&mut result, p2);
Self::batch_mul_into(&mut result, p1);
Self::batch_mul_into(&mut result, array);
result
}
#[inline]
fn batch_mod_pow_alpha(array: [BFieldElement; STATE_SIZE]) -> [BFieldElement; STATE_SIZE] {
let mut result = array;
Self::batch_square(&mut result);
Self::batch_mul_into(&mut result, array);
Self::batch_square(&mut result);
Self::batch_mul_into(&mut result, array);
result
}
#[allow(dead_code)]
fn batch_mod_pow(
array: [BFieldElement; STATE_SIZE],
power: u64,
) -> [BFieldElement; STATE_SIZE] {
let mut acc = [BFieldElement::one(); STATE_SIZE];
for i in (0..64).rev() {
if i != 63 {
Self::batch_square(&mut acc);
}
if power & (1 << i) != 0 {
Self::batch_mul_into(&mut acc, array);
}
}
acc
}
#[allow(clippy::many_single_char_names)]
fn ntt_noswap(x: &mut [BFieldElement]) {
const POWERS_OF_OMEGA_BITREVERSED: [BFieldElement; 8] = [
BFieldElement::new(1),
BFieldElement::new(281474976710656),
BFieldElement::new(18446744069397807105),
BFieldElement::new(18446742969902956801),
BFieldElement::new(17293822564807737345),
BFieldElement::new(4096),
BFieldElement::new(4503599626321920),
BFieldElement::new(18446744000695107585),
];
for j in 0..8 {
let u = x[j];
let v = x[j + 8] * BFieldElement::one();
x[j] = u + v;
x[j + 8] = u - v;
}
for (i, zeta) in POWERS_OF_OMEGA_BITREVERSED.iter().enumerate().take(2) {
let s = i * 8;
for j in s..(s + 4) {
let u = x[j];
let v = x[j + 4] * *zeta;
x[j] = u + v;
x[j + 4] = u - v;
}
}
for (i, zeta) in POWERS_OF_OMEGA_BITREVERSED.iter().enumerate().take(4) {
let s = i * 4;
for j in s..(s + 2) {
let u = x[j];
let v = x[j + 2] * *zeta;
x[j] = u + v;
x[j + 2] = u - v;
}
}
for (i, zeta) in POWERS_OF_OMEGA_BITREVERSED.iter().enumerate().take(8) {
let s = i * 2;
let u = x[s];
let v = x[s + 1] * *zeta;
x[s] = u + v;
x[s + 1] = u - v;
}
}
#[allow(clippy::many_single_char_names)]
fn intt_noswap(x: &mut [BFieldElement]) {
const POWERS_OF_OMEGA_INVERSE: [BFieldElement; 8] = [
BFieldElement::new(1),
BFieldElement::new(68719476736),
BFieldElement::new(1099511627520),
BFieldElement::new(18446744069414580225),
BFieldElement::new(18446462594437873665),
BFieldElement::new(18442240469788262401),
BFieldElement::new(16777216),
BFieldElement::new(1152921504606846976),
];
{
{
let u = x[1];
let v = x[0];
x[1] = v - u;
x[0] = v + u;
}
{
let u = x[2 + 1];
let v = x[2];
x[2 + 1] = v - u;
x[2] = v + u;
}
{
let u = x[4 + 1];
let v = x[4];
x[4 + 1] = v - u;
x[4] = v + u;
}
{
let u = x[6 + 1];
let v = x[6];
x[6 + 1] = v - u;
x[6] = v + u;
}
{
let u = x[8 + 1];
let v = x[8];
x[8 + 1] = v - u;
x[8] = v + u;
}
{
let u = x[10 + 1];
let v = x[10];
x[10 + 1] = v - u;
x[10] = v + u;
}
{
let u = x[12 + 1];
let v = x[12];
x[12 + 1] = v - u;
x[12] = v + u;
}
{
let u = x[14 + 1];
let v = x[14];
x[14 + 1] = v - u;
x[14] = v + u;
}
}
{
{
for j in 0..2 {
let zeta = POWERS_OF_OMEGA_INVERSE[4 * j];
{
let u = x[j + 2] * zeta;
let v = x[j];
x[j + 2] = v - u;
x[j] = v + u;
}
{
let u = x[4 + j + 2] * zeta;
let v = x[4 + j];
x[4 + j + 2] = v - u;
x[4 + j] = v + u;
}
{
let u = x[8 + j + 2] * zeta;
let v = x[8 + j];
x[8 + j + 2] = v - u;
x[8 + j] = v + u;
}
{
let u = x[12 + j + 2] * zeta;
let v = x[12 + j];
x[12 + j + 2] = v - u;
x[12 + j] = v + u;
}
}
}
}
{
{
for j in 0..4 {
let zeta = POWERS_OF_OMEGA_INVERSE[2 * j];
{
let u = x[j + 4] * zeta;
let v = x[j];
x[j + 4] = v - u;
x[j] = v + u;
}
{
let u = x[8 + j + 4] * zeta;
let v = x[8 + j];
x[8 + j + 4] = v - u;
x[8 + j] = v + u;
}
}
}
}
{
for j in 0..8 {
let zeta = POWERS_OF_OMEGA_INVERSE[j];
let u = x[j + 8] * zeta;
let v = x[j];
x[j + 8] = v - u;
x[j] = v + u;
}
}
}
#[inline]
pub fn mds_noswap(state: &mut [BFieldElement; STATE_SIZE]) {
let mds: [BFieldElement; STATE_SIZE] = [
BFieldElement::new(1363685766),
BFieldElement::new(818401426),
BFieldElement::new(2843477530982740278),
BFieldElement::new(15603266536318963895),
BFieldElement::new(4617387998068915967),
BFieldElement::new(13834281883405632256),
BFieldElement::new(18438678032804473072),
BFieldElement::new(3140224485136655),
BFieldElement::new(3747273207304324287),
BFieldElement::new(14700029414217449666),
BFieldElement::new(9286765195715607938),
BFieldElement::new(9160541823450023167),
BFieldElement::new(18392355339471673798),
BFieldElement::new(89869970136635963),
BFieldElement::new(16012825548870059521),
BFieldElement::new(2397315778488370688),
];
Self::ntt_noswap(state);
for (i, m) in mds.iter().enumerate() {
state[i] *= *m;
}
Self::intt_noswap(state);
const NINV: BFieldElement = BFieldElement::new(17293822565076172801);
state.iter_mut().for_each(|s| *s *= NINV);
}
fn xlix_round(sponge: &mut RescuePrimeOptimizedState, round_index: usize) {
debug_assert!(
round_index < NUM_ROUNDS,
"Cannot apply {round_index}th round; only have {NUM_ROUNDS} in total."
);
Self::mds_noswap(&mut sponge.state);
for i in 0..STATE_SIZE {
sponge.state[i] +=
BFieldElement::from(ROUND_CONSTANTS[round_index * STATE_SIZE * 2 + i]);
}
sponge.state = Self::batch_mod_pow_alpha(sponge.state);
Self::mds_noswap(&mut sponge.state);
for i in 0..STATE_SIZE {
sponge.state[i] +=
BFieldElement::from(ROUND_CONSTANTS[round_index * STATE_SIZE * 2 + STATE_SIZE + i]);
}
sponge.state = Self::batch_mod_pow_alpha_inv(sponge.state);
}
fn xlix(sponge: &mut RescuePrimeOptimizedState) {
for i in 0..NUM_ROUNDS {
Self::xlix_round(sponge, i);
}
}
pub fn hash_10(input: &[BFieldElement; 10]) -> [BFieldElement; 5] {
let mut sponge = RescuePrimeOptimizedState::new();
sponge.state[..10].copy_from_slice(input);
sponge.state[10] = BFieldElement::one();
Self::xlix(&mut sponge);
sponge.state[..5].try_into().unwrap()
}
pub fn hash_varlen(input: &[BFieldElement]) -> [BFieldElement; 5] {
let mut sponge = RescuePrimeOptimizedState::new();
let mut padded_input = if input.len() % RATE != 0 {
let mut padded_input = input.to_vec();
padded_input.push(BFieldElement::one());
while padded_input.len() % RATE != 0 {
padded_input.push(BFieldElement::zero());
}
sponge.state[0] = BFieldElement::one();
padded_input
} else {
input.to_vec()
};
while !padded_input.is_empty() {
for (sponge_state_element, input_element) in sponge.state[CAPACITY..]
.iter_mut()
.take(RATE)
.zip_eq(padded_input.iter().take(RATE))
{
*sponge_state_element = input_element.to_owned();
}
padded_input.drain(..RATE);
Self::xlix(&mut sponge);
}
sponge.state[CAPACITY..(CAPACITY + DIGEST_LENGTH)]
.try_into()
.unwrap()
}
pub fn trace(input: &[BFieldElement; 10]) -> [[BFieldElement; STATE_SIZE]; 1 + NUM_ROUNDS] {
let mut trace = [[BFieldElement::zero(); STATE_SIZE]; 1 + NUM_ROUNDS];
let mut sponge = RescuePrimeOptimizedState::new();
sponge.state[0..RATE].copy_from_slice(input);
sponge.state[RATE] = BFieldElement::new(1);
trace[0] = sponge.state;
for round_index in 0..NUM_ROUNDS {
Self::xlix_round(&mut sponge, round_index);
trace[1 + round_index] = sponge.state;
}
trace
}
}