use crate::vdaf::{CodecError, Decode, Encode};
#[cfg(feature = "crypto-dependencies")]
use aes::{
cipher::{KeyIvInit, StreamCipher},
Aes128,
};
#[cfg(feature = "crypto-dependencies")]
use cmac::{Cmac, Mac};
#[cfg(feature = "crypto-dependencies")]
use ctr::Ctr64BE;
#[cfg(feature = "crypto-dependencies")]
use std::fmt::Formatter;
use std::{
fmt::Debug,
io::{Cursor, Read},
};
pub(crate) type RandSource = fn(&mut [u8]) -> Result<(), getrandom::Error>;
#[derive(Clone, Debug, Eq)]
pub struct Seed<const L: usize>(pub(crate) [u8; L]);
impl<const L: usize> Seed<L> {
pub fn generate() -> Result<Self, getrandom::Error> {
Self::from_rand_source(getrandom::getrandom)
}
pub(crate) fn from_rand_source(rand_source: RandSource) -> Result<Self, getrandom::Error> {
let mut seed = [0; L];
rand_source(&mut seed)?;
Ok(Self(seed))
}
}
impl<const L: usize> AsRef<[u8; L]> for Seed<L> {
fn as_ref(&self) -> &[u8; L] {
&self.0
}
}
impl<const L: usize> PartialEq for Seed<L> {
fn eq(&self, other: &Self) -> bool {
let mut r = 0;
for (x, y) in (&self.0[..]).iter().zip(&other.0[..]) {
r |= x ^ y;
}
r == 0
}
}
impl<const L: usize> Encode for Seed<L> {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&self.0[..]);
}
}
impl<const L: usize> Decode for Seed<L> {
fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> {
let mut seed = [0; L];
bytes.read_exact(&mut seed)?;
Ok(Seed(seed))
}
}
pub trait SeedStream {
fn fill(&mut self, buf: &mut [u8]);
}
pub trait Prg<const L: usize>: Clone + Debug {
type SeedStream: SeedStream;
fn init(seed_bytes: &[u8; L]) -> Self;
fn update(&mut self, data: &[u8]);
fn into_seed_stream(self) -> Self::SeedStream;
fn into_seed(self) -> Seed<L> {
let mut new_seed = [0; L];
let mut seed_stream = self.into_seed_stream();
seed_stream.fill(&mut new_seed);
Seed(new_seed)
}
fn seed_stream(seed: &Seed<L>, info: &[u8]) -> Self::SeedStream {
let mut prg = Self::init(seed.as_ref());
prg.update(info);
prg.into_seed_stream()
}
}
#[derive(Clone, Debug)]
#[cfg(feature = "crypto-dependencies")]
pub struct PrgAes128(Cmac<Aes128>);
#[cfg(feature = "crypto-dependencies")]
impl Prg<16> for PrgAes128 {
type SeedStream = SeedStreamAes128;
fn init(seed_bytes: &[u8; 16]) -> Self {
Self(Cmac::new_from_slice(seed_bytes).unwrap())
}
fn update(&mut self, data: &[u8]) {
self.0.update(data);
}
fn into_seed_stream(self) -> SeedStreamAes128 {
let key = self.0.finalize().into_bytes();
SeedStreamAes128::new(&key, &[0; 16])
}
}
#[cfg(feature = "crypto-dependencies")]
pub struct SeedStreamAes128(Ctr64BE<Aes128>);
#[cfg(feature = "crypto-dependencies")]
impl SeedStreamAes128 {
pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self {
SeedStreamAes128(Ctr64BE::<Aes128>::new(key.into(), iv.into()))
}
}
#[cfg(feature = "crypto-dependencies")]
impl SeedStream for SeedStreamAes128 {
fn fill(&mut self, buf: &mut [u8]) {
buf.fill(0);
self.0.apply_keystream(buf);
}
}
#[cfg(feature = "crypto-dependencies")]
impl Debug for SeedStreamAes128 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.get_core().fmt(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{field::Field128, prng::Prng};
use serde::{Deserialize, Serialize};
use std::convert::TryInto;
#[derive(Deserialize, Serialize)]
struct PrgTestVector {
#[serde(with = "hex")]
seed: Vec<u8>,
#[serde(with = "hex")]
info: Vec<u8>,
length: usize,
#[serde(with = "hex")]
derived_seed: Vec<u8>,
#[serde(with = "hex")]
expanded_vec_field128: Vec<u8>,
}
fn test_prg<P, const L: usize>()
where
P: Prg<L>,
{
let seed = Seed::generate().unwrap();
let info = b"info string";
let mut prg = P::init(seed.as_ref());
prg.update(info);
let mut want = Seed([0; L]);
prg.clone().into_seed_stream().fill(&mut want.0[..]);
let got = prg.clone().into_seed();
assert_eq!(got, want);
let mut want = [0; 45];
prg.clone().into_seed_stream().fill(&mut want);
let mut got = [0; 45];
P::seed_stream(&seed, info).fill(&mut got);
assert_eq!(got, want);
}
#[test]
fn prg_aes128() {
let t: PrgTestVector =
serde_json::from_str(include_str!("test_vec/03/PrgAes128.json")).unwrap();
let mut prg = PrgAes128::init(&t.seed.try_into().unwrap());
prg.update(&t.info);
assert_eq!(
prg.clone().into_seed(),
Seed(t.derived_seed.try_into().unwrap())
);
let mut bytes = std::io::Cursor::new(t.expanded_vec_field128.as_slice());
let mut want = Vec::with_capacity(t.length);
while (bytes.position() as usize) < t.expanded_vec_field128.len() {
want.push(Field128::decode(&mut bytes).unwrap())
}
let got: Vec<Field128> = Prng::from_seed_stream(prg.clone().into_seed_stream())
.take(t.length)
.collect();
assert_eq!(got, want);
test_prg::<PrgAes128, 16>();
}
}