use crate::{FbError, PrimeField};
use crypto_bigint::{U128, U256, U512, Uint};
use rand::{Rng, rngs::ThreadRng, seq::index};
use std::sync::RwLock;
pub struct FbObj<T> {
pub(crate) c: RwLock<Vec<T>>,
pub(crate) r: Vec<T>,
}
pub type Fb128 = FbObj<U128>;
pub type Fb256 = FbObj<U256>;
pub type Fb512 = FbObj<U512>;
impl<const LIMBS: usize> FbObj<Uint<LIMBS>>
where
Uint<LIMBS>: PrimeField
{
pub(crate) fn add_block(
&self,
rng: &mut ThreadRng,
msg_uint: &Uint<LIMBS>
) -> Vec<(usize, usize)> {
let r = &self.r;
let n = rng.gen_range(2..=r.len());
let r_i = index::sample(rng, r.len(), n);
let ri_last = r_i.iter().last()
.expect("r_i will contain at least 2 elements");
let ri_last_inv = r[ri_last].field_inv();
let c_i;
let c_len;
{
let mut c = self.c.write().unwrap();
c_i = index::sample(rng, c.len(), n - 1);
let sum = c_i.iter()
.zip(r_i.iter())
.map(|(ci, ri)| c[ci].field_mul(&r[ri]))
.reduce(|acc, i| acc.field_add(&i))
.unwrap();
let c_new_el = msg_uint.field_sub(&sum).field_mul(&ri_last_inv);
c.push(c_new_el);
c_len = c.len();
}
let indices = c_i.into_iter()
.chain([c_len - 1].into_iter())
.zip(r_i.into_iter())
.collect();
indices
}
pub(crate) fn decrypt_block(
&self,
indices: &[(usize, usize)]
) -> Result<Uint<LIMBS>, FbError> {
let (c, r) = (self.c.read().unwrap(), &self.r);
if indices.len() > r.len() {
return Err(FbError::InvalidKey);
}
let mut msg = Uint::<LIMBS>::ZERO;
for &(ci, ri) in indices {
let c_el = c.get(ci).ok_or(FbError::InvalidKey)?;
let r_el = r.get(ri).ok_or(FbError::InvalidKey)?;
msg = msg.field_add(&c_el.field_mul(&r_el));
}
Ok(msg)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::FalseBottom;
#[test]
fn test_block_operations() {
let msg = U512::from_u16(369);
let fb = Fb512::init(12, 12);
let key = fb.add_block(&mut rand::thread_rng(), &msg);
let decrypted = fb.decrypt_block(&key).unwrap();
assert_eq!(msg, decrypted);
}
}