#![doc = include_str!("../README.md")]
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use base64::Engine;
pub use errors::{BBError, BBResult};
use keystore::{BASE64_ENGINE, BBKey};
pub use keystore::{EcCurve, KeyAlgorithm, KeyStore};
use serde::Deserialize;
pub mod errors;
pub mod keystore;
mod pem;
pub mod tls_ext;
pub enum ValidationStep {
Issuer(String),
Audience(String),
Nonce(String),
NotExpired,
HasSubject,
HasGroups,
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct JWTClaims {
pub headers: serde_json::Value,
pub claims: serde_json::Value,
}
#[derive(Deserialize)]
struct JOSEHeader {
alg: KeyAlgorithm,
kid: Option<String>,
}
#[derive(Deserialize)]
#[serde(untagged)]
enum Audience {
Single(String),
Multi(Vec<String>),
}
#[derive(Deserialize)]
struct ValidationClaims {
iss: Option<String>,
sub: Option<String>,
exp: Option<u64>,
aud: Option<Audience>,
nonce: Option<String>,
groups: Option<Vec<String>>,
}
pub fn default_validations(
issuer: &str,
audience: Option<&str>,
nonce: Option<&str>,
) -> Vec<ValidationStep> {
let mut validations =
vec![ValidationStep::Issuer(issuer.to_string()), ValidationStep::NotExpired];
if let Some(audience) = audience {
validations.push(ValidationStep::Audience(audience.to_string()));
}
if let Some(nonce) = nonce {
validations.push(ValidationStep::Nonce(nonce.to_string()));
}
validations
}
pub async fn validate_jwt(
jwt: &str,
validation_steps: &Vec<ValidationStep>,
keystore: &KeyStore,
) -> BBResult<JWTClaims> {
let parts: Vec<&str> = jwt.splitn(3, '.').collect();
if parts.len() != 3 {
return Err(BBError::TokenInvalid("Could not split token in 3 parts.".to_string()));
}
let hdr_json = BASE64_ENGINE.decode(parts[0])?;
let kid_hdr: JOSEHeader =
serde_json::from_slice(&hdr_json).map_err(|e| BBError::JSONError(format!("{:?}", e)))?;
if kid_hdr.alg == KeyAlgorithm::Other {
return Err(BBError::TokenInvalid("Unsupported algorithm".to_string()));
}
let pubkey = keystore.key_by_id(kid_hdr.kid.as_deref()).await?;
check_jwt_signature(&parts, &pubkey)?;
let payload_json = BASE64_ENGINE.decode(parts[1])?;
let claims: ValidationClaims =
serde_json::from_slice(&payload_json).map_err(|e| BBError::JSONError(format!("{:?}", e)))?;
let mut validation_errors = Vec::<String>::new();
for step in validation_steps {
if let Some(error) = validate_claim(&claims, step) {
validation_errors.push(error);
}
}
if !validation_errors.is_empty() {
let mut err = "One or more claims failed to validate:\n".to_string();
err.push_str(&validation_errors.join("\n"));
return Err(BBError::ClaimInvalid(err));
}
Ok(JWTClaims {
headers: serde_json::from_slice(&hdr_json)?,
claims: serde_json::from_slice(&payload_json)?,
})
}
fn validate_claim(claims: &ValidationClaims, step: &ValidationStep) -> Option<String> {
match step {
ValidationStep::Audience(aud) => {
if let Some(claims_aud) = &claims.aud {
match claims_aud {
Audience::Single(single) => {
if single != aud {
return Some(format!("'aud' does not match; expected '{}', got '{}'", aud, single));
}
}
Audience::Multi(multi) => {
if !multi.contains(aud) {
return Some(format!(
"'aud' claims don't match: '{}' not found in '{:?}'",
aud, multi
));
}
}
}
} else {
return Some("'aud' not set".to_string());
}
}
ValidationStep::Issuer(iss) => {
if let Some(claims_iss) = &claims.iss {
if claims_iss != iss {
return Some(format!("'iss' does not match; expected '{}', got '{}'", iss, claims_iss));
}
} else {
return Some("'iss' is missing".to_string());
}
}
ValidationStep::Nonce(nonce) => {
if let Some(claims_nonce) = &claims.nonce {
if claims_nonce != nonce {
return Some("'nonce' does not match".to_string());
}
} else {
return Some("'nonce' is missing".to_string());
}
}
ValidationStep::NotExpired => {
if let Some(exp) = &claims.exp {
let now = SystemTime::now().duration_since(UNIX_EPOCH).expect("System time is wrong.");
if Duration::from_secs(*exp) < now {
return Some("Token has expired.".to_string());
}
}
}
ValidationStep::HasSubject => {
if claims.sub.is_none() {
return Some("'sub' is missing".to_string());
}
}
ValidationStep::HasGroups => {
if claims.groups.is_none() {
return Some("'groups' is missing".to_string());
}
}
}
None
}
fn check_jwt_signature(jwt_parts: &[&str], pubkey: &BBKey) -> BBResult<()> {
let jwt_data = format!("{}.{}", jwt_parts[0], jwt_parts[1]);
let sig =
BASE64_ENGINE.decode(jwt_parts[2]).map_err(|e| BBError::DecodeError(format!("{:?}", e)))?;
pubkey.verify_signature(jwt_data.as_bytes(), &sig)
}
#[cfg(test)]
mod tests {
use core::panic;
use super::*;
fn empty_validations() -> ValidationClaims {
ValidationClaims {
aud: None,
iss: None,
nonce: None,
exp: None,
sub: None,
groups: None,
}
}
#[test]
fn validate_aud_claim() {
let claims = ValidationClaims {
aud: Some(Audience::Single("test".to_string())),
..empty_validations()
};
let step = ValidationStep::Audience("test".to_string());
assert!(validate_claim(&claims, &step).is_none());
let step = ValidationStep::Audience("test2".to_string());
if let Some(err_str) = validate_claim(&claims, &step) {
assert_eq!(err_str, "'aud' does not match; expected 'test2', got 'test'");
} else {
panic!("Invalid aud did not fail validation");
}
}
#[test]
fn validate_iss_claim() {
let claims = ValidationClaims {
iss: Some("test".to_string()),
..empty_validations()
};
let step = ValidationStep::Issuer("test".to_string());
assert!(validate_claim(&claims, &step).is_none());
let step = ValidationStep::Issuer("test2".to_string());
if let Some(err_str) = validate_claim(&claims, &step) {
assert_eq!(err_str, "'iss' does not match; expected 'test2', got 'test'");
} else {
panic!("Invalid iss did not fail validation");
}
}
}