use aes_gcm::{
aead::{Aead, AeadCore, KeyInit, OsRng},
Aes256Gcm, Nonce,
};
use anyhow::{bail, Context, Result};
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD as BASE64, Engine};
use crate::server::encryption::EncryptionProvider;
pub struct LocalEncryptionProvider {
cipher: Aes256Gcm,
}
impl LocalEncryptionProvider {
pub fn new(key_base64: &str) -> Result<Self> {
let key_bytes = BASE64
.decode(key_base64)
.context("Failed to decode encryption key from base64")?;
if key_bytes.len() != 32 {
bail!(
"Encryption key must be 32 bytes (256 bits) for AES-256-GCM, got {} bytes",
key_bytes.len()
);
}
let cipher =
Aes256Gcm::new_from_slice(&key_bytes).context("Failed to create AES-256-GCM cipher")?;
Ok(Self { cipher })
}
}
#[async_trait]
impl EncryptionProvider for LocalEncryptionProvider {
async fn encrypt(&self, plaintext: &str) -> Result<String> {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = self
.cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
let mut combined = nonce.to_vec();
combined.extend_from_slice(&ciphertext);
Ok(BASE64.encode(&combined))
}
async fn decrypt(&self, ciphertext_base64: &str) -> Result<String> {
let combined = BASE64
.decode(ciphertext_base64)
.context("Failed to decode ciphertext from base64")?;
if combined.len() < 12 {
bail!("Invalid ciphertext: too short");
}
let (nonce_bytes, ciphertext) = combined.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext_bytes = self
.cipher
.decrypt(nonce, ciphertext)
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
let plaintext =
String::from_utf8(plaintext_bytes).context("Decrypted data is not valid UTF-8")?;
Ok(plaintext)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_encrypt_decrypt_roundtrip() {
use rand::RngCore;
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key);
let key_base64 = BASE64.encode(key);
let provider = LocalEncryptionProvider::new(&key_base64).unwrap();
let plaintext = "my secret password";
let ciphertext = provider.encrypt(plaintext).await.unwrap();
let decrypted = provider.decrypt(&ciphertext).await.unwrap();
assert_eq!(plaintext, decrypted);
}
#[tokio::test]
async fn test_different_nonces() {
use rand::RngCore;
let mut key = [0u8; 32];
rand::thread_rng().fill_bytes(&mut key);
let key_base64 = BASE64.encode(key);
let provider = LocalEncryptionProvider::new(&key_base64).unwrap();
let plaintext = "same message";
let ciphertext1 = provider.encrypt(plaintext).await.unwrap();
let ciphertext2 = provider.encrypt(plaintext).await.unwrap();
assert_ne!(ciphertext1, ciphertext2);
assert_eq!(provider.decrypt(&ciphertext1).await.unwrap(), plaintext);
assert_eq!(provider.decrypt(&ciphertext2).await.unwrap(), plaintext);
}
#[tokio::test]
async fn test_invalid_key_length() {
let short_key = BASE64.encode(b"tooshort");
let result = LocalEncryptionProvider::new(&short_key);
assert!(result.is_err());
}
}