use ff::PrimeField;
pub(crate) fn generate_constants<F: PrimeField>(
field: u8,
sbox: u8,
field_size: u16,
t: u16,
r_f: u16,
r_p: u16,
) -> Vec<F> {
let n_bytes = F::Repr::default().as_ref().len();
if n_bytes != 32 {
unimplemented!("neptune currently supports 32-byte fields exclusively");
}
assert_eq!((f32::from(field_size) / 8.0).ceil() as usize, n_bytes);
let num_constants = (r_f + r_p) * t;
let mut init_sequence: Vec<bool> = Vec::new();
append_bits(&mut init_sequence, 2, field); append_bits(&mut init_sequence, 4, sbox); append_bits(&mut init_sequence, 12, field_size); append_bits(&mut init_sequence, 12, t); append_bits(&mut init_sequence, 10, r_f); append_bits(&mut init_sequence, 10, r_p); append_bits(&mut init_sequence, 30, 0b111111111111111111111111111111u128);
let mut grain = Grain::new(init_sequence, field_size);
let mut round_constants: Vec<F> = Vec::new();
match field {
1 => {
for _ in 0..num_constants {
loop {
let mut repr = F::Repr::default();
grain.get_next_bytes(repr.as_mut());
repr.as_mut().reverse();
if let Some(f) = F::from_repr_vartime(repr) {
round_constants.push(f);
break;
}
}
}
}
_ => {
panic!("Only prime fields are supported.");
}
}
round_constants
}
fn append_bits<T: Into<u128>>(vec: &mut Vec<bool>, n: usize, from: T) {
let val = from.into();
for i in (0..n).rev() {
vec.push((val >> i) & 1 != 0);
}
}
struct Grain {
state: Vec<bool>,
field_size: u16,
}
impl Grain {
fn new(init_sequence: Vec<bool>, field_size: u16) -> Self {
assert_eq!(80, init_sequence.len());
let mut g = Grain {
state: init_sequence,
field_size,
};
for _ in 0..160 {
g.generate_new_bit();
}
assert_eq!(80, g.state.len());
g
}
fn generate_new_bit(&mut self) -> bool {
let new_bit =
self.bit(62) ^ self.bit(51) ^ self.bit(38) ^ self.bit(23) ^ self.bit(13) ^ self.bit(0);
self.state.remove(0);
self.state.push(new_bit);
new_bit
}
fn bit(&self, index: usize) -> bool {
self.state[index]
}
fn next_byte(&mut self, bit_count: usize) -> u8 {
let mut acc: u8 = 0;
self.take(bit_count).for_each(|bit| {
acc <<= 1;
if bit {
acc += 1;
}
});
acc
}
fn get_next_bytes(&mut self, result: &mut [u8]) {
let remainder_bits = self.field_size as usize % 8;
if remainder_bits > 0 {
result[0] = self.next_byte(remainder_bits);
} else {
result[0] = self.next_byte(8);
}
for item in result.iter_mut().skip(1) {
*item = self.next_byte(8)
}
}
}
impl Iterator for Grain {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
let mut new_bit = self.generate_new_bit();
while !new_bit {
let _new_bit = self.generate_new_bit();
new_bit = self.generate_new_bit();
}
new_bit = self.generate_new_bit();
Some(new_bit)
}
}