use core::{
cmp::max,
marker::PhantomData,
ops::{Add, Div, Mul, Sub},
};
use cipher::{
array::{Array, ArraySize},
inout::InOut,
typenum::{Diff, IsLess, Le, NonZero, Sum, U1, U2, U4, U256, Unsigned},
};
use super::{
Block, BlockSize, ExpandedKeyTable, ExpandedKeyTableSize, Key, KeyAsWords, KeyAsWordsSize, Word,
};
pub struct RC6<W, R, B>
where
W: Word,
R: Unsigned,
R: IsLess<U256>,
R: Add<U2>,
Sum<R, U2>: Mul<U2>,
ExpandedKeyTableSize<R>: ArraySize,
{
key_table: ExpandedKeyTable<W, R>,
_key_size: PhantomData<B>,
}
impl<W, R, B> RC6<W, R, B>
where
W: Word,
W::Bytes: Mul<U4>,
BlockSize<W>: ArraySize,
R: Unsigned,
R: IsLess<U256>,
Le<R, U256>: NonZero,
R: Add<U2>,
Sum<R, U2>: Mul<U2>,
ExpandedKeyTableSize<R>: ArraySize,
B: ArraySize,
B: IsLess<U256>,
Le<B, U256>: NonZero,
B: Add<W::Bytes>,
Sum<B, W::Bytes>: Sub<U1>,
Diff<Sum<B, W::Bytes>, U1>: Div<W::Bytes>,
KeyAsWordsSize<W, B>: ArraySize,
{
pub fn new(key: &Key<B>) -> RC6<W, R, B> {
Self {
key_table: Self::substitute_key(key),
_key_size: PhantomData,
}
}
fn substitute_key(key: &Key<B>) -> ExpandedKeyTable<W, R> {
let key_as_words = Self::key_into_words(key);
let expanded_key_table = Self::initialize_expanded_key_table();
Self::mix_in(expanded_key_table, key_as_words)
}
fn key_into_words(key: &Key<B>) -> KeyAsWords<W, B> {
let mut key_as_words: Array<W, KeyAsWordsSize<W, B>> = Array::default();
for i in (0..B::USIZE).rev() {
key_as_words[i / W::Bytes::USIZE] =
key_as_words[i / W::Bytes::USIZE].rotate_left(W::EIGHT) + key[i].into();
}
key_as_words
}
fn initialize_expanded_key_table() -> ExpandedKeyTable<W, R> {
let mut expanded_key_table: Array<W, ExpandedKeyTableSize<R>> = Array::from_fn(|_| W::ZERO);
expanded_key_table[0] = W::P;
for i in 1..expanded_key_table.len() {
expanded_key_table[i] = expanded_key_table[i - 1].wrapping_add(W::Q);
}
expanded_key_table
}
fn mix_in(
mut key_table: ExpandedKeyTable<W, R>,
mut key_as_words: KeyAsWords<W, B>,
) -> ExpandedKeyTable<W, R> {
let (mut expanded_key_index, mut key_as_words_index) = (0, 0);
let (mut a, mut b) = (W::ZERO, W::ZERO);
for _ in 0..3 * max(key_as_words.len(), key_table.len()) {
key_table[expanded_key_index] = key_table[expanded_key_index]
.wrapping_add(a)
.wrapping_add(b)
.rotate_left(W::THREE);
a = key_table[expanded_key_index];
key_as_words[key_as_words_index] = key_as_words[key_as_words_index]
.wrapping_add(a)
.wrapping_add(b)
.rotate_left(a.wrapping_add(b));
b = key_as_words[key_as_words_index];
expanded_key_index = (expanded_key_index + 1) % key_table.len();
key_as_words_index = (key_as_words_index + 1) % key_as_words.len();
}
key_table
}
}
impl<W, R, B> RC6<W, R, B>
where
W: Word,
W::Bytes: Mul<U4>,
BlockSize<W>: ArraySize,
R: Unsigned,
R: IsLess<U256>,
Le<R, U256>: NonZero,
R: Add<U2>,
Sum<R, U2>: Mul<U2>,
ExpandedKeyTableSize<R>: ArraySize,
{
pub fn encrypt(&self, mut block: InOut<'_, '_, Block<W>>) {
let (mut a, mut b, mut c, mut d) = Self::words_from_block(block.get_in());
let key = &self.key_table;
let log_w = W::from((usize::BITS - 1 - (W::Bytes::USIZE * 8).leading_zeros()) as u8);
b = b.wrapping_add(key[0]);
d = d.wrapping_add(key[1]);
for i in 1..=R::USIZE {
let t = b
.wrapping_mul(b.wrapping_mul(W::from(2)).wrapping_add(W::from(1)))
.rotate_left(log_w);
let u = d
.wrapping_mul(d.wrapping_mul(W::from(2)).wrapping_add(W::from(1)))
.rotate_left(log_w);
a = a.bitxor(t).rotate_left(u).wrapping_add(key[2 * i]);
c = c.bitxor(u).rotate_left(t).wrapping_add(key[2 * i + 1]);
let tmp = a;
a = b;
b = c;
c = d;
d = tmp;
}
a = a.wrapping_add(key[2 * R::USIZE + 2]);
c = c.wrapping_add(key[2 * R::USIZE + 3]);
Self::block_from_words(a, b, c, d, block.get_out())
}
pub fn decrypt(&self, mut block: InOut<'_, '_, Block<W>>) {
let (mut a, mut b, mut c, mut d) = Self::words_from_block(block.get_in());
let key = &self.key_table;
let log_w = W::from((usize::BITS - 1 - (W::Bytes::USIZE * 8).leading_zeros()) as u8);
c = c.wrapping_sub(key[2 * R::USIZE + 3]);
a = a.wrapping_sub(key[2 * R::USIZE + 2]);
for i in (1..=R::USIZE).rev() {
let tmp = d;
d = c;
c = b;
b = a;
a = tmp;
let u = d
.wrapping_mul(d.wrapping_mul(W::from(2)).wrapping_add(W::from(1)))
.rotate_left(log_w);
let t = b
.wrapping_mul(b.wrapping_mul(W::from(2)).wrapping_add(W::from(1)))
.rotate_left(log_w);
c = c.wrapping_sub(key[2 * i + 1]).rotate_right(t).bitxor(u);
a = a.wrapping_sub(key[2 * i]).rotate_right(u).bitxor(t);
}
d = d.wrapping_sub(key[1]);
b = b.wrapping_sub(key[0]);
Self::block_from_words(a, b, c, d, block.get_out())
}
fn words_from_block(block: &Block<W>) -> (W, W, W, W) {
let a = W::from_le_bytes(block[..W::Bytes::USIZE].try_into().unwrap());
let b = W::from_le_bytes(
block[W::Bytes::USIZE..W::Bytes::USIZE * 2]
.try_into()
.unwrap(),
);
let c = W::from_le_bytes(
block[W::Bytes::USIZE * 2..W::Bytes::USIZE * 3]
.try_into()
.unwrap(),
);
let d = W::from_le_bytes(
block[W::Bytes::USIZE * 3..W::Bytes::USIZE * 4]
.try_into()
.unwrap(),
);
(a, b, c, d)
}
fn block_from_words(a: W, b: W, c: W, d: W, out_block: &mut Block<W>) {
let (left, right) = out_block.split_at_mut(W::Bytes::USIZE * 2);
let (l_l, l_h) = left.split_at_mut(W::Bytes::USIZE);
let (r_l, r_h) = right.split_at_mut(W::Bytes::USIZE);
l_l.copy_from_slice(&a.to_le_bytes());
l_h.copy_from_slice(&b.to_le_bytes());
r_l.copy_from_slice(&c.to_le_bytes());
r_h.copy_from_slice(&d.to_le_bytes());
}
}