openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
//! JWE helpers for encrypted JWT-style payloads.

use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use hkdf::Hkdf;
use josekit::jwe::{self, Dir, JweHeader};
use rand::rngs::OsRng;
use rand::RngCore;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::{Map, Number, Value};
use sha2::{Digest, Sha256};

use crate::crypto::SecretConfig;
use crate::error::OpenAuthError;

const JWE_SALT: &str = "better-auth-session";
const JWE_ENC: &str = "A256CBC-HS512";
const HKDF_INFO: &[u8] = b"BetterAuth.js Generated Encryption Key";
const CLOCK_TOLERANCE_SECONDS: i64 = 15;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct JweSecret {
    value: String,
}

/// Secret material accepted by Better Auth-compatible JWE helpers.
pub trait JweSecretSource {
    fn current_jwe_secret(&self) -> Result<String, OpenAuthError>;
    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError>;
}

impl JweSecretSource for str {
    fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
        Ok(self.to_owned())
    }

    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
        Ok(vec![JweSecret {
            value: self.to_owned(),
        }])
    }
}

impl JweSecretSource for String {
    fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
        self.as_str().current_jwe_secret()
    }

    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
        self.as_str().all_jwe_secrets()
    }
}

impl JweSecretSource for SecretConfig {
    fn current_jwe_secret(&self) -> Result<String, OpenAuthError> {
        self.keys
            .get(&self.current_version)
            .cloned()
            .ok_or_else(|| {
                OpenAuthError::InvalidSecretConfig(format!(
                    "secret version {} not found in keys",
                    self.current_version
                ))
            })
    }

    fn all_jwe_secrets(&self) -> Result<Vec<JweSecret>, OpenAuthError> {
        let mut secrets = Vec::new();
        secrets.push(JweSecret {
            value: self.current_jwe_secret()?,
        });
        for (version, value) in &self.keys {
            if *version != self.current_version {
                secrets.push(JweSecret {
                    value: value.clone(),
                });
            }
        }
        if let Some(legacy_secret) = &self.legacy_secret {
            if !secrets.iter().any(|secret| secret.value == *legacy_secret) {
                secrets.push(JweSecret {
                    value: legacy_secret.clone(),
                });
            }
        }
        Ok(secrets)
    }
}

pub fn symmetric_encode_jwt<T, K>(
    payload: &T,
    secret: &K,
    expires_in: u64,
) -> Result<String, OpenAuthError>
where
    T: Serialize,
    K: JweSecretSource + ?Sized,
{
    symmetric_encode_jwt_with_salt(payload, secret, JWE_SALT, expires_in)
}

pub fn symmetric_decode_jwt<T, K>(token: &str, secret: &K) -> Result<Option<T>, OpenAuthError>
where
    T: DeserializeOwned,
    K: JweSecretSource + ?Sized,
{
    symmetric_decode_jwt_with_salt(token, secret, JWE_SALT)
}

pub fn symmetric_encode_jwt_with_salt<T, K>(
    payload: &T,
    secret: &K,
    salt: &str,
    expires_in: u64,
) -> Result<String, OpenAuthError>
where
    T: Serialize,
    K: JweSecretSource + ?Sized,
{
    let current_secret = secret.current_jwe_secret()?;
    let encryption_secret = derive_encryption_secret(&current_secret, salt)?;
    let kid = jwk_thumbprint(&encryption_secret);
    let claims = claims_with_registered_fields(payload, expires_in)?;

    let mut header = JweHeader::new();
    header.set_content_encryption(JWE_ENC);
    header.set_key_id(kid);

    let encrypter = Dir
        .encrypter_from_bytes(encryption_secret)
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
    jwe::serialize_compact(&claims, &header, &encrypter)
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))
}

pub fn symmetric_decode_jwt_with_salt<T, K>(
    token: &str,
    secret: &K,
    salt: &str,
) -> Result<Option<T>, OpenAuthError>
where
    T: DeserializeOwned,
    K: JweSecretSource + ?Sized,
{
    if token.is_empty() {
        return Ok(None);
    }
    let Some(kid) = protected_header_kid(token) else {
        return Ok(None);
    };
    let secrets = secret.all_jwe_secrets()?;

    if let Some(kid) = kid {
        let Some(secret) = secrets
            .iter()
            .find(|secret| secret_kid(&secret.value, salt).is_ok_and(|value| value == kid))
        else {
            return Ok(None);
        };
        return decrypt_with_secret(token, &secret.value, salt);
    }

    for secret in secrets {
        if let Some(payload) = decrypt_with_secret(token, &secret.value, salt)? {
            return Ok(Some(payload));
        }
    }
    Ok(None)
}

fn derive_encryption_secret(secret: &str, salt: &str) -> Result<[u8; 64], OpenAuthError> {
    let hkdf = Hkdf::<Sha256>::new(Some(salt.as_bytes()), secret.as_bytes());
    let mut key = [0_u8; 64];
    hkdf.expand(HKDF_INFO, &mut key)
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
    Ok(key)
}

fn claims_with_registered_fields<T>(payload: &T, expires_in: u64) -> Result<Vec<u8>, OpenAuthError>
where
    T: Serialize,
{
    let mut claims = match serde_json::to_value(payload)
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?
    {
        Value::Object(claims) => claims,
        _ => {
            return Err(OpenAuthError::Crypto(
                "JWE payload must serialize to a JSON object".to_owned(),
            ));
        }
    };

    let now = time::OffsetDateTime::now_utc().unix_timestamp();
    claims.insert("iat".to_owned(), Value::Number(Number::from(now)));
    claims.insert(
        "exp".to_owned(),
        Value::Number(Number::from(now + expires_in as i64)),
    );
    claims.insert("jti".to_owned(), Value::String(random_jti()));

    serde_json::to_vec(&Value::Object(claims))
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))
}

fn decrypt_with_secret<T>(token: &str, secret: &str, salt: &str) -> Result<Option<T>, OpenAuthError>
where
    T: DeserializeOwned,
{
    let encryption_secret = derive_encryption_secret(secret, salt)?;
    let decrypter = Dir
        .decrypter_from_bytes(encryption_secret)
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))?;
    let Ok((payload, header)) = jwe::deserialize_compact(token, &decrypter) else {
        return Ok(None);
    };
    if header.content_encryption() != Some(JWE_ENC) {
        return Ok(None);
    }

    let value: Value = serde_json::from_slice(&payload)
        .map_err(|error| OpenAuthError::Crypto(format!("could not parse JWE payload: {error}")))?;
    if is_expired(&value) {
        return Ok(None);
    }
    serde_json::from_value(value)
        .map(Some)
        .map_err(|error| OpenAuthError::Crypto(error.to_string()))
}

fn is_expired(value: &Value) -> bool {
    let Some(exp) = value.get("exp").and_then(Value::as_i64) else {
        return false;
    };
    exp + CLOCK_TOLERANCE_SECONDS < time::OffsetDateTime::now_utc().unix_timestamp()
}

fn protected_header_kid(token: &str) -> Option<Option<String>> {
    let protected = token.split('.').next()?;
    let decoded = URL_SAFE_NO_PAD.decode(protected).ok()?;
    let header: Map<String, Value> = serde_json::from_slice(&decoded).ok()?;
    Some(header.get("kid").and_then(Value::as_str).map(str::to_owned))
}

fn secret_kid(secret: &str, salt: &str) -> Result<String, OpenAuthError> {
    let encryption_secret = derive_encryption_secret(secret, salt)?;
    Ok(jwk_thumbprint(&encryption_secret))
}

fn jwk_thumbprint(key: &[u8; 64]) -> String {
    let key = URL_SAFE_NO_PAD.encode(key);
    let canonical = format!(r#"{{"k":"{key}","kty":"oct"}}"#);
    URL_SAFE_NO_PAD.encode(Sha256::digest(canonical.as_bytes()))
}

fn random_jti() -> String {
    let mut bytes = [0_u8; 16];
    OsRng.fill_bytes(&mut bytes);
    bytes[6] = (bytes[6] & 0x0f) | 0x40;
    bytes[8] = (bytes[8] & 0x3f) | 0x80;
    format!(
        "{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
        u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
        u16::from_be_bytes([bytes[4], bytes[5]]),
        u16::from_be_bytes([bytes[6], bytes[7]]),
        u16::from_be_bytes([bytes[8], bytes[9]]),
        u64::from_be_bytes([
            0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
        ])
    )
}