use super::*;
pub struct SaturatingAddRF;
impl RoundFunction<u32> for SaturatingAddRF {
fn f(&self, right: &u32, round_key: &u32) -> u32 {
right.saturating_add(*round_key)
}
}
pub struct RotatingKeySchedule;
impl KeySchedule<u32, u32> for RotatingKeySchedule {
type State = u32;
fn new_state(&self, master_key: &u32) -> Self::State {
*master_key
}
fn next_key(&self, _master_key: &u32, state: &mut Self::State) -> u32 {
let current_key = *state;
*state = state.rotate_left(1);
current_key
}
}
struct IncKeySchedule;
impl KeySchedule<u32, u32> for IncKeySchedule {
type State = u32;
fn new_state(&self, _master_key: &u32) -> Self::State {
0
}
fn next_key(&self, _master_key: &u32, state: &mut Self::State) -> u32 {
*state = state.saturating_add(1);
*state
}
}
#[test]
fn key_schedule_inc() {
let ks = IncKeySchedule;
let mut ks_state = ks.new_state(&0);
let num_rounds = 3;
let mut keys = Vec::with_capacity(num_rounds);
for _ in 0..num_rounds {
keys.push(ks.next_key(&0, &mut ks_state));
}
assert_eq!(keys, vec![1, 2, 3], "Key seq is wrong");
}
#[test]
fn key_schedule_rev() {
let num_rounds = 3;
let mut keys = Vec::with_capacity(num_rounds);
let reversed = ReverseKeySchedule::from(IncKeySchedule, num_rounds);
let mut ks_state = reversed.new_state(&0);
for _ in 0..num_rounds {
keys.push(reversed.next_key(&0, &mut ks_state));
}
assert_eq!(keys, vec![3, 2, 1], "Reversed Key seq is wrong");
}
#[test]
fn test_feistel_full_cycle() {
let rounds = 16;
let master_key = 0xDEADBEEF_u32;
let original_l = 0x12345678_u32;
let original_r = 0x87654321_u32;
let encryptor = FeistelNetwork::new(RotatingKeySchedule, SaturatingAddRF, rounds);
let (enc_l, enc_r) = encryptor.process(original_l, original_r, &master_key);
assert_ne!((original_l, original_r), (enc_l, enc_r));
let decoder_ks = ReverseKeySchedule::from(RotatingKeySchedule, rounds);
let decryptor = FeistelNetwork::new(decoder_ks, SaturatingAddRF, rounds);
let (dec_l, dec_r) = decryptor.process(enc_l, enc_r, &master_key);
println!("Original: ({:08x}, {:08x})", original_l, original_r);
println!("Encrypted: ({:08x}, {:08x})", enc_l, enc_r);
println!("Decrypted: ({:08x}, {:08x})", dec_l, dec_r);
assert_eq!(dec_l, original_l, "Left block is wrong");
assert_eq!(dec_r, original_r, "Right block is wrong");
}