etf_crypto_primitives/encryption/
aes.rs

1use aes_gcm::{
2    aead::{Aead, AeadCore, AeadInPlace, KeyInit},
3    Aes256Gcm, Nonce, // Or `Aes128Gcm`
4};
5use ark_std::rand::Rng;
6use ark_bls12_381::Fr;
7use ark_ff::{Zero, One, Field, UniformRand};
8use ark_poly::{
9    polynomial::univariate::DensePolynomial,
10    DenseUVPolynomial, Polynomial,
11};
12use serde::{Deserialize, Serialize};
13// use alloc::vec::Vec;
14
15use ark_std::rand::CryptoRng;
16use ark_std::vec::Vec;
17/// The output of AES Encryption plus the ephemeral secret key
18#[derive(Serialize, Deserialize, Debug)]
19pub struct AESOutput {
20    /// the AES ciphertext
21    pub ciphertext: Vec<u8>,
22    /// the AES nonce
23    pub nonce: Vec<u8>,
24    pub key: Vec<u8>,
25}
26
27#[derive(Debug, PartialEq)]
28pub enum Error {
29    EncryptionError,
30    DecryptionError,
31    InvalidKey,
32}
33
34/// AES-GCM encryption of the message using an ephemeral keypair
35/// basically a wrapper around the AEADs library to handle serialization
36///
37/// * `message`: The message to encrypt
38///
39pub fn encrypt<R: Rng + CryptoRng + Sized>(message: &[u8], key: [u8;32], mut rng: R) -> Result<AESOutput, Error> {
40    let cipher = Aes256Gcm::new(generic_array::GenericArray::from_slice(&key));
41    let nonce = Aes256Gcm::generate_nonce(&mut rng); // 96-bits; unique per message
42
43    let mut buffer: Vec<u8> = Vec::new(); // Note: buffer needs 16-bytes overhead for auth tag
44    buffer.extend_from_slice(message);
45    // Encrypt `buffer` in-place, replacing the plaintext contents with ciphertext
46    // will this error ever be thrown here? nonces should always be valid as well as buffer
47    cipher.encrypt_in_place(&nonce, b"", &mut buffer)
48        .map_err(|_| Error::EncryptionError)?;
49    Ok(AESOutput{
50        ciphertext: buffer,
51        nonce: nonce.to_vec(),
52        key: key.to_vec(),
53    })
54}
55
56pub fn decrypt(ciphertext: Vec<u8>, nonce_slice: &[u8], key: &[u8]) -> Result<Vec<u8>, Error> {
57    // not sure about that...
58    let cipher = Aes256Gcm::new_from_slice(key)
59        .map_err(|_| Error::InvalidKey)?;
60    let nonce = Nonce::from_slice(nonce_slice);
61    let plaintext = cipher.decrypt(nonce, ciphertext.as_ref())
62        .map_err(|_| Error::DecryptionError)?;
63    Ok(plaintext)
64}
65
66/// Generate a random polynomial f and return evalulations (f(0), (1, f(1), ..., n, f(n)))
67/// f(0) is the 'secret' and the shares can be used to recover the secret with `let s = interpolate(shares);`
68///
69/// * `n`: The number of shares to generate
70/// * `t`: The degree of the polynomial
71/// * `rng`: A random number generator
72///
73pub fn generate_secrets<R: Rng + Sized>(
74    n: u8, t: u8, mut rng: R) -> (Fr, Vec<(Fr, Fr)>) {
75    
76    if n == 1 {
77        let r = Fr::rand(&mut rng);
78        return (r, vec![(Fr::zero(), r)]);
79    }
80
81    let f = DensePolynomial::<Fr>::rand(t as usize, &mut rng);
82    let msk = f.evaluate(&Fr::zero());
83    let evals: Vec<(Fr, Fr)> = (1..n+1)
84        .map(|i| {
85            let e = Fr::from(i);
86            (e, f.evaluate(&e))
87        }).collect();
88    (msk, evals)
89}
90
91/// interpolate a polynomial from the input and evaluate it at 0
92///
93/// * `evalulation`: a vec of (x, f(x)) pairs
94///
95pub fn interpolate(evaluations: Vec<(Fr, Fr)>) -> Fr {
96    let n = evaluations.len();
97
98    // Calculate the Lagrange basis polynomials evaluated at 0
99    let mut lagrange_at_zero: Vec<Fr> = Vec::with_capacity(n);
100    for i in 0..n {
101        let mut basis_value = Fr::one();
102        for j in 0..n {
103            if i != j {
104                let denominator = evaluations[i].0 - evaluations[j].0;
105                // todo: handle unwrap?
106                basis_value *= denominator.inverse().unwrap() * evaluations[j].0;
107            }
108        }
109        lagrange_at_zero.push(basis_value);
110    }
111
112    // Interpolate the value at 0
113    let mut interpolated_value = Fr::zero();
114    for i in 0..n {
115        interpolated_value += evaluations[i].1 * lagrange_at_zero[i];
116    }
117
118    interpolated_value
119}
120
121#[cfg(test)]
122mod test {
123    use super::*;
124    use rand_chacha::ChaCha20Rng;
125    use ark_std::rand::SeedableRng;
126
127    #[test]
128    pub fn aes_encrypt_decrypt_works() {
129        let msg = b"test";
130        let rng = ChaCha20Rng::from_seed([2;32]);
131        match encrypt(msg, [2;32], rng) {
132            Ok(aes_out) => {
133                match decrypt(aes_out.ciphertext, &aes_out.nonce, &aes_out.key) {
134                    Ok(plaintext) => {
135                        assert_eq!(msg.to_vec(), plaintext);
136                    }, 
137                    Err(_) => {
138                        panic!("test should pass");
139                    }
140                }
141            },
142            Err(_) => {
143                panic!("test should pass");
144            }
145        }
146    }
147
148    #[test]
149    pub fn aes_encrypt_decrypt_fails_with_bad_key() {
150        let msg = b"test";
151        let rng = ChaCha20Rng::from_seed([1;32]);
152        match encrypt(msg, [2;32], rng) {
153            Ok(aes_out) => {
154                match decrypt(aes_out.ciphertext, &aes_out.nonce, &b"hi".to_vec()) {
155                    Ok(_) => {
156                        panic!("should be an error");
157                    }, 
158                    Err(e) => {
159                        assert_eq!(e, Error::InvalidKey);
160                    }
161                }
162            },
163            Err(_) => {
164                panic!("test should pass");
165            }
166        }
167    }
168     
169    #[test]
170    pub fn aes_encrypt_decrypt_fails_with_bad_nonce() {
171        let msg = b"test";
172        let rng = ChaCha20Rng::from_seed([3;32]);
173        match encrypt(msg, [2;32], rng) {
174            Ok(aes_out) => {
175                match decrypt(aes_out.ciphertext, &vec![0,0,0,0,0,0,0,0,0,0,0,0], &aes_out.key) {
176                    Ok(_) => {
177                        panic!("should be an error");
178                    }, 
179                    Err(e) => {
180                        assert_eq!(e, Error::DecryptionError);
181                    }
182                }
183            },
184            Err(_) => {
185                panic!("test should pass");
186            }
187        }
188    }
189
190    #[test]
191    fn secrets_interpolation() {
192        let n = 5; // Number of participants
193        let t = 3; // Threshold
194        let rng = ChaCha20Rng::from_seed([4;32]);
195        let (msk, shares) = generate_secrets(n, t, rng);
196        // Perform Lagrange interpolation
197        let interpolated_msk = interpolate(shares);
198        // Check if the msk and the interpolated msk match
199        assert_eq!(msk, interpolated_msk);
200    }
201}