1use std::iter::once;
2
3use aes::{cipher::KeyInit, Aes128Enc};
4use ff::Field;
5
6use crate::{
7 algebra::field::{binary::Gf2_128, FieldExtension},
8 hashing::{hash_ccr, hash_into},
9 random::prg::Aes128Prng,
10 types::{HeapArray, Positive, SessionId},
11};
12
13#[derive(Debug, Clone)]
17pub struct FullTree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive> {
18 pub leaves: HeapArray<F, TreeLeafCnt>,
19 pub keys: HeapArray<Gf2_128, TreeDepth>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PuncturedTree<F: FieldExtension, TreeLeafCnt: Positive> {
24 pub leaves: HeapArray<F, TreeLeafCnt>,
25}
26
27#[derive(Debug, Clone)]
28pub struct PCGGM {
29 cipher: Aes128Enc,
30}
31
32impl PCGGM {
33 pub fn new(session_id: SessionId) -> Self {
34 Self {
35 cipher: Self::new_cipher(session_id),
36 }
37 }
38
39 fn new_cipher(session_id: SessionId) -> Aes128Enc {
40 let mut seed = [0; 16];
41 hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
42 Aes128Enc::new_from_slice(&seed).unwrap()
43 }
44
45 pub fn expand_full_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
46 &mut self,
47 root: Gf2_128,
48 delta: Gf2_128,
49 ) -> FullTree<F, TreeDepth, TreeLeafCnt> {
50 let (seeds, keys) = compute_full_tree(root, delta, &self.cipher);
51 let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
52
53 FullTree { leaves, keys }
54 }
55
56 pub fn expand_punctured_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
57 &mut self,
58 alpha: usize,
59 keys: HeapArray<Gf2_128, TreeDepth>, ) -> PuncturedTree<F, TreeLeafCnt> {
63 let seeds = compute_punctured_tree(alpha, keys, &self.cipher);
64
65 let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
66
67 PuncturedTree { leaves }
68 }
69}
70
71fn generate_from_seeds<F, N: Positive>(
72 cipher: &Aes128Enc,
73 seeds: HeapArray<Gf2_128, N>,
74) -> HeapArray<F, N>
75where
76 F: FieldExtension,
77{
78 seeds
79 .into_iter()
80 .map(|s| <F as Field>::random(Aes128Prng::new(cipher, s)))
81 .collect()
82}
83
84fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
85 root: Gf2_128,
86 delta: Gf2_128,
87 cipher: &Aes128Enc,
88) -> (
89 HeapArray<Gf2_128, TreeLeafCnt>, HeapArray<Gf2_128, TreeDepth>, ) {
92 assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
93
94 let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
95 leaves[0] = root;
96 leaves[TreeLeafCnt::USIZE >> 1] = delta + root;
97
98 let keys = once(root)
99 .chain((1..TreeDepth::USIZE).map(|level| ggm_level_expand(&mut leaves, level, cipher)[0]))
100 .collect();
101
102 (leaves, keys)
103}
104
105fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
106 alpha: usize,
107 keys_tilde: HeapArray<Gf2_128, TreeDepth>, cipher: &Aes128Enc,
112) -> HeapArray<Gf2_128, TreeLeafCnt> {
113 assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
114
115 let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
116
117 let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
119
120 let k: usize = TreeDepth::USIZE - 1;
123 let idx = (1 ^ (alpha >> k)) << k;
124 leaves[idx] = keys_tilde[0];
125
126 (1..TreeDepth::USIZE - 1).for_each(|level| {
133 let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
134 let k: usize = TreeDepth::USIZE - level - 1;
137 let idx = 1 ^ (alpha >> k);
138 let alpha_k_neg = idx & 1;
139 let idx = idx << k;
140 leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
141 });
142
143 let level = TreeDepth::USIZE - 1;
145 let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
146 leaves[alpha] += keys_tilde[level] - lvl_keys_tilde[alpha & 1];
147 leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
148
149 leaves
150}
151
152#[inline]
153fn ggm_level_expand(nodes: &mut [Gf2_128], level: usize, cipher: &Aes128Enc) -> [Gf2_128; 2] {
154 assert!(nodes.len() >= (1 << (level + 1)));
155 let mut k0 = Gf2_128::ZERO;
156 let mut k1 = Gf2_128::ZERO;
157
158 let n = nodes.len();
159 let step = n >> (level + 1);
160 for j in (0..n).step_by(step << 1) {
161 let s = &nodes[j];
162 let s0 = hash_ccr(cipher, s);
163 let s1 = s + s0;
164
165 k0 += &s0;
166 k1 += &s1;
167
168 nodes[j] = s0;
169 nodes[j + step] = s1;
170 }
171
172 [k0, k1]
173}
174
175#[cfg(test)]
176mod tests {
177 use rand::Rng;
178 use typenum::U;
179
180 use super::*;
181 use crate::{
182 algebra::{
183 elliptic_curve::{Curve25519Ristretto as C, ScalarField},
184 field::binary::Gf2_128,
185 },
186 izip_eq,
187 random::{self, Random},
188 };
189
190 fn cggm_tree<TreeDepth: Positive, TreeLeafCnt: Positive>() {
191 let mut rng = random::test_rng();
192
193 let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
195 let session_id = SessionId::random(&mut rng);
196
197 let root: Gf2_128 = Random::random(&mut rng);
198 let delta: Gf2_128 = Random::random(&mut rng);
199
200 let cipher = {
201 let mut seed = [0; 16];
202 hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
203 Aes128Enc::new_from_slice(&seed).unwrap()
204 };
205
206 let (full_leaves, full_keys) =
207 compute_full_tree::<TreeDepth, TreeLeafCnt>(root, delta, &cipher);
208
209 let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
211 let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
212 if alpha_k == 1 {
213 full_keys[k]
214 } else {
215 delta + full_keys[k]
216 }
217 });
218
219 let punctured_leaves =
220 compute_punctured_tree::<TreeDepth, TreeLeafCnt>(alpha, keys, &cipher);
221
222 izip_eq!(full_leaves, punctured_leaves)
223 .enumerate()
224 .for_each(|(k, (r, s))| {
225 if k != alpha {
226 assert_eq!(r, s);
227 } else {
228 assert_eq!(r + delta, s);
229 }
230 });
231 }
232
233 #[test]
234 fn test_cggm_tree() {
235 cggm_tree::<U<4>, U<16>>();
236 cggm_tree::<U<7>, U<128>>();
237 cggm_tree::<U<12>, U<4096>>();
238 }
239
240 fn pcggm<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
241 let mut rng = random::test_rng();
242
243 let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
245 let session_id = SessionId::random(&mut rng);
246
247 let root: Gf2_128 = Random::random(&mut rng);
248 let delta: Gf2_128 = Random::random(&mut rng);
249
250 let mut ggm = PCGGM::new(session_id);
251 let FullTree {
252 leaves: full_leaves,
253 keys: full_keys,
254 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root, delta);
255
256 let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
258 let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
259 if alpha_k == 1 {
260 full_keys[k]
261 } else {
262 delta + full_keys[k]
263 }
264 });
265
266 let PuncturedTree {
267 leaves: punctured_leaves,
268 } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt>(alpha, keys);
269
270 izip_eq!(full_leaves, punctured_leaves)
271 .enumerate()
272 .for_each(|(k, (r, s))| {
273 if k != alpha {
274 assert_eq!(r, s);
275 } else {
276 assert_ne!(r, s);
277 }
278 });
279 }
280
281 #[test]
282 fn test_pcggm() {
283 pcggm::<Gf2_128, U<4>, U<16>>();
284 pcggm::<Gf2_128, U<7>, U<128>>();
285 pcggm::<Gf2_128, U<12>, U<4096>>();
286
287 pcggm::<ScalarField<C>, U<4>, U<16>>();
288 pcggm::<ScalarField<C>, U<7>, U<128>>();
289 pcggm::<ScalarField<C>, U<12>, U<4096>>();
290 }
291
292 fn pcggm_are_distinct<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
293 let mut rng = random::test_rng();
294 let session_id = SessionId::random(&mut rng);
295
296 let root1: Gf2_128 = Random::random(&mut rng);
298 let root2: Gf2_128 = Random::random(&mut rng);
299
300 let delta: Gf2_128 = Random::random(&mut rng);
301
302 let mut ggm = PCGGM::new(session_id);
303 let FullTree {
304 leaves: full_leaves1,
305 ..
306 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root1, delta);
307
308 let FullTree {
309 leaves: full_leaves2,
310 ..
311 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root2, delta);
312
313 izip_eq!(full_leaves1, full_leaves2).for_each(|(e1, e2)| assert_ne!(e1, e2));
314 }
315
316 #[test]
317 fn test_pcggm_are_distinct() {
318 pcggm_are_distinct::<Gf2_128, U<4>, U<16>>();
319 pcggm_are_distinct::<Gf2_128, U<7>, U<128>>();
320 pcggm_are_distinct::<Gf2_128, U<12>, U<4096>>();
321
322 pcggm_are_distinct::<ScalarField<C>, U<4>, U<16>>();
323 pcggm_are_distinct::<ScalarField<C>, U<7>, U<128>>();
324 pcggm_are_distinct::<ScalarField<C>, U<12>, U<4096>>();
325 }
326}