arcium-primitives 0.4.2

Arcium primitives
Documentation
use std::iter::once;

use aes::{cipher::KeyInit, Aes128Enc};
use ff::Field;

use crate::{
    algebra::field::{binary::Gf2_128, FieldExtension},
    hashing::{hash_ccr, hash_into},
    random::prg::Aes128Prng,
    types::{HeapArray, Positive, SessionId},
};

// Implementation of the pseudo-random correlated GGM trees from Guo et al. "Half-Tree: Halving the
// Cost of Tree Expansion in COT and DPF" 2023

#[derive(Debug, Clone)]
pub struct FullTree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive> {
    pub leaves: HeapArray<F, TreeLeafCnt>,
    pub keys: HeapArray<Gf2_128, TreeDepth>,
}

#[derive(Debug, Clone)]
pub struct PuncturedTree<F: FieldExtension, TreeLeafCnt: Positive> {
    pub leaves: HeapArray<F, TreeLeafCnt>,
}

#[derive(Debug, Clone)]
pub struct PCGGM {
    cipher: Aes128Enc,
}

impl PCGGM {
    pub fn new(session_id: SessionId) -> Self {
        Self {
            cipher: Self::new_cipher(session_id),
        }
    }

    fn new_cipher(session_id: SessionId) -> Aes128Enc {
        let mut seed = [0; 16];
        hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
        Aes128Enc::new_from_slice(&seed).unwrap()
    }

    pub fn expand_full_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
        &mut self,
        root: Gf2_128,
        delta: Gf2_128,
    ) -> FullTree<F, TreeDepth, TreeLeafCnt> {
        let (seeds, keys) = compute_full_tree(root, delta, &self.cipher);
        let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);

        FullTree { leaves, keys }
    }

    pub fn expand_punctured_tree<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>(
        &mut self,
        alpha: usize,
        keys: HeapArray<Gf2_128, TreeDepth>, /* K^i_{~alpha_i}, where alpha_i is the i-th MSB
                                              * of
                                              * alpha. */
    ) -> PuncturedTree<F, TreeLeafCnt> {
        let seeds = compute_punctured_tree(alpha, keys, &self.cipher);

        let leaves = generate_from_seeds::<F, _>(&self.cipher, seeds);

        PuncturedTree { leaves }
    }
}

fn generate_from_seeds<F, N: Positive>(
    cipher: &Aes128Enc,
    seeds: HeapArray<Gf2_128, N>,
) -> HeapArray<F, N>
where
    F: FieldExtension,
{
    seeds
        .into_iter()
        .map(|s| <F as Field>::random(Aes128Prng::new(cipher, s)))
        .collect()
}

fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
    root: Gf2_128,
    delta: Gf2_128,
    cipher: &Aes128Enc,
) -> (
    HeapArray<Gf2_128, TreeLeafCnt>, // leaves
    HeapArray<Gf2_128, TreeDepth>,   // level keys
) {
    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);

    let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();
    leaves[0] = root;
    leaves[TreeLeafCnt::USIZE >> 1] = delta + root;

    let keys = once(root)
        .chain((1..TreeDepth::USIZE).map(|level| ggm_level_expand(&mut leaves, level, cipher)[0]))
        .collect();

    (leaves, keys)
}

fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive>(
    alpha: usize,
    keys_tilde: HeapArray<Gf2_128, TreeDepth>, /* K^i_{~alpha_i}, where alpha_i is the i-th
                                                * MSB
                                                * of
                                                * alpha. */
    cipher: &Aes128Enc,
) -> HeapArray<Gf2_128, TreeLeafCnt> {
    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);

    let mut leaves: HeapArray<Gf2_128, TreeLeafCnt> = Default::default();

    // Sanitize input, zero unused bits
    let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);

    // Start tree at first level with node at position `(~alpha_0) << (depth - 2)` set from OT
    // result
    let k: usize = TreeDepth::USIZE - 1;
    let idx = (1 ^ (alpha >> k)) << k;
    leaves[idx] = keys_tilde[0];

    // For each level `lvl=1..depth-1` expand CorrelatedGGM tree as usual and let `k =
    // depth - lvl`. The node at index `(alpha >> k) << (k - 1)` is unknown (gibberish in our
    // implementation). This node is expanded into 2 siblings with wrong values also:
    //    * node at index `(alpha >> (k - 1)) << (k - 2)` is the unknown node,
    //    * its sibling at index `((alpha >> (k - 1)) ^ 1) << (k - 2)` is the node whose value
    //      should be corrected using input `keys[lvl]`.
    (1..TreeDepth::USIZE - 1).for_each(|level| {
        let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
        // lvl_keys_tilde contains one (the) unknown node and its sibling to be corrected
        // correct node
        let k: usize = TreeDepth::USIZE - level - 1;
        let idx = 1 ^ (alpha >> k);
        let alpha_k_neg = idx & 1;
        let idx = idx << k;
        leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
    });

    // Last level is specific
    let level = TreeDepth::USIZE - 1;
    let lvl_keys_tilde = ggm_level_expand(&mut leaves, level, cipher);
    leaves[alpha] += keys_tilde[level] - lvl_keys_tilde[alpha & 1];
    leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];

    leaves
}

#[inline]
fn ggm_level_expand(nodes: &mut [Gf2_128], level: usize, cipher: &Aes128Enc) -> [Gf2_128; 2] {
    assert!(nodes.len() >= (1 << (level + 1)));
    let mut k0 = Gf2_128::ZERO;
    let mut k1 = Gf2_128::ZERO;

    let n = nodes.len();
    let step = n >> (level + 1);
    for j in (0..n).step_by(step << 1) {
        let s = &nodes[j];
        let s0 = hash_ccr(cipher, s);
        let s1 = s + s0;

        k0 += &s0;
        k1 += &s1;

        nodes[j] = s0;
        nodes[j + step] = s1;
    }

    [k0, k1]
}

#[cfg(test)]
mod tests {
    use rand::Rng;
    use typenum::U;

    use super::*;
    use crate::{
        algebra::{
            elliptic_curve::{Curve25519Ristretto as C, ScalarField},
            field::binary::Gf2_128,
        },
        izip_eq,
        random::{self, Random},
    };

    fn cggm_tree<TreeDepth: Positive, TreeLeafCnt: Positive>() {
        let mut rng = random::test_rng();

        // Random set of depth bytes
        let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
        let session_id = SessionId::random(&mut rng);

        let root: Gf2_128 = Random::random(&mut rng);
        let delta: Gf2_128 = Random::random(&mut rng);

        let cipher = {
            let mut seed = [0; 16];
            hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed);
            Aes128Enc::new_from_slice(&seed).unwrap()
        };

        let (full_leaves, full_keys) =
            compute_full_tree::<TreeDepth, TreeLeafCnt>(root, delta, &cipher);

        // Emulate COT between full_tree.keys and negated alpha bits
        let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
            let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
            if alpha_k == 1 {
                full_keys[k]
            } else {
                delta + full_keys[k]
            }
        });

        let punctured_leaves =
            compute_punctured_tree::<TreeDepth, TreeLeafCnt>(alpha, keys, &cipher);

        izip_eq!(full_leaves, punctured_leaves)
            .enumerate()
            .for_each(|(k, (r, s))| {
                if k != alpha {
                    assert_eq!(r, s);
                } else {
                    assert_eq!(r + delta, s);
                }
            });
    }

    #[test]
    fn test_cggm_tree() {
        cggm_tree::<U<4>, U<16>>();
        cggm_tree::<U<7>, U<128>>();
        cggm_tree::<U<12>, U<4096>>();
    }

    fn pcggm<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
        let mut rng = random::test_rng();

        // Random set of depth bytes
        let alpha = rng.gen::<usize>() % TreeLeafCnt::USIZE;
        let session_id = SessionId::random(&mut rng);

        let root: Gf2_128 = Random::random(&mut rng);
        let delta: Gf2_128 = Random::random(&mut rng);

        let mut ggm = PCGGM::new(session_id);
        let FullTree {
            leaves: full_leaves,
            keys: full_keys,
        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root, delta);

        // Emulate OT between full_tree.keys and negated alpha bits
        let keys: HeapArray<Gf2_128, TreeDepth> = HeapArray::from_fn(|k| {
            let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
            if alpha_k == 1 {
                full_keys[k]
            } else {
                delta + full_keys[k]
            }
        });

        let PuncturedTree {
            leaves: punctured_leaves,
        } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt>(alpha, keys);

        izip_eq!(full_leaves, punctured_leaves)
            .enumerate()
            .for_each(|(k, (r, s))| {
                if k != alpha {
                    assert_eq!(r, s);
                } else {
                    assert_ne!(r, s);
                }
            });
    }

    #[test]
    fn test_pcggm() {
        pcggm::<Gf2_128, U<4>, U<16>>();
        pcggm::<Gf2_128, U<7>, U<128>>();
        pcggm::<Gf2_128, U<12>, U<4096>>();

        pcggm::<ScalarField<C>, U<4>, U<16>>();
        pcggm::<ScalarField<C>, U<7>, U<128>>();
        pcggm::<ScalarField<C>, U<12>, U<4096>>();
    }

    fn pcggm_are_distinct<F: FieldExtension, TreeDepth: Positive, TreeLeafCnt: Positive>() {
        let mut rng = random::test_rng();
        let session_id = SessionId::random(&mut rng);

        // Random set of depth bytes
        let root1: Gf2_128 = Random::random(&mut rng);
        let root2: Gf2_128 = Random::random(&mut rng);

        let delta: Gf2_128 = Random::random(&mut rng);

        let mut ggm = PCGGM::new(session_id);
        let FullTree {
            leaves: full_leaves1,
            ..
        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root1, delta);

        let FullTree {
            leaves: full_leaves2,
            ..
        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt>(root2, delta);

        izip_eq!(full_leaves1, full_leaves2).for_each(|(e1, e2)| assert_ne!(e1, e2));
    }

    #[test]
    fn test_pcggm_are_distinct() {
        pcggm_are_distinct::<Gf2_128, U<4>, U<16>>();
        pcggm_are_distinct::<Gf2_128, U<7>, U<128>>();
        pcggm_are_distinct::<Gf2_128, U<12>, U<4096>>();

        pcggm_are_distinct::<ScalarField<C>, U<4>, U<16>>();
        pcggm_are_distinct::<ScalarField<C>, U<7>, U<128>>();
        pcggm_are_distinct::<ScalarField<C>, U<12>, U<4096>>();
    }
}