Skip to main content

arcis_compiler/utils/crypto/
aes_cipher.rs

1use crate::{
2    core::circuits::{
3        boolean::{
4            aes::{AES128, AES192, AES256},
5            boolean_value::Boolean,
6            byte::Byte,
7            sha3::SHA3_256,
8        },
9        x25519::get_shared_secret::get_shared_secret,
10    },
11    traits::{MxeX25519PrivateKey, Reveal, ToLeBytes, ToMontgomery},
12    utils::{
13        crypto::key::{
14            AES128Key,
15            AES192Key,
16            AES256Key,
17            MxeAES128Key,
18            MxeAES192Key,
19            MxeAES256Key,
20            X25519PrivateKey,
21            X25519PublicKey,
22        },
23        curve_point::Curve,
24        elliptic_curve::F25519,
25        field::ScalarField,
26    },
27};
28use std::ops::Mul;
29
30macro_rules! impl_aes_cipher {
31    ($t: ident, $block_cipher: ident, $key: ident, $key_func_trait: ident,$key_len:expr) => {
32        /// The Arcis AES cipher. We use it in Counter (CTR) mode, see
33        /// <https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf> (Section 6.5).
34        #[allow(dead_code)]
35        pub struct $t<B: Boolean> {
36            block_cipher: $block_cipher<B>,
37        }
38
39        impl<B: Boolean> $t<B> {
40            fn new(key: $key<B>) -> Self {
41                Self {
42                    block_cipher: $block_cipher::new(key),
43                }
44            }
45
46            /// Given a client public key:
47            /// - perform the x25519 key exchange with the MXE private key
48            /// - perform a key derivation, following [Section 4, Option 1.](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Cr2.pdf),
49            ///   with hash = SHA3-256.
50            pub fn new_with_client<
51                T: F25519 + ToLeBytes<BooleanOutput = B>,
52                S: Clone + Copy + MxeX25519PrivateKey + Mul<C, Output = C> + From<ScalarField>,
53                C: Curve + ToMontgomery<Output = T>,
54            >(
55                public_key: X25519PublicKey<C>,
56            ) -> Self {
57                let private_key = X25519PrivateKey::<S>::mxe_private_key();
58                let shared_secret = get_shared_secret(private_key, public_key);
59                let shared_secret_bytes = shared_secret.to_le_bytes().to_vec();
60                let hasher = SHA3_256::new();
61                // We follow [Section 4, Option 1.](https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-56Cr2.pdf).
62                // For our choice of hash function, we have:
63                // - H_outputBits = 256
64                // - max_H_inputBits = arbitrarily long, as SHA3 is built upon the sponge
65                //   construction
66                // - L = key_len.
67
68                // Build the vector `counter || Z || FixedInfo` (we only have i=1, since reps=1).
69                // the counter is a big-endian 4-byte unsigned integer
70                let mut counter = vec![Byte::from(0); 4];
71                counter[3] = Byte::from(1);
72                counter.extend(shared_secret_bytes);
73                // For the FixedInfo we simply take L. We represent L as a big-endian 2-byte
74                // unsigned integer.
75                let fixed_info = if $key_len == 256 {
76                    vec![Byte::from(1), Byte::from(0)]
77                } else {
78                    vec![
79                        Byte::from(0),
80                        Byte::from(
81                            u8::try_from($key_len).expect("Expected a key_len less than 256"),
82                        ),
83                    ]
84                };
85                counter.extend(fixed_info);
86                Self::new($key::new_from_inner(
87                    hasher
88                        .digest(counter)
89                        .into_iter()
90                        .take($key_len / 8)
91                        .collect::<Vec<Byte<B>>>()
92                        .try_into()
93                        .unwrap_or_else(|v: Vec<Byte<B>>| {
94                            panic!(
95                                "Expected a Vec of length {} (found {})",
96                                $key_len / 8,
97                                v.len()
98                            )
99                        }),
100                ))
101            }
102
103            pub fn new_for_mxe() -> Self
104            where
105                B: $key_func_trait,
106            {
107                Self::new($key::mxe_aes_key())
108            }
109
110            fn encrypt_batch(
111                block_cipher: &$block_cipher<B>,
112                ptxt: &[Byte<B>],
113                counter: [Byte<B>; 16],
114                output: &mut Vec<Byte<B>>,
115                reveal_output: bool,
116            ) {
117                let encrypted_counter = block_cipher
118                    .encrypt_block(counter)
119                    .into_iter()
120                    .take(ptxt.len())
121                    .collect::<Vec<Byte<B>>>();
122                output.extend(
123                    ptxt.iter()
124                        .zip(encrypted_counter)
125                        .map(|(p, e)| {
126                            let encrypted = *p ^ e;
127                            if reveal_output {
128                                encrypted.reveal()
129                            } else {
130                                encrypted
131                            }
132                        })
133                        .collect::<Vec<Byte<B>>>(),
134                )
135            }
136
137            /// Encrypt the masked plaintext vector in Counter (CTR) mode.
138            pub fn encrypt(
139                &self,
140                masked_plaintext: Vec<Byte<B>>,
141                nonce: [Byte<B>; 8],
142            ) -> Vec<Byte<B>> {
143                let mut result = Vec::new();
144                masked_plaintext
145                    .chunks(16)
146                    .enumerate()
147                    .for_each(|(i, chunk)| {
148                        // we follow https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf
149                        // (Appendix B.2, second approach)
150                        let mut counter = nonce.to_vec();
151                        // the counter is nonce | index_bytes, with index_bytes being the big-endian
152                        // encoding of i
153                        let mut index_bytes = i
154                            .to_be_bytes()
155                            .into_iter()
156                            .map(Byte::from)
157                            .collect::<Vec<Byte<B>>>();
158                        counter.append(&mut index_bytes);
159                        Self::encrypt_batch(
160                            &self.block_cipher,
161                            chunk,
162                            counter.try_into().unwrap_or_else(|v: Vec<Byte<B>>| {
163                                panic!("Expected a Vec of length 16 (found {})", v.len())
164                            }),
165                            &mut result,
166                            true,
167                        )
168                    });
169
170                result
171            }
172
173            /// Decrypt the ciphertext vector in Counter (CTR) mode.
174            pub fn decrypt(&self, ciphertext: Vec<Byte<B>>, nonce: [Byte<B>; 8]) -> Vec<Byte<B>> {
175                let mut result = Vec::new();
176                ciphertext.chunks(16).enumerate().for_each(|(i, chunk)| {
177                    // we follow https://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38a.pdf
178                    // (Appendix B.2, second approach)
179                    let mut counter = nonce.to_vec();
180                    // the counter is nonce | index_bytes, with index_bytes being the big-endian
181                    // encoding of i
182                    let mut index_bytes = i
183                        .to_be_bytes()
184                        .into_iter()
185                        .map(Byte::from)
186                        .collect::<Vec<Byte<B>>>();
187                    counter.append(&mut index_bytes);
188                    Self::encrypt_batch(
189                        &self.block_cipher,
190                        chunk,
191                        counter.try_into().unwrap_or_else(|v: Vec<Byte<B>>| {
192                            panic!("Expected a Vec of length 16 (found {})", v.len())
193                        }),
194                        &mut result,
195                        false,
196                    )
197                });
198
199                result
200            }
201        }
202    };
203}
204
205impl_aes_cipher!(AES128Cipher, AES128, AES128Key, MxeAES128Key, 128usize);
206impl_aes_cipher!(AES192Cipher, AES192, AES192Key, MxeAES192Key, 192usize);
207impl_aes_cipher!(AES256Cipher, AES256, AES256Key, MxeAES256Key, 256usize);