Skip to main content

better_auth_api/plugins/oauth/
encryption.rs

1//! AES-256-GCM encryption utilities for OAuth tokens.
2//!
3//! When `AccountConfig::encrypt_oauth_tokens` is `true`, access tokens,
4//! refresh tokens, and ID tokens are encrypted before being persisted and
5//! decrypted transparently on read.
6
7use aes_gcm::aead::{Aead, KeyInit, OsRng};
8use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
9use base64::Engine;
10use hkdf::Hkdf;
11use sha2::Sha256;
12
13use better_auth_core::AuthError;
14
15/// Domain separator used for HKDF key derivation.
16const HKDF_INFO: &[u8] = b"better-auth-oauth-token-encryption";
17
18/// Derive a 256-bit key from the auth secret using HKDF-SHA256.
19///
20/// Uses an empty salt (extraction still strengthens the key) and a
21/// domain-specific info string to ensure the derived key is isolated
22/// to OAuth token encryption.
23fn derive_key(secret: &str) -> Key<Aes256Gcm> {
24    let hk = Hkdf::<Sha256>::new(None, secret.as_bytes());
25    let mut okm = [0u8; 32];
26    // info is static so expand will never fail
27    hk.expand(HKDF_INFO, &mut okm)
28        .expect("32 bytes is a valid length for HKDF-SHA256");
29    *Key::<Aes256Gcm>::from_slice(&okm)
30}
31
32/// Encrypt a plaintext string using AES-256-GCM.
33///
34/// Returns a base64-encoded string of `nonce || ciphertext`.
35pub fn encrypt_token(plaintext: &str, secret: &str) -> Result<String, AuthError> {
36    let key = derive_key(secret);
37    let cipher = Aes256Gcm::new(&key);
38    let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
39
40    let ciphertext = cipher
41        .encrypt(&nonce, plaintext.as_bytes())
42        .map_err(|e| AuthError::internal(format!("Token encryption failed: {}", e)))?;
43
44    // Prepend nonce (12 bytes) to ciphertext
45    let mut combined = nonce.to_vec();
46    combined.extend_from_slice(&ciphertext);
47
48    Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
49}
50
51/// Decrypt a base64-encoded `nonce || ciphertext` string using AES-256-GCM.
52pub fn decrypt_token(encrypted: &str, secret: &str) -> Result<String, AuthError> {
53    let key = derive_key(secret);
54    let cipher = Aes256Gcm::new(&key);
55
56    let combined = base64::engine::general_purpose::STANDARD
57        .decode(encrypted)
58        .map_err(|e| AuthError::internal(format!("Token decryption base64 error: {}", e)))?;
59
60    if combined.len() < 12 {
61        return Err(AuthError::internal(
62            "Encrypted token too short (missing nonce)",
63        ));
64    }
65
66    let (nonce_bytes, ciphertext) = combined.split_at(12);
67    let nonce = Nonce::from_slice(nonce_bytes);
68
69    let plaintext = cipher
70        .decrypt(nonce, ciphertext)
71        .map_err(|e| AuthError::internal(format!("Token decryption failed: {}", e)))?;
72
73    String::from_utf8(plaintext)
74        .map_err(|e| AuthError::internal(format!("Decrypted token is not valid UTF-8: {}", e)))
75}
76
77/// Conditionally encrypt a token value. Returns the original value when
78/// encryption is disabled, or the encrypted value when enabled.
79pub fn maybe_encrypt(
80    value: Option<String>,
81    encrypt: bool,
82    secret: &str,
83) -> Result<Option<String>, AuthError> {
84    match (value, encrypt) {
85        (Some(v), true) => Ok(Some(encrypt_token(&v, secret)?)),
86        (v, _) => Ok(v),
87    }
88}
89
90/// Conditionally decrypt a token value. Returns the original value when
91/// encryption is disabled, or the decrypted value when enabled.
92///
93/// When encryption is enabled and decryption fails (e.g. because the token
94/// was stored as plaintext before encryption was turned on), the original
95/// value is returned as-is. This graceful fallback allows enabling
96/// `encrypt_oauth_tokens` on an existing database without breaking reads
97/// for previously stored plaintext tokens.
98pub fn maybe_decrypt(
99    value: Option<&str>,
100    encrypt: bool,
101    secret: &str,
102) -> Result<Option<String>, AuthError> {
103    match (value, encrypt) {
104        (Some(v), true) => match decrypt_token(v, secret) {
105            Ok(decrypted) => Ok(Some(decrypted)),
106            Err(_) => Ok(Some(v.to_string())),
107        },
108        (Some(v), false) => Ok(Some(v.to_string())),
109        (None, _) => Ok(None),
110    }
111}
112
113/// A set of OAuth tokens (access, refresh, id) after conditional encryption.
114pub struct EncryptedTokenSet {
115    pub access_token: Option<String>,
116    pub refresh_token: Option<String>,
117    pub id_token: Option<String>,
118}
119
120/// Read `encrypt_oauth_tokens` and `secret` from the auth context and
121/// conditionally encrypt a full set of OAuth tokens in one call.
122pub fn encrypt_token_set<DB: better_auth_core::DatabaseAdapter>(
123    ctx: &better_auth_core::AuthContext<DB>,
124    access_token: Option<String>,
125    refresh_token: Option<String>,
126    id_token: Option<String>,
127) -> Result<EncryptedTokenSet, AuthError> {
128    let encrypt = ctx.config.account.encrypt_oauth_tokens;
129    let secret = &ctx.config.secret;
130    Ok(EncryptedTokenSet {
131        access_token: maybe_encrypt(access_token, encrypt, secret)?,
132        refresh_token: maybe_encrypt(refresh_token, encrypt, secret)?,
133        id_token: maybe_encrypt(id_token, encrypt, secret)?,
134    })
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_encrypt_decrypt_roundtrip() {
143        let secret = "a]vt!MFX8H-e!4igKa5)Tu.{ec:2$z%n";
144        let plaintext = "ya29.a0AfH6SMBx-some-access-token";
145
146        let encrypted = encrypt_token(plaintext, secret).unwrap();
147        assert_ne!(encrypted, plaintext);
148
149        let decrypted = decrypt_token(&encrypted, secret).unwrap();
150        assert_eq!(decrypted, plaintext);
151    }
152
153    #[test]
154    fn test_maybe_encrypt_none() {
155        let result = maybe_encrypt(None, true, "secret-key-that-is-32-chars-long").unwrap();
156        assert!(result.is_none());
157    }
158
159    #[test]
160    fn test_maybe_encrypt_disabled() {
161        let token = "plain-token".to_string();
162        let result = maybe_encrypt(Some(token.clone()), false, "secret").unwrap();
163        assert_eq!(result, Some(token));
164    }
165
166    #[test]
167    fn test_maybe_decrypt_none() {
168        let result = maybe_decrypt(None, true, "secret-key-that-is-32-chars-long").unwrap();
169        assert!(result.is_none());
170    }
171
172    #[test]
173    fn test_maybe_decrypt_plaintext_fallback() {
174        // Simulate a token that was stored as plaintext before encryption was enabled.
175        // `maybe_decrypt` should gracefully fall back to returning the original value.
176        let plaintext = "ya29.a0AfH6SMBx-some-access-token";
177        let result = maybe_decrypt(Some(plaintext), true, "some-secret").unwrap();
178        assert_eq!(result, Some(plaintext.to_string()));
179    }
180}