Skip to main content

primitives/random/prf/
ggm.rs

1use std::ops::Mul;
2
3use aes::{cipher::KeyInit, Aes128Enc};
4use ff::Field;
5use itertools::enumerate;
6
7use crate::{
8    algebra::field::{binary::Gf2_128, FieldExtension},
9    hashing::{hash_cr, hash_into},
10    izip_eq,
11    random::prg::Aes128Prng,
12    types::{HeapArray, HeapMatrix, Positive, SessionId},
13};
14
15#[derive(Debug, Clone)]
16pub struct FullTrees<
17    F: FieldExtension,
18    TreeDepth: Positive,
19    TreeLeafCnt: Positive,
20    BatchSize: Positive,
21> {
22    pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
23    pub keys: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>,
24}
25
26#[derive(Debug, Clone)]
27pub struct PuncturedTrees<F: FieldExtension, TreeLeafCnt: Positive, BatchSize: Positive> {
28    pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
29}
30
31#[derive(Debug, Clone)]
32pub struct GGM {
33    ciphers: (Aes128Enc, Aes128Enc),
34}
35
36impl GGM {
37    pub fn new(session_id: SessionId) -> Self {
38        Self {
39            ciphers: Self::new_ciphers(session_id),
40        }
41    }
42
43    fn new_ciphers(session_id: SessionId) -> (Aes128Enc, Aes128Enc) {
44        let (mut seed0, mut seed1) = ([0; 16], [0; 16]);
45
46        hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
47        hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);
48
49        (
50            Aes128Enc::new_from_slice(&seed0).unwrap(),
51            Aes128Enc::new_from_slice(&seed1).unwrap(),
52        )
53    }
54
55    pub fn expand_full_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
56        &mut self,
57        root_batches: &HeapArray<Gf2_128, BatchSize>,
58    ) -> FullTrees<F, TreeDepth, TreeLeafCnt, BatchSize>
59    where
60        F: FieldExtension,
61        TreeDepth: Positive,
62        TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
63        BatchSize: Positive,
64    {
65        let (seeds, keys) = compute_full_tree(root_batches, &self.ciphers.0, &self.ciphers.1);
66        let leaves =
67            generate_from_seeds::<F, _, BatchSize>(&self.ciphers.0, &self.ciphers.1, seeds);
68
69        FullTrees { leaves, keys }
70    }
71
72    pub fn expand_punctured_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
73        &mut self,
74        alpha_batches: &HeapArray<usize, BatchSize>,
75        keys_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>, /* K^i_{~alpha_i}, where
76                                                                   * alpha_i is the i-th MSB
77                                                                   * of
78                                                                   * alpha. */
79    ) -> PuncturedTrees<F, TreeLeafCnt, BatchSize>
80    where
81        F: FieldExtension,
82        TreeDepth: Positive,
83        TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
84        BatchSize: Positive,
85    {
86        let seeds = compute_punctured_tree(
87            alpha_batches,
88            keys_batches,
89            &self.ciphers.0,
90            &self.ciphers.1,
91        );
92
93        let leaves = generate_from_seeds::<F, _, _>(&self.ciphers.0, &self.ciphers.1, seeds);
94
95        PuncturedTrees { leaves }
96    }
97}
98
99fn generate_from_seeds<F, TreeLeafCnt, BatchSize>(
100    enc_left: &Aes128Enc,
101    enc_right: &Aes128Enc,
102    seeds: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>,
103) -> HeapMatrix<F, TreeLeafCnt, BatchSize>
104where
105    F: FieldExtension,
106    TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
107    BatchSize: Positive,
108{
109    let ciphers = [enc_left, enc_right];
110
111    let tmp: Vec<_> = enumerate(seeds.into_flat_iter())
112        .map(|(i, s)| F::random(Aes128Prng::new(ciphers[i % 2], s)))
113        .collect();
114
115    HeapMatrix::<_, TreeLeafCnt, BatchSize>::try_from(tmp).unwrap()
116}
117
118#[allow(clippy::type_complexity)]
119fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
120    root_batches: &HeapArray<Gf2_128, BatchSize>,
121    cipher0: &Aes128Enc,
122    cipher1: &Aes128Enc,
123) -> (
124    HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>,    // leaves
125    HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>, // level keys
126) {
127    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
128    assert!(BatchSize::USIZE > 0);
129
130    let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
131    let mut keys_batches: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize> = Default::default();
132
133    izip_eq!(
134        root_batches,
135        leaves_batches.col_iter_mut(),
136        keys_batches.col_iter_mut()
137    )
138    .for_each(|(root, leaves, keys)| {
139        leaves[0] = *root;
140
141        keys.iter_mut().enumerate().for_each(|(level, keys)| {
142            *keys = ggm_level_expand(leaves, level, cipher0, cipher1);
143        });
144    });
145
146    (leaves_batches, keys_batches)
147}
148
149fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
150    alpha_batches: &HeapArray<usize, BatchSize>,
151    keys_tilde_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>, /* K^ij_{~alpha_ij}, where
152                                                                     * alpha_ij
153                                                                     * is the i-th
154                                                                     * MSB
155                                                                     * of
156                                                                     * alpha_j. */
157    cipher0: &Aes128Enc,
158    cipher1: &Aes128Enc,
159) -> HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> {
160    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
161    assert!(BatchSize::USIZE > 0);
162
163    let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
164
165    izip_eq!(
166        leaves_batches.col_iter_mut(),
167        alpha_batches,
168        keys_tilde_batches.col_iter()
169    )
170    .for_each(|(leaves, alpha, keys_tilde)| {
171        // Sanitize input, zero unused bits
172        let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
173
174        // Start tree at first level with node at position `(~alpha_0) << (depth - 2)` set from
175        // OT result
176        let k: usize = TreeDepth::USIZE - 1;
177        let idx = (1 ^ (alpha >> k)) << k;
178        leaves[idx] = keys_tilde[0];
179
180        // For each level `lvl=1..depth-1` expand GGM tree as usual and let `k = depth - lvl`.
181        // The node at index `(alpha >> k) << (k - 1)` is unknown (gibberish in our
182        // implementation). This node is expanded into 2 siblings with wrong values
183        // also:
184        //    * node at index `(alpha >> (k - 1)) << (k - 2)` is the unknown node,
185        //    * its sibling at index `((alpha >> (k - 1)) ^ 1) << (k - 2)` is the node whose value
186        //      should be corrected using input `keys[lvl]`.
187        (1..TreeDepth::USIZE - 1).for_each(|level| {
188            let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
189            // lvl_keys_tilde contains one (the) unknown node and its sibling to be corrected
190            // correct node
191            let k: usize = TreeDepth::USIZE - level - 1;
192            let idx = 1 ^ (alpha >> k);
193            let alpha_k_neg = idx & 1;
194            let idx = idx << k;
195            leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
196        });
197
198        // Last level is specific
199        let level = TreeDepth::USIZE - 1;
200        let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
201        leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
202    });
203
204    leaves_batches
205}
206
207#[inline]
208fn ggm_level_expand(
209    nodes: &mut [Gf2_128],
210    level: usize,
211    cipher0: &Aes128Enc,
212    cipher1: &Aes128Enc,
213) -> [Gf2_128; 2] {
214    assert!(nodes.len() >= (1 << (level + 1)));
215    let mut k0 = Gf2_128::ZERO;
216    let mut k1 = Gf2_128::ZERO;
217
218    let n = nodes.len();
219    let step = n >> (level + 1);
220    for j in (0..n).step_by(step << 1) {
221        let s0 = hash_cr(cipher0, &nodes[j]);
222        let s1 = hash_cr(cipher1, &nodes[j]);
223
224        k0 += &s0;
225        k1 += &s1;
226
227        nodes[j] = s0;
228        nodes[j + step] = s1;
229    }
230
231    [k0, k1]
232}
233
234#[cfg(test)]
235mod tests {
236    use std::fmt::Debug;
237
238    use rand::Rng;
239    use typenum::U;
240
241    use super::*;
242    use crate::{
243        algebra::{
244            elliptic_curve::{Curve25519Ristretto, ScalarField},
245            field::binary::Gf2_128,
246        },
247        izip_eq,
248        izip_eq_lazy,
249        random::Random,
250    };
251
252    fn ggm_tree<TreeDepth, TreeLeafCnt, BatchSize>()
253    where
254        TreeDepth: Debug + Positive + Mul<BatchSize, Output: Positive>,
255        TreeLeafCnt: Debug + Positive,
256        BatchSize: Debug + Positive,
257    {
258        let mut rng = crate::random::test_rng();
259        let session_id = SessionId::random(&mut rng);
260
261        // Random set of depth bytes
262        let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);
263
264        let root_batches = Gf2_128::random_elements(&mut rng);
265
266        let (cipher0, cipher1) = {
267            let (mut seed0, mut seed1) = ([0; 16], [0; 16]);
268
269            hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
270            hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);
271
272            (
273                Aes128Enc::new_from_slice(&seed0).unwrap(),
274                Aes128Enc::new_from_slice(&seed1).unwrap(),
275            )
276        };
277
278        let (full_leaves_batches, full_keys_batches) = compute_full_tree::<
279            TreeDepth,
280            TreeLeafCnt,
281            BatchSize,
282        >(&root_batches, &cipher0, &cipher1);
283
284        // Emulate OT between full_tree.keys and negated alpha bits
285        let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
286            .map(|(alpha, full_keys)| {
287                HeapArray::from_fn(|k| {
288                    let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
289                    full_keys[k][1 ^ alpha_k]
290                })
291            })
292            .collect();
293
294        let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);
295
296        let punctured_leaves_batches = compute_punctured_tree::<TreeDepth, TreeLeafCnt, BatchSize>(
297            &alpha_batches,
298            &keys_batches,
299            &cipher0,
300            &cipher1,
301        );
302
303        izip_eq!(
304            alpha_batches,
305            full_leaves_batches.col_iter(),
306            punctured_leaves_batches.col_iter()
307        )
308        .for_each(|(alpha, full_leaves, puncured_leaves)| {
309            izip_eq_lazy!(full_leaves, puncured_leaves)
310                .enumerate()
311                .for_each(|(k, (r, s))| {
312                    if k != alpha {
313                        assert_eq!(r, s);
314                    } else {
315                        assert_ne!(r, s);
316                    }
317                });
318        })
319    }
320
321    #[test]
322    fn test_ggm_tree() {
323        ggm_tree::<U<2>, U<4>, U<17>>();
324        ggm_tree::<U<7>, U<128>, U<4>>();
325        ggm_tree::<U<12>, U<4096>, U<1>>();
326    }
327
328    fn ggm_classic<F, TreeDepth, TreeLeafCnt, BatchSize>()
329    where
330        F: FieldExtension,
331        TreeDepth: Positive + Mul<BatchSize, Output: Positive>,
332        TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
333        BatchSize: Positive,
334    {
335        let mut rng = crate::random::test_rng();
336
337        // Random set of depth bytes
338        let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);
339
340        let root_batches = Gf2_128::random_elements(&mut rng);
341
342        let mut ggm = GGM::new(SessionId::random(&mut rng));
343        let FullTrees {
344            leaves: full_leaves_batches,
345            keys: full_keys_batches,
346        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(&root_batches);
347
348        // Emulate OT between full_tree.keys and negated alpha bits
349        let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
350            .map(|(alpha, full_keys)| {
351                HeapArray::from_fn(|k| {
352                    let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
353                    full_keys[k][1 ^ alpha_k]
354                })
355            })
356            .collect();
357
358        let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);
359
360        let PuncturedTrees {
361            leaves: punctured_leaves_batches,
362        } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(
363            &alpha_batches,
364            &keys_batches,
365        );
366
367        izip_eq!(
368            alpha_batches,
369            full_leaves_batches.col_iter(),
370            punctured_leaves_batches.col_iter()
371        )
372        .for_each(|(alpha, full_leaves, puncured_leaves)| {
373            izip_eq_lazy!(full_leaves, puncured_leaves)
374                .enumerate()
375                .for_each(|(k, (r, s))| {
376                    if k != alpha {
377                        assert_eq!(r, s);
378                    } else {
379                        assert_ne!(r, s);
380                    }
381                });
382        })
383    }
384
385    #[test]
386    fn test_ggm_classic() {
387        ggm_classic::<Gf2_128, U<4>, U<16>, U<17>>();
388        ggm_classic::<Gf2_128, U<7>, U<128>, U<4>>();
389        ggm_classic::<Gf2_128, U<12>, U<4096>, U<1>>();
390
391        type Fq = ScalarField<Curve25519Ristretto>;
392        ggm_classic::<Fq, U<4>, U<16>, U<17>>();
393        ggm_classic::<Fq, U<7>, U<128>, U<4>>();
394        ggm_classic::<Fq, U<12>, U<4096>, U<1>>();
395    }
396}