use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::error::SdkError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SignedByClaims {
pub sub: String,
pub iss: String,
pub aud: String,
pub exp: u64,
pub iat: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub amr: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub merkle_root: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub proof_verified: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub issuer: String,
pub audience: String,
pub jwks: Option<String>,
pub leeway: u64,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
issuer: "https://api.beta.privacy-lion.com".into(),
audience: String::new(),
jwks: None,
leeway: 60,
}
}
}
pub struct TokenValidator {
config: ValidationConfig,
keys: HashMap<String, DecodingKey>,
}
#[derive(Debug, Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize)]
struct Jwk {
kid: String,
kty: String,
#[serde(rename = "use")]
use_: Option<String>,
n: Option<String>, e: Option<String>, x: Option<String>, y: Option<String>, crv: Option<String>, }
impl TokenValidator {
pub fn new(config: ValidationConfig) -> Result<Self, SdkError> {
let keys = if let Some(ref jwks_json) = config.jwks {
parse_jwks(jwks_json)?
} else {
HashMap::new()
};
Ok(Self { config, keys })
}
#[cfg(feature = "oidc")]
pub async fn fetch_keys(&mut self) -> Result<(), SdkError> {
let discovery_url = format!("{}/.well-known/openid-configuration", self.config.issuer);
let client = reqwest::Client::new();
let discovery: serde_json::Value = client
.get(&discovery_url)
.send()
.await?
.json()
.await?;
let jwks_uri = discovery["jwks_uri"]
.as_str()
.ok_or_else(|| SdkError::OidcError("No jwks_uri in discovery".into()))?;
let jwks_json: String = client
.get(jwks_uri)
.send()
.await?
.text()
.await?;
self.keys = parse_jwks(&jwks_json)?;
Ok(())
}
pub fn validate(&self, token: &str) -> Result<SignedByClaims, SdkError> {
let header = decode_header(token)?;
let kid = header.kid
.ok_or_else(|| SdkError::JwtError("Token missing kid header".into()))?;
let key = self.keys.get(&kid)
.ok_or_else(|| SdkError::JwtError(format!("Unknown key ID: {}", kid)))?;
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[&self.config.issuer]);
validation.set_audience(&[&self.config.audience]);
validation.leeway = self.config.leeway;
let token_data = decode::<SignedByClaims>(token, key, &validation)?;
Ok(token_data.claims)
}
pub fn validate_ignore_exp(&self, token: &str) -> Result<SignedByClaims, SdkError> {
let header = decode_header(token)?;
let kid = header.kid
.ok_or_else(|| SdkError::JwtError("Token missing kid header".into()))?;
let key = self.keys.get(&kid)
.ok_or_else(|| SdkError::JwtError(format!("Unknown key ID: {}", kid)))?;
let mut validation = Validation::new(header.alg);
validation.set_issuer(&[&self.config.issuer]);
validation.set_audience(&[&self.config.audience]);
validation.validate_exp = false;
let token_data = decode::<SignedByClaims>(token, key, &validation)?;
Ok(token_data.claims)
}
}
fn parse_jwks(jwks_json: &str) -> Result<HashMap<String, DecodingKey>, SdkError> {
let jwks: Jwks = serde_json::from_str(jwks_json)?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
let key = match jwk.kty.as_str() {
"RSA" => {
let n = jwk.n.ok_or_else(|| SdkError::InvalidInput("RSA key missing n".into()))?;
let e = jwk.e.ok_or_else(|| SdkError::InvalidInput("RSA key missing e".into()))?;
DecodingKey::from_rsa_components(&n, &e)?
}
"EC" => {
let x = jwk.x.ok_or_else(|| SdkError::InvalidInput("EC key missing x".into()))?;
let y = jwk.y.ok_or_else(|| SdkError::InvalidInput("EC key missing y".into()))?;
DecodingKey::from_ec_components(&x, &y)?
}
_ => continue, };
keys.insert(jwk.kid, key);
}
Ok(keys)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = ValidationConfig::default();
assert_eq!(config.leeway, 60);
}
}