mod aes_gcm;
mod xchacha20;
mod key;
pub use aes_gcm::{encrypt_aes_gcm, decrypt_aes_gcm};
pub use xchacha20::{encrypt_xchacha20, decrypt_xchacha20};
pub use key::{generate_key, derive_key_hkdf, derive_key_pbkdf2, Key};
use crate::{Error, Result, MAGIC_ENCRYPTED, FORMAT_VERSION};
use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
use zeroize::Zeroize;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[repr(u8)]
pub enum Algorithm {
Aes256Gcm = 0x01,
XChaCha20Poly1305 = 0x02,
}
impl Algorithm {
pub fn from_byte(byte: u8) -> Result<Self> {
match byte {
0x01 => Ok(Algorithm::Aes256Gcm),
0x02 => Ok(Algorithm::XChaCha20Poly1305),
_ => Err(Error::UnsupportedAlgorithm(byte)),
}
}
pub fn nonce_len(&self) -> usize {
match self {
Algorithm::Aes256Gcm => 12,
Algorithm::XChaCha20Poly1305 => 24,
}
}
pub fn name(&self) -> &'static str {
match self {
Algorithm::Aes256Gcm => "aes-256-gcm",
Algorithm::XChaCha20Poly1305 => "xchacha20-poly1305",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncryptionResult {
pub ciphertext: Vec<u8>,
pub algorithm: Algorithm,
pub nonce: Vec<u8>,
pub tag: Vec<u8>,
}
impl EncryptionResult {
pub fn to_bytes(&self) -> Vec<u8> {
let nonce_len = self.nonce.len();
let total_len = 8 + nonce_len + self.ciphertext.len() + self.tag.len();
let mut buf = Vec::with_capacity(total_len);
buf.extend_from_slice(MAGIC_ENCRYPTED);
buf.push(FORMAT_VERSION);
buf.push(self.algorithm as u8);
buf.push(nonce_len as u8);
buf.push(0x00);
buf.extend_from_slice(&self.nonce);
buf.extend_from_slice(&self.ciphertext);
buf.extend_from_slice(&self.tag);
buf
}
pub fn from_bytes(data: &[u8]) -> Result<Self> {
if data.len() < 8 {
return Err(Error::TruncatedPayload {
expected: 8,
actual: data.len(),
});
}
if &data[0..4] != MAGIC_ENCRYPTED {
return Err(Error::InvalidFormat);
}
let version = data[4];
if version != FORMAT_VERSION {
return Err(Error::UnsupportedVersion(version));
}
let algorithm = Algorithm::from_byte(data[5])?;
let nonce_len = data[6] as usize;
if nonce_len != algorithm.nonce_len() {
return Err(Error::InvalidNonceLength {
expected: algorithm.nonce_len(),
actual: nonce_len,
});
}
let min_size = 8 + nonce_len + 16; if data.len() < min_size {
return Err(Error::TruncatedPayload {
expected: min_size,
actual: data.len(),
});
}
let nonce = data[8..8 + nonce_len].to_vec();
let tag = data[data.len() - 16..].to_vec();
let ciphertext = data[8 + nonce_len..data.len() - 16].to_vec();
Ok(EncryptionResult {
ciphertext,
algorithm,
nonce,
tag,
})
}
pub fn to_json(&self) -> Result<String> {
#[derive(Serialize)]
struct JsonFormat<'a> {
v: &'static str,
alg: &'a str,
nonce: String,
ct: String,
tag: String,
}
use base64::{Engine, engine::general_purpose::STANDARD};
let json = JsonFormat {
v: "1.0",
alg: self.algorithm.name(),
nonce: STANDARD.encode(&self.nonce),
ct: STANDARD.encode(&self.ciphertext),
tag: STANDARD.encode(&self.tag),
};
serde_json::to_string(&json).map_err(|e| Error::SerializationError(e.to_string()))
}
pub fn from_json(json: &str) -> Result<Self> {
#[derive(Deserialize)]
struct JsonFormat {
v: String,
alg: String,
nonce: String,
ct: String,
tag: String,
}
let parsed: JsonFormat = serde_json::from_str(json)?;
if parsed.v != "1.0" {
return Err(Error::UnsupportedVersion(0));
}
let algorithm = match parsed.alg.as_str() {
"aes-256-gcm" => Algorithm::Aes256Gcm,
"xchacha20-poly1305" => Algorithm::XChaCha20Poly1305,
_ => return Err(Error::UnsupportedAlgorithm(0)),
};
use base64::{Engine, engine::general_purpose::STANDARD};
Ok(EncryptionResult {
ciphertext: STANDARD.decode(&parsed.ct)?,
algorithm,
nonce: STANDARD.decode(&parsed.nonce)?,
tag: STANDARD.decode(&parsed.tag)?,
})
}
}
impl Drop for EncryptionResult {
fn drop(&mut self) {
self.ciphertext.zeroize();
self.nonce.zeroize();
self.tag.zeroize();
}
}
#[derive(Debug, Clone, Default)]
pub struct EncryptOptions {
pub algorithm: Option<Algorithm>,
pub aad: Option<Vec<u8>>,
}
pub fn encrypt(plaintext: &[u8], key: &Key, options: Option<EncryptOptions>) -> Result<EncryptionResult> {
let opts = options.unwrap_or_default();
let algorithm = opts.algorithm.unwrap_or(Algorithm::Aes256Gcm);
let aad = opts.aad.as_deref().unwrap_or(&[]);
match algorithm {
Algorithm::Aes256Gcm => encrypt_aes_gcm(plaintext, key, aad),
Algorithm::XChaCha20Poly1305 => encrypt_xchacha20(plaintext, key, aad),
}
}
pub fn decrypt(encrypted: &EncryptionResult, key: &Key) -> Result<Vec<u8>> {
decrypt_with_aad(encrypted, key, &[])
}
pub fn decrypt_with_aad(encrypted: &EncryptionResult, key: &Key, aad: &[u8]) -> Result<Vec<u8>> {
match encrypted.algorithm {
Algorithm::Aes256Gcm => decrypt_aes_gcm(encrypted, key, aad),
Algorithm::XChaCha20Poly1305 => decrypt_xchacha20(encrypted, key, aad),
}
}
pub fn decrypt_bytes(data: &[u8], key: &Key) -> Result<Vec<u8>> {
let encrypted = EncryptionResult::from_bytes(data)?;
decrypt(&encrypted, key)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = generate_key();
let plaintext = b"Hello, World!";
let encrypted = encrypt(plaintext, &key, None).unwrap();
let decrypted = decrypt(&encrypted, &key).unwrap();
assert_eq!(plaintext, &decrypted[..]);
}
#[test]
fn test_binary_serialization() {
let key = generate_key();
let plaintext = b"Test data for serialization";
let encrypted = encrypt(plaintext, &key, None).unwrap();
let bytes = encrypted.to_bytes();
let restored = EncryptionResult::from_bytes(&bytes).unwrap();
assert_eq!(encrypted.algorithm, restored.algorithm);
assert_eq!(encrypted.nonce, restored.nonce);
assert_eq!(encrypted.ciphertext, restored.ciphertext);
assert_eq!(encrypted.tag, restored.tag);
}
#[test]
fn test_json_serialization() {
let key = generate_key();
let plaintext = b"Test data for JSON";
let encrypted = encrypt(plaintext, &key, None).unwrap();
let json = encrypted.to_json().unwrap();
let restored = EncryptionResult::from_json(&json).unwrap();
assert_eq!(encrypted.algorithm, restored.algorithm);
assert_eq!(encrypted.nonce, restored.nonce);
assert_eq!(encrypted.ciphertext, restored.ciphertext);
assert_eq!(encrypted.tag, restored.tag);
}
}