use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use base64::Engine;
use crate::error::{BatataError, Result};
pub const CIPHER_PREFIX: &str = "cipher-";
pub fn is_encrypted_data_id(data_id: &str) -> bool {
data_id.starts_with(CIPHER_PREFIX)
}
pub fn strip_cipher_prefix(data_id: &str) -> &str {
data_id.strip_prefix(CIPHER_PREFIX).unwrap_or(data_id)
}
pub fn decrypt_content(ciphertext: &str, data_key: &[u8]) -> Result<String> {
if data_key.len() != 32 {
return Err(BatataError::EncryptionError {
message: format!("Invalid key length: expected 32 bytes, got {}", data_key.len()),
});
}
let ciphertext_bytes = base64::engine::general_purpose::STANDARD
.decode(ciphertext)
.map_err(|e| BatataError::EncryptionError {
message: format!("Failed to decode ciphertext: {}", e),
})?;
if ciphertext_bytes.len() < 12 {
return Err(BatataError::EncryptionError {
message: "Ciphertext too short: missing nonce".to_string(),
});
}
let (nonce_bytes, encrypted_data) = ciphertext_bytes.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let cipher = Aes256Gcm::new_from_slice(data_key).map_err(|e| BatataError::EncryptionError {
message: format!("Failed to create cipher: {}", e),
})?;
let plaintext = cipher
.decrypt(nonce, encrypted_data)
.map_err(|e| BatataError::EncryptionError {
message: format!("Decryption failed: {}", e),
})?;
String::from_utf8(plaintext).map_err(|e| BatataError::EncryptionError {
message: format!("Invalid UTF-8 in decrypted content: {}", e),
})
}
pub fn encrypt_content(plaintext: &str, data_key: &[u8]) -> Result<String> {
if data_key.len() != 32 {
return Err(BatataError::EncryptionError {
message: format!("Invalid key length: expected 32 bytes, got {}", data_key.len()),
});
}
let cipher = Aes256Gcm::new_from_slice(data_key).map_err(|e| BatataError::EncryptionError {
message: format!("Failed to create cipher: {}", e),
})?;
let nonce_bytes: [u8; 12] = rand_nonce();
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext.as_bytes())
.map_err(|e| BatataError::EncryptionError {
message: format!("Encryption failed: {}", e),
})?;
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend(ciphertext);
Ok(base64::engine::general_purpose::STANDARD.encode(result))
}
fn rand_nonce() -> [u8; 12] {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let mut nonce = [0u8; 12];
let bytes = nanos.to_le_bytes();
nonce[..8].copy_from_slice(&bytes[..8]);
let ptr = &nonce as *const _ as usize;
let extra = (ptr ^ (nanos as usize)).to_le_bytes();
nonce[8..12].copy_from_slice(&extra[..4]);
nonce
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_encrypted_data_id() {
assert!(is_encrypted_data_id("cipher-test-config"));
assert!(!is_encrypted_data_id("test-config"));
assert!(!is_encrypted_data_id("test-cipher-config"));
}
#[test]
fn test_strip_cipher_prefix() {
assert_eq!(strip_cipher_prefix("cipher-test-config"), "test-config");
assert_eq!(strip_cipher_prefix("test-config"), "test-config");
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = [0u8; 32]; let plaintext = "Hello, World! This is a test message.";
let ciphertext = encrypt_content(plaintext, &key).unwrap();
let decrypted = decrypt_content(&ciphertext, &key).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_decrypt_invalid_key_length() {
let key = [0u8; 16]; let ciphertext = "dGVzdA==";
let result = decrypt_content(ciphertext, &key);
assert!(result.is_err());
}
#[test]
fn test_decrypt_invalid_ciphertext() {
let key = [0u8; 32];
let ciphertext = "!!!invalid!!!";
let result = decrypt_content(ciphertext, &key);
assert!(result.is_err());
}
}