arcis-compiler 0.9.1

A framework for writing secure multi-party computation (MPC) circuits to be executed on the Arcium network.
Documentation
use crate::{
    core::circuits::{
        boolean::{
            aes::{AES128, AES192, AES256},
            boolean_value::Boolean,
            byte::Byte,
            sha3::SHA3_256,
        },
        x25519::get_shared_secret::get_shared_secret,
    },
    traits::{MxeX25519PrivateKey, Reveal, ToLeBytes, ToMontgomery},
    utils::{
        crypto::key::{
            AES128Key,
            AES192Key,
            AES256Key,
            MxeAES128Key,
            MxeAES192Key,
            MxeAES256Key,
            X25519PrivateKey,
            X25519PublicKey,
        },
        curve_point::Curve,
        elliptic_curve::F25519,
        field::ScalarField,
    },
};
use std::ops::Mul;

macro_rules! impl_aes_cipher {
    ($t: ident, $block_cipher: ident, $key: ident, $key_func_trait: ident,$key_len:expr) => {
        /// The Arcis AES cipher. We use it in Counter (CTR) mode, see
        /// <https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf> (Section 6.5).
        #[allow(dead_code)]
        pub struct $t<B: Boolean> {
            block_cipher: $block_cipher<B>,
        }

        impl<B: Boolean> $t<B> {
            fn new(key: $key<B>) -> Self {
                Self {
                    block_cipher: $block_cipher::new(key),
                }
            }

            /// Given a client public key:
            /// - perform the x25519 key exchange with the MXE private key
            /// - perform a key derivation, following [Section 4, Option 1.](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Cr2.pdf),
            ///   with hash = SHA3-256.
            pub fn new_with_client<
                T: F25519 + ToLeBytes<BooleanOutput = B>,
                S: Clone + Copy + MxeX25519PrivateKey + Mul<C, Output = C> + From<ScalarField>,
                C: Curve + ToMontgomery<Output = T>,
            >(
                public_key: X25519PublicKey<C>,
            ) -> Self {
                let private_key = X25519PrivateKey::<S>::mxe_private_key();
                let shared_secret = get_shared_secret(private_key, public_key);
                let shared_secret_bytes = shared_secret.to_le_bytes().to_vec();
                let hasher = SHA3_256::new();
                // We follow [Section 4, Option 1.](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Cr2.pdf).
                // For our choice of hash function, we have:
                // - H_outputBits = 256
                // - max_H_inputBits = arbitrarily long, as SHA3 is built upon the sponge
                //   construction
                // - L = key_len.

                // Build the vector `counter || Z || FixedInfo` (we only have i=1, since reps=1).
                // the counter is a big-endian 4-byte unsigned integer
                let mut counter = vec![Byte::from(0); 4];
                counter[3] = Byte::from(1);
                counter.extend(shared_secret_bytes);
                // For the FixedInfo we simply take L. We represent L as a big-endian 2-byte
                // unsigned integer.
                let fixed_info = if $key_len == 256 {
                    vec![Byte::from(1), Byte::from(0)]
                } else {
                    vec![
                        Byte::from(0),
                        Byte::from(
                            u8::try_from($key_len).expect("Expected a key_len less than 256"),
                        ),
                    ]
                };
                counter.extend(fixed_info);
                Self::new($key::new_from_inner(
                    hasher
                        .digest(counter)
                        .into_iter()
                        .take($key_len / 8)
                        .collect::<Vec<Byte<B>>>()
                        .try_into()
                        .unwrap_or_else(|v: Vec<Byte<B>>| {
                            panic!(
                                "Expected a Vec of length {} (found {})",
                                $key_len / 8,
                                v.len()
                            )
                        }),
                ))
            }

            pub fn new_for_mxe() -> Self
            where
                B: $key_func_trait,
            {
                Self::new($key::mxe_aes_key())
            }

            fn encrypt_batch(
                block_cipher: &$block_cipher<B>,
                ptxt: &[Byte<B>],
                counter: [Byte<B>; 16],
                output: &mut Vec<Byte<B>>,
                reveal_output: bool,
            ) {
                let encrypted_counter = block_cipher
                    .encrypt_block(counter)
                    .into_iter()
                    .take(ptxt.len())
                    .collect::<Vec<Byte<B>>>();
                output.extend(
                    ptxt.iter()
                        .zip(encrypted_counter)
                        .map(|(p, e)| {
                            let encrypted = *p ^ e;
                            if reveal_output {
                                encrypted.reveal()
                            } else {
                                encrypted
                            }
                        })
                        .collect::<Vec<Byte<B>>>(),
                )
            }

            /// Encrypt the masked plaintext vector in Counter (CTR) mode.
            pub fn encrypt(
                &self,
                masked_plaintext: Vec<Byte<B>>,
                nonce: [Byte<B>; 8],
            ) -> Vec<Byte<B>> {
                let mut result = Vec::new();
                masked_plaintext
                    .chunks(16)
                    .enumerate()
                    .for_each(|(i, chunk)| {
                        // we follow https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf
                        // (Appendix B.2, second approach)
                        let mut counter = nonce.to_vec();
                        // the counter is nonce | index_bytes, with index_bytes being the big-endian
                        // encoding of i
                        let mut index_bytes = i
                            .to_be_bytes()
                            .into_iter()
                            .map(Byte::from)
                            .collect::<Vec<Byte<B>>>();
                        counter.append(&mut index_bytes);
                        Self::encrypt_batch(
                            &self.block_cipher,
                            chunk,
                            counter.try_into().unwrap_or_else(|v: Vec<Byte<B>>| {
                                panic!("Expected a Vec of length 16 (found {})", v.len())
                            }),
                            &mut result,
                            true,
                        )
                    });

                result
            }

            /// Decrypt the ciphertext vector in Counter (CTR) mode.
            pub fn decrypt(&self, ciphertext: Vec<Byte<B>>, nonce: [Byte<B>; 8]) -> Vec<Byte<B>> {
                let mut result = Vec::new();
                ciphertext.chunks(16).enumerate().for_each(|(i, chunk)| {
                    // we follow https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf
                    // (Appendix B.2, second approach)
                    let mut counter = nonce.to_vec();
                    // the counter is nonce | index_bytes, with index_bytes being the big-endian
                    // encoding of i
                    let mut index_bytes = i
                        .to_be_bytes()
                        .into_iter()
                        .map(Byte::from)
                        .collect::<Vec<Byte<B>>>();
                    counter.append(&mut index_bytes);
                    Self::encrypt_batch(
                        &self.block_cipher,
                        chunk,
                        counter.try_into().unwrap_or_else(|v: Vec<Byte<B>>| {
                            panic!("Expected a Vec of length 16 (found {})", v.len())
                        }),
                        &mut result,
                        false,
                    )
                });

                result
            }
        }
    };
}

impl_aes_cipher!(AES128Cipher, AES128, AES128Key, MxeAES128Key, 128usize);
impl_aes_cipher!(AES192Cipher, AES192, AES192Key, MxeAES192Key, 192usize);
impl_aes_cipher!(AES256Cipher, AES256, AES256Key, MxeAES256Key, 256usize);