use core::hash::Hasher;
use rustc_hash::FxHasher;
#[derive(Debug, Clone)]
pub struct Permutor {
feistel: FeistelNetwork,
pub max: u128,
pub values_returned: u128,
}
impl Permutor {
pub fn new_with_u64_key(max: u128, key: u64) -> Permutor {
let key = key ^ 0xDEADBEEF_FEE1DEAD;
let key = u64_to_32slice(key);
Permutor {
feistel: FeistelNetwork::new_with_slice_key(max, key),
max,
values_returned: 0,
}
}
pub fn new_with_slice_key(max: u128, key: [u8; 32]) -> Permutor {
Permutor {
feistel: FeistelNetwork::new_with_slice_key(max, key),
max,
values_returned: 0,
}
}
pub fn forward(&self, plaintext: u128) -> u128 {
let mut result = self.feistel.permute(plaintext);
while result >= self.max {
result = self.feistel.permute(result);
}
result
}
pub fn backward(&self, ciphertext: u128) -> u128 {
let mut result = self.feistel.invert(ciphertext);
while result >= self.max {
result = self.feistel.invert(result);
}
result
}
}
impl Iterator for Permutor {
type Item = u128;
fn next(&mut self) -> Option<Self::Item> {
if self.values_returned < self.max {
let next = self.forward(self.values_returned);
self.values_returned += 1;
return Some(next);
}
None
}
}
impl ExactSizeIterator for Permutor {
fn len(&self) -> usize {
(self.max - self.values_returned) as usize
}
}
#[derive(Debug, Clone)]
pub struct FeistelNetwork {
pub half_width: u128,
pub right_mask: u128,
pub left_mask: u128,
key: [u8; 32],
rounds: u8,
}
impl FeistelNetwork {
pub fn new_with_slice_key(max_value: u128, key: [u8; 32]) -> FeistelNetwork {
let mut width = integer_log2(max_value).unwrap_or(0);
if width % 2 != 0 {
width += 1;
}
let half_width = width / 2;
let mut right_mask = 0;
for i in 0..half_width {
right_mask |= 1 << i;
}
let left_mask = right_mask << half_width;
let num_rounds = 8 + (60 / integer_log2(max_value).unwrap_or(0).max(4));
let num_rounds = num_rounds.min(32);
FeistelNetwork {
half_width: half_width as u128,
right_mask,
left_mask,
key,
rounds: num_rounds as u8,
}
}
pub fn permute(&self, input: u128) -> u128 {
let mut left = (input & self.left_mask) >> self.half_width;
let mut right = input & self.right_mask;
for i in 0..self.rounds {
let new_left = right;
let f = self.round_function(right, i, self.key, self.right_mask);
right = left ^ f;
left = new_left;
}
let result = (left << self.half_width) | right;
result & (self.left_mask | self.right_mask)
}
pub fn invert(&self, input: u128) -> u128 {
let mut left = (input & self.left_mask) >> self.half_width;
let mut right = input & self.right_mask;
for i in (0..self.rounds).rev() {
let new_right = left;
let f = self.round_function(left, i, self.key, self.right_mask);
left = right ^ f;
right = new_right;
}
let result = (left << self.half_width) | right;
result & (self.left_mask | self.right_mask)
}
fn round_function(&self, right: u128, round: u8, key: [u8; 32], mask: u128) -> u128 {
let right_bytes = u128::to_le_bytes(right);
let round_bytes = [round];
let mut hasher = FxHasher::default();
hasher.write(&key[..]);
hasher.write(&right_bytes[..]);
hasher.write(&round_bytes[..]);
hasher.write(&key[..]);
(hasher.finish() as u128) & mask
}
}
fn u64_to_32slice(input: u64) -> [u8; 32] {
let result8 = u64::to_be_bytes(input);
let mut result: [u8; 32] = [0; 32];
result[..8].clone_from_slice(&result8[..8]);
result
}
pub fn integer_log2(input: u128) -> Option<u32> {
if input == 0 {
return None;
}
Some(128 - input.leading_zeros())
}
#[cfg(test)]
mod tests {
use ahash::AHashSet;
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;
use super::*;
#[quickcheck]
fn test_invert_roundtrip(u: u128, v: u128, key: u64) -> TestResult {
let (ub, input) = if u < v { (u, v) } else { (v, u) };
if ub <= input || ub == 0 {
return TestResult::discard();
}
let feistel = FeistelNetwork::new_with_slice_key(ub, u64_to_32slice(key));
let output = feistel.permute(input);
let inverted = feistel.invert(output);
assert_eq!(
inverted, input,
"Inversion does not produce the original input"
);
TestResult::passed()
}
#[quickcheck]
fn test_permuter_bounded(u: u128, v: u128, key: u64) -> TestResult {
let (ub, input) = if u < v { (u, v) } else { (v, u) };
if ub <= input || ub == 0 {
return TestResult::discard();
}
let permutor = Permutor::new_with_u64_key(ub, key);
for value in permutor {
assert!(value < ub, "Value returned is not within the bounds");
}
TestResult::passed()
}
#[quickcheck]
fn test_permuter_roundtrip(u: u128, v: u128, key: u64) -> TestResult {
let (ub, input) = if u < v { (u, v) } else { (v, u) };
if ub <= input || ub == 0 {
return TestResult::discard();
}
let permutor = Permutor::new_with_u64_key(ub, key);
for value in 0..ub {
let forward = permutor.forward(value);
let backward = permutor.backward(forward);
assert_eq!(backward, value, "Forward and backward do not match");
}
TestResult::passed()
}
#[quickcheck]
fn test_permuter_covers_all(ub: u16, key: u64) -> TestResult {
if ub == 0 {
return TestResult::discard();
}
let permutor = Permutor::new_with_u64_key(ub as u128, key);
let seen = AHashSet::from_iter(permutor);
assert_eq!(seen.len(), ub as usize, "Not all values were returned");
for value in 0..(ub as u128) {
assert!(seen.contains(&value), "Value was not returned");
}
TestResult::passed()
}
#[test]
fn test_integer_log() {
assert_eq!(None, integer_log2(0), "failed for {}", 0);
assert_eq!(Some(1), integer_log2(1), "failed for {}", 1);
assert_eq!(Some(2), integer_log2(2), "failed for {}", 2);
assert_eq!(Some(2), integer_log2(3), "failed for {}", 3);
assert_eq!(Some(3), integer_log2(4), "failed for {}", 4);
assert_eq!(Some(3), integer_log2(5), "failed for {}", 5);
assert_eq!(Some(3), integer_log2(6), "failed for {}", 6);
assert_eq!(Some(3), integer_log2(7), "failed for {}", 7);
assert_eq!(Some(4), integer_log2(8), "failed for {}", 8);
assert_eq!(Some(4), integer_log2(9), "failed for {}", 9);
assert_eq!(Some(4), integer_log2(10), "failed for {}", 10);
}
}