use std::ops::Mul;
use aes::{cipher::KeyInit, Aes128Enc};
use ff::Field;
use itertools::enumerate;
use crate::{
algebra::field::{binary::Gf2_128, FieldExtension},
hashing::{hash_cr, hash_into},
izip_eq,
random::prg::Aes128Prng,
types::{HeapArray, HeapMatrix, Positive, SessionId},
};
#[derive(Debug, Clone)]
pub struct FullTrees<
F: FieldExtension,
TreeDepth: Positive,
TreeLeafCnt: Positive,
BatchSize: Positive,
> {
pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
pub keys: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>,
}
#[derive(Debug, Clone)]
pub struct PuncturedTrees<F: FieldExtension, TreeLeafCnt: Positive, BatchSize: Positive> {
pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
}
#[derive(Debug, Clone)]
pub struct GGM {
ciphers: (Aes128Enc, Aes128Enc),
}
impl GGM {
pub fn new(session_id: SessionId) -> Self {
Self {
ciphers: Self::new_ciphers(session_id),
}
}
fn new_ciphers(session_id: SessionId) -> (Aes128Enc, Aes128Enc) {
let (mut seed0, mut seed1) = ([0; 16], [0; 16]);
hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);
(
Aes128Enc::new_from_slice(&seed0).unwrap(),
Aes128Enc::new_from_slice(&seed1).unwrap(),
)
}
pub fn expand_full_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
&mut self,
root_batches: &HeapArray<Gf2_128, BatchSize>,
) -> FullTrees<F, TreeDepth, TreeLeafCnt, BatchSize>
where
F: FieldExtension,
TreeDepth: Positive,
TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
BatchSize: Positive,
{
let (seeds, keys) = compute_full_tree(root_batches, &self.ciphers.0, &self.ciphers.1);
let leaves =
generate_from_seeds::<F, _, BatchSize>(&self.ciphers.0, &self.ciphers.1, seeds);
FullTrees { leaves, keys }
}
pub fn expand_punctured_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
&mut self,
alpha_batches: &HeapArray<usize, BatchSize>,
keys_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>,
) -> PuncturedTrees<F, TreeLeafCnt, BatchSize>
where
F: FieldExtension,
TreeDepth: Positive,
TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
BatchSize: Positive,
{
let seeds = compute_punctured_tree(
alpha_batches,
keys_batches,
&self.ciphers.0,
&self.ciphers.1,
);
let leaves = generate_from_seeds::<F, _, _>(&self.ciphers.0, &self.ciphers.1, seeds);
PuncturedTrees { leaves }
}
}
fn generate_from_seeds<F, TreeLeafCnt, BatchSize>(
enc_left: &Aes128Enc,
enc_right: &Aes128Enc,
seeds: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>,
) -> HeapMatrix<F, TreeLeafCnt, BatchSize>
where
F: FieldExtension,
TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
BatchSize: Positive,
{
let ciphers = [enc_left, enc_right];
let tmp: Vec<_> = enumerate(seeds.into_flat_iter())
.map(|(i, s)| <F as Field>::random(Aes128Prng::new(ciphers[i % 2], s)))
.collect();
HeapMatrix::<_, TreeLeafCnt, BatchSize>::try_from(tmp).unwrap()
}
#[allow(clippy::type_complexity)]
fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
root_batches: &HeapArray<Gf2_128, BatchSize>,
cipher0: &Aes128Enc,
cipher1: &Aes128Enc,
) -> (
HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>, // leaves
HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>, // level keys
) {
assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
assert!(BatchSize::USIZE > 0);
let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
let mut keys_batches: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize> = Default::default();
izip_eq!(
root_batches,
leaves_batches.col_iter_mut(),
keys_batches.col_iter_mut()
)
.for_each(|(root, leaves, keys)| {
leaves[0] = *root;
keys.iter_mut().enumerate().for_each(|(level, keys)| {
*keys = ggm_level_expand(leaves, level, cipher0, cipher1);
});
});
(leaves_batches, keys_batches)
}
fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
alpha_batches: &HeapArray<usize, BatchSize>,
keys_tilde_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>,
cipher0: &Aes128Enc,
cipher1: &Aes128Enc,
) -> HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> {
assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
assert!(BatchSize::USIZE > 0);
let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
izip_eq!(
leaves_batches.col_iter_mut(),
alpha_batches,
keys_tilde_batches.col_iter()
)
.for_each(|(leaves, alpha, keys_tilde)| {
let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
let k: usize = TreeDepth::USIZE - 1;
let idx = (1 ^ (alpha >> k)) << k;
leaves[idx] = keys_tilde[0];
(1..TreeDepth::USIZE - 1).for_each(|level| {
let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
let k: usize = TreeDepth::USIZE - level - 1;
let idx = 1 ^ (alpha >> k);
let alpha_k_neg = idx & 1;
let idx = idx << k;
leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
});
let level = TreeDepth::USIZE - 1;
let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
});
leaves_batches
}
#[inline]
fn ggm_level_expand(
nodes: &mut [Gf2_128],
level: usize,
cipher0: &Aes128Enc,
cipher1: &Aes128Enc,
) -> [Gf2_128; 2] {
assert!(nodes.len() >= (1 << (level + 1)));
let mut k0 = Gf2_128::ZERO;
let mut k1 = Gf2_128::ZERO;
let n = nodes.len();
let step = n >> (level + 1);
for j in (0..n).step_by(step << 1) {
let s0 = hash_cr(cipher0, &nodes[j]);
let s1 = hash_cr(cipher1, &nodes[j]);
k0 += &s0;
k1 += &s1;
nodes[j] = s0;
nodes[j + step] = s1;
}
[k0, k1]
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use rand::Rng;
use typenum::U;
use super::*;
use crate::{
algebra::{
elliptic_curve::{Curve25519Ristretto, ScalarField},
field::binary::Gf2_128,
},
izip_eq,
izip_eq_lazy,
random::Random,
};
fn ggm_tree<TreeDepth, TreeLeafCnt, BatchSize>()
where
TreeDepth: Debug + Positive + Mul<BatchSize, Output: Positive>,
TreeLeafCnt: Debug + Positive,
BatchSize: Debug + Positive,
{
let mut rng = crate::random::test_rng();
let session_id = SessionId::random(&mut rng);
let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);
let root_batches = Gf2_128::random_array(&mut rng);
let (cipher0, cipher1) = {
let (mut seed0, mut seed1) = ([0; 16], [0; 16]);
hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);
(
Aes128Enc::new_from_slice(&seed0).unwrap(),
Aes128Enc::new_from_slice(&seed1).unwrap(),
)
};
let (full_leaves_batches, full_keys_batches) = compute_full_tree::<
TreeDepth,
TreeLeafCnt,
BatchSize,
>(&root_batches, &cipher0, &cipher1);
let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
.map(|(alpha, full_keys)| {
HeapArray::from_fn(|k| {
let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
full_keys[k][1 ^ alpha_k]
})
})
.collect();
let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);
let punctured_leaves_batches = compute_punctured_tree::<TreeDepth, TreeLeafCnt, BatchSize>(
&alpha_batches,
&keys_batches,
&cipher0,
&cipher1,
);
izip_eq!(
alpha_batches,
full_leaves_batches.col_iter(),
punctured_leaves_batches.col_iter()
)
.for_each(|(alpha, full_leaves, puncured_leaves)| {
izip_eq_lazy!(full_leaves, puncured_leaves)
.enumerate()
.for_each(|(k, (r, s))| {
if k != alpha {
assert_eq!(r, s);
} else {
assert_ne!(r, s);
}
});
})
}
#[test]
fn test_ggm_tree() {
ggm_tree::<U<2>, U<4>, U<17>>();
ggm_tree::<U<7>, U<128>, U<4>>();
ggm_tree::<U<12>, U<4096>, U<1>>();
}
fn ggm_classic<F, TreeDepth, TreeLeafCnt, BatchSize>()
where
F: FieldExtension,
TreeDepth: Positive + Mul<BatchSize, Output: Positive>,
TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
BatchSize: Positive,
{
let mut rng = crate::random::test_rng();
let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);
let root_batches = Gf2_128::random_array(&mut rng);
let mut ggm = GGM::new(SessionId::random(&mut rng));
let FullTrees {
leaves: full_leaves_batches,
keys: full_keys_batches,
} = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(&root_batches);
let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
.map(|(alpha, full_keys)| {
HeapArray::from_fn(|k| {
let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
full_keys[k][1 ^ alpha_k]
})
})
.collect();
let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);
let PuncturedTrees {
leaves: punctured_leaves_batches,
} = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(
&alpha_batches,
&keys_batches,
);
izip_eq!(
alpha_batches,
full_leaves_batches.col_iter(),
punctured_leaves_batches.col_iter()
)
.for_each(|(alpha, full_leaves, puncured_leaves)| {
izip_eq_lazy!(full_leaves, puncured_leaves)
.enumerate()
.for_each(|(k, (r, s))| {
if k != alpha {
assert_eq!(r, s);
} else {
assert_ne!(r, s);
}
});
})
}
#[test]
fn test_ggm_classic() {
ggm_classic::<Gf2_128, U<4>, U<16>, U<17>>();
ggm_classic::<Gf2_128, U<7>, U<128>, U<4>>();
ggm_classic::<Gf2_128, U<12>, U<4096>, U<1>>();
type Fq = ScalarField<Curve25519Ristretto>;
ggm_classic::<Fq, U<4>, U<16>, U<17>>();
ggm_classic::<Fq, U<7>, U<128>, U<4>>();
ggm_classic::<Fq, U<12>, U<4096>, U<1>>();
}
}