1use std::iter::once;
2
3use aes::{cipher::KeyInit, Aes128Enc};
4use ff::Field;
5
6use crate::{
7 algebra::field::{binary::Gf2_128, FieldExtension},
8 hashing::{hash_ccr, hash_into},
9 random::prg::Aes128Prng,
10 types::{HeapArray, Positive, SessionId},
11};
12
13#[derive(Debug, Clone)]
17pub struct FullTree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive> {
18 pub leaves: HeapArray<F, TreeLeafCnt>,
19 pub keys: HeapArray<Gf2_128, TreeDepth>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PuncturedTree<F: FieldExtension, TreeLeafCnt: Positive> {
24 pub leaves: HeapArray<F, TreeLeafCnt>,
25}
26
27#[derive(Debug, Clone)]
28pub struct PCGGM {
29 cipher: Aes128Enc,
30}
31
32impl PCGGM {
33 pub fn new(session_id: SessionId) -> Self {
34 Self {
35 cipher: Self::new_cipher(session_id),
36 }
37 }
38
39 fn new_cipher(session_id: SessionId) -> Aes128Enc {
40 let mut seed = [0; 16];
41 hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
42 Aes128Enc::new_from_slice(&seed).unwrap()
43 }
44
45 pub fn expand_full_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
46 &mut self,
47 root: Gf2_128,
48 delta: Gf2_128,
49 ) -> FullTree<F, TreeDepth, TreeLeafCnt> {
50 let (seeds, keys) = compute_full_tree(root, delta, &self.cipher);
51 let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
52
53 FullTree { leaves, keys }
54 }
55
56 pub fn expand_punctured_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
57 &mut self,
58 alpha: usize,
59 keys: HeapArray<Gf2_128, TreeDepth>, ) -> PuncturedTree<F, TreeLeafCnt> {
63 let seeds = compute_punctured_tree(alpha, keys, &self.cipher);
64
65 let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
66
67 PuncturedTree { leaves }
68 }
69}
70
71fn generate_from_seeds<F, N: Positive>(
72 cipher: &Aes128Enc,
73 seeds: HeapArray<Gf2_128, N>,
74) -> HeapArray<F, N>
75where
76 F: FieldExtension,
77{
78 seeds
79 .into_iter()
80 .map(|s| F::random(Aes128Prng::new(cipher, s)))
81 .collect()
82}
83
84fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
85 root: Gf2_128,
86 delta: Gf2_128,
87 cipher: &Aes128Enc,
88) -> (
89 HeapArray<Gf2_128, TreeLeafCnt>, HeapArray<Gf2_128, TreeDepth>, ) {
92 assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
93
94 let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
95 leaves[0] = root;
96 leaves[TreeLeafCnt::USIZE >> 1] = delta + root;
97
98 let keys = once(root)
99 .chain((1..TreeDepth::USIZE).map(|level| ggm_level_expand(&mut leaves, level, cipher)[0]))
100 .collect();
101
102 (leaves, keys)
103}
104
105fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
106 alpha: usize,
107 keys_tilde: HeapArray<Gf2_128, TreeDepth>, cipher: &Aes128Enc,
112) -> HeapArray<Gf2_128, TreeLeafCnt> {
113 assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
114
115 let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
116
117 let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
119
120 let k: usize = TreeDepth::USIZE - 1;
123 let idx = (1 ^ (alpha >> k)) << k;
124 leaves[idx] = keys_tilde[0];
125
126 (1..TreeDepth::USIZE - 1).for_each(|level| {
133 let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
134 let k: usize = TreeDepth::USIZE - level - 1;
137 let idx = 1 ^ (alpha >> k);
138 let alpha_k_neg = idx & 1;
139 let idx = idx << k;
140 leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
141 });
142
143 let level = TreeDepth::USIZE - 1;
145 let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
146 leaves[alpha] += keys_tilde[level] - lvl_keys_tilde[alpha & 1];
147 leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
148
149 leaves
150}
151
152#[inline]
153fn ggm_level_expand(nodes: &mut [Gf2_128], level: usize, cipher: &Aes128Enc) -> [Gf2_128; 2] {
154 assert!(nodes.len() >= (1 << (level + 1)));
155 let mut k0 = Gf2_128::ZERO;
156 let mut k1 = Gf2_128::ZERO;
157
158 let n = nodes.len();
159 let step = n >> (level + 1);
160 for j in (0..n).step_by(step << 1) {
161 let s = &nodes[j];
162 let s0 = hash_ccr(cipher, s);
163 let s1 = s + s0;
164
165 k0 += &s0;
166 k1 += &s1;
167
168 nodes[j] = s0;
169 nodes[j + step] = s1;
170 }
171
172 [k0, k1]
173}
174
175#[cfg(test)]
176mod tests {
177 use ff::Field;
178 use rand::Rng;
179 use typenum::U;
180
181 use super::*;
182 use crate::{
183 algebra::{
184 elliptic_curve::{Curve25519Ristretto as C, ScalarField},
185 field::binary::Gf2_128,
186 },
187 izip_eq,
188 random::{self, Random},
189 };
190
191 fn cggm_tree<TreeDepth: Positive, TreeLeafCnt: Positive>() {
192 let mut rng = random::test_rng();
193
194 let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
196 let session_id = SessionId::random(&mut rng);
197
198 let root = Gf2_128::random(&mut rng);
199 let delta = Gf2_128::random(&mut rng);
200
201 let cipher = {
202 let mut seed = [0; 16];
203 hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
204 Aes128Enc::new_from_slice(&seed).unwrap()
205 };
206
207 let (full_leaves, full_keys) =
208 compute_full_tree::<TreeDepth, TreeLeafCnt>(root, delta, &cipher);
209
210 let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
212 let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
213 if alpha_k == 1 {
214 full_keys[k]
215 } else {
216 delta + full_keys[k]
217 }
218 });
219
220 let punctured_leaves =
221 compute_punctured_tree::<TreeDepth, TreeLeafCnt>(alpha, keys, &cipher);
222
223 izip_eq!(full_leaves, punctured_leaves)
224 .enumerate()
225 .for_each(|(k, (r, s))| {
226 if k != alpha {
227 assert_eq!(r, s);
228 } else {
229 assert_eq!(r + delta, s);
230 }
231 });
232 }
233
234 #[test]
235 fn test_cggm_tree() {
236 cggm_tree::<U<4>, U<16>>();
237 cggm_tree::<U<7>, U<128>>();
238 cggm_tree::<U<12>, U<4096>>();
239 }
240
241 fn pcggm<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
242 let mut rng = random::test_rng();
243
244 let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
246 let session_id = SessionId::random(&mut rng);
247
248 let root = Gf2_128::random(&mut rng);
249 let delta = Gf2_128::random(&mut rng);
250
251 let mut ggm = PCGGM::new(session_id);
252 let FullTree {
253 leaves: full_leaves,
254 keys: full_keys,
255 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root, delta);
256
257 let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
259 let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
260 if alpha_k == 1 {
261 full_keys[k]
262 } else {
263 delta + full_keys[k]
264 }
265 });
266
267 let PuncturedTree {
268 leaves: punctured_leaves,
269 } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt>(alpha, keys);
270
271 izip_eq!(full_leaves, punctured_leaves)
272 .enumerate()
273 .for_each(|(k, (r, s))| {
274 if k != alpha {
275 assert_eq!(r, s);
276 } else {
277 assert_ne!(r, s);
278 }
279 });
280 }
281
282 #[test]
283 fn test_pcggm() {
284 pcggm::<Gf2_128, U<4>, U<16>>();
285 pcggm::<Gf2_128, U<7>, U<128>>();
286 pcggm::<Gf2_128, U<12>, U<4096>>();
287
288 pcggm::<ScalarField<C>, U<4>, U<16>>();
289 pcggm::<ScalarField<C>, U<7>, U<128>>();
290 pcggm::<ScalarField<C>, U<12>, U<4096>>();
291 }
292
293 fn pcggm_are_distinct<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
294 let mut rng = random::test_rng();
295 let session_id = SessionId::random(&mut rng);
296
297 let root1 = Gf2_128::random(&mut rng);
299 let root2 = Gf2_128::random(&mut rng);
300
301 let delta = Gf2_128::random(&mut rng);
302
303 let mut ggm = PCGGM::new(session_id);
304 let FullTree {
305 leaves: full_leaves1,
306 ..
307 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root1, delta);
308
309 let FullTree {
310 leaves: full_leaves2,
311 ..
312 } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root2, delta);
313
314 izip_eq!(full_leaves1, full_leaves2).for_each(|(e1, e2)| assert_ne!(e1, e2));
315 }
316
317 #[test]
318 fn test_pcggm_are_distinct() {
319 pcggm_are_distinct::<Gf2_128, U<4>, U<16>>();
320 pcggm_are_distinct::<Gf2_128, U<7>, U<128>>();
321 pcggm_are_distinct::<Gf2_128, U<12>, U<4096>>();
322
323 pcggm_are_distinct::<ScalarField<C>, U<4>, U<16>>();
324 pcggm_are_distinct::<ScalarField<C>, U<7>, U<128>>();
325 pcggm_are_distinct::<ScalarField<C>, U<12>, U<4096>>();
326 }
327}