use crate::error::AuthiaError;
use crate::types::{TokenPayload, VerifyOptions};
use base64::prelude::*;
use ed25519_dalek::{Signature, Verifier, VerifyingKey};
use serde::de::DeserializeOwned;
use serde_json::Value;
use std::sync::OnceLock;
use subtle::ConstantTimeEq;
const CLOCK_SKEW_SECONDS: i64 = 300;
static CACHED_PUBLIC_KEY: OnceLock<(String, VerifyingKey)> = OnceLock::new();
pub fn verify_jwt<T: DeserializeOwned + TokenPayload>(
token: &str,
options: &VerifyOptions,
) -> Result<T, AuthiaError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(AuthiaError::malformed_token("JWT must have 3 parts"));
}
let header_b64 = parts[0];
let payload_b64 = parts[1];
let signature_b64 = parts[2];
let header_bytes = BASE64_URL_SAFE_NO_PAD
.decode(header_b64)
.map_err(|_| AuthiaError::malformed_token("Invalid header encoding"))?;
if !header_bytes.windows(5).any(|w| w == b"EdDSA") {
return Err(AuthiaError::invalid_signature());
}
let payload_bytes = BASE64_URL_SAFE_NO_PAD
.decode(payload_b64)
.map_err(|e| AuthiaError::malformed_token(format!("Invalid payload encoding: {}", e)))?;
let payload: T = serde_json::from_slice(&payload_bytes)
.map_err(|e| AuthiaError::invalid_claims(format!("Failed to parse payload: {}", e)))?;
let now = js_sys::Date::now() as i64 / 1000;
payload.validate(options, now)?;
let signature_bytes = BASE64_URL_SAFE_NO_PAD
.decode(signature_b64)
.map_err(|e| AuthiaError::malformed_token(format!("Invalid signature encoding: {}", e)))?;
let signature =
Signature::from_slice(&signature_bytes).map_err(|_| AuthiaError::invalid_signature())?;
let public_key = get_or_decode_public_key(options)?;
let mut message = Vec::with_capacity(header_b64.len() + 1 + payload_b64.len());
message.extend_from_slice(header_b64.as_bytes());
message.push(b'.');
message.extend_from_slice(payload_b64.as_bytes());
public_key
.verify(&message, &signature)
.map_err(|_| AuthiaError::invalid_signature())?;
Ok(payload)
}
fn get_or_decode_public_key(options: &VerifyOptions) -> Result<&'static VerifyingKey, AuthiaError> {
if let Some(raw_jwk) = &options.public_key_jwk_raw {
if let Some((cached_jwk, cached_key)) = CACHED_PUBLIC_KEY.get() {
if cached_jwk == raw_jwk {
return Ok(cached_key);
}
}
let key = decode_public_key_from_raw_json(raw_jwk)?;
let _ = CACHED_PUBLIC_KEY.set((raw_jwk.to_string(), key));
return CACHED_PUBLIC_KEY
.get()
.map(|(_, k)| k)
.ok_or_else(|| AuthiaError::invalid_public_key("Failed to cache public key"));
}
if let Some(jwk_base64) = &options.public_key_jwk {
if let Some((cached_jwk, cached_key)) = CACHED_PUBLIC_KEY.get() {
if cached_jwk == jwk_base64 {
return Ok(cached_key);
}
}
let key = decode_public_key_jwk(jwk_base64)?;
let _ = CACHED_PUBLIC_KEY.set((jwk_base64.to_string(), key));
return CACHED_PUBLIC_KEY
.get()
.map(|(_, k)| k)
.ok_or_else(|| AuthiaError::invalid_public_key("Failed to cache public key"));
}
Err(AuthiaError::invalid_public_key(
"No public key provided (either publicKeyJwk or publicKeyJwkRaw required)",
))
}
fn decode_public_key_from_raw_json(jwk_input: &str) -> Result<VerifyingKey, AuthiaError> {
let jwk_trimmed = jwk_input.trim();
let jwk_json_bytes = if jwk_trimmed.starts_with('{') {
jwk_trimmed.as_bytes().to_vec()
} else {
BASE64_STANDARD
.decode(jwk_trimmed)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid public key format: expected JSON but could not decode as Base64: {}", e)))?
};
let jwk: Value = serde_json::from_slice(&jwk_json_bytes)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid JWK JSON: {}", e)))?;
let kty = jwk
.get("kty")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthiaError::invalid_public_key("Missing 'kty' parameter in JWK"))?;
if kty != "OKP" {
return Err(AuthiaError::invalid_public_key(format!(
"Invalid 'kty': expected 'OKP', got '{}'",
kty
)));
}
let crv = jwk
.get("crv")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthiaError::invalid_public_key("Missing 'crv' parameter in JWK"))?;
if crv != "Ed25519" {
return Err(AuthiaError::invalid_public_key(format!(
"Invalid 'crv': expected 'Ed25519', got '{}'",
crv
)));
}
let x = jwk
.get("x")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthiaError::invalid_public_key("Missing 'x' parameter in JWK"))?;
let key_bytes = BASE64_URL_SAFE_NO_PAD
.decode(x)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid 'x' encoding: {}", e)))?;
VerifyingKey::from_bytes(
key_bytes.as_slice().try_into().map_err(|_| {
AuthiaError::invalid_public_key("Invalid key length (expected 32 bytes)")
})?,
)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid Ed25519 key: {}", e)))
}
fn decode_public_key_jwk(jwk_base64: &str) -> Result<VerifyingKey, AuthiaError> {
let jwk_json = BASE64_STANDARD
.decode(jwk_base64)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid base64: {}", e)))?;
let jwk: Value = serde_json::from_slice(&jwk_json)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid JWK JSON: {}", e)))?;
let kty = jwk
.get("kty")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthiaError::invalid_public_key("Missing 'kty' parameter in JWK"))?;
if kty != "OKP" {
return Err(AuthiaError::invalid_public_key(format!(
"Invalid 'kty': expected 'OKP', got '{}'",
kty
)));
}
let crv = jwk
.get("crv")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthiaError::invalid_public_key("Missing 'crv' parameter in JWK"))?;
if crv != "Ed25519" {
return Err(AuthiaError::invalid_public_key(format!(
"Invalid 'crv': expected 'Ed25519', got '{}'",
crv
)));
}
let x = jwk
.get("x")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthiaError::invalid_public_key("Missing 'x' parameter in JWK"))?;
let key_bytes = BASE64_URL_SAFE_NO_PAD
.decode(x)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid 'x' encoding: {}", e)))?;
VerifyingKey::from_bytes(
key_bytes.as_slice().try_into().map_err(|_| {
AuthiaError::invalid_public_key("Invalid key length (expected 32 bytes)")
})?,
)
.map_err(|e| AuthiaError::invalid_public_key(format!("Invalid Ed25519 key: {}", e)))
}
#[cfg(test)]
mod tests {
#[test]
fn test_jwt_split() {
let token = "header.payload.signature";
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
}
}