use alloc::vec::Vec;
use p3_field::{Field, PrimeCharacteristicRing};
use p3_mds::MdsPermutation;
use p3_symmetric::Permutation;
use rand::distr::{Distribution, StandardUniform};
use rand::{Rng, RngExt};
#[inline(always)]
fn apply_hl_mat4<R>(x: &mut [R; 4])
where
R: PrimeCharacteristicRing,
{
let t0 = x[0].clone() + x[1].clone();
let t1 = x[2].clone() + x[3].clone();
let t2 = x[1].double() + t1.clone();
let t3 = x[3].double() + t0.clone();
let t4 = t1.double().double() + t3.clone();
let t5 = t0.double().double() + t2.clone();
let t6 = t3 + t5.clone();
let t7 = t2 + t4.clone();
x[0] = t6;
x[1] = t5;
x[2] = t7;
x[3] = t4;
}
#[inline(always)]
fn apply_mat4<R>(x: &mut [R; 4])
where
R: PrimeCharacteristicRing,
{
let t01 = x[0].clone() + x[1].clone();
let t23 = x[2].clone() + x[3].clone();
let t0123 = t01.clone() + t23.clone();
let t01123 = t0123.clone() + x[1].clone();
let t01233 = t0123 + x[3].clone();
x[3] = t01233.clone() + x[0].double(); x[1] = t01123.clone() + x[2].double(); x[0] = t01123 + t01; x[2] = t01233 + t23; }
#[derive(Clone, Default)]
pub struct HLMDSMat4;
impl<R: PrimeCharacteristicRing> Permutation<[R; 4]> for HLMDSMat4 {
#[inline(always)]
fn permute_mut(&self, input: &mut [R; 4]) {
apply_hl_mat4(input);
}
}
impl<R: PrimeCharacteristicRing> MdsPermutation<R, 4> for HLMDSMat4 {}
#[derive(Clone, Default)]
pub struct MDSMat4;
impl<R: PrimeCharacteristicRing> Permutation<[R; 4]> for MDSMat4 {
#[inline(always)]
fn permute_mut(&self, input: &mut [R; 4]) {
apply_mat4(input);
}
}
impl<R: PrimeCharacteristicRing> MdsPermutation<R, 4> for MDSMat4 {}
#[inline(always)]
pub fn mds_light_permutation<
R: PrimeCharacteristicRing,
MdsPerm4: MdsPermutation<R, 4>,
const WIDTH: usize,
>(
state: &mut [R; WIDTH],
mdsmat: &MdsPerm4,
) {
match WIDTH {
2 => {
let sum = state[0].clone() + state[1].clone();
state[0] += sum.clone();
state[1] += sum;
}
3 => {
let sum = state[0].clone() + state[1].clone() + state[2].clone();
state[0] += sum.clone();
state[1] += sum.clone();
state[2] += sum;
}
4 | 8 | 12 | 16 | 20 | 24 | 32 => {
for chunk in state.chunks_exact_mut(4) {
mdsmat.permute_mut(chunk.try_into().unwrap());
}
let sums: [R; 4] =
core::array::from_fn(|k| (0..WIDTH).step_by(4).map(|j| state[j + k].clone()).sum());
state
.iter_mut()
.enumerate()
.for_each(|(i, elem)| *elem += sums[i % 4].clone());
}
_ => {
panic!("Unsupported width");
}
}
}
#[derive(Debug, Clone)]
pub struct ExternalLayerConstants<T, const WIDTH: usize> {
initial: Vec<[T; WIDTH]>,
terminal: Vec<[T; WIDTH]>,
}
impl<T, const WIDTH: usize> ExternalLayerConstants<T, WIDTH> {
pub const fn new(initial: Vec<[T; WIDTH]>, terminal: Vec<[T; WIDTH]>) -> Self {
assert!(
initial.len() == terminal.len(),
"The number of initial and terminal external rounds should be equal."
);
Self { initial, terminal }
}
pub fn new_from_rng<R: Rng>(external_round_number: usize, rng: &mut R) -> Self
where
StandardUniform: Distribution<[T; WIDTH]>,
{
let half_f = external_round_number / 2;
assert_eq!(
2 * half_f,
external_round_number,
"The total number of external rounds should be even"
);
let initial_constants = rng.sample_iter(StandardUniform).take(half_f).collect();
let terminal_constants = rng.sample_iter(StandardUniform).take(half_f).collect();
Self::new(initial_constants, terminal_constants)
}
pub fn new_from_saved_array<U, const N: usize>(
[initial, terminal]: [[[U; WIDTH]; N]; 2],
conversion_fn: fn([U; WIDTH]) -> [T; WIDTH],
) -> Self
where
T: Clone,
{
let initial_consts = initial.map(conversion_fn).to_vec();
let terminal_consts = terminal.map(conversion_fn).to_vec();
Self::new(initial_consts, terminal_consts)
}
pub const fn get_initial_constants(&self) -> &Vec<[T; WIDTH]> {
&self.initial
}
pub const fn get_terminal_constants(&self) -> &Vec<[T; WIDTH]> {
&self.terminal
}
}
pub trait ExternalLayerConstructor<F, const WIDTH: usize>
where
F: Field,
{
fn new_from_constants(external_constants: ExternalLayerConstants<F, WIDTH>) -> Self;
}
pub trait ExternalLayer<R, const WIDTH: usize, const D: u64>: Sync + Clone
where
R: PrimeCharacteristicRing,
{
fn permute_state_initial(&self, state: &mut [R; WIDTH]);
fn permute_state_terminal(&self, state: &mut [R; WIDTH]);
}
#[inline]
pub fn external_terminal_permute_state<
R: PrimeCharacteristicRing,
CT: Copy, MdsPerm4: MdsPermutation<R, 4>,
const WIDTH: usize,
>(
state: &mut [R; WIDTH],
terminal_external_constants: &[[CT; WIDTH]],
add_rc_and_sbox: fn(&mut R, CT),
mat4: &MdsPerm4,
) {
for elem in terminal_external_constants {
state
.iter_mut()
.zip(elem.iter())
.for_each(|(s, &rc)| add_rc_and_sbox(s, rc));
mds_light_permutation(state, mat4);
}
}
#[inline]
pub fn external_initial_permute_state<
R: PrimeCharacteristicRing,
CT: Copy, MdsPerm4: MdsPermutation<R, 4>,
const WIDTH: usize,
>(
state: &mut [R; WIDTH],
initial_external_constants: &[[CT; WIDTH]],
add_rc_and_sbox: fn(&mut R, CT),
mat4: &MdsPerm4,
) {
mds_light_permutation(state, mat4);
external_terminal_permute_state(state, initial_external_constants, add_rc_and_sbox, mat4);
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use super::*;
type F = BabyBear;
#[test]
fn test_apply_mat4() {
let mut rng = SmallRng::seed_from_u64(12345678);
let x0: F = rng.random();
let x1: F = rng.random();
let x2: F = rng.random();
let x3: F = rng.random();
let mut x = [x0, x1, x2, x3];
apply_mat4(&mut x);
let expected = [
F::TWO * x0 + F::from_u8(3) * x1 + x2 + x3,
x0 + F::TWO * x1 + F::from_u8(3) * x2 + x3,
x0 + x1 + F::TWO * x2 + F::from_u8(3) * x3,
F::from_u8(3) * x0 + x1 + x2 + F::TWO * x3,
];
assert_eq!(x, expected, "apply_mat4 did not produce expected output");
}
#[test]
fn test_apply_hl_mat4_with_manual_verification() {
let mut rng = SmallRng::seed_from_u64(87654321);
let x0: F = rng.random();
let x1: F = rng.random();
let x2: F = rng.random();
let x3: F = rng.random();
let mut x = [x0, x1, x2, x3];
apply_hl_mat4(&mut x);
let expected = [
F::from_u8(5) * x0 + F::from_u8(7) * x1 + x2 + F::from_u8(3) * x3,
F::from_u8(4) * x0 + F::from_u8(6) * x1 + x2 + x3,
x0 + F::from_u8(3) * x1 + F::from_u8(5) * x2 + F::from_u8(7) * x3,
x0 + x1 + F::from_u8(4) * x2 + F::from_u8(6) * x3,
];
assert_eq!(x, expected, "apply_hl_mat4 did not produce expected output");
}
}