use crate::{PoseidonGrainLFSR, PrimeField, serial_batch_inversion_and_mul};
use aleo_std::{end_timer, start_timer};
use itertools::Itertools;
use anyhow::{Result, bail};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PoseidonParameters<F: PrimeField, const RATE: usize, const CAPACITY: usize> {
pub full_rounds: usize,
pub partial_rounds: usize,
pub alpha: u64,
pub ark: Vec<Vec<F>>,
pub mds: Vec<Vec<F>>,
}
pub trait PoseidonDefaultField {
fn default_poseidon_parameters<const RATE: usize>() -> Result<PoseidonParameters<Self, RATE, 1>>
where
Self: PrimeField,
{
#[allow(clippy::type_complexity)]
fn find_poseidon_ark_and_mds<F: PrimeField, const RATE: usize>(
full_rounds: u64,
partial_rounds: u64,
skip_matrices: u64,
) -> Result<(Vec<Vec<F>>, Vec<Vec<F>>)> {
let lfsr_time = start_timer!(|| "LFSR Init");
let mut lfsr =
PoseidonGrainLFSR::new(false, F::size_in_bits() as u64, (RATE + 1) as u64, full_rounds, partial_rounds);
end_timer!(lfsr_time);
let ark_time = start_timer!(|| "Constructing ARK");
let mut ark = Vec::with_capacity((full_rounds + partial_rounds) as usize);
for _ in 0..(full_rounds + partial_rounds) {
ark.push(lfsr.get_field_elements_rejection_sampling(RATE + 1)?);
}
end_timer!(ark_time);
let skip_time = start_timer!(|| "Skipping matrices");
for _ in 0..skip_matrices {
let _ = lfsr.get_field_elements_mod_p::<F>(2 * (RATE + 1))?;
}
end_timer!(skip_time);
let xs = lfsr.get_field_elements_mod_p::<F>(RATE + 1)?;
let ys = lfsr.get_field_elements_mod_p::<F>(RATE + 1)?;
let mds_time = start_timer!(|| "Construct MDS");
let mut mds_flattened = vec![F::zero(); (RATE + 1) * (RATE + 1)];
for (x, mds_row_i) in xs.iter().take(RATE + 1).zip_eq(mds_flattened.chunks_mut(RATE + 1)) {
for (y, e) in ys.iter().take(RATE + 1).zip_eq(mds_row_i) {
*e = *x + y;
}
}
serial_batch_inversion_and_mul(&mut mds_flattened, &F::one());
let mds = mds_flattened.chunks(RATE + 1).map(|row| row.to_vec()).collect();
end_timer!(mds_time);
Ok((ark, mds))
}
match Self::Parameters::PARAMS_OPT_FOR_CONSTRAINTS.iter().find(|entry| entry.rate == RATE) {
Some(entry) => {
let (ark, mds) = find_poseidon_ark_and_mds::<Self, RATE>(
entry.full_rounds as u64,
entry.partial_rounds as u64,
entry.skip_matrices as u64,
)?;
Ok(PoseidonParameters {
full_rounds: entry.full_rounds,
partial_rounds: entry.partial_rounds,
alpha: entry.alpha as u64,
ark,
mds,
})
}
None => bail!("No Poseidon parameters were found for this rate"),
}
}
}
pub trait PoseidonDefaultParameters {
const PARAMS_OPT_FOR_CONSTRAINTS: [PoseidonDefaultParametersEntry; 7];
}
pub struct PoseidonDefaultParametersEntry {
pub rate: usize,
pub alpha: usize,
pub full_rounds: usize,
pub partial_rounds: usize,
pub skip_matrices: usize,
}
impl PoseidonDefaultParametersEntry {
pub const fn new(
rate: usize,
alpha: usize,
full_rounds: usize,
partial_rounds: usize,
skip_matrices: usize,
) -> Self {
Self { rate, alpha, full_rounds, partial_rounds, skip_matrices }
}
}