use core::marker::PhantomData;
use bitvec::prelude::*;
use group::ff::{Field, FromUniformBytes, PrimeField};
const STATE: usize = 80;
#[derive(Debug, Clone, Copy)]
pub(super) enum FieldType {
#[allow(dead_code)]
Binary,
PrimeOrder,
}
impl FieldType {
fn tag(&self) -> u8 {
match self {
FieldType::Binary => 0,
FieldType::PrimeOrder => 1,
}
}
}
#[derive(Debug, Clone, Copy)]
pub(super) enum SboxType {
Pow,
#[allow(dead_code)]
Inv,
}
impl SboxType {
fn tag(&self) -> u8 {
match self {
SboxType::Pow => 0,
SboxType::Inv => 1,
}
}
}
pub(super) struct Grain<F: Field> {
state: BitArr!(for 80, in u8, Msb0),
next_bit: usize,
_field: PhantomData<F>,
}
impl<F: PrimeField> Grain<F> {
pub(super) fn new(sbox: SboxType, t: u16, r_f: u16, r_p: u16) -> Self {
let mut state = bitarr![u8, Msb0; 1; STATE];
let mut set_bits = |offset: usize, len, value| {
for i in 0..len {
*state.get_mut(offset + len - 1 - i).unwrap() = (value >> i) & 1 != 0;
}
};
set_bits(0, 2, FieldType::PrimeOrder.tag() as u16);
set_bits(2, 4, sbox.tag() as u16);
set_bits(6, 12, F::NUM_BITS as u16);
set_bits(18, 12, t);
set_bits(30, 10, r_f);
set_bits(40, 10, r_p);
let mut grain = Grain {
state,
next_bit: STATE,
_field: PhantomData::default(),
};
for _ in 0..20 {
grain.load_next_8_bits();
grain.next_bit = STATE;
}
grain
}
fn load_next_8_bits(&mut self) {
let mut new_bits = 0u8;
for i in 0..8 {
new_bits |= ((self.state[i + 62]
^ self.state[i + 51]
^ self.state[i + 38]
^ self.state[i + 23]
^ self.state[i + 13]
^ self.state[i]) as u8)
<< i;
}
self.state.rotate_left(8);
self.next_bit -= 8;
for i in 0..8 {
*self.state.get_mut(self.next_bit + i).unwrap() = (new_bits >> i) & 1 != 0;
}
}
fn get_next_bit(&mut self) -> bool {
if self.next_bit == STATE {
self.load_next_8_bits();
}
let ret = self.state[self.next_bit];
self.next_bit += 1;
ret
}
pub(super) fn next_field_element(&mut self) -> F {
loop {
let mut bytes = F::Repr::default();
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
let i = F::NUM_BITS as usize - 1 - i;
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
if let Some(f) = F::from_repr_vartime(bytes) {
break f;
}
}
}
}
impl<F: FromUniformBytes<64>> Grain<F> {
pub(super) fn next_field_element_without_rejection(&mut self) -> F {
let mut bytes = [0u8; 64];
let view = bytes.as_mut();
for (i, bit) in self.take(F::NUM_BITS as usize).enumerate() {
let i = F::NUM_BITS as usize - 1 - i;
view[i / 8] |= if bit { 1 << (i % 8) } else { 0 };
}
F::from_uniform_bytes(&bytes)
}
}
impl<F: PrimeField> Iterator for Grain<F> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
while !self.get_next_bit() {
self.get_next_bit();
}
Some(self.get_next_bit())
}
}