#![feature(trusted_len, generator_trait, never_type,)]
use hkdf::Hkdf;
use digest::{Input, BlockInput, FixedOutput, Reset,};
use rand_core::{RngCore, SeedableRng, CryptoRng, Error,};
use clear_on_drop::ClearOnDrop;
use std::{ops, pin::Pin, iter::TrustedLen, marker::{PhantomData, Unpin,},};
pub use digest;
pub use digest::generic_array;
pub use generic_array::typenum;
use typenum::{Unsigned, Add1, Diff, NonZero, bit::B1, consts,};
use generic_array::{GenericArray, ArrayLength,};
#[cfg(feature = "serde")]
mod serde;
pub struct Ratchet<Digest, State, Rounds = consts::U1,>
where State: ArrayLength<u8>, {
state: ClearOnDrop<Box<GenericArray<u8, State>>>,
_data: PhantomData<(Digest, Rounds,)>,
}
impl<D, S, R,> Ratchet<D, S, R,>
where S: ArrayLength<u8>, {
#[inline]
pub fn new<Rand,>(rand: &mut Rand,) -> Self
where Rand: RngCore + CryptoRng, {
let mut res = Self::default();
res.reseed(rand,);
res
}
#[inline]
pub fn reseed<Rand,>(&mut self, rand: &mut Rand,)
where Rand: RngCore + CryptoRng, {
rand.fill_bytes(&mut self.state,)
}
}
impl<D, S, R,> Ratchet<D, S, R,>
where D: Input + BlockInput + FixedOutput + Reset + Default + Clone,
S: ArrayLength<u8> + ops::Sub<D::BlockSize> + ops::Add<B1> + ops::Sub<B1>,
R: Unsigned + NonZero,
D::BlockSize: Clone,
<S as ops::Sub<D::BlockSize>>::Output: Unsigned,
<S as ops::Add<B1>>::Output: ArrayLength<u8>,
<S as ops::Sub<B1>>::Output: Unsigned, {
pub fn next(&mut self,) -> u8 {
let mut okm = GenericArray::<u8, Add1<S>>::default();
let mut okm = ClearOnDrop::new(okm.as_mut(),);
for _ in 0..R::USIZE {
let (salt, ikm,) = self.state.split_at(Diff::<S, D::BlockSize>::USIZE,);
Hkdf::<D>::extract(None, ikm,).expand(salt, &mut okm,)
.expect("Failed to expand data");
self.state.copy_from_slice(&okm[..S::USIZE],);
}
okm[Diff::<S, B1>::USIZE]
}
}
impl<D, S, R, Rand,> From<&mut Rand> for Ratchet<D, S, R,>
where S: ArrayLength<u8>,
Rand: RngCore + CryptoRng, {
#[inline]
fn from(rand: &mut Rand,) -> Self { Ratchet::new(rand,) }
}
impl<'a, D, S, R,> From<&'a mut [u8]> for Ratchet<D, S, R,>
where S: ArrayLength<u8>, {
fn from(state: &'a mut [u8],) -> Self {
let state = ClearOnDrop::new(state,);
let mut res = Self::default();
let iter = res.state.iter_mut()
.zip(state.iter().copied(),);
for (a, b,) in iter { *a = b }
res
}
}
impl<D, S, R,> Default for Ratchet<D, S, R,>
where S: ArrayLength<u8>, {
#[inline]
fn default() -> Self {
let state = ClearOnDrop::new(Box::default(),);
Self { state, _data: PhantomData, }
}
}
impl<D, S, R,> Clone for Ratchet<D, S, R,>
where S: ArrayLength<u8>, {
#[inline]
fn clone(&self,) -> Self {
let mut res = Self::default();
res.state.copy_from_slice(&self.state,);
res
}
}
impl<D, S, R,> Iterator for Ratchet<D, S, R,>
where D: Input + BlockInput + FixedOutput + Reset + Default + Clone,
S: ArrayLength<u8> + ops::Sub<D::BlockSize> + ops::Add<B1> + ops::Sub<B1>,
R: Unsigned + NonZero,
D::BlockSize: Clone,
<S as ops::Sub<D::BlockSize>>::Output: Unsigned,
<S as ops::Add<B1>>::Output: ArrayLength<u8>,
<S as ops::Sub<B1>>::Output: Unsigned, {
type Item = u8;
#[inline]
fn size_hint(&self,) -> (usize, Option<usize>,) { (std::usize::MAX, None,) }
#[inline]
fn next(&mut self,) -> Option<Self::Item> { Some(self.next()) }
}
unsafe impl<D, S, R,> TrustedLen for Ratchet<D, S, R,>
where S: ArrayLength<u8>,
Self: Iterator<Item = u8>, {}
impl<D, S, R,> RngCore for Ratchet<D, S, R,>
where S: ArrayLength<u8>,
Self: TrustedLen<Item = u8>, {
#[inline]
fn next_u32(&mut self,) -> u32 { self.next_u64() as u32 }
#[inline]
fn next_u64(&mut self,) -> u64 {
let mut bytes = [0; 8];
self.fill_bytes(bytes.as_mut(),);
u64::from_ne_bytes(bytes,)
}
#[inline]
fn fill_bytes(&mut self, dest: &mut [u8],) {
for (a, b,) in dest.iter_mut().zip(self,) { *a = b }
}
#[inline]
fn try_fill_bytes(&mut self, dest: &mut [u8],) -> Result<(), Error> {
Ok(self.fill_bytes(dest,))
}
}
impl<D, S, R,> SeedableRng for Ratchet<D, S, R,>
where S: ArrayLength<u8>,
Self: TrustedLen<Item = u8>, {
type Seed = GenericArray<u8, S>;
#[inline]
fn from_seed(mut seed: Self::Seed,) -> Self { seed.as_mut().into() }
}
impl<D, S, R,> CryptoRng for Ratchet<D, S, R,>
where S: ArrayLength<u8>,
Self: TrustedLen<Item = u8>, {}
impl<D, S, R,> ops::Generator for Ratchet<D, S, R,>
where S: ArrayLength<u8>,
Self: TrustedLen<Item = u8> + Unpin, {
type Yield = u8;
type Return = !;
#[inline]
fn resume(self: Pin<&mut Self>,) -> ops::GeneratorState<Self::Yield, Self::Return> {
use std::slice;
let mut byte = 0;
self.get_mut().fill_bytes(slice::from_mut(&mut byte,),);
ops::GeneratorState::Yielded(byte,)
}
}
#[cfg(test,)]
impl<D, S, R,> PartialEq for Ratchet<D, S, R,>
where S: ArrayLength<u8>, {
#[inline]
fn eq(&self, rhs: &Self,) -> bool { self.state == rhs.state }
}
#[cfg(test,)]
impl<D, S, R,> Eq for Ratchet<D, S, R,>
where S: ArrayLength<u8>, {}
impl<D, S, R,> Drop for Ratchet<D, S, R,>
where S: ArrayLength<u8>, {
#[inline]
fn drop(&mut self,) {}
}
#[cfg(test,)]
mod tests {
use super::*;
use sha1::Sha1;
#[test]
fn test_ratchet_drop() {
use std::{slice, mem,};
let mut bytes = {
let mut bytes = GenericArray::<u8, consts::U500>::default();
rand::thread_rng().fill_bytes(&mut bytes,);
bytes
};
let ratchet = Ratchet::<Sha1, consts::U200,>::from(bytes.as_mut(),);
assert_eq!(vec![0; bytes.len()], bytes.as_ref(), "Input bytes was not cleared",);
let slice = {
let ptr = &*ratchet.state as *const _ as *const u8;
let size = ratchet.state.len();
unsafe { slice::from_raw_parts(ptr, size,) }
};
mem::drop(ratchet,);
assert_eq!(vec![0; slice.len()], slice, "Inner state was not cleared",);
}
#[test]
fn test_ratchet_output() {
use std::collections::HashSet;
const ROUNDS: usize = 3000;
let bytes = Ratchet::<Sha1, consts::U64,>::new(&mut rand::thread_rng(),)
.take(ROUNDS,)
.collect::<HashSet<_>>();
assert_eq!(256, bytes.len(), "Ratchet is not random",);
let mut ratchet1 = Ratchet::<Sha1, consts::U64,>::new(&mut rand::thread_rng(),);
let mut ratchet2 = Ratchet::<Sha1, consts::U64,>::from((*ratchet1.state.clone()).as_mut(),);
let out1 = (&mut ratchet1).take(ROUNDS,).collect::<Box<_>>();
let out2 = (&mut ratchet2).take(ROUNDS,).collect::<Box<_>>();
assert_eq!(&ratchet1.state, &ratchet2.state, "States are not the same",);
assert_eq!(out1, out2, "Ratchets gave different output.",);
let ratchet2 = Ratchet::<Sha1, consts::U64,>::new(&mut rand::thread_rng(),);
let out1 = (&mut ratchet1).take(ROUNDS,).collect::<Box<_>>();
let out2 = ratchet2.take(ROUNDS,).collect::<Box<_>>();
assert_ne!(out1, out2, "Ratchets gave same output.",);
}
}