use aes_gcm::{
aead::{Aead, KeyInit, Payload},
Aes256Gcm, Key, Nonce,
};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use chacha20poly1305::XChaCha20Poly1305;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use crate::algorithm::{AlgorithmPolicy, CryptoAlgorithm};
use crate::error::DataError;
use crate::kms::KeyProvider;
const ENVELOPE_VERSION: &str = "1";
#[must_use]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvelopeEncrypted {
pub version: String,
pub algorithm: String,
pub key_alias: String,
pub key_version: String,
pub wrapped_data_key: Vec<u8>,
pub nonce: Vec<u8>,
pub ciphertext: Vec<u8>,
pub aad: Vec<u8>,
}
pub async fn encrypt_for_storage<P: KeyProvider>(
plaintext: &[u8],
key_alias: &str,
provider: &P,
) -> Result<EnvelopeEncrypted, DataError> {
encrypt_with_policy(plaintext, key_alias, provider, &AlgorithmPolicy::default()).await
}
pub async fn encrypt_with_policy<P: KeyProvider>(
plaintext: &[u8],
key_alias: &str,
provider: &P,
policy: &AlgorithmPolicy,
) -> Result<EnvelopeEncrypted, DataError> {
policy.validate()?;
let algorithm = policy.preferred();
let (dek, wrapped_data_key, key_version) = provider.generate_data_key(key_alias).await?;
let nonce_len = algorithm.nonce_len();
let mut nonce_bytes = vec![0u8; nonce_len];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let aad = build_aad(
ENVELOPE_VERSION,
algorithm.as_str(),
key_alias,
&key_version,
);
let ciphertext = encrypt_with_algorithm(algorithm, &dek, &nonce_bytes, plaintext, &aad)?;
Ok(EnvelopeEncrypted {
version: ENVELOPE_VERSION.to_string(),
algorithm: algorithm.as_str().to_string(),
key_alias: key_alias.to_string(),
key_version,
wrapped_data_key,
nonce: nonce_bytes,
ciphertext,
aad,
})
}
pub async fn decrypt_for_use<P: KeyProvider>(
envelope: &EnvelopeEncrypted,
provider: &P,
) -> Result<Vec<u8>, DataError> {
let algorithm = CryptoAlgorithm::from_envelope_str(&envelope.algorithm)?;
let dek = provider
.unwrap_data_key(
&envelope.wrapped_data_key,
&envelope.key_alias,
&envelope.key_version,
)
.await?;
let expected_nonce_len = algorithm.nonce_len();
if envelope.nonce.len() != expected_nonce_len {
return Err(DataError::InvalidNonce {
expected: expected_nonce_len,
actual: envelope.nonce.len(),
});
}
let recomputed_aad = build_aad(
&envelope.version,
&envelope.algorithm,
&envelope.key_alias,
&envelope.key_version,
);
if recomputed_aad != envelope.aad {
return Err(DataError::AuthenticationFailure);
}
decrypt_with_algorithm(
algorithm,
&dek,
&envelope.nonce,
&envelope.ciphertext,
&envelope.aad,
)
}
fn aes_cipher_from_dek(dek: &[u8]) -> Result<Aes256Gcm, DataError> {
if dek.len() != 32 {
return Err(DataError::WrappedKeyLengthMismatch);
}
let key = Key::<Aes256Gcm>::from_slice(dek);
Ok(Aes256Gcm::new(key))
}
fn xchacha_cipher_from_dek(dek: &[u8]) -> Result<XChaCha20Poly1305, DataError> {
if dek.len() != 32 {
return Err(DataError::WrappedKeyLengthMismatch);
}
let key = chacha20poly1305::Key::from_slice(dek);
Ok(XChaCha20Poly1305::new(key))
}
fn encrypt_with_algorithm(
algorithm: CryptoAlgorithm,
dek: &[u8],
nonce_bytes: &[u8],
plaintext: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, DataError> {
match algorithm {
CryptoAlgorithm::Aes256Gcm => {
let cipher = aes_cipher_from_dek(dek)?;
let nonce = Nonce::from_slice(nonce_bytes);
let payload = Payload {
msg: plaintext,
aad,
};
cipher
.encrypt(nonce, payload)
.map_err(|_| DataError::EncryptionFailed {
reason: "AES-256-GCM encryption failed".to_string(),
})
}
CryptoAlgorithm::XChaCha20Poly1305 => {
let cipher = xchacha_cipher_from_dek(dek)?;
let nonce = chacha20poly1305::XNonce::from_slice(nonce_bytes);
let payload = Payload {
msg: plaintext,
aad,
};
cipher
.encrypt(nonce, payload)
.map_err(|_| DataError::EncryptionFailed {
reason: "XChaCha20-Poly1305 encryption failed".to_string(),
})
}
}
}
fn decrypt_with_algorithm(
algorithm: CryptoAlgorithm,
dek: &[u8],
nonce_bytes: &[u8],
ciphertext: &[u8],
aad: &[u8],
) -> Result<Vec<u8>, DataError> {
match algorithm {
CryptoAlgorithm::Aes256Gcm => {
let cipher = aes_cipher_from_dek(dek)?;
let nonce = Nonce::from_slice(nonce_bytes);
let payload = Payload {
msg: ciphertext,
aad,
};
cipher
.decrypt(nonce, payload)
.map_err(|_| DataError::AuthenticationFailure)
}
CryptoAlgorithm::XChaCha20Poly1305 => {
let cipher = xchacha_cipher_from_dek(dek)?;
let nonce = chacha20poly1305::XNonce::from_slice(nonce_bytes);
let payload = Payload {
msg: ciphertext,
aad,
};
cipher
.decrypt(nonce, payload)
.map_err(|_| DataError::AuthenticationFailure)
}
}
}
fn build_aad(version: &str, algorithm: &str, key_alias: &str, key_version: &str) -> Vec<u8> {
format!("v={version};alg={algorithm};alias={key_alias};kver={key_version}").into_bytes()
}
#[allow(dead_code)]
fn to_b64(bytes: &[u8]) -> String {
B64.encode(bytes)
}