1use crate::{FbError, PrimeField};
3use crypto_bigint::{U128, U256, U512, Uint};
4use rand::{Rng, rngs::ThreadRng, seq::index};
5use std::sync::RwLock;
6
7pub struct FbObj<T> {
9 pub(crate) c: RwLock<Vec<T>>,
10 pub(crate) r: Vec<T>,
11}
12
13pub type Fb128 = FbObj<U128>;
15pub type Fb256 = FbObj<U256>;
17pub type Fb512 = FbObj<U512>;
19
20impl<const LIMBS: usize> FbObj<Uint<LIMBS>>
21where
22 Uint<LIMBS>: PrimeField
23{
24 pub(crate) fn add_block(
25 &self,
26 rng: &mut ThreadRng,
27 msg_uint: &Uint<LIMBS>
28 ) -> Vec<(usize, usize)> {
29 let r = &self.r;
30 let n = rng.gen_range(2..=r.len());
31 let r_i = index::sample(rng, r.len(), n);
32 let ri_last = r_i.iter().last()
33 .expect("r_i will contain at least 2 elements");
34 let ri_last_inv = r[ri_last].field_inv();
35 let c_i;
36 let c_len;
37 {
38 let mut c = self.c.write().unwrap();
39 c_i = index::sample(rng, c.len(), n - 1);
40 let sum = c_i.iter()
41 .zip(r_i.iter())
42 .map(|(ci, ri)| c[ci].field_mul(&r[ri]))
43 .reduce(|acc, i| acc.field_add(&i))
44 .unwrap();
45 let c_new_el = msg_uint.field_sub(&sum).field_mul(&ri_last_inv);
46 c.push(c_new_el);
47 c_len = c.len();
48 }
49 let indices = c_i.into_iter()
50 .chain([c_len - 1].into_iter())
51 .zip(r_i.into_iter())
52 .collect();
53
54 indices
55 }
56
57 pub(crate) fn decrypt_block(
58 &self,
59 indices: &[(usize, usize)]
60 ) -> Result<Uint<LIMBS>, FbError> {
61 let (c, r) = (self.c.read().unwrap(), &self.r);
62 if indices.len() > r.len() {
63 return Err(FbError::InvalidKey);
64 }
65 let mut msg = Uint::<LIMBS>::ZERO;
66 for &(ci, ri) in indices {
67 let c_el = c.get(ci).ok_or(FbError::InvalidKey)?;
68 let r_el = r.get(ri).ok_or(FbError::InvalidKey)?;
69 msg = msg.field_add(&c_el.field_mul(&r_el));
70 }
71
72 Ok(msg)
73 }
74}
75
76#[cfg(test)]
77mod test {
78 use super::*;
79 use crate::FalseBottom;
80
81 #[test]
82 fn test_block_operations() {
83 let msg = U512::from_u16(369);
84 let fb = Fb512::init(12, 12);
85 let key = fb.add_block(&mut rand::thread_rng(), &msg);
86 let decrypted = fb.decrypt_block(&key).unwrap();
87 assert_eq!(msg, decrypted);
88 }
89}