use chacha20poly1305::{
ChaCha20Poly1305, Nonce,
aead::{Aead, KeyInit, Payload},
};
use rand::RngCore;
use crate::auth::{AuthError, error::Result};
#[derive(Debug, Clone)]
pub struct EncryptedState {
pub ciphertext: Vec<u8>,
pub nonce: [u8; 12],
}
impl EncryptedState {
pub fn new(ciphertext: Vec<u8>, nonce: [u8; 12]) -> Self {
Self { ciphertext, nonce }
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(12 + self.ciphertext.len());
bytes.extend_from_slice(&self.nonce);
bytes.extend_from_slice(&self.ciphertext);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() < 12 {
return Err(AuthError::InvalidState);
}
let mut nonce = [0u8; 12];
nonce.copy_from_slice(&bytes[0..12]);
let ciphertext = bytes[12..].to_vec();
Ok(Self::new(ciphertext, nonce))
}
}
pub struct StateEncryption {
cipher: ChaCha20Poly1305,
}
impl StateEncryption {
pub fn new(key_bytes: &[u8; 32]) -> Result<Self> {
let cipher =
ChaCha20Poly1305::new_from_slice(key_bytes).map_err(|_| AuthError::ConfigError {
message: "Invalid state encryption key".to_string(),
})?;
Ok(Self { cipher })
}
pub fn encrypt(&self, state: &str) -> Result<EncryptedState> {
let mut nonce_bytes = [0u8; 12];
rand::rngs::OsRng.fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from(nonce_bytes);
let ciphertext =
self.cipher.encrypt(&nonce, Payload::from(state.as_bytes())).map_err(|_| {
AuthError::Internal {
message: "State encryption failed".to_string(),
}
})?;
Ok(EncryptedState::new(ciphertext, nonce_bytes))
}
pub fn decrypt(&self, encrypted: &EncryptedState) -> Result<String> {
let nonce = Nonce::from(encrypted.nonce);
let plaintext = self
.cipher
.decrypt(&nonce, Payload::from(encrypted.ciphertext.as_slice()))
.map_err(|_| AuthError::InvalidState)?;
String::from_utf8(plaintext).map_err(|_| AuthError::InvalidState)
}
pub fn encrypt_to_bytes(&self, state: &str) -> Result<Vec<u8>> {
let encrypted = self.encrypt(state)?;
Ok(encrypted.to_bytes())
}
pub fn decrypt_from_bytes(&self, bytes: &[u8]) -> Result<String> {
let encrypted = EncryptedState::from_bytes(bytes)?;
self.decrypt(&encrypted)
}
}
pub fn generate_state_encryption_key() -> [u8; 32] {
let mut key = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut key);
key
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key() -> [u8; 32] {
[42u8; 32]
}
#[test]
fn test_encrypt_decrypt() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "oauth_state_test_value";
let encrypted = encryption.encrypt(state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_encrypt_produces_ciphertext() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "test_state";
let encrypted = encryption.encrypt(state).expect("Encryption failed");
assert_ne!(encrypted.ciphertext, state.as_bytes());
}
#[test]
fn test_empty_state() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "";
let encrypted = encryption.encrypt(state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_different_keys_fail_decryption() {
let key1 = [42u8; 32];
let key2 = [99u8; 32];
let state = "secret_state";
let encryption1 = StateEncryption::new(&key1).expect("Init 1 failed");
let encrypted = encryption1.encrypt(state).expect("Encryption failed");
let encryption2 = StateEncryption::new(&key2).expect("Init 2 failed");
let result = encryption2.decrypt(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_tampered_ciphertext_fails() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "tamper_test";
let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
if !encrypted.ciphertext.is_empty() {
encrypted.ciphertext[0] ^= 0xFF;
}
let result = encryption.decrypt(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_tampered_nonce_fails() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "nonce_tamper";
let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
encrypted.nonce[0] ^= 0xFF;
let result = encryption.decrypt(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_truncated_ciphertext_fails() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "truncation_test";
let mut encrypted = encryption.encrypt(state).expect("Encryption failed");
if encrypted.ciphertext.len() > 1 {
encrypted.ciphertext.truncate(encrypted.ciphertext.len() - 1);
}
let result = encryption.decrypt(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_serialization() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "serialization_test";
let bytes = encryption.encrypt_to_bytes(state).expect("Encryption failed");
let decrypted = encryption.decrypt_from_bytes(&bytes).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_random_nonces() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "random_nonce_test";
let encrypted1 = encryption.encrypt(state).expect("Encryption 1 failed");
let encrypted2 = encryption.encrypt(state).expect("Encryption 2 failed");
assert_ne!(encrypted1.nonce, encrypted2.nonce);
let decrypted1 = encryption.decrypt(&encrypted1).expect("Decryption 1 failed");
let decrypted2 = encryption.decrypt(&encrypted2).expect("Decryption 2 failed");
assert_eq!(decrypted1, state);
assert_eq!(decrypted2, state);
}
#[test]
fn test_long_state() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "a".repeat(10_000);
let encrypted = encryption.encrypt(&state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_special_characters() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "state:with-special_chars.and/symbols!@#$%^&*()";
let encrypted = encryption.encrypt(state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_unicode_state() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "state_with_emoji_🔐_🔒_🔓_and_emoji";
let encrypted = encryption.encrypt(state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_null_bytes_in_state() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "state_with\x00null\x00bytes\x00";
let encrypted = encryption.encrypt(state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
#[test]
fn test_key_generation() {
let key1 = generate_state_encryption_key();
let key2 = generate_state_encryption_key();
assert_ne!(key1, key2);
assert_eq!(key1.len(), 32);
assert_eq!(key2.len(), 32);
let enc1 = StateEncryption::new(&key1).expect("Init 1 failed");
let enc2 = StateEncryption::new(&key2).expect("Init 2 failed");
let state = "test";
let encrypted1 = enc1.encrypt(state).expect("Encryption 1 failed");
let encrypted2 = enc2.encrypt(state).expect("Encryption 2 failed");
assert_eq!(enc1.decrypt(&encrypted1).expect("Decryption 1 failed"), state);
assert_eq!(enc2.decrypt(&encrypted2).expect("Decryption 2 failed"), state);
}
#[test]
fn test_large_ciphertext() {
let encryption = StateEncryption::new(&test_key()).expect("Init failed");
let state = "x".repeat(100_000);
let encrypted = encryption.encrypt(&state).expect("Encryption failed");
let decrypted = encryption.decrypt(&encrypted).expect("Decryption failed");
assert_eq!(decrypted, state);
}
}