use crate::{
ciphertext::*,
primitives::{
hash::Aes128Z2Hash, prf::Aes128Prf, prp::KnuthShufflePRP, AesBlock, Hash, HashKey, Prf,
Prp, NONCE_SIZE,
},
OreCipher, OreError, PlainText,
};
use aes::cipher::generic_array::GenericArray;
use lazy_static::lazy_static;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use std::cell::RefCell;
use std::cmp::Ordering;
use subtle_ng::{Choice, ConditionallySelectable, ConstantTimeEq};
use zeroize::ZeroizeOnDrop;
pub mod block_types;
pub use self::block_types::*;
#[derive(Debug, ZeroizeOnDrop)]
pub struct OreAes128<R: Rng + SeedableRng> {
prf1: Aes128Prf,
prf2: Aes128Prf,
#[zeroize(skip)]
rng: RefCell<R>,
}
pub type OreAes128ChaCha20 = OreAes128<ChaCha20Rng>;
type EncryptLeftResult<R, const N: usize> = Result<Left<OreAes128<R>, N>, OreError>;
type EncryptResult<R, const N: usize> = Result<CipherText<OreAes128<R>, N>, OreError>;
fn cmp(a: u8, b: u8) -> u8 {
u8::from(a > b)
}
impl<R: Rng + SeedableRng> OreCipher for OreAes128<R> {
type LeftBlockType = LeftBlock16;
type RightBlockType = RightBlock32;
fn init(k1: &[u8; 16], k2: &[u8; 16]) -> Result<Self, OreError> {
let rng: R = SeedableRng::from_entropy();
Ok(OreAes128 {
prf1: Prf::new(GenericArray::from_slice(k1)),
prf2: Prf::new(GenericArray::from_slice(k2)),
rng: RefCell::new(rng),
})
}
fn encrypt_left<const N: usize>(&self, x: &PlainText<N>) -> EncryptLeftResult<R, N> {
let mut output = Left::<Self, N>::init();
output.f.iter_mut().enumerate().for_each(|(n, block)| {
block[0..n].clone_from_slice(&x[0..n]);
});
self.prf2.encrypt_all(&mut output.f);
for (n, xn) in x.iter().enumerate().take(N) {
let prp: KnuthShufflePRP<u8, 256> = Prp::new(&output.f[n])?;
output.xt[n] = prp.permute(*xn)?;
}
output.f = [Default::default(); N];
for n in 0..N {
output.f[n][0..n].clone_from_slice(&x[0..n]);
output.f[n][n] = output.xt[n];
output.f[n][N] = n as u8;
}
self.prf1.encrypt_all(&mut output.f);
Ok(output)
}
fn encrypt<const N: usize>(&self, x: &PlainText<N>) -> EncryptResult<R, N> {
let mut left = Left::<Self, N>::init();
let mut right = Right::<Self, N>::init();
self.rng.borrow_mut().try_fill(&mut right.nonce)?;
left.f.iter_mut().enumerate().for_each(|(n, block)| {
block[0..n].clone_from_slice(&x[0..n]);
});
self.prf2.encrypt_all(&mut left.f);
lazy_static! {
static ref ZEROED_RO_KEYS: [AesBlock; 256] = [Default::default(); 256];
}
let mut ro_keys = *ZEROED_RO_KEYS;
for n in 0..N {
let prp: KnuthShufflePRP<u8, 256> = Prp::new(&left.f[n])?;
left.xt[n] = prp.permute(x[n])?;
left.f[n].default_in_place();
left.f[n][0..n].clone_from_slice(&x[0..n]);
left.f[n][n] = left.xt[n];
left.f[n][N] = n as u8;
for (j, ro_key) in ro_keys.iter_mut().enumerate() {
ro_key[0..n].clone_from_slice(&x[0..n]);
ro_key[n] = j as u8;
ro_key[N] = n as u8;
}
self.prf1.encrypt_all(&mut ro_keys);
let hasher: Aes128Z2Hash = Hash::new(AesBlock::from_slice(&right.nonce));
let hashes = hasher.hash_all(&mut ro_keys);
for (j, h) in hashes.iter().enumerate() {
let jstar = prp.invert(j as u8)?;
let indicator = cmp(jstar, x[n]);
right.data[n].set_bit(j, indicator ^ h);
}
ro_keys.clone_from_slice(&*ZEROED_RO_KEYS);
}
self.prf1.encrypt_all(&mut left.f);
Ok(CipherText { left, right })
}
fn compare_raw_slices(a: &[u8], b: &[u8]) -> Option<Ordering> {
if a.len() != b.len() {
return None;
};
let left_size = Self::LeftBlockType::BLOCK_SIZE;
let right_size = Self::RightBlockType::BLOCK_SIZE;
let num_blocks = (a.len() - NONCE_SIZE) / (left_size + right_size + 1);
let mut is_equal = Choice::from(1);
let mut l: u64 = 0;
let a_f = &a[num_blocks..];
let b_f = &b[num_blocks..];
for n in 0..num_blocks {
let prp_eq: Choice = !a[n].ct_eq(&b[n]);
let left_block_comparison: Choice = !left_block(a_f, n).ct_eq(left_block(b_f, n));
let condition: Choice = prp_eq | left_block_comparison;
l.conditional_assign(&(n as u64), is_equal & condition);
is_equal.conditional_assign(&Choice::from(0), is_equal & condition);
}
let l: usize = l as usize;
if bool::from(is_equal) {
return Some(Ordering::Equal);
}
let b_right = &b[num_blocks * (left_size + 1)..];
let hash_key = HashKey::from_slice(&b_right[0..NONCE_SIZE]);
let hash: Aes128Z2Hash = Hash::new(hash_key);
let h = hash.hash(left_block(a_f, l));
let target_block = right_block(&b_right[NONCE_SIZE..], l);
let test = get_bit(target_block, a[l] as usize) ^ h;
if test == 1 {
return Some(Ordering::Greater);
}
Some(Ordering::Less)
}
}
#[inline]
fn left_block(input: &[u8], n: usize) -> &[u8] {
let f_pos = n * LeftBlock16::BLOCK_SIZE;
&input[f_pos..(f_pos + LeftBlock16::BLOCK_SIZE)]
}
#[inline]
fn right_block(input: &[u8], n: usize) -> &[u8] {
let f_pos = n * RightBlock32::BLOCK_SIZE;
&input[f_pos..(f_pos + RightBlock32::BLOCK_SIZE)]
}
#[inline]
fn get_bit(block: &[u8], bit: usize) -> u8 {
debug_assert!(block.len() == RightBlock32::BLOCK_SIZE);
debug_assert!(bit < 256);
let byte_index = bit / 8;
let position = bit % 8;
let v = 1 << position;
(block[byte_index] & v) >> position
}
impl<const N: usize> PartialEq for CipherText<OreAes128ChaCha20, N> {
fn eq(&self, b: &Self) -> bool {
matches!(self.cmp(b), Ordering::Equal)
}
}
impl<const N: usize> Ord for CipherText<OreAes128ChaCha20, N> {
fn cmp(&self, b: &Self) -> Ordering {
let mut is_equal = Choice::from(1);
let mut l: u64 = 0;
for n in 0..N {
let condition: Choice =
!(self.left.xt[n].ct_eq(&b.left.xt[n])) | !(self.left.f[n].ct_eq(&b.left.f[n]));
l.conditional_assign(&(n as u64), is_equal & condition);
is_equal.conditional_assign(&Choice::from(0), is_equal & condition);
}
let l: usize = l as usize;
if bool::from(is_equal) {
return Ordering::Equal;
}
let hash: Aes128Z2Hash = Hash::new(AesBlock::from_slice(&b.right.nonce));
let h = hash.hash(&self.left.f[l]);
let test = b.right.data[l].get_bit(self.left.xt[l] as usize) ^ h;
if test == 1 {
return Ordering::Greater;
}
Ordering::Less
}
}
impl<const N: usize> PartialOrd for CipherText<OreAes128ChaCha20, N> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<const N: usize> Eq for CipherText<OreAes128ChaCha20, N> {}
#[cfg(test)]
mod tests {
use super::*;
use crate::encrypt::OreEncrypt;
use quickcheck::TestResult;
type Ore = OreAes128ChaCha20;
fn init_ore() -> Ore {
let mut k1: [u8; 16] = Default::default();
let mut k2: [u8; 16] = Default::default();
let mut rng = ChaCha20Rng::from_entropy();
rng.fill(&mut k1);
rng.fill(&mut k2);
OreCipher::init(&k1, &k2).unwrap()
}
quickcheck! {
fn compare_u64(x: u64, y: u64) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap();
let b = y.encrypt(&ore).unwrap();
match x.cmp(&y) {
Ordering::Greater => a > b,
Ordering::Less => a < b,
Ordering::Equal => a == b
}
}
fn compare_u64_raw_slices(x: u64, y: u64) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap().to_bytes();
let b = y.encrypt(&ore).unwrap().to_bytes();
match Ore::compare_raw_slices(&a, &b) {
Some(Ordering::Greater) => x > y,
Some(Ordering::Less) => x < y,
Some(Ordering::Equal) => x == y,
None => false
}
}
fn equality_u64(x: u64) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap();
let b = x.encrypt(&ore).unwrap();
a == b
}
fn equality_u64_raw_slices(x: u64) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap().to_bytes();
let b = x.encrypt(&ore).unwrap().to_bytes();
matches!(Ore::compare_raw_slices(&a, &b), Some(Ordering::Equal))
}
fn compare_u32(x: u32, y: u32) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap();
let b = y.encrypt(&ore).unwrap();
match x.cmp(&y) {
Ordering::Greater => a > b,
Ordering::Less => a < b,
Ordering::Equal => a == b
}
}
fn equality_u32(x: u64) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap();
let b = x.encrypt(&ore).unwrap();
a == b
}
fn compare_f64(x: f64, y: f64) -> TestResult {
if x.is_nan() || x.is_infinite() || y.is_nan() || y.is_infinite() {
return TestResult::discard();
}
let ore = init_ore();
let a = x.encrypt(&ore).unwrap();
let b = y.encrypt(&ore).unwrap();
match x.partial_cmp(&y) {
Some(Ordering::Greater) => TestResult::from_bool(a > b),
Some(Ordering::Less) => TestResult::from_bool(a < b),
Some(Ordering::Equal) => TestResult::from_bool(a == b),
None => TestResult::failed()
}
}
fn equality_f64(x: f64) -> bool {
let ore = init_ore();
let a = x.encrypt(&ore).unwrap();
let b = x.encrypt(&ore).unwrap();
a == b
}
fn compare_plaintext(x: u64, y: u64) -> bool {
let ore = init_ore();
let a = x.to_be_bytes().encrypt(&ore).unwrap();
let b = y.to_be_bytes().encrypt(&ore).unwrap();
match x.cmp(&y) {
Ordering::Greater => a > b,
Ordering::Less => a < b,
Ordering::Equal => a == b
}
}
fn equality_plaintext(x: f64) -> bool {
let ore = init_ore();
let a = x.to_be_bytes().encrypt(&ore).unwrap();
let b = x.to_be_bytes().encrypt(&ore).unwrap();
a == b
}
}
#[test]
fn smallest_to_largest() {
let ore = init_ore();
let a = 0u64.encrypt(&ore).unwrap();
let b = 18446744073709551615u64.encrypt(&ore).unwrap();
assert!(a < b);
}
#[test]
fn largest_to_smallest() {
let ore = init_ore();
let a = 18446744073709551615u64.encrypt(&ore).unwrap();
let b = 0u64.encrypt(&ore).unwrap();
assert!(a > b);
}
#[test]
fn smallest_to_smallest() {
let ore = init_ore();
let a = 0u64.encrypt(&ore).unwrap();
let b = 0u64.encrypt(&ore).unwrap();
assert!(a == b);
}
#[test]
fn largest_to_largest() {
let ore = init_ore();
let a = 18446744073709551615u64.encrypt(&ore).unwrap();
let b = 18446744073709551615u64.encrypt(&ore).unwrap();
assert!(a == b);
}
#[test]
fn signed_zeros_compare_equal() {
let ore = init_ore();
let pos_zero = 0.0_f64.encrypt(&ore).unwrap();
let neg_zero = (-0.0_f64).encrypt(&ore).unwrap();
assert_eq!(0.0_f64.partial_cmp(&-0.0_f64), Some(Ordering::Equal));
assert!(pos_zero == neg_zero);
}
#[test]
fn comparisons_in_first_block() {
let ore = init_ore();
let a = 18446744073709551615u64.encrypt(&ore).unwrap();
let b = 18446744073709551612u64.encrypt(&ore).unwrap();
assert!(a > b);
assert!(b < a);
}
#[test]
fn comparisons_in_last_block() {
let ore = init_ore();
let a = 10u64.encrypt(&ore).unwrap();
let b = 73u64.encrypt(&ore).unwrap();
assert!(a < b);
assert!(b > a);
}
#[test]
fn compare_raw_slices_mismatched_lengths() {
let ore = init_ore();
let a_64 = 10u64.encrypt(&ore).unwrap().to_bytes();
let a_32 = 10u32.encrypt(&ore).unwrap().to_bytes();
assert_eq!(Ore::compare_raw_slices(&a_64, &a_32), Option::None);
}
#[test]
fn binary_encoding() {
let ore = init_ore();
let a = 10u64.encrypt(&ore).unwrap();
let bin = a.to_bytes();
assert_eq!(
a,
CipherText::<OreAes128ChaCha20, 8>::from_slice(&bin).unwrap()
);
}
#[test]
#[should_panic(expected = "ParseError")]
fn binary_encoding_invalid_length() {
let bin = vec![0, 1, 2, 3];
CipherText::<OreAes128ChaCha20, 8>::from_slice(&bin).unwrap();
}
#[test]
fn test_different_prf_keys() {
let k1: [u8; 16] = [
97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
];
let k2: [u8; 16] = [
129, 4, 114, 186, 102, 145, 225, 73, 166, 57, 244, 251, 56, 92, 188, 36,
];
let k3: [u8; 16] = [
49, 50, 51, 52, 53, 54, 55, 56, 57, 48, 97, 98, 99, 100, 101, 102,
];
let ore1: OreAes128ChaCha20 = OreCipher::init(&k1, &k2).unwrap();
let ore2: OreAes128ChaCha20 = OreCipher::init(&k3, &k2).unwrap();
let a = 1000u32.encrypt(&ore1).unwrap().to_bytes();
let b = 1000u32.encrypt(&ore2).unwrap().to_bytes();
assert_ne!(Some(Ordering::Equal), Ore::compare_raw_slices(&a, &b));
}
#[test]
fn test_different_prp_keys() {
let k1: [u8; 16] = [
97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
];
let k2: [u8; 16] = [
129, 4, 114, 186, 102, 145, 225, 73, 166, 57, 244, 251, 56, 92, 188, 36,
];
let k3: [u8; 16] = [
49, 50, 51, 52, 53, 54, 55, 56, 57, 48, 97, 98, 99, 100, 101, 102,
];
let ore1: OreAes128ChaCha20 = OreCipher::init(&k1, &k2).unwrap();
let ore2: OreAes128ChaCha20 = OreCipher::init(&k1, &k3).unwrap();
let a = 1000u32.encrypt(&ore1).unwrap().to_bytes();
let b = 1000u32.encrypt(&ore2).unwrap().to_bytes();
assert_ne!(Some(Ordering::Equal), Ore::compare_raw_slices(&a, &b));
}
}