1use std::ops::Mul;
2
3use aes::{cipher::KeyInit, Aes128Enc};
4use ff::Field;
5use itertools::enumerate;
6
7use crate::{
8 algebra::field::{binary::Gf2_128, FieldExtension},
9 hashing::{hash_cr, hash_into},
10 izip_eq,
11 random::prg::Aes128Prng,
12 types::{HeapArray, HeapMatrix, Positive, SessionId},
13};
14
15#[derive(Debug, Clone)]
16pub struct FullTrees<
17 F: FieldExtension,
18 TreeDepth: Positive,
19 TreeLeafCnt: Positive,
20 BatchSize: Positive,
21> {
22 pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
23 pub keys: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>,
24}
25
26#[derive(Debug, Clone)]
27pub struct PuncturedTrees<F: FieldExtension, TreeLeafCnt: Positive, BatchSize: Positive> {
28 pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
29}
30
31#[derive(Debug, Clone)]
32pub struct GGM {
33 ciphers: (Aes128Enc, Aes128Enc),
34}
35
36impl GGM {
37 pub fn new(session_id: SessionId) -> Self {
38 Self {
39 ciphers: Self::new_ciphers(session_id),
40 }
41 }
42
43 fn new_ciphers(session_id: SessionId) -> (Aes128Enc, Aes128Enc) {
44 let (mut seed0, mut seed1) = ([0; 16], [0; 16]);
45
46 hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
47 hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);
48
49 (
50 Aes128Enc::new_from_slice(&seed0).unwrap(),
51 Aes128Enc::new_from_slice(&seed1).unwrap(),
52 )
53 }
54
55 pub fn expand_full_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
56 &mut self,
57 root_batches: &HeapArray<Gf2_128, BatchSize>,
58 ) -> FullTrees<F, TreeDepth, TreeLeafCnt, BatchSize>
59 where
60 F: FieldExtension,
61 TreeDepth: Positive,
62 TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
63 BatchSize: Positive,
64 {
65 let (seeds, keys) = compute_full_tree(root_batches, &self.ciphers.0, &self.ciphers.1);
66 let leaves =
67 generate_from_seeds::<F, _, BatchSize>(&self.ciphers.0, &self.ciphers.1, seeds);
68
69 FullTrees { leaves, keys }
70 }
71
72 pub fn expand_punctured_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
73 &mut self,
74 alpha_batches: &HeapArray<usize, BatchSize>,
75 keys_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>, ) -> PuncturedTrees<F, TreeLeafCnt, BatchSize>
80 where
81 F: FieldExtension,
82 TreeDepth: Positive,
83 TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
84 BatchSize: Positive,
85 {
86 let seeds = compute_punctured_tree(
87 alpha_batches,
88 keys_batches,
89 &self.ciphers.0,
90 &self.ciphers.1,
91 );
92
93 let leaves = generate_from_seeds::<F, _, _>(&self.ciphers.0, &self.ciphers.1, seeds);
94
95 PuncturedTrees { leaves }
96 }
97}
98
99fn generate_from_seeds<F, TreeLeafCnt, BatchSize>(
100 enc_left: &Aes128Enc,
101 enc_right: &Aes128Enc,
102 seeds: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>,
103) -> HeapMatrix<F, TreeLeafCnt, BatchSize>
104where
105 F: FieldExtension,
106 TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
107 BatchSize: Positive,
108{
109 let ciphers = [enc_left, enc_right];
110
111 let tmp: Vec<_> = enumerate(seeds.into_flat_iter())
112 .map(|(i, s)| F::random(Aes128Prng::new(ciphers[i % 2], s)))
113 .collect();
114
115 HeapMatrix::<_, TreeLeafCnt, BatchSize>::try_from(tmp).unwrap()
116}
117
118#[allow(clippy::type_complexity)]
119fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
120 root_batches: &HeapArray<Gf2_128, BatchSize>,
121 cipher0: &Aes128Enc,
122 cipher1: &Aes128Enc,
123) -> (
124 HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>, HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>, ) {
127 assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
128 assert!(BatchSize::USIZE > 0);
129
130 let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
131 let mut keys_batches: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize> = Default::default();
132
133 izip_eq!(
134 root_batches,
135 leaves_batches.col_iter_mut(),
136 keys_batches.col_iter_mut()
137 )
138 .for_each(|(root, leaves, keys)| {
139 leaves[0] = *root;
140
141 keys.iter_mut().enumerate().for_each(|(level, keys)| {
142 *keys = ggm_level_expand(leaves, level, cipher0, cipher1);
143 });
144 });
145
146 (leaves_batches, keys_batches)
147}
148
149fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
150 alpha_batches: &HeapArray<usize, BatchSize>,
151 keys_tilde_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>, cipher0: &Aes128Enc,
158 cipher1: &Aes128Enc,
159) -> HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> {
160 assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
161 assert!(BatchSize::USIZE > 0);
162
163 let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
164
165 izip_eq!(
166 leaves_batches.col_iter_mut(),
167 alpha_batches,
168 keys_tilde_batches.col_iter()
169 )
170 .for_each(|(leaves, alpha, keys_tilde)| {
171 let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
173
174 let k: usize = TreeDepth::USIZE - 1;
177 let idx = (1 ^ (alpha >> k)) << k;
178 leaves[idx] = keys_tilde[0];
179
180 (1..TreeDepth::USIZE - 1).for_each(|level| {
188 let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
189 let k: usize = TreeDepth::USIZE - level - 1;
192 let idx = 1 ^ (alpha >> k);
193 let alpha_k_neg = idx & 1;
194 let idx = idx << k;
195 leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
196 });
197
198 let level = TreeDepth::USIZE - 1;
200 let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
201 leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
202 });
203
204 leaves_batches
205}
206
207#[inline]
208fn ggm_level_expand(
209 nodes: &mut [Gf2_128],
210 level: usize,
211 cipher0: &Aes128Enc,
212 cipher1: &Aes128Enc,
213) -> [Gf2_128; 2] {
214 assert!(nodes.len() >= (1 << (level + 1)));
215 let mut k0 = Gf2_128::ZERO;
216 let mut k1 = Gf2_128::ZERO;
217
218 let n = nodes.len();
219 let step = n >> (level + 1);
220 for j in (0..n).step_by(step << 1) {
221 let s0 = hash_cr(cipher0, &nodes[j]);
222 let s1 = hash_cr(cipher1, &nodes[j]);
223
224 k0 += &s0;
225 k1 += &s1;
226
227 nodes[j] = s0;
228 nodes[j + step] = s1;
229 }
230
231 [k0, k1]
232}
233
234#[cfg(test)]
235mod tests {
236 use std::fmt::Debug;
237
238 use rand::Rng;
239 use typenum::U;
240
241 use super::*;
242 use crate::{
243 algebra::{
244 elliptic_curve::{Curve25519Ristretto, ScalarField},
245 field::binary::Gf2_128,
246 },
247 izip_eq,
248 izip_eq_lazy,
249 random::Random,
250 };
251
252 fn ggm_tree<TreeDepth, TreeLeafCnt, BatchSize>()
253 where
254 TreeDepth: Debug + Positive + Mul<BatchSize, Output: Positive>,
255 TreeLeafCnt: Debug + Positive,
256 BatchSize: Debug + Positive,
257 {
258 let mut rng = crate::random::test_rng();
259 let session_id = SessionId::random(&mut rng);
260
261 let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);
263
264 let root_batches = Gf2_128::random_elements(&mut rng);
265
266 let (cipher0, cipher1) = {
267 let (mut seed0, mut seed1) = ([0; 16], [0; 16]);
268
269 hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
270 hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);
271
272 (
273 Aes128Enc::new_from_slice(&seed0).unwrap(),
274 Aes128Enc::new_from_slice(&seed1).unwrap(),
275 )
276 };
277
278 let (full_leaves_batches, full_keys_batches) = compute_full_tree::<
279 TreeDepth,
280 TreeLeafCnt,
281 BatchSize,
282 >(&root_batches, &cipher0, &cipher1);
283
284 let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
286 .map(|(alpha, full_keys)| {
287 HeapArray::from_fn(|k| {
288 let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
289 full_keys[k][1 ^ alpha_k]
290 })
291 })
292 .collect();
293
294 let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);
295
296 let punctured_leaves_batches = compute_punctured_tree::<TreeDepth, TreeLeafCnt, BatchSize>(
297 &alpha_batches,
298 &keys_batches,
299 &cipher0,
300 &cipher1,
301 );
302
303 izip_eq!(
304 alpha_batches,
305 full_leaves_batches.col_iter(),
306 punctured_leaves_batches.col_iter()
307 )
308 .for_each(|(alpha, full_leaves, puncured_leaves)| {
309 izip_eq_lazy!(full_leaves, puncured_leaves)
310 .enumerate()
311 .for_each(|(k, (r, s))| {
312 if k != alpha {
313 assert_eq!(r, s);
314 } else {
315 assert_ne!(r, s);
316 }
317 });
318 })
319 }
320
321 #[test]
322 fn test_ggm_tree() {
323 ggm_tree::<U<2>, U<4>, U<17>>();
324 ggm_tree::<U<7>, U<128>, U<4>>();
325 ggm_tree::<U<12>, U<4096>, U<1>>();
326 }
327
328 fn ggm_classic<F, TreeDepth, TreeLeafCnt, BatchSize>()
329 where
330 F: FieldExtension,
331 TreeDepth: Positive + Mul<BatchSize, Output: Positive>,
332 TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
333 BatchSize: Positive,
334 {
335 let mut rng = crate::random::test_rng();
336
337 let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);
339
340 let root_batches = Gf2_128::random_elements(&mut rng);
341
342 let mut ggm = GGM::new(SessionId::random(&mut rng));
343 let FullTrees {
344 leaves: full_leaves_batches,
345 keys: full_keys_batches,
346 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(&root_batches);
347
348 let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
350 .map(|(alpha, full_keys)| {
351 HeapArray::from_fn(|k| {
352 let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
353 full_keys[k][1 ^ alpha_k]
354 })
355 })
356 .collect();
357
358 let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);
359
360 let PuncturedTrees {
361 leaves: punctured_leaves_batches,
362 } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(
363 &alpha_batches,
364 &keys_batches,
365 );
366
367 izip_eq!(
368 alpha_batches,
369 full_leaves_batches.col_iter(),
370 punctured_leaves_batches.col_iter()
371 )
372 .for_each(|(alpha, full_leaves, puncured_leaves)| {
373 izip_eq_lazy!(full_leaves, puncured_leaves)
374 .enumerate()
375 .for_each(|(k, (r, s))| {
376 if k != alpha {
377 assert_eq!(r, s);
378 } else {
379 assert_ne!(r, s);
380 }
381 });
382 })
383 }
384
385 #[test]
386 fn test_ggm_classic() {
387 ggm_classic::<Gf2_128, U<4>, U<16>, U<17>>();
388 ggm_classic::<Gf2_128, U<7>, U<128>, U<4>>();
389 ggm_classic::<Gf2_128, U<12>, U<4096>, U<1>>();
390
391 type Fq = ScalarField<Curve25519Ristretto>;
392 ggm_classic::<Fq, U<4>, U<16>, U<17>>();
393 ggm_classic::<Fq, U<7>, U<128>, U<4>>();
394 ggm_classic::<Fq, U<12>, U<4096>, U<1>>();
395 }
396}