use crate::{FBError, FBKey, FBObj, FBObjTrait, FieldOps, Packing};
use crypto_bigint::{NonZero, RandomMod};
use rand::{rngs::ThreadRng, seq::index, Rng};
use rayon::iter::*;
use std::marker::Send;
use std::sync::RwLock;
pub trait FBAlgo<T>
where
Self: BlockOps<T> + Sync + Send,
T: FieldOps + Packing + RandomMod + Send + Sync,
{
const MODULUS: NonZero<T>;
fn init(cipher_len: usize, keybase_len: usize) -> FBObj<T> {
if cipher_len < keybase_len || keybase_len < 2 {
panic!("{}", FBError::InvalidParams);
}
let mut rng = rand::thread_rng();
let r = (0..keybase_len)
.map(|_| T::random_mod(&mut rng, &Self::MODULUS))
.collect();
let c_vec = (0..cipher_len)
.map(|_| T::random_mod(&mut rng, &Self::MODULUS))
.collect();
let c = RwLock::new(c_vec);
FBObj {c, r}
}
fn add(&mut self, msg: &[u8]) -> FBKey {
let indices = T::pack(msg)
.into_par_iter()
.map_init(
|| rand::thread_rng(),
|rng, index_row| self.add_block(rng, &index_row),
)
.collect();
FBKey { indices }
}
fn decrypt(&self, key: &FBKey) -> Result<Vec<u8>, FBError> {
let decr = key.indices.iter()
.map(|index_row| self.decrypt_block(&index_row))
.collect::<Result<Vec<_>, _>>()?;
let mut msg = T::unpack(decr)?;
msg.shrink_to_fit();
Ok(msg)
}
}
pub trait BlockOps<T>
where
Self: FBObjTrait<T>,
T: FieldOps + RandomMod + Send + Sync,
{
fn add_block(&self, rng: &mut ThreadRng, msg_uint: &T) -> Vec<(usize, usize)> {
let r = self.keybase();
let n = rng.gen_range(2..=r.len());
let r_i = index::sample(rng, r.len(), n);
let ri_last = r_i.iter().last()
.expect("r_i will contain at least 2 elements");
let ri_last_inv = r[ri_last].field_inv();
let c_i;
let c_len;
{
let mut c = self.cipher().write().unwrap();
c_i = index::sample(rng, c.len(), n - 1);
let sum = c_i.iter()
.zip(r_i.iter())
.map(|(ci, ri)| c[ci].field_mul(&r[ri]))
.reduce(|acc, i| acc.field_add(&i))
.unwrap();
let c_new_el = msg_uint.field_sub(&sum).field_mul(&ri_last_inv);
c.push(c_new_el);
c_len = c.len();
}
let indices = c_i.iter()
.chain([c_len - 1].into_iter())
.zip(r_i.iter())
.collect();
indices
}
fn decrypt_block(&self, indices: &[(usize, usize)]) -> Result<T, FBError> {
let (c, r) = (self.cipher().read().unwrap(), self.keybase());
if indices.len() > r.len() {
return Err(FBError::InvalidKey);
}
let mut msg = T::ZERO;
for &(ci, ri) in indices {
let c_el = c.get(ci).ok_or(FBError::InvalidKey)?;
let r_el = r.get(ri).ok_or(FBError::InvalidKey)?;
msg = msg.field_add(&c_el.field_mul(&r_el));
}
Ok(msg)
}
}
#[test]
fn encrypt_u128() {
use crypto_bigint::U128;
let msg = U128::from_u32(100);
let fb = FBObj::<U128>::init(18, 12);
let rng = &mut rand::thread_rng();
let key = fb.add_block(rng, &msg);
let decrypted = fb.decrypt_block(&key).unwrap();
assert_eq!(msg, decrypted);
}
#[test]
fn encrypt_bytes() {
use crypto_bigint::U128;
let input1 = vec![255_u8; 33];
let input2 = vec![0_u8; 102];
let mut fb = FBObj::<U128>::init(21, 9);
let key1 = fb.add(&input1);
let key2 = fb.add(&input2);
let decr1 = fb.decrypt(&key1).unwrap();
let decr2 = fb.decrypt(&key2).unwrap();
assert_eq!(input1, decr1);
assert_eq!(input2, decr2);
}