use aes_gcm::aead::{Aead, KeyInit};
use aes_gcm::{Aes256Gcm, Nonce};
use base64::Engine as _;
use hmac::{Hmac, Mac};
use rand::rngs::OsRng;
use rand::RngCore;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use crate::errors::AppError;
const REFRESH_TOKEN_HMAC_INFO: &[u8] = b"cedros:refresh_token_hash:v1";
const TOKEN_CIPHER_KEY_INFO: &[u8] = b"cedros:token_cipher_aes256:v1";
pub fn hash_refresh_token(token: &str, secret: &str) -> String {
let mut key_hasher = Sha256::new();
key_hasher.update(secret.as_bytes());
key_hasher.update(REFRESH_TOKEN_HMAC_INFO);
let hmac_key = key_hasher.finalize();
let mut mac: Hmac<Sha256> =
Mac::new_from_slice(&hmac_key).expect("HMAC can take key of any size");
mac.update(token.as_bytes());
let result = mac.finalize();
hex::encode(result.into_bytes())
}
pub fn verify_refresh_token(token: &str, secret: &str, stored_hash: &str) -> bool {
let computed_hash = hash_refresh_token(token, secret);
computed_hash
.as_bytes()
.ct_eq(stored_hash.as_bytes())
.into()
}
#[derive(Clone, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
pub struct TokenCipher {
key: [u8; 32],
legacy_key: [u8; 32],
}
impl std::fmt::Debug for TokenCipher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenCipher")
.field("key", &"[REDACTED]")
.field("legacy_key", &"[REDACTED]")
.finish()
}
}
impl TokenCipher {
pub fn new(secret: &str) -> Self {
let mut mac: Hmac<Sha256> =
Mac::new_from_slice(secret.as_bytes()).expect("HMAC can take key of any size");
mac.update(TOKEN_CIPHER_KEY_INFO);
let result = mac.finalize();
let mut key = [0u8; 32];
key.copy_from_slice(&result.into_bytes());
let legacy_digest = Sha256::digest(secret.as_bytes());
let mut legacy_key = [0u8; 32];
legacy_key.copy_from_slice(&legacy_digest);
Self { key, legacy_key }
}
pub fn encrypt(&self, token: &str) -> Result<String, AppError> {
let cipher = Aes256Gcm::new_from_slice(&self.key)
.expect("TokenCipher key length is fixed at 32 bytes");
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
#[allow(deprecated)]
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, token.as_bytes())
.map_err(|_| AppError::Internal(anyhow::anyhow!("Failed to encrypt token")))?;
let mut combined = Vec::with_capacity(nonce_bytes.len() + ciphertext.len());
combined.extend_from_slice(&nonce_bytes);
combined.extend_from_slice(&ciphertext);
Ok(base64::engine::general_purpose::STANDARD.encode(combined))
}
pub fn decrypt(&self, encoded: &str) -> Result<String, AppError> {
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded)
.map_err(|_| AppError::Internal(anyhow::anyhow!("Failed to decode token payload")))?;
if decoded.len() < 12 {
return Err(AppError::Internal(anyhow::anyhow!(
"Invalid token payload length"
)));
}
let (nonce_bytes, ciphertext) = decoded.split_at(12);
let cipher = Aes256Gcm::new_from_slice(&self.key)
.expect("TokenCipher key length is fixed at 32 bytes");
#[allow(deprecated)]
let nonce = Nonce::from_slice(nonce_bytes);
if let Ok(plaintext) = cipher.decrypt(nonce, ciphertext) {
return String::from_utf8(plaintext).map_err(|_| {
AppError::Internal(anyhow::anyhow!("Invalid token payload encoding"))
});
}
let legacy_cipher = Aes256Gcm::new_from_slice(&self.legacy_key)
.expect("TokenCipher legacy key length is fixed at 32 bytes");
let plaintext = legacy_cipher
.decrypt(nonce, ciphertext)
.map_err(|_| AppError::Internal(anyhow::anyhow!("Failed to decrypt token")))?;
String::from_utf8(plaintext)
.map_err(|_| AppError::Internal(anyhow::anyhow!("Invalid token payload encoding")))
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &str = "test-secret-key-for-hmac-operations";
#[test]
fn test_hash_refresh_token_deterministic() {
let token = "test-refresh-token-123";
let hash1 = hash_refresh_token(token, TEST_SECRET);
let hash2 = hash_refresh_token(token, TEST_SECRET);
assert_eq!(hash1, hash2);
}
#[test]
fn test_hash_refresh_token_different_inputs() {
let hash1 = hash_refresh_token("token1", TEST_SECRET);
let hash2 = hash_refresh_token("token2", TEST_SECRET);
assert_ne!(hash1, hash2);
}
#[test]
fn test_hash_refresh_token_different_secrets() {
let token = "same-token";
let hash1 = hash_refresh_token(token, "secret1");
let hash2 = hash_refresh_token(token, "secret2");
assert_ne!(
hash1, hash2,
"Different secrets should produce different hashes"
);
}
#[test]
fn test_hash_refresh_token_length() {
let hash = hash_refresh_token("any-token", TEST_SECRET);
assert_eq!(hash.len(), 64);
}
#[test]
fn test_hash_refresh_token_hex_format() {
let hash = hash_refresh_token("test", TEST_SECRET);
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_verify_refresh_token_valid() {
let token = "my-refresh-token";
let hash = hash_refresh_token(token, TEST_SECRET);
assert!(verify_refresh_token(token, TEST_SECRET, &hash));
}
#[test]
fn test_verify_refresh_token_invalid_token() {
let token = "my-refresh-token";
let hash = hash_refresh_token(token, TEST_SECRET);
assert!(!verify_refresh_token("wrong-token", TEST_SECRET, &hash));
}
#[test]
fn test_verify_refresh_token_invalid_secret() {
let token = "my-refresh-token";
let hash = hash_refresh_token(token, TEST_SECRET);
assert!(!verify_refresh_token(token, "wrong-secret", &hash));
}
#[test]
fn test_verify_refresh_token_invalid_hash() {
let token = "my-refresh-token";
assert!(!verify_refresh_token(token, TEST_SECRET, "invalid-hash"));
}
#[test]
fn test_token_cipher_roundtrip() {
let cipher = TokenCipher::new("test-secret");
let token = "token-value-123";
let encrypted = cipher.encrypt(token).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, token);
}
#[test]
fn test_token_cipher_rejects_invalid_payload() {
let cipher = TokenCipher::new("test-secret");
let result = cipher.decrypt("not-base64");
assert!(result.is_err());
}
#[test]
fn test_token_cipher_legacy_decryption() {
let secret = "test-secret";
let token = "legacy-token-123";
let legacy_digest = Sha256::digest(secret.as_bytes());
let legacy_cipher =
Aes256Gcm::new_from_slice(&legacy_digest).expect("SHA-256 digest is 32 bytes");
let mut nonce_bytes = [0u8; 12];
OsRng.fill_bytes(&mut nonce_bytes);
#[allow(deprecated)]
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = legacy_cipher.encrypt(nonce, token.as_bytes()).unwrap();
let mut combined = Vec::with_capacity(nonce_bytes.len() + ciphertext.len());
combined.extend_from_slice(&nonce_bytes);
combined.extend_from_slice(&ciphertext);
let legacy_encrypted = base64::engine::general_purpose::STANDARD.encode(combined);
let cipher = TokenCipher::new(secret);
let decrypted = cipher.decrypt(&legacy_encrypted).unwrap();
assert_eq!(decrypted, token);
}
#[test]
fn test_token_cipher_uses_new_key_for_encryption() {
let cipher = TokenCipher::new("test-secret");
let token = "new-token-123";
let encrypted = cipher.encrypt(token).unwrap();
let decrypted = cipher.decrypt(&encrypted).unwrap();
assert_eq!(decrypted, token);
}
#[test]
fn test_token_cipher_key_depends_on_secret() {
let a = TokenCipher::new("secret-a");
let b = TokenCipher::new("secret-b");
assert_ne!(
a.encrypt("token").unwrap(),
b.encrypt("token").unwrap(),
"Different secrets should produce different ciphertexts"
);
}
}