use super::claims::Claims;
use super::jwks::Jwk;
use crate::error::AuthError;
use base64::{engine::general_purpose, Engine as _};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtHeader {
pub alg: String,
pub kid: Option<String>,
pub typ: Option<String>,
}
pub struct JwtParser;
impl JwtParser {
const SUPABASE_ALGORITHM: &str = "ES256";
pub fn decode_header(token: &str) -> Result<JwtHeader, AuthError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 || parts[0].is_empty() {
return Err(AuthError::InvalidToken);
}
tracing::debug!(
"JWT parts lengths - header: {}, payload: {}, signature: {}",
parts[0].len(),
parts[1].len(),
parts[2].len()
);
for (i, part) in parts.iter().enumerate() {
if part
.chars()
.any(|c| !c.is_ascii_alphanumeric() && c != '-' && c != '_' && c != '=')
{
tracing::warn!(
"JWT part {} contains invalid characters: {}",
i,
if part.len() > 20 { &part[..20] } else { part }
);
return Err(AuthError::InvalidToken);
}
}
let header_bytes = general_purpose::URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|_| AuthError::DecodeHeader)?;
let header: JwtHeader =
serde_json::from_slice(&header_bytes).map_err(|_| AuthError::DecodeHeader)?;
if header.alg != Self::SUPABASE_ALGORITHM {
return Err(AuthError::InvalidAlgorithm);
}
Ok(header)
}
pub fn create_decoding_key(jwk: &Jwk) -> Result<DecodingKey, AuthError> {
match jwk.kty.as_str() {
"EC" => {
let x = jwk.x.as_ref().ok_or_else(|| {
AuthError::InvalidKeyComponent("Missing x coordinate for EC key".to_string())
})?;
let y = jwk.y.as_ref().ok_or_else(|| {
AuthError::InvalidKeyComponent("Missing y coordinate for EC key".to_string())
})?;
let crv = jwk.crv.as_ref().ok_or_else(|| {
AuthError::InvalidKeyComponent("Missing curve type for EC key".to_string())
})?;
if crv != "P-256" {
return Err(AuthError::UnsupportedCurve(format!(
"Expected P-256, but got {crv}"
)));
}
let x_bytes = general_purpose::URL_SAFE_NO_PAD
.decode(x)
.map_err(|e| AuthError::Base64Decode(format!("Failed to decode x: {e}")))?;
let y_bytes = general_purpose::URL_SAFE_NO_PAD
.decode(y)
.map_err(|e| AuthError::Base64Decode(format!("Failed to decode y: {e}")))?;
const P256_COORD_LEN: usize = 32;
if x_bytes.len() != P256_COORD_LEN || y_bytes.len() != P256_COORD_LEN {
return Err(AuthError::InvalidKeyComponent(format!(
"Invalid P-256 coordinate length: got x={}, y={} (expected: {P256_COORD_LEN})",
x_bytes.len(),
y_bytes.len()
)));
}
DecodingKey::from_ec_components(x, y).map_err(|e| {
AuthError::InvalidKeyComponent(format!(
"Failed to create key from EC components: {e}"
))
})
}
unsupported_kty => {
Err(AuthError::UnsupportedKeyType(unsupported_kty.to_string()))
}
}
}
pub fn verify_and_decode(
token: &str,
decoding_key: &DecodingKey,
algorithm: Algorithm,
) -> Result<Claims, AuthError> {
let mut validation = Validation::new(algorithm);
validation.validate_exp = true;
validation.validate_aud = false; validation.validate_nbf = true;
validation.leeway = 30;
let token_data = decode::<Claims>(token, decoding_key, &validation).map_err(|e| {
tracing::warn!("JWT validation failed: {:?}", e);
AuthError::Verification
})?;
tracing::debug!(
"JWT validation successful for user: {}",
token_data.claims.sub
);
Ok(token_data.claims)
}
pub fn parse_algorithm(alg: &str) -> Result<Algorithm, AuthError> {
if alg == Self::SUPABASE_ALGORITHM {
Ok(Algorithm::ES256)
} else {
tracing::warn!(
"Unsupported JWT algorithm for Supabase: {} (expected: {})",
alg,
Self::SUPABASE_ALGORITHM
);
Err(AuthError::InvalidAlgorithm)
}
}
}