arcium-primitives 0.4.2

Arcium primitives
Documentation
use std::ops::Mul;

use aes::{cipher::KeyInit, Aes128Enc};
use ff::Field;
use itertools::enumerate;

use crate::{
    algebra::field::{binary::Gf2_128, FieldExtension},
    hashing::{hash_cr, hash_into},
    izip_eq,
    random::prg::Aes128Prng,
    types::{HeapArray, HeapMatrix, Positive, SessionId},
};

#[derive(Debug, Clone)]
pub struct FullTrees<
    F: FieldExtension,
    TreeDepth: Positive,
    TreeLeafCnt: Positive,
    BatchSize: Positive,
> {
    pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
    pub keys: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>,
}

#[derive(Debug, Clone)]
pub struct PuncturedTrees<F: FieldExtension, TreeLeafCnt: Positive, BatchSize: Positive> {
    pub leaves: HeapMatrix<F, TreeLeafCnt, BatchSize>,
}

#[derive(Debug, Clone)]
pub struct GGM {
    ciphers: (Aes128Enc, Aes128Enc),
}

impl GGM {
    pub fn new(session_id: SessionId) -> Self {
        Self {
            ciphers: Self::new_ciphers(session_id),
        }
    }

    fn new_ciphers(session_id: SessionId) -> (Aes128Enc, Aes128Enc) {
        let (mut seed0, mut seed1) = ([0; 16], [0; 16]);

        hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
        hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);

        (
            Aes128Enc::new_from_slice(&seed0).unwrap(),
            Aes128Enc::new_from_slice(&seed1).unwrap(),
        )
    }

    pub fn expand_full_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
        &mut self,
        root_batches: &HeapArray<Gf2_128, BatchSize>,
    ) -> FullTrees<F, TreeDepth, TreeLeafCnt, BatchSize>
    where
        F: FieldExtension,
        TreeDepth: Positive,
        TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
        BatchSize: Positive,
    {
        let (seeds, keys) = compute_full_tree(root_batches, &self.ciphers.0, &self.ciphers.1);
        let leaves =
            generate_from_seeds::<F, _, BatchSize>(&self.ciphers.0, &self.ciphers.1, seeds);

        FullTrees { leaves, keys }
    }

    pub fn expand_punctured_tree<F, TreeDepth, TreeLeafCnt, BatchSize>(
        &mut self,
        alpha_batches: &HeapArray<usize, BatchSize>,
        keys_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>, /* K^i_{~alpha_i}, where
                                                                   * alpha_i is the i-th MSB
                                                                   * of
                                                                   * alpha. */
    ) -> PuncturedTrees<F, TreeLeafCnt, BatchSize>
    where
        F: FieldExtension,
        TreeDepth: Positive,
        TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
        BatchSize: Positive,
    {
        let seeds = compute_punctured_tree(
            alpha_batches,
            keys_batches,
            &self.ciphers.0,
            &self.ciphers.1,
        );

        let leaves = generate_from_seeds::<F, _, _>(&self.ciphers.0, &self.ciphers.1, seeds);

        PuncturedTrees { leaves }
    }
}

fn generate_from_seeds<F, TreeLeafCnt, BatchSize>(
    enc_left: &Aes128Enc,
    enc_right: &Aes128Enc,
    seeds: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>,
) -> HeapMatrix<F, TreeLeafCnt, BatchSize>
where
    F: FieldExtension,
    TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
    BatchSize: Positive,
{
    let ciphers = [enc_left, enc_right];

    let tmp: Vec<_> = enumerate(seeds.into_flat_iter())
        .map(|(i, s)| <F as Field>::random(Aes128Prng::new(ciphers[i % 2], s)))
        .collect();

    HeapMatrix::<_, TreeLeafCnt, BatchSize>::try_from(tmp).unwrap()
}

#[allow(clippy::type_complexity)]
fn compute_full_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
    root_batches: &HeapArray<Gf2_128, BatchSize>,
    cipher0: &Aes128Enc,
    cipher1: &Aes128Enc,
) -> (
    HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize>,    // leaves
    HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize>, // level keys
) {
    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
    assert!(BatchSize::USIZE > 0);

    let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();
    let mut keys_batches: HeapMatrix<[Gf2_128; 2], TreeDepth, BatchSize> = Default::default();

    izip_eq!(
        root_batches,
        leaves_batches.col_iter_mut(),
        keys_batches.col_iter_mut()
    )
    .for_each(|(root, leaves, keys)| {
        leaves[0] = *root;

        keys.iter_mut().enumerate().for_each(|(level, keys)| {
            *keys = ggm_level_expand(leaves, level, cipher0, cipher1);
        });
    });

    (leaves_batches, keys_batches)
}

fn compute_punctured_tree<TreeDepth: Positive, TreeLeafCnt: Positive, BatchSize: Positive>(
    alpha_batches: &HeapArray<usize, BatchSize>,
    keys_tilde_batches: &HeapMatrix<Gf2_128, TreeDepth, BatchSize>, /* K^ij_{~alpha_ij}, where
                                                                     * alpha_ij
                                                                     * is the i-th
                                                                     * MSB
                                                                     * of
                                                                     * alpha_j. */
    cipher0: &Aes128Enc,
    cipher1: &Aes128Enc,
) -> HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> {
    assert_eq!(1u64 << TreeDepth::USIZE, TreeLeafCnt::U64);
    assert!(BatchSize::USIZE > 0);

    let mut leaves_batches: HeapMatrix<Gf2_128, TreeLeafCnt, BatchSize> = Default::default();

    izip_eq!(
        leaves_batches.col_iter_mut(),
        alpha_batches,
        keys_tilde_batches.col_iter()
    )
    .for_each(|(leaves, alpha, keys_tilde)| {
        // Sanitize input, zero unused bits
        let alpha = alpha & ((1 << TreeDepth::USIZE) - 1);

        // Start tree at first level with node at position `(~alpha_0) << (depth - 2)` set from
        // OT result
        let k: usize = TreeDepth::USIZE - 1;
        let idx = (1 ^ (alpha >> k)) << k;
        leaves[idx] = keys_tilde[0];

        // For each level `lvl=1..depth-1` expand GGM tree as usual and let `k = depth - lvl`.
        // The node at index `(alpha >> k) << (k - 1)` is unknown (gibberish in our
        // implementation). This node is expanded into 2 siblings with wrong values
        // also:
        //    * node at index `(alpha >> (k - 1)) << (k - 2)` is the unknown node,
        //    * its sibling at index `((alpha >> (k - 1)) ^ 1) << (k - 2)` is the node whose value
        //      should be corrected using input `keys[lvl]`.
        (1..TreeDepth::USIZE - 1).for_each(|level| {
            let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
            // lvl_keys_tilde contains one (the) unknown node and its sibling to be corrected
            // correct node
            let k: usize = TreeDepth::USIZE - level - 1;
            let idx = 1 ^ (alpha >> k);
            let alpha_k_neg = idx & 1;
            let idx = idx << k;
            leaves[idx] += keys_tilde[level] - lvl_keys_tilde[alpha_k_neg];
        });

        // Last level is specific
        let level = TreeDepth::USIZE - 1;
        let lvl_keys_tilde = ggm_level_expand(leaves, level, cipher0, cipher1);
        leaves[alpha ^ 1] += keys_tilde[level] - lvl_keys_tilde[(alpha ^ 1) & 1];
    });

    leaves_batches
}

#[inline]
fn ggm_level_expand(
    nodes: &mut [Gf2_128],
    level: usize,
    cipher0: &Aes128Enc,
    cipher1: &Aes128Enc,
) -> [Gf2_128; 2] {
    assert!(nodes.len() >= (1 << (level + 1)));
    let mut k0 = Gf2_128::ZERO;
    let mut k1 = Gf2_128::ZERO;

    let n = nodes.len();
    let step = n >> (level + 1);
    for j in (0..n).step_by(step << 1) {
        let s0 = hash_cr(cipher0, &nodes[j]);
        let s1 = hash_cr(cipher1, &nodes[j]);

        k0 += &s0;
        k1 += &s1;

        nodes[j] = s0;
        nodes[j + step] = s1;
    }

    [k0, k1]
}

#[cfg(test)]
mod tests {
    use std::fmt::Debug;

    use rand::Rng;
    use typenum::U;

    use super::*;
    use crate::{
        algebra::{
            elliptic_curve::{Curve25519Ristretto, ScalarField},
            field::binary::Gf2_128,
        },
        izip_eq,
        izip_eq_lazy,
        random::Random,
    };

    fn ggm_tree<TreeDepth, TreeLeafCnt, BatchSize>()
    where
        TreeDepth: Debug + Positive + Mul<BatchSize, Output: Positive>,
        TreeLeafCnt: Debug + Positive,
        BatchSize: Debug + Positive,
    {
        let mut rng = crate::random::test_rng();
        let session_id = SessionId::random(&mut rng);

        // Random set of depth bytes
        let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);

        let root_batches = Gf2_128::random_array(&mut rng);

        let (cipher0, cipher1) = {
            let (mut seed0, mut seed1) = ([0; 16], [0; 16]);

            hash_into([b"0".as_slice(), session_id.as_ref()], &mut seed0);
            hash_into([b"1".as_slice(), session_id.as_ref()], &mut seed1);

            (
                Aes128Enc::new_from_slice(&seed0).unwrap(),
                Aes128Enc::new_from_slice(&seed1).unwrap(),
            )
        };

        let (full_leaves_batches, full_keys_batches) = compute_full_tree::<
            TreeDepth,
            TreeLeafCnt,
            BatchSize,
        >(&root_batches, &cipher0, &cipher1);

        // Emulate OT between full_tree.keys and negated alpha bits
        let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
            .map(|(alpha, full_keys)| {
                HeapArray::from_fn(|k| {
                    let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
                    full_keys[k][1 ^ alpha_k]
                })
            })
            .collect();

        let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);

        let punctured_leaves_batches = compute_punctured_tree::<TreeDepth, TreeLeafCnt, BatchSize>(
            &alpha_batches,
            &keys_batches,
            &cipher0,
            &cipher1,
        );

        izip_eq!(
            alpha_batches,
            full_leaves_batches.col_iter(),
            punctured_leaves_batches.col_iter()
        )
        .for_each(|(alpha, full_leaves, puncured_leaves)| {
            izip_eq_lazy!(full_leaves, puncured_leaves)
                .enumerate()
                .for_each(|(k, (r, s))| {
                    if k != alpha {
                        assert_eq!(r, s);
                    } else {
                        assert_ne!(r, s);
                    }
                });
        })
    }

    #[test]
    fn test_ggm_tree() {
        ggm_tree::<U<2>, U<4>, U<17>>();
        ggm_tree::<U<7>, U<128>, U<4>>();
        ggm_tree::<U<12>, U<4096>, U<1>>();
    }

    fn ggm_classic<F, TreeDepth, TreeLeafCnt, BatchSize>()
    where
        F: FieldExtension,
        TreeDepth: Positive + Mul<BatchSize, Output: Positive>,
        TreeLeafCnt: Positive + Mul<BatchSize, Output: Positive>,
        BatchSize: Positive,
    {
        let mut rng = crate::random::test_rng();

        // Random set of depth bytes
        let alpha_batches = HeapArray::from_fn(|_| rng.gen::<usize>() % TreeLeafCnt::USIZE);

        let root_batches = Gf2_128::random_array(&mut rng);

        let mut ggm = GGM::new(SessionId::random(&mut rng));
        let FullTrees {
            leaves: full_leaves_batches,
            keys: full_keys_batches,
        } = ggm.expand_full_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(&root_batches);

        // Emulate OT between full_tree.keys and negated alpha bits
        let tmp = izip_eq!(&alpha_batches, full_keys_batches.col_iter())
            .map(|(alpha, full_keys)| {
                HeapArray::from_fn(|k| {
                    let alpha_k = (alpha >> (TreeDepth::USIZE - k - 1)) & 1;
                    full_keys[k][1 ^ alpha_k]
                })
            })
            .collect();

        let keys_batches: HeapMatrix<Gf2_128, TreeDepth, BatchSize> = HeapMatrix::from_cols(tmp);

        let PuncturedTrees {
            leaves: punctured_leaves_batches,
        } = ggm.expand_punctured_tree::<F, TreeDepth, TreeLeafCnt, BatchSize>(
            &alpha_batches,
            &keys_batches,
        );

        izip_eq!(
            alpha_batches,
            full_leaves_batches.col_iter(),
            punctured_leaves_batches.col_iter()
        )
        .for_each(|(alpha, full_leaves, puncured_leaves)| {
            izip_eq_lazy!(full_leaves, puncured_leaves)
                .enumerate()
                .for_each(|(k, (r, s))| {
                    if k != alpha {
                        assert_eq!(r, s);
                    } else {
                        assert_ne!(r, s);
                    }
                });
        })
    }

    #[test]
    fn test_ggm_classic() {
        ggm_classic::<Gf2_128, U<4>, U<16>, U<17>>();
        ggm_classic::<Gf2_128, U<7>, U<128>, U<4>>();
        ggm_classic::<Gf2_128, U<12>, U<4096>, U<1>>();

        type Fq = ScalarField<Curve25519Ristretto>;
        ggm_classic::<Fq, U<4>, U<16>, U<17>>();
        ggm_classic::<Fq, U<7>, U<128>, U<4>>();
        ggm_classic::<Fq, U<12>, U<4096>, U<1>>();
    }
}