use std::iter::once;
use aes::{cipher::KeyInit, Aes128Enc};
use ff::Field;
use crate::{
algebra::field::{binary::Gf2_128, FieldExtension},
hashing::{hash_ccr, hash_into},
random::prg::Aes128Prng,
types::{HeapArray, Positive, SessionId},
};
#[derive(Debug, Clone)]
pub struct FullTree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive> {
pub leaves: HeapArray<F, TreeLeafCnt>,
pub keys: HeapArray<Gf2_128, TreeDepth>,
}
#[derive(Debug, Clone)]
pub struct PuncturedTree<F: FieldExtension, TreeLeafCnt: Positive> {
pub leaves: HeapArray<F, TreeLeafCnt>,
}
#[derive(Debug, Clone)]
pub struct PCGGM {
cipher: Aes128Enc,
}
impl PCGGM {
pub fn new(session_id: SessionId) -> Self {
Self {
cipher: Self::new_cipher(session_id),
}
}
fn new_cipher(session_id: SessionId) -> Aes128Enc {
let mut seed = [0; 16];
hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
Aes128Enc::new_from_slice(&seed).unwrap()
}
pub fn expand_full_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
&mut self,
root: Gf2_128,
delta: Gf2_128,
) -> FullTree<F, TreeDepth, TreeLeafCnt> {
let (seeds, keys) = compute_full_tree(root, delta, &self.cipher);
let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
FullTree { leaves, keys }
}
pub fn expand_punctured_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
&mut self,
alpha: usize,
keys: HeapArray<Gf2_128, TreeDepth>,
) -> PuncturedTree<F, TreeLeafCnt> {
let seeds = compute_punctured_tree(alpha, keys, &self.cipher);
let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
PuncturedTree { leaves }
}
}
fn generate_from_seeds<F, N: Positive>(
cipher: &Aes128Enc,
seeds: HeapArray<Gf2_128, N>,
) -> HeapArray<F, N>
where
F: FieldExtension,
{
seeds
.into_iter()
.map(|s| <F as Field>::random(Aes128Prng::new(cipher, s)))
.collect()
}
fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
root: Gf2_128,
delta: Gf2_128,
cipher: &Aes128Enc,
) -> (
HeapArray<Gf2_128, TreeLeafCnt>, // leaves
HeapArray<Gf2_128, TreeDepth>, // level keys
) {
assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
leaves[0] = root;
leaves[TreeLeafCnt::USIZE >> 1] = delta + root;
let keys = once(root)
.chain((1..TreeDepth::USIZE).map(|level| ggm_level_expand(&mut leaves, level, cipher)[0]))
.collect();
(leaves, keys)
}
fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
alpha: usize,
keys_tilde: HeapArray<Gf2_128, TreeDepth>,
cipher: &Aes128Enc,
) -> HeapArray<Gf2_128, TreeLeafCnt> {
assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
let k: usize = TreeDepth::USIZE - 1;
let idx = (1 ^ (alpha >> k)) << k;
leaves[idx] = keys_tilde[0];
(1..TreeDepth::USIZE - 1).for_each(|level| {
let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
let k: usize = TreeDepth::USIZE - level - 1;
let idx = 1 ^ (alpha >> k);
let alpha_k_neg = idx & 1;
let idx = idx << k;
leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
});
let level = TreeDepth::USIZE - 1;
let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
leaves[alpha] += keys_tilde[level] - lvl_keys_tilde[alpha & 1];
leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
leaves
}
#[inline]
fn ggm_level_expand(nodes: &mut [Gf2_128], level: usize, cipher: &Aes128Enc) -> [Gf2_128; 2] {
assert!(nodes.len() >= (1 << (level + 1)));
let mut k0 = Gf2_128::ZERO;
let mut k1 = Gf2_128::ZERO;
let n = nodes.len();
let step = n >> (level + 1);
for j in (0..n).step_by(step << 1) {
let s = &nodes[j];
let s0 = hash_ccr(cipher, s);
let s1 = s + s0;
k0 += &s0;
k1 += &s1;
nodes[j] = s0;
nodes[j + step] = s1;
}
[k0, k1]
}
#[cfg(test)]
mod tests {
use rand::Rng;
use typenum::U;
use super::*;
use crate::{
algebra::{
elliptic_curve::{Curve25519Ristretto as C, ScalarField},
field::binary::Gf2_128,
},
izip_eq,
random::{self, Random},
};
fn cggm_tree<TreeDepth: Positive, TreeLeafCnt: Positive>() {
let mut rng = random::test_rng();
let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
let session_id = SessionId::random(&mut rng);
let root: Gf2_128 = Random::random(&mut rng);
let delta: Gf2_128 = Random::random(&mut rng);
let cipher = {
let mut seed = [0; 16];
hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
Aes128Enc::new_from_slice(&seed).unwrap()
};
let (full_leaves, full_keys) =
compute_full_tree::<TreeDepth, TreeLeafCnt>(root, delta, &cipher);
let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
if alpha_k == 1 {
full_keys[k]
} else {
delta + full_keys[k]
}
});
let punctured_leaves =
compute_punctured_tree::<TreeDepth, TreeLeafCnt>(alpha, keys, &cipher);
izip_eq!(full_leaves, punctured_leaves)
.enumerate()
.for_each(|(k, (r, s))| {
if k != alpha {
assert_eq!(r, s);
} else {
assert_eq!(r + delta, s);
}
});
}
#[test]
fn test_cggm_tree() {
cggm_tree::<U<4>, U<16>>();
cggm_tree::<U<7>, U<128>>();
cggm_tree::<U<12>, U<4096>>();
}
fn pcggm<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
let mut rng = random::test_rng();
let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
let session_id = SessionId::random(&mut rng);
let root: Gf2_128 = Random::random(&mut rng);
let delta: Gf2_128 = Random::random(&mut rng);
let mut ggm = PCGGM::new(session_id);
let FullTree {
leaves: full_leaves,
keys: full_keys,
} = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root, delta);
let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
if alpha_k == 1 {
full_keys[k]
} else {
delta + full_keys[k]
}
});
let PuncturedTree {
leaves: punctured_leaves,
} = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt>(alpha, keys);
izip_eq!(full_leaves, punctured_leaves)
.enumerate()
.for_each(|(k, (r, s))| {
if k != alpha {
assert_eq!(r, s);
} else {
assert_ne!(r, s);
}
});
}
#[test]
fn test_pcggm() {
pcggm::<Gf2_128, U<4>, U<16>>();
pcggm::<Gf2_128, U<7>, U<128>>();
pcggm::<Gf2_128, U<12>, U<4096>>();
pcggm::<ScalarField<C>, U<4>, U<16>>();
pcggm::<ScalarField<C>, U<7>, U<128>>();
pcggm::<ScalarField<C>, U<12>, U<4096>>();
}
fn pcggm_are_distinct<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
let mut rng = random::test_rng();
let session_id = SessionId::random(&mut rng);
let root1: Gf2_128 = Random::random(&mut rng);
let root2: Gf2_128 = Random::random(&mut rng);
let delta: Gf2_128 = Random::random(&mut rng);
let mut ggm = PCGGM::new(session_id);
let FullTree {
leaves: full_leaves1,
..
} = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root1, delta);
let FullTree {
leaves: full_leaves2,
..
} = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root2, delta);
izip_eq!(full_leaves1, full_leaves2).for_each(|(e1, e2)| assert_ne!(e1, e2));
}
#[test]
fn test_pcggm_are_distinct() {
pcggm_are_distinct::<Gf2_128, U<4>, U<16>>();
pcggm_are_distinct::<Gf2_128, U<7>, U<128>>();
pcggm_are_distinct::<Gf2_128, U<12>, U<4096>>();
pcggm_are_distinct::<ScalarField<C>, U<4>, U<16>>();
pcggm_are_distinct::<ScalarField<C>, U<7>, U<128>>();
pcggm_are_distinct::<ScalarField<C>, U<12>, U<4096>>();
}
}