use anyhow::{anyhow, Result};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use chacha20poly1305::aead::{Aead, KeyInit};
use chacha20poly1305::{XChaCha20Poly1305, XNonce};
use rand::RngCore;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const KEY_SIZE: usize = 32;
pub const NONCE_SIZE: usize = 24;
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct TunnelKey {
key: [u8; KEY_SIZE],
}
impl TunnelKey {
pub fn generate() -> Self {
let mut key = [0u8; KEY_SIZE];
rand::rngs::OsRng.fill_bytes(&mut key);
Self { key }
}
pub fn from_bytes(bytes: [u8; KEY_SIZE]) -> Self {
Self { key: bytes }
}
pub fn from_base64(encoded: &str) -> Result<Self> {
let bytes = URL_SAFE_NO_PAD
.decode(encoded)
.map_err(|e| anyhow!("Invalid base64: {}", e))?;
if bytes.len() != KEY_SIZE {
return Err(anyhow!(
"Invalid key length: expected {}, got {}",
KEY_SIZE,
bytes.len()
));
}
let mut key = [0u8; KEY_SIZE];
key.copy_from_slice(&bytes);
Ok(Self { key })
}
pub fn to_base64(&self) -> String {
URL_SAFE_NO_PAD.encode(self.key)
}
pub fn as_bytes(&self) -> &[u8; KEY_SIZE] {
&self.key
}
}
impl std::fmt::Debug for TunnelKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TunnelKey")
.field("key", &"[REDACTED]")
.finish()
}
}
pub struct TunnelCrypto {
cipher: XChaCha20Poly1305,
}
impl TunnelCrypto {
pub fn new(key: &TunnelKey) -> Self {
let cipher = XChaCha20Poly1305::new(key.as_bytes().into());
Self { cipher }
}
pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
let nonce = Self::generate_nonce();
let ciphertext = self
.cipher
.encrypt(XNonce::from_slice(&nonce), plaintext)
.map_err(|_| anyhow!("Encryption failed"))?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce);
result.extend(ciphertext);
Ok(result)
}
pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < NONCE_SIZE + 16 {
return Err(anyhow!(
"Ciphertext too short: need at least {} bytes, got {}",
NONCE_SIZE + 16,
data.len()
));
}
let (nonce, ciphertext) = data.split_at(NONCE_SIZE);
self.cipher
.decrypt(XNonce::from_slice(nonce), ciphertext)
.map_err(|_| anyhow!("Decryption failed: authentication tag mismatch"))
}
fn generate_nonce() -> [u8; NONCE_SIZE] {
let mut nonce = [0u8; NONCE_SIZE];
rand::rngs::OsRng.fill_bytes(&mut nonce);
nonce
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_generation() {
let key1 = TunnelKey::generate();
let key2 = TunnelKey::generate();
assert_ne!(key1.as_bytes(), key2.as_bytes());
}
#[test]
fn test_key_base64_roundtrip() {
let key = TunnelKey::generate();
let encoded = key.to_base64();
let decoded = TunnelKey::from_base64(&encoded).unwrap();
assert_eq!(key.as_bytes(), decoded.as_bytes());
}
#[test]
fn test_key_base64_length() {
let key = TunnelKey::generate();
let encoded = key.to_base64();
assert_eq!(encoded.len(), 43);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let plaintext = b"Hello, World! This is a test message.";
let ciphertext = crypto.encrypt(plaintext).unwrap();
let decrypted = crypto.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_encrypt_produces_different_ciphertext() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let plaintext = b"Same message";
let ciphertext1 = crypto.encrypt(plaintext).unwrap();
let ciphertext2 = crypto.encrypt(plaintext).unwrap();
assert_ne!(ciphertext1, ciphertext2);
}
#[test]
fn test_decrypt_wrong_key_fails() {
let key1 = TunnelKey::generate();
let key2 = TunnelKey::generate();
let crypto1 = TunnelCrypto::new(&key1);
let crypto2 = TunnelCrypto::new(&key2);
let plaintext = b"Secret message";
let ciphertext = crypto1.encrypt(plaintext).unwrap();
let result = crypto2.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_decrypt_tampered_data_fails() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let plaintext = b"Original message";
let mut ciphertext = crypto.encrypt(plaintext).unwrap();
if let Some(byte) = ciphertext.get_mut(NONCE_SIZE + 5) {
*byte ^= 0xFF;
}
let result = crypto.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn test_decrypt_too_short_fails() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let short_data = vec![0u8; 30];
let result = crypto.decrypt(&short_data);
assert!(result.is_err());
}
#[test]
fn test_empty_plaintext() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let plaintext = b"";
let ciphertext = crypto.encrypt(plaintext).unwrap();
let decrypted = crypto.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext.as_slice(), decrypted.as_slice());
}
#[test]
fn test_large_plaintext() {
let key = TunnelKey::generate();
let crypto = TunnelCrypto::new(&key);
let plaintext = vec![0xAB; 1024 * 1024];
let ciphertext = crypto.encrypt(&plaintext).unwrap();
let decrypted = crypto.decrypt(&ciphertext).unwrap();
assert_eq!(plaintext, decrypted);
}
}