use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce};
use base64::{engine::general_purpose::STANDARD as B64, Engine};
use pbkdf2::pbkdf2_hmac_array;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use dynamic_waas_sdk_core::{Error, Result};
pub const V1: &str = "v1";
pub const V2: &str = "v2";
const ITERATIONS_V1: u32 = 100_000;
const ITERATIONS_V2: u32 = 1_000_000;
const SALT_LEN: usize = 16;
const IV_LEN: usize = 12;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct EncryptedData {
pub salt: String,
pub iv: String,
pub cipher: String,
pub version: String,
}
pub fn encrypt(data: &str, password: &str) -> Result<EncryptedData> {
encrypt_versioned(data, password, V2)
}
pub fn encrypt_versioned(data: &str, password: &str, version: &str) -> Result<EncryptedData> {
let iterations = iterations_for(version)?;
let mut salt = [0u8; SALT_LEN];
rand::rng().fill_bytes(&mut salt);
let mut iv = [0u8; IV_LEN];
rand::rng().fill_bytes(&mut iv);
let key_bytes = pbkdf2_hmac_array::<Sha256, 32>(password.as_bytes(), &salt, iterations);
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key_bytes));
let ciphertext = cipher
.encrypt(Nonce::from_slice(&iv), data.as_bytes())
.map_err(|e| Error::Encryption(format!("AES-GCM encrypt failed: {e}")))?;
Ok(EncryptedData {
salt: B64.encode(salt),
iv: B64.encode(iv),
cipher: B64.encode(ciphertext),
version: version.to_owned(),
})
}
pub fn decrypt(data: &EncryptedData, password: &str) -> Result<String> {
let iterations = iterations_for(&data.version)?;
let salt = B64
.decode(&data.salt)
.map_err(|e| Error::Encryption(format!("salt b64 decode: {e}")))?;
let iv = B64
.decode(&data.iv)
.map_err(|e| Error::Encryption(format!("iv b64 decode: {e}")))?;
let ciphertext = B64
.decode(&data.cipher)
.map_err(|e| Error::Encryption(format!("cipher b64 decode: {e}")))?;
if iv.len() != IV_LEN {
return Err(Error::Encryption(format!(
"iv must be {IV_LEN} bytes, got {}",
iv.len()
)));
}
let key_bytes = pbkdf2_hmac_array::<Sha256, 32>(password.as_bytes(), &salt, iterations);
let cipher = Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(&key_bytes));
let plaintext = cipher
.decrypt(Nonce::from_slice(&iv), ciphertext.as_slice())
.map_err(|e| Error::Encryption(format!("AES-GCM decrypt failed: {e}")))?;
String::from_utf8(plaintext)
.map_err(|e| Error::Encryption(format!("decrypted bytes not UTF-8: {e}")))
}
fn iterations_for(version: &str) -> Result<u32> {
match version {
V1 => Ok(ITERATIONS_V1),
V2 => Ok(ITERATIONS_V2),
other => Err(Error::Encryption(format!(
"unsupported encryption version: {other}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_v2() {
let data = r#"{"keyShareId":"abc","secretShare":"def"}"#;
let enc = encrypt(data, "mypassword").unwrap();
assert_eq!(enc.version, V2);
let dec = decrypt(&enc, "mypassword").unwrap();
assert_eq!(dec, data);
}
#[test]
fn roundtrip_v1() {
let data = "hello";
let enc = encrypt_versioned(data, "pw", V1).unwrap();
assert_eq!(enc.version, V1);
assert_eq!(decrypt(&enc, "pw").unwrap(), data);
}
#[test]
fn wrong_password_fails() {
let enc = encrypt("data", "right").unwrap();
assert!(decrypt(&enc, "wrong").is_err());
}
}