1use std::collections::{BTreeMap, HashSet};
4use std::fmt;
5
6use chacha20poly1305::aead::{Aead, AeadCore, KeyInit, OsRng};
7use chacha20poly1305::{Key, XChaCha20Poly1305, XNonce};
8use sha2::{Digest, Sha256};
9
10use super::envelope::{format_envelope, parse_envelope};
11use crate::error::OpenAuthError;
12
13const DEFAULT_SECRET: &str = "better-auth-secret-12345678901234567890";
14
15#[derive(Clone, PartialEq, Eq)]
17pub struct SecretEntry {
18 pub version: u32,
19 pub value: String,
20}
21
22impl SecretEntry {
23 pub fn new(version: u32, value: impl Into<String>) -> Self {
24 Self {
25 version,
26 value: value.into(),
27 }
28 }
29}
30
31impl fmt::Debug for SecretEntry {
32 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
33 formatter
34 .debug_struct("SecretEntry")
35 .field("version", &self.version)
36 .field("value", &"<redacted>")
37 .finish()
38 }
39}
40
41#[derive(Clone, PartialEq, Eq)]
43pub struct SecretConfig {
44 pub keys: BTreeMap<u32, String>,
45 pub current_version: u32,
46 pub legacy_secret: Option<String>,
47}
48
49impl SecretConfig {
50 pub fn new<I, S>(entries: I) -> Self
51 where
52 I: IntoIterator<Item = (u32, S)>,
53 S: Into<String>,
54 {
55 let mut keys = BTreeMap::new();
56 let mut current_version = None;
57 for (version, value) in entries {
58 if current_version.is_none() {
59 current_version = Some(version);
60 }
61 keys.insert(version, value.into());
62 }
63
64 Self {
65 keys,
66 current_version: current_version.unwrap_or(0),
67 legacy_secret: None,
68 }
69 }
70
71 pub fn with_legacy_secret(mut self, secret: impl Into<String>) -> Self {
72 self.legacy_secret = Some(secret.into());
73 self
74 }
75
76 fn current_secret(&self) -> Result<&str, OpenAuthError> {
77 self.keys
78 .get(&self.current_version)
79 .map(String::as_str)
80 .ok_or_else(|| {
81 OpenAuthError::InvalidSecretConfig(format!(
82 "secret version {} not found in keys",
83 self.current_version
84 ))
85 })
86 }
87}
88
89impl fmt::Debug for SecretConfig {
90 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
91 formatter
92 .debug_struct("SecretConfig")
93 .field("key_versions", &self.keys.keys().collect::<Vec<_>>())
94 .field("current_version", &self.current_version)
95 .field(
96 "legacy_secret",
97 &self.legacy_secret.as_ref().map(|_| "<redacted>"),
98 )
99 .finish()
100 }
101}
102
103pub fn parse_secrets_env(value: Option<&str>) -> Result<Option<Vec<SecretEntry>>, OpenAuthError> {
105 let Some(value) = value else {
106 return Ok(None);
107 };
108 if value.trim().is_empty() {
109 return Ok(None);
110 }
111
112 let mut entries = Vec::new();
113 for entry in value.split(',') {
114 let entry = entry.trim();
115 let Some((version, secret)) = entry.split_once(':') else {
116 return Err(OpenAuthError::InvalidSecretConfig(format!(
117 "invalid secret entry `{entry}`; expected `<version>:<secret>`"
118 )));
119 };
120 let version = version.trim().parse::<u32>().map_err(|_| {
121 OpenAuthError::InvalidSecretConfig(format!(
122 "invalid version `{}`; version must be a non-negative integer",
123 version.trim()
124 ))
125 })?;
126 let secret = secret.trim();
127 if secret.is_empty() {
128 return Err(OpenAuthError::InvalidSecretConfig(format!(
129 "empty secret value for version {version}"
130 )));
131 }
132 entries.push(SecretEntry::new(version, secret));
133 }
134
135 Ok(Some(entries))
136}
137
138pub fn validate_secrets(secrets: &[SecretEntry]) -> Result<Vec<String>, OpenAuthError> {
140 if secrets.is_empty() {
141 return Err(OpenAuthError::InvalidSecretConfig(
142 "`secrets` must contain at least one entry".to_owned(),
143 ));
144 }
145
146 let mut seen = HashSet::new();
147 for secret in secrets {
148 if secret.value.is_empty() {
149 return Err(OpenAuthError::InvalidSecretConfig(format!(
150 "empty secret value for version {}",
151 secret.version
152 )));
153 }
154 if !seen.insert(secret.version) {
155 return Err(OpenAuthError::InvalidSecretConfig(format!(
156 "duplicate version {}",
157 secret.version
158 )));
159 }
160 }
161
162 let mut warnings = Vec::new();
163 let current = &secrets[0];
164 if current.value.len() < 32 {
165 warnings.push(format!(
166 "current secret version {} should be at least 32 characters long",
167 current.version
168 ));
169 }
170 if estimate_entropy(¤t.value) < 120.0 {
171 warnings.push("current secret appears low entropy".to_owned());
172 }
173
174 Ok(warnings)
175}
176
177pub fn build_secret_config(
179 secrets: &[SecretEntry],
180 legacy_secret: &str,
181) -> Result<SecretConfig, OpenAuthError> {
182 validate_secrets(secrets)?;
183 let mut config = SecretConfig::new(
184 secrets
185 .iter()
186 .map(|entry| (entry.version, entry.value.clone())),
187 );
188 if !legacy_secret.is_empty() && legacy_secret != DEFAULT_SECRET {
189 config.legacy_secret = Some(legacy_secret.to_owned());
190 }
191 Ok(config)
192}
193
194fn estimate_entropy(value: &str) -> f64 {
195 let unique = value.chars().collect::<HashSet<_>>().len();
196 if unique == 0 {
197 return 0.0;
198 }
199 (unique as f64).log2() * value.chars().count() as f64
200}
201
202fn derive_key(secret: &str) -> [u8; 32] {
203 Sha256::digest(secret.as_bytes()).into()
204}
205
206fn raw_encrypt(secret: &str, data: &str) -> Result<String, OpenAuthError> {
207 let key = derive_key(secret);
208 let cipher = XChaCha20Poly1305::new(Key::from_slice(&key));
209 let nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
210 let ciphertext = cipher
211 .encrypt(&nonce, data.as_bytes())
212 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
213
214 let mut payload = Vec::with_capacity(nonce.len() + ciphertext.len());
215 payload.extend_from_slice(&nonce);
216 payload.extend_from_slice(&ciphertext);
217 Ok(hex::encode(payload))
218}
219
220fn raw_decrypt(secret: &str, hex_payload: &str) -> Result<String, OpenAuthError> {
221 let payload =
222 hex::decode(hex_payload).map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
223 if payload.len() <= 24 {
224 return Err(OpenAuthError::Crypto(
225 "encrypted payload is too short".to_owned(),
226 ));
227 }
228
229 let (nonce, ciphertext) = payload.split_at(24);
230 let key = derive_key(secret);
231 let cipher = XChaCha20Poly1305::new(Key::from_slice(&key));
232 let plaintext = cipher
233 .decrypt(XNonce::from_slice(nonce), ciphertext)
234 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
235
236 String::from_utf8(plaintext).map_err(|error| OpenAuthError::Crypto(error.to_string()))
237}
238
239pub trait SecretSource {
241 fn encrypt_current(&self, data: &str) -> Result<String, OpenAuthError>;
242 fn decrypt_payload(&self, data: &str) -> Result<String, OpenAuthError>;
243}
244
245impl SecretSource for &str {
246 fn encrypt_current(&self, data: &str) -> Result<String, OpenAuthError> {
247 raw_encrypt(self, data)
248 }
249
250 fn decrypt_payload(&self, data: &str) -> Result<String, OpenAuthError> {
251 raw_decrypt(self, data)
252 }
253}
254
255impl SecretSource for String {
256 fn encrypt_current(&self, data: &str) -> Result<String, OpenAuthError> {
257 self.as_str().encrypt_current(data)
258 }
259
260 fn decrypt_payload(&self, data: &str) -> Result<String, OpenAuthError> {
261 self.as_str().decrypt_payload(data)
262 }
263}
264
265impl SecretSource for &SecretConfig {
266 fn encrypt_current(&self, data: &str) -> Result<String, OpenAuthError> {
267 let ciphertext = raw_encrypt(self.current_secret()?, data)?;
268 Ok(format_envelope(self.current_version, &ciphertext))
269 }
270
271 fn decrypt_payload(&self, data: &str) -> Result<String, OpenAuthError> {
272 if let Some(envelope) = parse_envelope(data) {
273 let secret = self.keys.get(&envelope.version).ok_or_else(|| {
274 OpenAuthError::InvalidSecretConfig(format!(
275 "secret version {} not found in keys; key may have been retired",
276 envelope.version
277 ))
278 })?;
279 return raw_decrypt(secret, &envelope.ciphertext);
280 }
281
282 if let Some(legacy_secret) = &self.legacy_secret {
283 return raw_decrypt(legacy_secret, data);
284 }
285
286 Err(OpenAuthError::InvalidSecretConfig(
287 "cannot decrypt legacy bare payload: no legacy secret available".to_owned(),
288 ))
289 }
290}
291
292pub fn symmetric_encrypt(key: impl SecretSource, data: &str) -> Result<String, OpenAuthError> {
294 key.encrypt_current(data)
295}
296
297pub fn symmetric_decrypt(key: impl SecretSource, data: &str) -> Result<String, OpenAuthError> {
299 key.decrypt_payload(data)
300}