mod tag;
use crate::tag::{TagDecoder, TagEncoder};
use aes_gcm::aead::rand_core::RngCore;
use aes_gcm::aead::{Aead, KeyInit, OsRng, Payload};
use aes_gcm::{Aes256Gcm, Key, Nonce};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum Error {
#[error("AES-GCM encrypt error")]
Encrypt,
#[error("AES-GCM decrypt error")]
Decrypt,
#[error("Unsupported version")]
UnsupportedVersion,
#[error("Unsupported tag")]
UnsupportedTag,
#[error("UTF-8 error")]
Utf8(#[from] std::string::FromUtf8Error),
}
pub fn generate_key() -> Vec<u8> {
Aes256Gcm::generate_key(OsRng).to_vec()
}
pub struct Vault {
cipher: Aes256Gcm,
tag: String,
}
impl Vault {
pub fn new(key: &[u8], tag: &str) -> Self {
Self {
cipher: Aes256Gcm::new(Key::<Aes256Gcm>::from_slice(key)),
tag: tag.to_string(),
}
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, Error> {
let mut iv = [0u8; 12];
OsRng.fill_bytes(&mut iv);
let nonce = Nonce::from_slice(&iv);
let aad = b"AES256GCM";
let ciphertext_with_tag = self
.cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| Error::Encrypt)?;
let (ciphertext, ciphertag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
let encoded_tag = TagEncoder::encode(self.tag.as_bytes());
let mut encoded = Vec::new();
encoded.extend_from_slice(&encoded_tag); encoded.extend_from_slice(&iv); encoded.extend_from_slice(ciphertag); encoded.extend_from_slice(ciphertext);
Ok(encoded)
}
pub fn decrypt(&self, ciphertext: &[u8]) -> Result<String, Error> {
let (tag, remainder) =
TagDecoder::decode(ciphertext).map_err(|_| Error::UnsupportedVersion)?;
if tag != self.tag.as_bytes() {
return Err(Error::UnsupportedTag);
}
let iv = &remainder[..12]; let ciphertag = &remainder[12..28]; let ciphertext = &remainder[28..];
let mut combined_ciphertext = Vec::new();
combined_ciphertext.extend_from_slice(ciphertext);
combined_ciphertext.extend_from_slice(ciphertag);
let nonce = Nonce::from_slice(&iv);
let aad = b"AES256GCM";
let plaintext = self
.cipher
.decrypt(
nonce,
Payload {
msg: &combined_ciphertext,
aad,
},
)
.map_err(|_| Error::Decrypt)?;
Ok(String::from_utf8(plaintext)?)
}
}