use aes_gcm::{
aead::{Aead, AeadCore, AeadInPlace, KeyInit},
Aes256Gcm, Nonce,
};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::rand::Rng;
use serde::{Deserialize, Serialize};
use ark_std::{rand::CryptoRng, vec::Vec};
#[derive(Clone, Serialize, Deserialize, Debug, CanonicalSerialize, CanonicalDeserialize)]
pub struct AESOutput {
pub ciphertext: Vec<u8>,
pub nonce: Vec<u8>,
}
const AES_GCM_NONCE_LEN: usize = 12;
#[derive(Debug, PartialEq)]
pub enum Error {
CiphertextTooLarge,
InvalidKey,
BadNonce,
}
pub trait BlockCipherProvider<const N: usize> {
const CIPHER_SUITE: &'static [u8];
type Ciphertext: CanonicalDeserialize + CanonicalSerialize;
fn encrypt<R: Rng + CryptoRng + Sized>(
message: &[u8],
key: [u8; N],
rng: R,
) -> Result<Self::Ciphertext, Error>;
fn decrypt(ciphertext: Self::Ciphertext, key: [u8; N]) -> Result<Vec<u8>, Error>;
}
pub struct AESGCMBlockCipherProvider;
impl BlockCipherProvider<32> for AESGCMBlockCipherProvider {
const CIPHER_SUITE: &'static [u8] = b"AES_GCM_";
type Ciphertext = AESOutput;
fn encrypt<R: Rng + CryptoRng + Sized>(
message: &[u8],
key: [u8; 32],
mut rng: R,
) -> Result<Self::Ciphertext, Error> {
let cipher = Aes256Gcm::new(generic_array::GenericArray::from_slice(&key));
let nonce = Aes256Gcm::generate_nonce(&mut rng);
let mut buffer: Vec<u8> = Vec::new(); buffer.extend_from_slice(message);
cipher
.encrypt_in_place(&nonce, b"", &mut buffer)
.map_err(|_| Error::CiphertextTooLarge)?;
Ok(Self::Ciphertext { ciphertext: buffer, nonce: nonce.to_vec() })
}
fn decrypt(ct: Self::Ciphertext, key: [u8; 32]) -> Result<Vec<u8>, Error> {
let cipher = Aes256Gcm::new_from_slice(&key).map_err(|_| Error::InvalidKey)?;
if ct.nonce.len() != AES_GCM_NONCE_LEN {
return Err(Error::BadNonce);
}
let nonce = Nonce::from_slice(&ct.nonce);
let plaintext =
cipher.decrypt(nonce, ct.ciphertext.as_ref()).map_err(|_| Error::InvalidKey)?;
Ok(plaintext)
}
}
#[cfg(test)]
mod test {
use super::*;
use alloc::vec;
use ark_std::rand::rngs::OsRng;
#[test]
pub fn aes_encrypt_decrypt_works() {
let msg = b"test";
let esk = [2; 32];
match AESGCMBlockCipherProvider::encrypt(msg, esk, OsRng) {
Ok(aes_out) => match AESGCMBlockCipherProvider::decrypt(aes_out, esk) {
Ok(plaintext) => {
assert_eq!(msg.to_vec(), plaintext);
},
Err(_) => {
panic!("test should pass");
},
},
Err(_) => {
panic!("test should pass");
},
}
}
#[test]
pub fn aes_encrypt_decrypt_fails_with_bad_key() {
let msg = b"test";
let esk = [2; 32];
match AESGCMBlockCipherProvider::encrypt(msg, esk, OsRng) {
Ok(aes_out) => {
let bad = AESOutput { ciphertext: aes_out.ciphertext, nonce: aes_out.nonce };
match AESGCMBlockCipherProvider::decrypt(bad, [4; 32]) {
Ok(_) => {
panic!("should be an error");
},
Err(e) => {
assert_eq!(e, Error::InvalidKey);
},
}
},
Err(_) => {
panic!("test should pass");
},
}
}
#[test]
pub fn aes_encrypt_decrypt_fails_with_invalid_nonce() {
let msg = b"test";
let esk = [2; 32];
match AESGCMBlockCipherProvider::encrypt(msg, esk, OsRng) {
Ok(aes_out) => {
let bad = AESOutput {
ciphertext: aes_out.ciphertext,
nonce: vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
};
match AESGCMBlockCipherProvider::decrypt(bad, esk) {
Ok(_) => {
panic!("should be an error");
},
Err(e) => {
assert_eq!(e, Error::InvalidKey);
},
}
},
Err(_) => {
panic!("test should pass");
},
}
}
#[test]
pub fn aes_encrypt_decrypt_fails_with_bad_length_nonce() {
let msg = b"test";
let esk = [2; 32];
match AESGCMBlockCipherProvider::encrypt(msg, esk, OsRng) {
Ok(aes_out) => {
let bad = AESOutput {
ciphertext: aes_out.ciphertext,
nonce: vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0,
],
};
match AESGCMBlockCipherProvider::decrypt(bad, esk) {
Ok(_) => {
panic!("should be an error");
},
Err(e) => {
assert_eq!(e, Error::BadNonce);
},
}
},
Err(_) => {
panic!("test should pass");
},
}
}
}