Skip to main content

openauth_core/crypto/
jwe.rs

1//! JWE helpers for encrypted JWT-style payloads.
2
3use base64::engine::general_purpose::URL_SAFE_NO_PAD;
4use base64::Engine;
5use hkdf::Hkdf;
6use josekit::jwe::{self, Dir, JweHeader};
7use rand::rngs::OsRng;
8use rand::RngCore;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use serde_json::{Map, Number, Value};
12use sha2::{Digest, Sha256};
13
14use crate::crypto::SecretConfig;
15use crate::error::OpenAuthError;
16
17const JWE_SALT: &str = "better-auth-session";
18const JWE_ENC: &str = "A256CBC-HS512";
19const HKDF_INFO: &[u8] = b"BetterAuth.js Generated Encryption Key";
20const CLOCK_TOLERANCE_SECONDS: i64 = 15;
21
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct JweSecret {
24    value: String,
25}
26
27/// Secret material accepted by Better Auth-compatible JWE helpers.
28pub trait JweSecretSource {
29    fn current_jwe_secret(&self) -> Result<String, OpenAuthError>;
30    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError>;
31}
32
33impl JweSecretSource for str {
34    fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
35        Ok(self.to_owned())
36    }
37
38    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
39        Ok(vec![JweSecret {
40            value: self.to_owned(),
41        }])
42    }
43}
44
45impl JweSecretSource for String {
46    fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
47        self.as_str().current_jwe_secret()
48    }
49
50    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
51        self.as_str().all_jwe_secrets()
52    }
53}
54
55impl JweSecretSource for SecretConfig {
56    fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
57        self.keys
58            .get(&self.current_version)
59            .cloned()
60            .ok_or_else(|| {
61                OpenAuthError::InvalidSecretConfig(format!(
62                    "secret version {} not found in keys",
63                    self.current_version
64                ))
65            })
66    }
67
68    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
69        let mut secrets = Vec::new();
70        secrets.push(JweSecret {
71            value: self.current_jwe_secret()?,
72        });
73        for (version, value) in &self.keys {
74            if *version != self.current_version {
75                secrets.push(JweSecret {
76                    value: value.clone(),
77                });
78            }
79        }
80        if let Some(legacy_secret) = &self.legacy_secret {
81            if !secrets.iter().any(|secret| secret.value == *legacy_secret) {
82                secrets.push(JweSecret {
83                    value: legacy_secret.clone(),
84                });
85            }
86        }
87        Ok(secrets)
88    }
89}
90
91pub fn symmetric_encode_jwt<T, K>(
92    payload: &T,
93    secret: &K,
94    expires_in: u64,
95) -> Result<String, OpenAuthError>
96where
97    T: Serialize,
98    K: JweSecretSource + ?Sized,
99{
100    symmetric_encode_jwt_with_salt(payload, secret, JWE_SALT, expires_in)
101}
102
103pub fn symmetric_decode_jwt<T, K>(token: &str, secret: &K) -> Result<Option<T>, OpenAuthError>
104where
105    T: DeserializeOwned,
106    K: JweSecretSource + ?Sized,
107{
108    symmetric_decode_jwt_with_salt(token, secret, JWE_SALT)
109}
110
111pub fn symmetric_encode_jwt_with_salt<T, K>(
112    payload: &T,
113    secret: &K,
114    salt: &str,
115    expires_in: u64,
116) -> Result<String, OpenAuthError>
117where
118    T: Serialize,
119    K: JweSecretSource + ?Sized,
120{
121    let current_secret = secret.current_jwe_secret()?;
122    let encryption_secret = derive_encryption_secret(&current_secret, salt)?;
123    let kid = jwk_thumbprint(&encryption_secret);
124    let claims = claims_with_registered_fields(payload, expires_in)?;
125
126    let mut header = JweHeader::new();
127    header.set_content_encryption(JWE_ENC);
128    header.set_key_id(kid);
129
130    let encrypter = Dir
131        .encrypter_from_bytes(encryption_secret)
132        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
133    jwe::serialize_compact(&claims, &header, &encrypter)
134        .map_err(|error| OpenAuthError::Crypto(error.to_string()))
135}
136
137pub fn symmetric_decode_jwt_with_salt<T, K>(
138    token: &str,
139    secret: &K,
140    salt: &str,
141) -> Result<Option<T>, OpenAuthError>
142where
143    T: DeserializeOwned,
144    K: JweSecretSource + ?Sized,
145{
146    if token.is_empty() {
147        return Ok(None);
148    }
149    let Some(kid) = protected_header_kid(token) else {
150        return Ok(None);
151    };
152    let secrets = secret.all_jwe_secrets()?;
153
154    if let Some(kid) = kid {
155        let Some(secret) = secrets
156            .iter()
157            .find(|secret| secret_kid(&secret.value, salt).is_ok_and(|value| value == kid))
158        else {
159            return Ok(None);
160        };
161        return decrypt_with_secret(token, &secret.value, salt);
162    }
163
164    for secret in secrets {
165        if let Some(payload) = decrypt_with_secret(token, &secret.value, salt)? {
166            return Ok(Some(payload));
167        }
168    }
169    Ok(None)
170}
171
172fn derive_encryption_secret(secret: &str, salt: &str) -> Result<[u8; 64], OpenAuthError> {
173    let hkdf = Hkdf::<Sha256>::new(Some(salt.as_bytes()), secret.as_bytes());
174    let mut key = [0_u8; 64];
175    hkdf.expand(HKDF_INFO, &mut key)
176        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
177    Ok(key)
178}
179
180fn claims_with_registered_fields<T>(payload: &T, expires_in: u64) -> Result<Vec<u8>, OpenAuthError>
181where
182    T: Serialize,
183{
184    let mut claims = match serde_json::to_value(payload)
185        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?
186    {
187        Value::Object(claims) => claims,
188        _ => {
189            return Err(OpenAuthError::Crypto(
190                "JWE payload must serialize to a JSON object".to_owned(),
191            ));
192        }
193    };
194
195    let now = time::OffsetDateTime::now_utc().unix_timestamp();
196    claims.insert("iat".to_owned(), Value::Number(Number::from(now)));
197    claims.insert(
198        "exp".to_owned(),
199        Value::Number(Number::from(now + expires_in as i64)),
200    );
201    claims.insert("jti".to_owned(), Value::String(random_jti()));
202
203    serde_json::to_vec(&Value::Object(claims))
204        .map_err(|error| OpenAuthError::Crypto(error.to_string()))
205}
206
207fn decrypt_with_secret<T>(token: &str, secret: &str, salt: &str) -> Result<Option<T>, OpenAuthError>
208where
209    T: DeserializeOwned,
210{
211    let encryption_secret = derive_encryption_secret(secret, salt)?;
212    let decrypter = Dir
213        .decrypter_from_bytes(encryption_secret)
214        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
215    let Ok((payload, header)) = jwe::deserialize_compact(token, &decrypter) else {
216        return Ok(None);
217    };
218    if header.content_encryption() != Some(JWE_ENC) {
219        return Ok(None);
220    }
221
222    let value: Value = serde_json::from_slice(&payload)
223        .map_err(|error| OpenAuthError::Crypto(format!("could not parse JWE payload: {error}")))?;
224    if is_expired(&value) {
225        return Ok(None);
226    }
227    serde_json::from_value(value)
228        .map(Some)
229        .map_err(|error| OpenAuthError::Crypto(error.to_string()))
230}
231
232fn is_expired(value: &Value) -> bool {
233    let Some(exp) = value.get("exp").and_then(Value::as_i64) else {
234        return false;
235    };
236    exp + CLOCK_TOLERANCE_SECONDS < time::OffsetDateTime::now_utc().unix_timestamp()
237}
238
239fn protected_header_kid(token: &str) -> Option<Option<String>> {
240    let protected = token.split('.').next()?;
241    let decoded = URL_SAFE_NO_PAD.decode(protected).ok()?;
242    let header: Map<String, Value> = serde_json::from_slice(&decoded).ok()?;
243    Some(header.get("kid").and_then(Value::as_str).map(str::to_owned))
244}
245
246fn secret_kid(secret: &str, salt: &str) -> Result<String, OpenAuthError> {
247    let encryption_secret = derive_encryption_secret(secret, salt)?;
248    Ok(jwk_thumbprint(&encryption_secret))
249}
250
251fn jwk_thumbprint(key: &[u8; 64]) -> String {
252    let key = URL_SAFE_NO_PAD.encode(key);
253    let canonical = format!(r#"{{"k":"{key}","kty":"oct"}}"#);
254    URL_SAFE_NO_PAD.encode(Sha256::digest(canonical.as_bytes()))
255}
256
257fn random_jti() -> String {
258    let mut bytes = [0_u8; 16];
259    OsRng.fill_bytes(&mut bytes);
260    bytes[6] = (bytes[6] & 0x0f) | 0x40;
261    bytes[8] = (bytes[8] & 0x3f) | 0x80;
262    format!(
263        "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
264        u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
265        u16::from_be_bytes([bytes[4], bytes[5]]),
266        u16::from_be_bytes([bytes[6], bytes[7]]),
267        u16::from_be_bytes([bytes[8], bytes[9]]),
268        u64::from_be_bytes([
269            0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
270        ])
271    )
272}