use aes_gcm::aead::{Aead, AeadCore, KeyInit, OsRng};
use aes_gcm::{Aes256Gcm, Nonce};
use base64::{Engine, engine::general_purpose::STANDARD};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EncryptedRefreshToken(String);
impl EncryptedRefreshToken {
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub fn into_inner(self) -> String {
self.0
}
#[must_use]
pub fn from_stored(ciphertext: String) -> Self {
Self(ciphertext)
}
}
impl std::fmt::Display for EncryptedRefreshToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
const NONCE_SIZE: usize = 12;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum CipherError {
#[error("key decode failed: {0}")]
KeyDecode(String),
#[error("key must be 32 bytes after base64 decode, got {0}")]
KeyWrongLength(usize),
#[error("encrypt failed: {0}")]
Encrypt(String),
#[error("ciphertext decode failed: {0}")]
CiphertextDecode(String),
#[error("ciphertext shorter than nonce")]
CiphertextTruncated,
#[error("nonce size mismatch")]
NonceSize,
#[error("decrypt failed: {0}")]
Decrypt(String),
#[error("plaintext is not valid utf-8: {0}")]
PlaintextUtf8(String),
}
#[derive(Clone)]
pub struct TokenCipher {
cipher: Aes256Gcm,
}
impl TokenCipher {
pub fn from_base64_key(key_b64: &str) -> Result<Self, CipherError> {
let bytes = STANDARD
.decode(key_b64.trim())
.map_err(|e| CipherError::KeyDecode(e.to_string()))?;
let cipher = Aes256Gcm::new_from_slice(&bytes)
.map_err(|_| CipherError::KeyWrongLength(bytes.len()))?;
Ok(Self { cipher })
}
pub fn encrypt(&self, plaintext: &str) -> Result<String, CipherError> {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = self
.cipher
.encrypt(&nonce, plaintext.as_bytes())
.map_err(|e| CipherError::Encrypt(e.to_string()))?;
let nonce_bytes: &[u8] = nonce.as_ref();
let mut combined = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
combined.extend_from_slice(nonce_bytes);
combined.extend_from_slice(&ciphertext);
Ok(STANDARD.encode(&combined))
}
pub fn encrypt_to_token(&self, plaintext: &str) -> Result<EncryptedRefreshToken, CipherError> {
Ok(EncryptedRefreshToken(self.encrypt(plaintext)?))
}
pub fn decrypt(&self, ciphertext_b64: &str) -> Result<String, CipherError> {
let combined = STANDARD
.decode(ciphertext_b64.trim())
.map_err(|e| CipherError::CiphertextDecode(e.to_string()))?;
if combined.len() <= NONCE_SIZE {
return Err(CipherError::CiphertextTruncated);
}
let (nonce_bytes, ct) = combined.split_at(NONCE_SIZE);
let nonce_array: [u8; NONCE_SIZE] =
nonce_bytes.try_into().map_err(|_| CipherError::NonceSize)?;
let nonce = Nonce::from(nonce_array);
let plaintext = self
.cipher
.decrypt(&nonce, ct)
.map_err(|e| CipherError::Decrypt(e.to_string()))?;
String::from_utf8(plaintext).map_err(|e| CipherError::PlaintextUtf8(e.to_string()))
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn zero_key() -> String {
STANDARD.encode([0u8; 32])
}
fn other_key() -> String {
let mut key = [0u8; 32];
key[0] = 1;
STANDARD.encode(key)
}
#[test]
fn roundtrip_preserves_plaintext() {
let cipher = TokenCipher::from_base64_key(&zero_key()).unwrap();
let plain = "rt_live_abc123def456";
let encrypted = cipher.encrypt(plain).unwrap();
assert_eq!(cipher.decrypt(&encrypted).unwrap(), plain);
}
#[test]
fn encryption_is_randomized() {
let cipher = TokenCipher::from_base64_key(&zero_key()).unwrap();
let plain = "same input";
let a = cipher.encrypt(plain).unwrap();
let b = cipher.encrypt(plain).unwrap();
assert_ne!(a, b);
assert_eq!(cipher.decrypt(&a).unwrap(), plain);
assert_eq!(cipher.decrypt(&b).unwrap(), plain);
}
#[test]
fn wrong_key_fails_authentication() {
let c1 = TokenCipher::from_base64_key(&zero_key()).unwrap();
let c2 = TokenCipher::from_base64_key(&other_key()).unwrap();
let ct = c1.encrypt("secret").unwrap();
assert!(matches!(c2.decrypt(&ct), Err(CipherError::Decrypt(_))));
}
#[test]
fn tampered_ciphertext_fails_authentication() {
let cipher = TokenCipher::from_base64_key(&zero_key()).unwrap();
let ct = cipher.encrypt("secret").unwrap();
let mut bytes = STANDARD.decode(&ct).unwrap();
let last = bytes.len() - 1;
bytes[last] ^= 0x01;
let tampered = STANDARD.encode(&bytes);
assert!(matches!(
cipher.decrypt(&tampered),
Err(CipherError::Decrypt(_))
));
}
#[test]
fn short_input_is_rejected() {
let cipher = TokenCipher::from_base64_key(&zero_key()).unwrap();
let too_short = STANDARD.encode([0u8; 8]);
assert!(matches!(
cipher.decrypt(&too_short),
Err(CipherError::CiphertextTruncated)
));
}
#[test]
fn invalid_base64_key_is_rejected() {
assert!(matches!(
TokenCipher::from_base64_key("not base64!!!"),
Err(CipherError::KeyDecode(_))
));
}
#[test]
fn wrong_length_key_is_rejected() {
let too_short = STANDARD.encode([0u8; 16]);
assert!(matches!(
TokenCipher::from_base64_key(&too_short),
Err(CipherError::KeyWrongLength(16))
));
}
#[test]
fn key_with_trailing_whitespace_still_parses() {
let mut key = zero_key();
key.push('\n');
key.push(' ');
assert!(TokenCipher::from_base64_key(&key).is_ok());
}
#[test]
fn encrypt_to_token_roundtrips_via_from_stored() {
let cipher = TokenCipher::from_base64_key(&zero_key()).unwrap();
let plaintext = "rt_live_xyz789";
let token = cipher.encrypt_to_token(plaintext).unwrap();
let persisted: String = token.as_str().to_string();
let restored = EncryptedRefreshToken::from_stored(persisted);
assert_eq!(cipher.decrypt(restored.as_str()).unwrap(), plaintext);
}
#[test]
fn encrypted_refresh_token_display_matches_inner() {
let token = EncryptedRefreshToken::from_stored("base64-cipher".to_string());
assert_eq!(token.to_string(), "base64-cipher");
assert_eq!(token.as_str(), "base64-cipher");
}
}