1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use base64::Engine;
5use hkdf::Hkdf;
6use josekit::jwe::{self, Dir, JweHeader};
7use rand::rngs::OsRng;
8use rand::RngCore;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use serde_json::{Map, Number, Value};
12use sha2::{Digest, Sha256};
13
14use crate::crypto::SecretConfig;
15use crate::error::OpenAuthError;
16
17const JWE_SALT: &str = "better-auth-session";
18const JWE_ENC: &str = "A256CBC-HS512";
19const HKDF_INFO: &[u8] = b"BetterAuth.js Generated Encryption Key";
20const CLOCK_TOLERANCE_SECONDS: i64 = 15;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct JweSecret {
24 value: String,
25}
26
27pub trait JweSecretSource {
29 fn current_jwe_secret(&self) -> Result<String, OpenAuthError>;
30 fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError>;
31}
32
33impl JweSecretSource for str {
34 fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
35 Ok(self.to_owned())
36 }
37
38 fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
39 Ok(vec![JweSecret {
40 value: self.to_owned(),
41 }])
42 }
43}
44
45impl JweSecretSource for String {
46 fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
47 self.as_str().current_jwe_secret()
48 }
49
50 fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
51 self.as_str().all_jwe_secrets()
52 }
53}
54
55impl JweSecretSource for SecretConfig {
56 fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
57 self.keys
58 .get(&self.current_version)
59 .cloned()
60 .ok_or_else(|| {
61 OpenAuthError::InvalidSecretConfig(format!(
62 "secret version {} not found in keys",
63 self.current_version
64 ))
65 })
66 }
67
68 fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
69 let mut secrets = Vec::new();
70 secrets.push(JweSecret {
71 value: self.current_jwe_secret()?,
72 });
73 for (version, value) in &self.keys {
74 if *version != self.current_version {
75 secrets.push(JweSecret {
76 value: value.clone(),
77 });
78 }
79 }
80 if let Some(legacy_secret) = &self.legacy_secret {
81 if !secrets.iter().any(|secret| secret.value == *legacy_secret) {
82 secrets.push(JweSecret {
83 value: legacy_secret.clone(),
84 });
85 }
86 }
87 Ok(secrets)
88 }
89}
90
91pub fn symmetric_encode_jwt<T, K>(
92 payload: &T,
93 secret: &K,
94 expires_in: u64,
95) -> Result<String, OpenAuthError>
96where
97 T: Serialize,
98 K: JweSecretSource + ?Sized,
99{
100 symmetric_encode_jwt_with_salt(payload, secret, JWE_SALT, expires_in)
101}
102
103pub fn symmetric_decode_jwt<T, K>(token: &str, secret: &K) -> Result<Option<T>, OpenAuthError>
104where
105 T: DeserializeOwned,
106 K: JweSecretSource + ?Sized,
107{
108 symmetric_decode_jwt_with_salt(token, secret, JWE_SALT)
109}
110
111pub fn symmetric_encode_jwt_with_salt<T, K>(
112 payload: &T,
113 secret: &K,
114 salt: &str,
115 expires_in: u64,
116) -> Result<String, OpenAuthError>
117where
118 T: Serialize,
119 K: JweSecretSource + ?Sized,
120{
121 let current_secret = secret.current_jwe_secret()?;
122 let encryption_secret = derive_encryption_secret(¤t_secret, salt)?;
123 let kid = jwk_thumbprint(&encryption_secret);
124 let claims = claims_with_registered_fields(payload, expires_in)?;
125
126 let mut header = JweHeader::new();
127 header.set_content_encryption(JWE_ENC);
128 header.set_key_id(kid);
129
130 let encrypter = Dir
131 .encrypter_from_bytes(encryption_secret)
132 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
133 jwe::serialize_compact(&claims, &header, &encrypter)
134 .map_err(|error| OpenAuthError::Crypto(error.to_string()))
135}
136
137pub fn symmetric_decode_jwt_with_salt<T, K>(
138 token: &str,
139 secret: &K,
140 salt: &str,
141) -> Result<Option<T>, OpenAuthError>
142where
143 T: DeserializeOwned,
144 K: JweSecretSource + ?Sized,
145{
146 if token.is_empty() {
147 return Ok(None);
148 }
149 let Some(kid) = protected_header_kid(token) else {
150 return Ok(None);
151 };
152 let secrets = secret.all_jwe_secrets()?;
153
154 if let Some(kid) = kid {
155 let Some(secret) = secrets
156 .iter()
157 .find(|secret| secret_kid(&secret.value, salt).is_ok_and(|value| value == kid))
158 else {
159 return Ok(None);
160 };
161 return decrypt_with_secret(token, &secret.value, salt);
162 }
163
164 for secret in secrets {
165 if let Some(payload) = decrypt_with_secret(token, &secret.value, salt)? {
166 return Ok(Some(payload));
167 }
168 }
169 Ok(None)
170}
171
172fn derive_encryption_secret(secret: &str, salt: &str) -> Result<[u8; 64], OpenAuthError> {
173 let hkdf = Hkdf::<Sha256>::new(Some(salt.as_bytes()), secret.as_bytes());
174 let mut key = [0_u8; 64];
175 hkdf.expand(HKDF_INFO, &mut key)
176 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
177 Ok(key)
178}
179
180fn claims_with_registered_fields<T>(payload: &T, expires_in: u64) -> Result<Vec<u8>, OpenAuthError>
181where
182 T: Serialize,
183{
184 let mut claims = match serde_json::to_value(payload)
185 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?
186 {
187 Value::Object(claims) => claims,
188 _ => {
189 return Err(OpenAuthError::Crypto(
190 "JWE payload must serialize to a JSON object".to_owned(),
191 ));
192 }
193 };
194
195 let now = time::OffsetDateTime::now_utc().unix_timestamp();
196 claims.insert("iat".to_owned(), Value::Number(Number::from(now)));
197 claims.insert(
198 "exp".to_owned(),
199 Value::Number(Number::from(now + expires_in as i64)),
200 );
201 claims.insert("jti".to_owned(), Value::String(random_jti()));
202
203 serde_json::to_vec(&Value::Object(claims))
204 .map_err(|error| OpenAuthError::Crypto(error.to_string()))
205}
206
207fn decrypt_with_secret<T>(token: &str, secret: &str, salt: &str) -> Result<Option<T>, OpenAuthError>
208where
209 T: DeserializeOwned,
210{
211 let encryption_secret = derive_encryption_secret(secret, salt)?;
212 let decrypter = Dir
213 .decrypter_from_bytes(encryption_secret)
214 .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
215 let Ok((payload, header)) = jwe::deserialize_compact(token, &decrypter) else {
216 return Ok(None);
217 };
218 if header.content_encryption() != Some(JWE_ENC) {
219 return Ok(None);
220 }
221
222 let value: Value = serde_json::from_slice(&payload)
223 .map_err(|error| OpenAuthError::Crypto(format!("could not parse JWE payload: {error}")))?;
224 if is_expired(&value) {
225 return Ok(None);
226 }
227 serde_json::from_value(value)
228 .map(Some)
229 .map_err(|error| OpenAuthError::Crypto(error.to_string()))
230}
231
232fn is_expired(value: &Value) -> bool {
233 let Some(exp) = value.get("exp").and_then(Value::as_i64) else {
234 return false;
235 };
236 exp + CLOCK_TOLERANCE_SECONDS < time::OffsetDateTime::now_utc().unix_timestamp()
237}
238
239fn protected_header_kid(token: &str) -> Option<Option<String>> {
240 let protected = token.split('.').next()?;
241 let decoded = URL_SAFE_NO_PAD.decode(protected).ok()?;
242 let header: Map<String, Value> = serde_json::from_slice(&decoded).ok()?;
243 Some(header.get("kid").and_then(Value::as_str).map(str::to_owned))
244}
245
246fn secret_kid(secret: &str, salt: &str) -> Result<String, OpenAuthError> {
247 let encryption_secret = derive_encryption_secret(secret, salt)?;
248 Ok(jwk_thumbprint(&encryption_secret))
249}
250
251fn jwk_thumbprint(key: &[u8; 64]) -> String {
252 let key = URL_SAFE_NO_PAD.encode(key);
253 let canonical = format!(r#"{{"k":"{key}","kty":"oct"}}"#);
254 URL_SAFE_NO_PAD.encode(Sha256::digest(canonical.as_bytes()))
255}
256
257fn random_jti() -> String {
258 let mut bytes = [0_u8; 16];
259 OsRng.fill_bytes(&mut bytes);
260 bytes[6] = (bytes[6] & 0x0f) | 0x40;
261 bytes[8] = (bytes[8] & 0x3f) | 0x80;
262 format!(
263 "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
264 u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
265 u16::from_be_bytes([bytes[4], bytes[5]]),
266 u16::from_be_bytes([bytes[6], bytes[7]]),
267 u16::from_be_bytes([bytes[8], bytes[9]]),
268 u64::from_be_bytes([
269 0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
270 ])
271 )
272}