use core::cmp;
use block_modes::{block_padding::NoPadding, BlockMode, Cbc};
use cipher::{
generic_array::GenericArray, Block, BlockCipher, BlockDecrypt, BlockEncrypt, NewBlockCipher,
};
#[cfg(feature = "alloc")]
mod alloc;
#[cfg(feature = "alloc")]
pub use self::alloc::{BinaryNumeralString, FlexibleNumeralString};
#[derive(Debug, PartialEq)]
enum Radix {
Any(u32),
PowerTwo { radix: u32, log_radix: u8 },
}
impl Radix {
pub fn from(radix: u32) -> Result<Self, ()> {
if !(2..=(1 << 16)).contains(&radix) {
return Err(());
}
let mut tmp = radix;
let mut log_radix = None;
let mut found_bit = false;
for i in 0..17 {
if tmp & 1 != 0 {
if found_bit {
log_radix = None;
} else {
log_radix = Some(i);
found_bit = true;
}
}
tmp >>= 1;
}
Ok(match log_radix {
Some(log_radix) => Radix::PowerTwo { radix, log_radix },
None => Radix::Any(radix),
})
}
fn calculate_b(&self, v: usize) -> usize {
use libm::{ceil, log2};
match *self {
Radix::Any(r) => ceil(v as f64 * log2(f64::from(r)) / 8f64) as usize,
Radix::PowerTwo { log_radix, .. } => ((v * log_radix as usize) + 7) / 8,
}
}
fn to_u32(&self) -> u32 {
match *self {
Radix::Any(r) => r,
Radix::PowerTwo { radix, .. } => radix,
}
}
}
pub trait Numeral {
type Bytes: AsRef<[u8]>;
fn from_bytes(s: impl Iterator<Item = u8>) -> Self;
fn to_bytes(&self, b: usize) -> Self::Bytes;
fn add_mod_exp(self, other: Self, radix: u32, m: usize) -> Self;
fn sub_mod_exp(self, other: Self, radix: u32, m: usize) -> Self;
}
pub trait NumeralString: Sized {
type Num: Numeral;
fn is_valid(&self, radix: u32) -> bool;
fn len(&self) -> usize;
fn split(&self, u: usize) -> (Self, Self);
fn concat(a: Self, b: Self) -> Self;
fn num_radix(&self, radix: u32) -> Self::Num;
fn str_radix(x: Self::Num, radix: u32, m: usize) -> Self;
}
#[derive(Clone)]
struct Prf<CIPH: BlockEncrypt + BlockDecrypt> {
state: Cbc<CIPH, NoPadding>,
buf: [Block<CIPH>; 1],
offset: usize,
}
impl<CIPH: BlockEncrypt + BlockDecrypt + Clone> Prf<CIPH> {
fn new(ciph: &CIPH) -> Self {
let ciph = ciph.clone();
Prf {
state: Cbc::new(ciph, GenericArray::from_slice(&[0; 16])),
buf: [Block::<CIPH>::default()],
offset: 0,
}
}
fn update(&mut self, mut data: &[u8]) {
while !data.is_empty() {
let to_read = cmp::min(self.buf[0].len() - self.offset, data.len());
self.buf[0][self.offset..self.offset + to_read].copy_from_slice(&data[..to_read]);
self.offset += to_read;
data = &data[to_read..];
if self.offset == self.buf[0].len() {
self.state.encrypt_blocks(&mut self.buf);
self.offset = 0;
}
}
}
fn output(&self) -> &Block<CIPH> {
assert_eq!(self.offset, 0);
&self.buf[0]
}
}
fn generate_s<'a, CIPH: BlockEncrypt>(
ciph: &'a CIPH,
r: &'a Block<CIPH>,
d: usize,
) -> impl Iterator<Item = u8> + 'a {
r.clone()
.into_iter()
.chain((1..((d + 15) / 16) as u128).flat_map(move |j| {
let mut block = r.clone();
for (b, j) in block.iter_mut().zip(j.to_be_bytes().iter()) {
*b ^= j;
}
ciph.encrypt_block(&mut block);
block.into_iter()
}))
.take(d)
}
pub struct FF1<CIPH: BlockCipher> {
ciph: CIPH,
radix: Radix,
}
impl<CIPH: NewBlockCipher + BlockEncrypt + BlockDecrypt + Clone> FF1<CIPH> {
pub fn new(key: &[u8], radix: u32) -> Result<Self, ()> {
let ciph = CIPH::new(GenericArray::from_slice(key));
let radix = Radix::from(radix)?;
Ok(FF1 { ciph, radix })
}
#[allow(clippy::many_single_char_names)]
pub fn encrypt<NS: NumeralString>(&self, tweak: &[u8], x: &NS) -> Result<NS, ()> {
if !x.is_valid(self.radix.to_u32()) {
return Err(());
}
let n = x.len();
let t = tweak.len();
let u = n / 2;
let v = n - u;
let (mut x_a, mut x_b) = x.split(u);
let b = self.radix.calculate_b(v);
let d = 4 * ((b + 3) / 4) + 4;
let mut p = [1, 2, 1, 0, 0, 0, 10, u as u8, 0, 0, 0, 0, 0, 0, 0, 0];
p[3..6].copy_from_slice(&self.radix.to_u32().to_be_bytes()[1..]);
p[8..12].copy_from_slice(&(n as u32).to_be_bytes());
p[12..16].copy_from_slice(&(t as u32).to_be_bytes());
let mut prf = Prf::new(&self.ciph);
prf.update(&p);
prf.update(tweak);
for _ in 0..((((-(t as i32) - (b as i32) - 1) % 16) + 16) % 16) {
prf.update(&[0]);
}
for i in 0..10 {
let mut prf = prf.clone();
prf.update(&[i]);
prf.update(x_b.num_radix(self.radix.to_u32()).to_bytes(b).as_ref());
let r = prf.output();
let s = generate_s(&self.ciph, r, d);
let y = NS::Num::from_bytes(s);
let m = if i % 2 == 0 { u } else { v };
let c = x_a
.num_radix(self.radix.to_u32())
.add_mod_exp(y, self.radix.to_u32(), m);
let x_c = NS::str_radix(c, self.radix.to_u32(), m);
x_a = x_b;
x_b = x_c;
}
Ok(NS::concat(x_a, x_b))
}
#[allow(clippy::many_single_char_names)]
pub fn decrypt<NS: NumeralString>(&self, tweak: &[u8], x: &NS) -> Result<NS, ()> {
if !x.is_valid(self.radix.to_u32()) {
return Err(());
}
let n = x.len();
let t = tweak.len();
let u = n / 2;
let v = n - u;
let (mut x_a, mut x_b) = x.split(u);
let b = self.radix.calculate_b(v);
let d = 4 * ((b + 3) / 4) + 4;
let mut p = [1, 2, 1, 0, 0, 0, 10, u as u8, 0, 0, 0, 0, 0, 0, 0, 0];
p[3..6].copy_from_slice(&self.radix.to_u32().to_be_bytes()[1..]);
p[8..12].copy_from_slice(&(n as u32).to_be_bytes());
p[12..16].copy_from_slice(&(t as u32).to_be_bytes());
let mut prf = Prf::new(&self.ciph);
prf.update(&p);
prf.update(tweak);
for _ in 0..((((-(t as i32) - (b as i32) - 1) % 16) + 16) % 16) {
prf.update(&[0]);
}
for i in 0..10 {
let i = 9 - i;
let mut prf = prf.clone();
prf.update(&[i]);
prf.update(x_a.num_radix(self.radix.to_u32()).to_bytes(b).as_ref());
let r = prf.output();
let s = generate_s(&self.ciph, r, d);
let y = NS::Num::from_bytes(s);
let m = if i % 2 == 0 { u } else { v };
let c = x_b
.num_radix(self.radix.to_u32())
.sub_mod_exp(y, self.radix.to_u32(), m);
let x_c = NS::str_radix(c, self.radix.to_u32(), m);
x_b = x_a;
x_a = x_c;
}
Ok(NS::concat(x_a, x_b))
}
}
#[cfg(test)]
mod tests {
use super::Radix;
#[test]
fn radix() {
assert_eq!(Radix::from(1), Err(()));
assert_eq!(
Radix::from(2),
Ok(Radix::PowerTwo {
radix: 2,
log_radix: 1,
})
);
assert_eq!(Radix::from(3), Ok(Radix::Any(3)));
assert_eq!(
Radix::from(4),
Ok(Radix::PowerTwo {
radix: 4,
log_radix: 2,
})
);
assert_eq!(Radix::from(5), Ok(Radix::Any(5)));
assert_eq!(Radix::from(6), Ok(Radix::Any(6)));
assert_eq!(Radix::from(7), Ok(Radix::Any(7)));
assert_eq!(
Radix::from(8),
Ok(Radix::PowerTwo {
radix: 8,
log_radix: 3,
})
);
assert_eq!(
Radix::from(32768),
Ok(Radix::PowerTwo {
radix: 32768,
log_radix: 15,
})
);
assert_eq!(Radix::from(65535), Ok(Radix::Any(65535)));
assert_eq!(
Radix::from(65536),
Ok(Radix::PowerTwo {
radix: 65536,
log_radix: 16,
})
);
assert_eq!(Radix::from(65537), Err(()));
}
}