better_auth_api/plugins/oauth/
encryption.rs1use aes_gcm::aead::{Aead, KeyInit, OsRng};
8use aes_gcm::{AeadCore, Aes256Gcm, Key, Nonce};
9use base64::Engine;
10use hkdf::Hkdf;
11use sha2::Sha256;
12
13use better_auth_core::AuthError;
14
15const HKDF_INFO: &[u8] = b"better-auth-oauth-token-encryption";
17
18fn derive_key(secret: &str) -> Key<Aes256Gcm> {
24 let hk = Hkdf::<Sha256>::new(None, secret.as_bytes());
25 let mut okm = [0u8; 32];
26 hk.expand(HKDF_INFO, &mut okm)
28 .expect("32 bytes is a valid length for HKDF-SHA256");
29 *Key::<Aes256Gcm>::from_slice(&okm)
30}
31
32pub fn encrypt_token(plaintext: &str, secret: &str) -> Result<String, AuthError> {
36 let key = derive_key(secret);
37 let cipher = Aes256Gcm::new(&key);
38 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
39
40 let ciphertext = cipher
41 .encrypt(&nonce, plaintext.as_bytes())
42 .map_err(|e| AuthError::internal(format!("Token encryption failed: {}", e)))?;
43
44 let mut combined = nonce.to_vec();
46 combined.extend_from_slice(&ciphertext);
47
48 Ok(base64::engine::general_purpose::STANDARD.encode(&combined))
49}
50
51pub fn decrypt_token(encrypted: &str, secret: &str) -> Result<String, AuthError> {
53 let key = derive_key(secret);
54 let cipher = Aes256Gcm::new(&key);
55
56 let combined = base64::engine::general_purpose::STANDARD
57 .decode(encrypted)
58 .map_err(|e| AuthError::internal(format!("Token decryption base64 error: {}", e)))?;
59
60 if combined.len() < 12 {
61 return Err(AuthError::internal(
62 "Encrypted token too short (missing nonce)",
63 ));
64 }
65
66 let (nonce_bytes, ciphertext) = combined.split_at(12);
67 let nonce = Nonce::from_slice(nonce_bytes);
68
69 let plaintext = cipher
70 .decrypt(nonce, ciphertext)
71 .map_err(|e| AuthError::internal(format!("Token decryption failed: {}", e)))?;
72
73 String::from_utf8(plaintext)
74 .map_err(|e| AuthError::internal(format!("Decrypted token is not valid UTF-8: {}", e)))
75}
76
77pub fn maybe_encrypt(
80 value: Option<String>,
81 encrypt: bool,
82 secret: &str,
83) -> Result<Option<String>, AuthError> {
84 match (value, encrypt) {
85 (Some(v), true) => Ok(Some(encrypt_token(&v, secret)?)),
86 (v, _) => Ok(v),
87 }
88}
89
90pub fn maybe_decrypt(
99 value: Option<&str>,
100 encrypt: bool,
101 secret: &str,
102) -> Result<Option<String>, AuthError> {
103 match (value, encrypt) {
104 (Some(v), true) => match decrypt_token(v, secret) {
105 Ok(decrypted) => Ok(Some(decrypted)),
106 Err(_) => Ok(Some(v.to_string())),
107 },
108 (Some(v), false) => Ok(Some(v.to_string())),
109 (None, _) => Ok(None),
110 }
111}
112
113pub struct EncryptedTokenSet {
115 pub access_token: Option<String>,
116 pub refresh_token: Option<String>,
117 pub id_token: Option<String>,
118}
119
120pub fn encrypt_token_set<DB: better_auth_core::DatabaseAdapter>(
123 ctx: &better_auth_core::AuthContext<DB>,
124 access_token: Option<String>,
125 refresh_token: Option<String>,
126 id_token: Option<String>,
127) -> Result<EncryptedTokenSet, AuthError> {
128 let encrypt = ctx.config.account.encrypt_oauth_tokens;
129 let secret = &ctx.config.secret;
130 Ok(EncryptedTokenSet {
131 access_token: maybe_encrypt(access_token, encrypt, secret)?,
132 refresh_token: maybe_encrypt(refresh_token, encrypt, secret)?,
133 id_token: maybe_encrypt(id_token, encrypt, secret)?,
134 })
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_encrypt_decrypt_roundtrip() {
143 let secret = "a]vt!MFX8H-e!4igKa5)Tu.{ec:2$z%n";
144 let plaintext = "ya29.a0AfH6SMBx-some-access-token";
145
146 let encrypted = encrypt_token(plaintext, secret).unwrap();
147 assert_ne!(encrypted, plaintext);
148
149 let decrypted = decrypt_token(&encrypted, secret).unwrap();
150 assert_eq!(decrypted, plaintext);
151 }
152
153 #[test]
154 fn test_maybe_encrypt_none() {
155 let result = maybe_encrypt(None, true, "secret-key-that-is-32-chars-long").unwrap();
156 assert!(result.is_none());
157 }
158
159 #[test]
160 fn test_maybe_encrypt_disabled() {
161 let token = "plain-token".to_string();
162 let result = maybe_encrypt(Some(token.clone()), false, "secret").unwrap();
163 assert_eq!(result, Some(token));
164 }
165
166 #[test]
167 fn test_maybe_decrypt_none() {
168 let result = maybe_decrypt(None, true, "secret-key-that-is-32-chars-long").unwrap();
169 assert!(result.is_none());
170 }
171
172 #[test]
173 fn test_maybe_decrypt_plaintext_fallback() {
174 let plaintext = "ya29.a0AfH6SMBx-some-access-token";
177 let result = maybe_decrypt(Some(plaintext), true, "some-secret").unwrap();
178 assert_eq!(result, Some(plaintext.to_string()));
179 }
180}