use rand::{thread_rng, Rng};
use crate::{bits::Bits, hashable::Hashable};
use super::Okvs;
const LOG_2_E: f64 = 1.4426950408889634073599246810018921374266459541529859341354494069;
pub struct GarbledBf<V: Bits> {
bins: Vec<V>,
hash_seeds: Vec<u64>,
}
impl<V: Bits> Okvs<V> for GarbledBf<V> {
fn try_encode<K: Hashable>(key_value_pairs: &[(K, V)], lambda: usize) -> Option<Self> {
let hash_count = lambda;
let bin_count = ((hash_count * key_value_pairs.len()) as f64 * LOG_2_E).ceil() as usize;
let mut bins = vec![None; bin_count];
let hash_seeds: Vec<u64> = (0..hash_count).map(|_| thread_rng().gen()).collect();
for (key, value) in key_value_pairs {
let mut empty_slot = None;
let mut final_share = *value;
for seed in &hash_seeds {
let index = key.hash_to_index(*seed, bin_count);
match bins[index] {
None => {
match empty_slot {
None => {
empty_slot = Some(index);
}
Some(slot) => {
if index == slot {
return None;
}
let random = V::random();
final_share ^= random;
bins[index] = Some(random);
}
}
}
Some(share) => {
final_share ^= share;
}
}
}
match empty_slot {
Some(slot) => bins[slot] = Some(final_share),
None => return None, }
}
Some(Self {
bins: bins
.into_iter()
.map(|bin| match bin {
Some(share) => share,
None => V::random(),
})
.collect(),
hash_seeds,
})
}
fn decode<K: Hashable>(&self, key: &K) -> V {
let mut iterator = self
.hash_seeds
.iter()
.map(|seed| self.bins[key.hash_to_index(*seed, self.bins.len())]);
let mut result = iterator.next().unwrap();
for share in iterator {
result ^= share;
}
result
}
fn to_bytes(self) -> Vec<u8> {
todo!()
}
fn from_bytes(_bytes: &[u8]) -> Self {
todo!()
}
}
#[cfg(test)]
mod tests {
use super::GarbledBf;
use crate::{bits::Bits, schemes::Okvs};
#[test]
fn encode_decode_u64() {
let r1 = u64::random();
let r2 = u64::random();
let r3 = u64::random();
let okvs = GarbledBf::encode(&[(1u64, r1), (1000u64, r2), (123u64, r3)], 40);
assert_eq!(okvs.decode(&1u64), r1);
assert_eq!(okvs.decode(&1000u64), r2);
assert_eq!(okvs.decode(&123u64), r3);
assert_ne!(okvs.decode(&0u64), r1);
assert_ne!(okvs.decode(&0u64), r2);
assert_ne!(okvs.decode(&0u64), r3);
assert_ne!(okvs.decode(&2u64), r1);
assert_ne!(okvs.decode(&2u64), r2);
assert_ne!(okvs.decode(&2u64), r3);
assert_ne!(okvs.decode(&u64::MAX), r1);
assert_ne!(okvs.decode(&u64::MAX), r2);
assert_ne!(okvs.decode(&u64::MAX), r3);
}
}