primitives/random/prf/
pcggm.rs

1use std::iter::once;
2
3use aes::{cipher::KeyInit, Aes128Enc};
4use ff::Field;
5
6use crate::{
7    algebra::field::{binary::Gf2_128, FieldExtension},
8    hashing::{hash_ccr, hash_into},
9    random::prg::Aes128Prng,
10    types::{HeapArray, Positive, SessionId},
11};
12
13// Implementation of the pseudo-random correlated GGM trees from Guo et al. "Half-Tree: Halving the
14// Cost of Tree Expansion in COT and DPF" 2023
15
16#[derive(Debug, Clone)]
17pub struct FullTree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive> {
18    pub leaves: HeapArray<F, TreeLeafCnt>,
19    pub keys: HeapArray<Gf2_128, TreeDepth>,
20}
21
22#[derive(Debug, Clone)]
23pub struct PuncturedTree<F: FieldExtension, TreeLeafCnt: Positive> {
24    pub leaves: HeapArray<F, TreeLeafCnt>,
25}
26
27#[derive(Debug, Clone)]
28pub struct PCGGM {
29    cipher: Aes128Enc,
30}
31
32impl PCGGM {
33    pub fn new(session_id: SessionId) -> Self {
34        Self {
35            cipher: Self::new_cipher(session_id),
36        }
37    }
38
39    fn new_cipher(session_id: SessionId) -> Aes128Enc {
40        let mut seed = [0; 16];
41        hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
42        Aes128Enc::new_from_slice(&seed).unwrap()
43    }
44
45    pub fn expand_full_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
46        &mut self,
47        root: Gf2_128,
48        delta: Gf2_128,
49    ) -> FullTree<F, TreeDepth, TreeLeafCnt> {
50        let (seeds, keys) = compute_full_tree(root, delta, &self.cipher);
51        let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
52
53        FullTree { leaves, keys }
54    }
55
56    pub fn expand_punctured_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
57        &mut self,
58        alpha: usize,
59        keys: HeapArray<Gf2_128, TreeDepth>, /* K^i_{~alpha_i}, where alpha_i is the i-th MSB
60                                              * of
61                                              * alpha. */
62    ) -> PuncturedTree<F, TreeLeafCnt> {
63        let seeds = compute_punctured_tree(alpha, keys, &self.cipher);
64
65        let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);
66
67        PuncturedTree { leaves }
68    }
69}
70
71fn generate_from_seeds<F, N: Positive>(
72    cipher: &Aes128Enc,
73    seeds: HeapArray<Gf2_128, N>,
74) -> HeapArray<F, N>
75where
76    F: FieldExtension,
77{
78    seeds
79        .into_iter()
80        .map(|s| F::random(Aes128Prng::new(cipher, s)))
81        .collect()
82}
83
84fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
85    root: Gf2_128,
86    delta: Gf2_128,
87    cipher: &Aes128Enc,
88) -> (
89    HeapArray<Gf2_128, TreeLeafCnt>, // leaves
90    HeapArray<Gf2_128, TreeDepth>,   // level keys
91) {
92    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
93
94    let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
95    leaves[0] = root;
96    leaves[TreeLeafCnt::USIZE >> 1] = delta + root;
97
98    let keys = once(root)
99        .chain((1..TreeDepth::USIZE).map(|level| ggm_level_expand(&mut leaves, level, cipher)[0]))
100        .collect();
101
102    (leaves, keys)
103}
104
105fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
106    alpha: usize,
107    keys_tilde: HeapArray<Gf2_128, TreeDepth>, /* K^i_{~alpha_i}, where alpha_i is the i-th
108                                                * MSB
109                                                * of
110                                                * alpha. */
111    cipher: &Aes128Enc,
112) -> HeapArray<Gf2_128, TreeLeafCnt> {
113    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
114
115    let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
116
117    // Sanitize input, zero unused bits
118    let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);
119
120    // Start tree at first level with node at position `(~alpha_0) << (depth - 2)` set from OT
121    // result
122    let k: usize = TreeDepth::USIZE - 1;
123    let idx = (1 ^ (alpha >> k)) << k;
124    leaves[idx] = keys_tilde[0];
125
126    // For each level `lvl=1..depth-1` expand CorrelatedGGM tree as usual and let `k =
127    // depth - lvl`. The node at index `(alpha >> k) << (k - 1)` is unknown (gibberish in our
128    // implementation). This node is expanded into 2 siblings with wrong values also:
129    //    * node at index `(alpha >> (k - 1)) << (k - 2)` is the unknown node,
130    //    * its sibling at index `((alpha >> (k - 1)) ^ 1) << (k - 2)` is the node whose value
131    //      should be corrected using input `keys[lvl]`.
132    (1..TreeDepth::USIZE - 1).for_each(|level| {
133        let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
134        // lvl_keys_tilde contains one (the) unknown node and its sibling to be corrected
135        // correct node
136        let k: usize = TreeDepth::USIZE - level - 1;
137        let idx = 1 ^ (alpha >> k);
138        let alpha_k_neg = idx & 1;
139        let idx = idx << k;
140        leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
141    });
142
143    // Last level is specific
144    let level = TreeDepth::USIZE - 1;
145    let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
146    leaves[alpha] += keys_tilde[level] - lvl_keys_tilde[alpha & 1];
147    leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
148
149    leaves
150}
151
152#[inline]
153fn ggm_level_expand(nodes: &mut [Gf2_128], level: usize, cipher: &Aes128Enc) -> [Gf2_128; 2] {
154    assert!(nodes.len() >= (1 << (level + 1)));
155    let mut k0 = Gf2_128::ZERO;
156    let mut k1 = Gf2_128::ZERO;
157
158    let n = nodes.len();
159    let step = n >> (level + 1);
160    for j in (0..n).step_by(step << 1) {
161        let s = &nodes[j];
162        let s0 = hash_ccr(cipher, s);
163        let s1 = s + s0;
164
165        k0 += &s0;
166        k1 += &s1;
167
168        nodes[j] = s0;
169        nodes[j + step] = s1;
170    }
171
172    [k0, k1]
173}
174
175#[cfg(test)]
176mod tests {
177    use ff::Field;
178    use rand::Rng;
179    use typenum::U;
180
181    use super::*;
182    use crate::{
183        algebra::{
184            elliptic_curve::{Curve25519Ristretto as C, ScalarField},
185            field::binary::Gf2_128,
186        },
187        izip_eq,
188        random::{self, Random},
189    };
190
191    fn cggm_tree<TreeDepth: Positive, TreeLeafCnt: Positive>() {
192        let mut rng = random::test_rng();
193
194        // Random set of depth bytes
195        let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
196        let session_id = SessionId::random(&mut rng);
197
198        let root = Gf2_128::random(&mut rng);
199        let delta = Gf2_128::random(&mut rng);
200
201        let cipher = {
202            let mut seed = [0; 16];
203            hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
204            Aes128Enc::new_from_slice(&seed).unwrap()
205        };
206
207        let (full_leaves, full_keys) =
208            compute_full_tree::<TreeDepth, TreeLeafCnt>(root, delta, &cipher);
209
210        // Emulate COT between full_tree.keys and negated alpha bits
211        let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
212            let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
213            if alpha_k == 1 {
214                full_keys[k]
215            } else {
216                delta + full_keys[k]
217            }
218        });
219
220        let punctured_leaves =
221            compute_punctured_tree::<TreeDepth, TreeLeafCnt>(alpha, keys, &cipher);
222
223        izip_eq!(full_leaves, punctured_leaves)
224            .enumerate()
225            .for_each(|(k, (r, s))| {
226                if k != alpha {
227                    assert_eq!(r, s);
228                } else {
229                    assert_eq!(r + delta, s);
230                }
231            });
232    }
233
234    #[test]
235    fn test_cggm_tree() {
236        cggm_tree::<U<4>, U<16>>();
237        cggm_tree::<U<7>, U<128>>();
238        cggm_tree::<U<12>, U<4096>>();
239    }
240
241    fn pcggm<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
242        let mut rng = random::test_rng();
243
244        // Random set of depth bytes
245        let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
246        let session_id = SessionId::random(&mut rng);
247
248        let root = Gf2_128::random(&mut rng);
249        let delta = Gf2_128::random(&mut rng);
250
251        let mut ggm = PCGGM::new(session_id);
252        let FullTree {
253            leaves: full_leaves,
254            keys: full_keys,
255        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root, delta);
256
257        // Emulate OT between full_tree.keys and negated alpha bits
258        let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
259            let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
260            if alpha_k == 1 {
261                full_keys[k]
262            } else {
263                delta + full_keys[k]
264            }
265        });
266
267        let PuncturedTree {
268            leaves: punctured_leaves,
269        } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt>(alpha, keys);
270
271        izip_eq!(full_leaves, punctured_leaves)
272            .enumerate()
273            .for_each(|(k, (r, s))| {
274                if k != alpha {
275                    assert_eq!(r, s);
276                } else {
277                    assert_ne!(r, s);
278                }
279            });
280    }
281
282    #[test]
283    fn test_pcggm() {
284        pcggm::<Gf2_128, U<4>, U<16>>();
285        pcggm::<Gf2_128, U<7>, U<128>>();
286        pcggm::<Gf2_128, U<12>, U<4096>>();
287
288        pcggm::<ScalarField<C>, U<4>, U<16>>();
289        pcggm::<ScalarField<C>, U<7>, U<128>>();
290        pcggm::<ScalarField<C>, U<12>, U<4096>>();
291    }
292
293    fn pcggm_are_distinct<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
294        let mut rng = random::test_rng();
295        let session_id = SessionId::random(&mut rng);
296
297        // Random set of depth bytes
298        let root1 = Gf2_128::random(&mut rng);
299        let root2 = Gf2_128::random(&mut rng);
300
301        let delta = Gf2_128::random(&mut rng);
302
303        let mut ggm = PCGGM::new(session_id);
304        let FullTree {
305            leaves: full_leaves1,
306            ..
307        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root1, delta);
308
309        let FullTree {
310            leaves: full_leaves2,
311            ..
312        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root2, delta);
313
314        izip_eq!(full_leaves1, full_leaves2).for_each(|(e1, e2)| assert_ne!(e1, e2));
315    }
316
317    #[test]
318    fn test_pcggm_are_distinct() {
319        pcggm_are_distinct::<Gf2_128, U<4>, U<16>>();
320        pcggm_are_distinct::<Gf2_128, U<7>, U<128>>();
321        pcggm_are_distinct::<Gf2_128, U<12>, U<4096>>();
322
323        pcggm_are_distinct::<ScalarField<C>, U<4>, U<16>>();
324        pcggm_are_distinct::<ScalarField<C>, U<7>, U<128>>();
325        pcggm_are_distinct::<ScalarField<C>, U<12>, U<4096>>();
326    }
327}