arcis-compiler 0.9.6

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
//! Arcis implementation of <https://github.com/solana-program/zk-elgamal-proof/blob/main/zk-sdk/src/range_proof/inner_product.rs>

use crate::{
    core::{
        circuits::boolean::{boolean_value::BooleanValue, byte::Byte},
        global_value::{
            curve_value::{CompressedCurveValue, CurveValue},
            value::FieldValue,
        },
    },
    traits::{Invert, Reveal, ToLeBytes},
    utils::{
        field::ScalarField,
        zkp::{transcript::Transcript, util::UNIT_LEN},
    },
};
use std::iter;

#[allow(non_snake_case, dead_code)]
#[derive(Debug, Clone)]
pub struct InnerProductProof {
    pub(crate) L_vec: Vec<CompressedCurveValue>,
    pub(crate) R_vec: Vec<CompressedCurveValue>,
    pub(crate) a: FieldValue<ScalarField>,
    pub(crate) b: FieldValue<ScalarField>,
}

#[allow(non_snake_case)]
impl InnerProductProof {
    /// Creates an inner-product proof.
    ///
    /// This function implements Protocol 2 from the Bulletproofs paper, a recursive
    /// argument to prove knowledge of two vectors `a` and `b` such that `<a,b> = c`.
    /// The length of the vectors must be a power of two.
    #[allow(clippy::too_many_arguments)]
    pub fn new(
        Q: &CurveValue,
        G_factors: &[FieldValue<ScalarField>],
        H_factors: &[FieldValue<ScalarField>],
        mut G_vec: Vec<CurveValue>,
        mut H_vec: Vec<CurveValue>,
        mut a_vec: Vec<FieldValue<ScalarField>>,
        mut b_vec: Vec<FieldValue<ScalarField>>,
        transcript: &mut Transcript<BooleanValue>,
    ) -> Self {
        // Create slices G, H, a, b backed by their respective
        // vectors.  This lets us reslice as we compress the lengths
        // of the vectors in the main loop below.
        let mut G = &mut G_vec[..];
        let mut H = &mut H_vec[..];
        let mut a = &mut a_vec[..];
        let mut b = &mut b_vec[..];

        let mut n = G.len();

        // All of the input vectors must have the same length.
        assert!(
            G.len() == n
                && H.len() == n
                && a.len() == n
                && b.len() == n
                && G_factors.len() == n
                && H_factors.len() == n
        );

        // All of the input vectors must have a length that is a power of two.
        assert!(n.is_power_of_two());

        transcript.inner_product_proof_domain_separator(n as u64);

        let lg_n = n.next_power_of_two().trailing_zeros() as usize;
        let mut L_vec = Vec::with_capacity(lg_n);
        let mut R_vec = Vec::with_capacity(lg_n);

        // This is an optimization: the first round of the protocol is unrolled from the
        // main loop to handle the `G_factors` and `H_factors` more efficiently using
        // a single multiscalar multiplication. Subsequent rounds use a simplified loop.
        if n != 1 {
            n = n.checked_div(2).unwrap();
            let (a_L, a_R) = a.split_at_mut(n);
            let (b_L, b_R) = b.split_at_mut(n);
            let (G_L, G_R) = G.split_at_mut(n);
            let (H_L, H_R) = H.split_at_mut(n);

            // Compute the cross terms c_L and c_R
            let c_L = a_L
                .iter()
                .zip(b_R.iter())
                .map(|(lhs, rhs)| *lhs * *rhs)
                .reduce(|lhs, rhs| lhs + rhs)
                .unwrap();
            let c_R = a_R
                .iter()
                .zip(b_L.iter())
                .map(|(lhs, rhs)| *lhs * *rhs)
                .reduce(|lhs, rhs| lhs + rhs)
                .unwrap();

            // Compute L and R points for this round
            // L = <a_L, G_R> + <b_R, H_L> + c_L * Q
            let L = CurveValue::multiscalar_mul(
                a_L.iter()
                    // `n` was previously divided in half and therefore, it cannot overflow.
                    .zip(G_factors[n..n.checked_mul(2).unwrap()].iter())
                    .map(|(a_L_i, g)| *a_L_i * *g)
                    .chain(
                        b_R.iter()
                            .zip(H_factors[0..n].iter())
                            .map(|(b_R_i, h)| *b_R_i * *h),
                    )
                    .chain(iter::once(c_L))
                    .collect::<Vec<FieldValue<ScalarField>>>(),
                G_R.iter()
                    .chain(H_L.iter())
                    .chain(iter::once(Q))
                    .copied()
                    .collect::<Vec<CurveValue>>(),
            )
            .reveal()
            .compress();

            // R = <a_R, G_L> + <b_L, H_R> + c_R * Q
            let R = CurveValue::multiscalar_mul(
                a_R.iter()
                    .zip(G_factors[0..n].iter())
                    .map(|(a_R_i, g)| *a_R_i * *g)
                    .chain(
                        b_L.iter()
                            .zip(H_factors[n..n.checked_mul(2).unwrap()].iter())
                            .map(|(b_L_i, h)| *b_L_i * *h),
                    )
                    .chain(iter::once(c_R))
                    .collect::<Vec<FieldValue<ScalarField>>>(),
                G_L.iter()
                    .chain(H_R.iter())
                    .chain(iter::once(Q))
                    .copied()
                    .collect::<Vec<CurveValue>>(),
            )
            .reveal()
            .compress();

            L_vec.push(L);
            R_vec.push(R);

            transcript.append_point(b"L", &L);
            transcript.append_point(b"R", &R);

            let u = transcript.challenge_scalar(b"u");
            // on plaintext values we simply set is_expected_non_zero = true
            let u_inv = u.invert(true);

            for i in 0..n {
                a_L[i] = a_L[i] * u + u_inv * a_R[i];
                b_L[i] = b_L[i] * u_inv + u * b_R[i];
                G_L[i] = CurveValue::multiscalar_mul(
                    vec![
                        u_inv * G_factors[i],
                        u * G_factors[n.checked_add(i).unwrap()],
                    ],
                    vec![G_L[i], G_R[i]],
                );
                H_L[i] = CurveValue::multiscalar_mul(
                    vec![
                        u * H_factors[i],
                        u_inv * H_factors[n.checked_add(i).unwrap()],
                    ],
                    vec![H_L[i], H_R[i]],
                )
            }

            a = a_L;
            b = b_L;
            G = G_L;
            H = H_L;
        }

        // Main recursive loop
        while n != 1 {
            n = n.checked_div(2).unwrap();
            let (a_L, a_R) = a.split_at_mut(n);
            let (b_L, b_R) = b.split_at_mut(n);
            let (G_L, G_R) = G.split_at_mut(n);
            let (H_L, H_R) = H.split_at_mut(n);

            // Compute the cross terms c_L and c_R
            let c_L = a_L
                .iter()
                .zip(b_R.iter())
                .map(|(lhs, rhs)| *lhs * *rhs)
                .reduce(|lhs, rhs| lhs + rhs)
                .unwrap();
            let c_R = a_R
                .iter()
                .zip(b_L.iter())
                .map(|(lhs, rhs)| *lhs * *rhs)
                .reduce(|lhs, rhs| lhs + rhs)
                .unwrap();

            // Compute L and R points for this round
            // L = <a_L, G_R> + <b_R, H_L> + c_L * Q
            let L = CurveValue::multiscalar_mul(
                a_L.iter()
                    .chain(b_R.iter())
                    .chain(iter::once(&c_L))
                    .copied()
                    .collect::<Vec<FieldValue<ScalarField>>>(),
                G_R.iter()
                    .chain(H_L.iter())
                    .chain(iter::once(Q))
                    .copied()
                    .collect::<Vec<CurveValue>>(),
            )
            .reveal()
            .compress();

            // R = <a_R, G_L> + <b_L, H_R> + c_R * Q
            let R = CurveValue::multiscalar_mul(
                a_R.iter()
                    .chain(b_L.iter())
                    .chain(iter::once(&c_R))
                    .copied()
                    .collect::<Vec<FieldValue<ScalarField>>>(),
                G_L.iter()
                    .chain(H_R.iter())
                    .chain(iter::once(Q))
                    .copied()
                    .collect::<Vec<CurveValue>>(),
            )
            .reveal()
            .compress();

            L_vec.push(L);
            R_vec.push(R);

            transcript.append_point(b"L", &L);
            transcript.append_point(b"R", &R);

            let u = transcript.challenge_scalar(b"u");
            // on plaintext values we simply set is_expected_non_zero = true
            let u_inv = u.invert(true);

            for i in 0..n {
                a_L[i] = a_L[i] * u + u_inv * a_R[i];
                b_L[i] = b_L[i] * u_inv + u * b_R[i];
                G_L[i] = CurveValue::multiscalar_mul(vec![u_inv, u], vec![G_L[i], G_R[i]]);
                H_L[i] = CurveValue::multiscalar_mul(vec![u, u_inv], vec![H_L[i], H_R[i]]);
            }

            a = a_L;
            b = b_L;
            G = G_L;
            H = H_L;
        }

        InnerProductProof {
            L_vec,
            R_vec,
            a: a[0].reveal(),
            b: b[0].reveal(),
        }
    }

    /// Returns the size in bytes required to serialize the inner product proof.
    ///
    /// For vectors of length `n`, the proof size is `(2*log2(n) + 2) * 32` bytes.
    pub fn serialized_size(&self) -> usize {
        (self.L_vec.len() * 2 + 2) * UNIT_LEN
    }

    /// Serializes the proof into a byte array.
    /// The layout of the inner product proof is:
    /// - `log(n)` compressed Ristretto points for L_vec
    /// - `log(n)` compressed Ristretto points for R_vec
    /// - a scalar `a`
    /// - a scalar `b`
    pub fn to_bytes(&self) -> Vec<Byte<BooleanValue>> {
        let mut buf = Vec::with_capacity(self.serialized_size());
        for (l, r) in self.L_vec.iter().zip(self.R_vec.iter()) {
            buf.extend_from_slice(&l.to_bytes());
            buf.extend_from_slice(&r.to_bytes());
        }
        buf.extend_from_slice(&self.a.to_le_bytes());
        buf.extend_from_slice(&self.b.to_le_bytes());
        buf
    }
}